Skip to content

Commit 9ed7c8b

Browse files
[Weight Compression] Rework WC Algorithm to Return WC Params (#3636)
### Changes Re-worked the `apply` method in WC algorithm to use an extra method to return weights compression params such that the apply method is more concise and only contains algorithm, quantization logic etc. ### Reason for changes This is done so that OpenVINO quantizer can obtain the final collection of weights compression parameters for all the nodes so that are to be compressed. --------- Co-authored-by: Daniil Lyakhov <daniil.lyakhov@intel.com>
1 parent c463052 commit 9ed7c8b

File tree

1 file changed

+33
-4
lines changed
  • src/nncf/quantization/algorithms/weight_compression

1 file changed

+33
-4
lines changed

src/nncf/quantization/algorithms/weight_compression/algorithm.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -769,15 +769,28 @@ def is_weight_compression_supported(
769769

770770
return is_supported_dtype and not no_bit_reduction
771771

772-
def apply(
772+
def get_weight_compression_parameters(
773773
self,
774774
model: TModel,
775775
graph: NNCFGraph,
776776
statistic_points: Optional[StatisticPointsContainer] = None,
777777
dataset: Optional[Dataset] = None,
778-
) -> TModel:
779-
self.set_backend_entity(model)
778+
) -> tuple[list[WeightCompressionParameters], Optional[dict[str, WCTensorStatistic]]]:
779+
"""
780+
Generates a list of weight compression parameters based on the Weight Compression algorithm
781+
configuration. Determines the appropriate quantization parameters for each node eligible for
782+
weight compression. Also, Generates a mapping of target node names to the collected statistics
783+
based on the provided statistic_points. If statistic_points is None, collects required
784+
compression statistics on the given dataset.
780785
786+
:param model: Backend-specific input model.
787+
:param graph: NNCFGraph instance.
788+
:param statistic_points: Optional pre-collected statistic points.
789+
:param dataset: Optional dataset for statistics collection.
790+
:return: A tuple consisting of a list of weight compression parameters, based on the Weight
791+
Compression algorithm configuration, and a mapping of target node names to the
792+
collected statistics.
793+
"""
781794
nodes_to_compress = self.get_nodes_to_compress(graph)
782795

783796
all_weight_params: list[WeightCompressionParameters] = []
@@ -787,12 +800,13 @@ def apply(
787800
is_last_layer_skipped = False
788801
n = len(nodes_to_compress)
789802
ignored_names = self.get_ignored_node_names(graph)
803+
790804
for i, node in enumerate(nodes_to_compress):
791805
is_target_node = should_consider_scope(node.node_name, ignored_names)
792806
for weight_name, weight_port_id in self._backend_entity.get_weight_names_and_port_ids(node, graph):
793807
is_last_layer = i == n - 1
794808
if weight_name in weight_names:
795-
# If the last layer has shared weights then skiped
809+
# If the last layer has shared weights then skip it
796810
# to avoid processing the same weight more than once
797811
is_last_layer_skipped = is_last_layer
798812
continue
@@ -828,6 +842,7 @@ def apply(
828842
)
829843
if self.is_weight_compression_supported(weight_dtype, mode):
830844
wc_config = WeightCompressionConfig(mode=mode)
845+
831846
weight_params = WeightCompressionParameters(
832847
weight_name, node, weight_port_id, weight_dtype, weight_shape, reduction_axes, wc_config
833848
)
@@ -884,6 +899,20 @@ def apply(
884899
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
885900
all_weight_params = list(filter(lambda w_params: w_params.compression_config is not None, all_weight_params))
886901

902+
return all_weight_params, statistics
903+
904+
def apply(
905+
self,
906+
model: TModel,
907+
graph: NNCFGraph,
908+
statistic_points: Optional[StatisticPointsContainer] = None,
909+
dataset: Optional[Dataset] = None,
910+
) -> TModel:
911+
self.set_backend_entity(model)
912+
913+
# Get processed weight compression parameters ready for compression
914+
all_weight_params, statistics = self.get_weight_compression_parameters(model, graph, statistic_points, dataset)
915+
887916
if self._awq:
888917
model = self.awq_algo.apply(model, graph, all_weight_params, statistics, self._backend_entity)
889918
# After applying AWQ we need to update statistics since AWQ alters the activations

0 commit comments

Comments
 (0)