@@ -18,7 +18,6 @@ def __init__(self, config, layer_type, *args, **kwargs):
1818 self .activations = None
1919 self .total = 0.0
2020 self .is_pretraining = True
21- self .done = False
2221 self .threshold = ops .convert_to_tensor (config .pruning_parameters .threshold )
2322 self .t_start_collecting_batch = self .config .pruning_parameters .t_start_collecting_batch
2423
@@ -37,7 +36,7 @@ def collect_output(self, output, training):
3736 linear/convolution layer are over 0. Every t_delta steps, uses these values to update
3837 the mask to prune those channels and neurons that are active less than a given threshold
3938 """
40- if self . done or not training or self .is_pretraining :
39+ if not training or self .is_pretraining :
4140 # Don't collect during validation
4241 return
4342 if self .activations is None :
@@ -54,6 +53,7 @@ def collect_output(self, output, training):
5453 pct_active = self .activations / self .total
5554 self .t = 0
5655 self .total = 0
56+ self .batches_collected = 0
5757 if self .layer_type == "linear" :
5858 self .mask = ops .expand_dims (ops .cast ((pct_active > self .threshold ), pct_active .dtype ), 1 )
5959 else :
@@ -65,7 +65,6 @@ def collect_output(self, output, training):
6565 else :
6666 self .mask = ops .reshape (pct_active_above_threshold , list (pct_active_above_threshold .shape ) + [1 , 1 , 1 ])
6767 self .activations *= 0.0
68- self .done = True
6968
7069 def call (self , weight ): # Mask is only updated every t_delta step, using collect_output
7170 if self .is_pretraining :
0 commit comments