@@ -769,15 +769,28 @@ def is_weight_compression_supported(
769
769
770
770
return is_supported_dtype and not no_bit_reduction
771
771
772
- def apply (
772
+ def get_weight_compression_parameters (
773
773
self ,
774
774
model : TModel ,
775
775
graph : NNCFGraph ,
776
776
statistic_points : Optional [StatisticPointsContainer ] = None ,
777
777
dataset : Optional [Dataset ] = None ,
778
- ) -> TModel :
779
- self .set_backend_entity (model )
778
+ ) -> tuple [list [WeightCompressionParameters ], Optional [dict [str , WCTensorStatistic ]]]:
779
+ """
780
+ Generates a list of weight compression parameters based on the Weight Compression algorithm
781
+ configuration. Determines the appropriate quantization parameters for each node eligible for
782
+ weight compression. Also, Generates a mapping of target node names to the collected statistics
783
+ based on the provided statistic_points. If statistic_points is None, collects required
784
+ compression statistics on the given dataset.
780
785
786
+ :param model: Backend-specific input model.
787
+ :param graph: NNCFGraph instance.
788
+ :param statistic_points: Optional pre-collected statistic points.
789
+ :param dataset: Optional dataset for statistics collection.
790
+ :return: A tuple consisting of a list of weight compression parameters, based on the Weight
791
+ Compression algorithm configuration, and a mapping of target node names to the
792
+ collected statistics.
793
+ """
781
794
nodes_to_compress = self .get_nodes_to_compress (graph )
782
795
783
796
all_weight_params : list [WeightCompressionParameters ] = []
@@ -787,12 +800,13 @@ def apply(
787
800
is_last_layer_skipped = False
788
801
n = len (nodes_to_compress )
789
802
ignored_names = self .get_ignored_node_names (graph )
803
+
790
804
for i , node in enumerate (nodes_to_compress ):
791
805
is_target_node = should_consider_scope (node .node_name , ignored_names )
792
806
for weight_name , weight_port_id in self ._backend_entity .get_weight_names_and_port_ids (node , graph ):
793
807
is_last_layer = i == n - 1
794
808
if weight_name in weight_names :
795
- # If the last layer has shared weights then skiped
809
+ # If the last layer has shared weights then skip it
796
810
# to avoid processing the same weight more than once
797
811
is_last_layer_skipped = is_last_layer
798
812
continue
@@ -828,6 +842,7 @@ def apply(
828
842
)
829
843
if self .is_weight_compression_supported (weight_dtype , mode ):
830
844
wc_config = WeightCompressionConfig (mode = mode )
845
+
831
846
weight_params = WeightCompressionParameters (
832
847
weight_name , node , weight_port_id , weight_dtype , weight_shape , reduction_axes , wc_config
833
848
)
@@ -884,6 +899,20 @@ def apply(
884
899
# Filter all_weight_params and by excluding nodes that should remain in their original floating-point precision
885
900
all_weight_params = list (filter (lambda w_params : w_params .compression_config is not None , all_weight_params ))
886
901
902
+ return all_weight_params , statistics
903
+
904
+ def apply (
905
+ self ,
906
+ model : TModel ,
907
+ graph : NNCFGraph ,
908
+ statistic_points : Optional [StatisticPointsContainer ] = None ,
909
+ dataset : Optional [Dataset ] = None ,
910
+ ) -> TModel :
911
+ self .set_backend_entity (model )
912
+
913
+ # Get processed weight compression parameters ready for compression
914
+ all_weight_params , statistics = self .get_weight_compression_parameters (model , graph , statistic_points , dataset )
915
+
887
916
if self ._awq :
888
917
model = self .awq_algo .apply (model , graph , all_weight_params , statistics , self ._backend_entity )
889
918
# After applying AWQ we need to update statistics since AWQ alters the activations
0 commit comments