Skip to content

Commit 89d7eaf

Browse files
committed
make activation pruning a continuous method instead of a single one-shot
1 parent d98d53c commit 89d7eaf

1 file changed

Lines changed: 2 additions & 3 deletions

File tree

src/pquant/pruning_methods/activation_pruning.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def __init__(self, config, layer_type, *args, **kwargs):
1818
self.activations = None
1919
self.total = 0.0
2020
self.is_pretraining = True
21-
self.done = False
2221
self.threshold = ops.convert_to_tensor(config.pruning_parameters.threshold)
2322
self.t_start_collecting_batch = self.config.pruning_parameters.t_start_collecting_batch
2423

@@ -37,7 +36,7 @@ def collect_output(self, output, training):
3736
linear/convolution layer are over 0. Every t_delta steps, uses these values to update
3837
the mask to prune those channels and neurons that are active less than a given threshold
3938
"""
40-
if self.done or not training or self.is_pretraining:
39+
if not training or self.is_pretraining:
4140
# Don't collect during validation
4241
return
4342
if self.activations is None:
@@ -54,6 +53,7 @@ def collect_output(self, output, training):
5453
pct_active = self.activations / self.total
5554
self.t = 0
5655
self.total = 0
56+
self.batches_collected = 0
5757
if self.layer_type == "linear":
5858
self.mask = ops.expand_dims(ops.cast((pct_active > self.threshold), pct_active.dtype), 1)
5959
else:
@@ -65,7 +65,6 @@ def collect_output(self, output, training):
6565
else:
6666
self.mask = ops.reshape(pct_active_above_threshold, list(pct_active_above_threshold.shape) + [1, 1, 1])
6767
self.activations *= 0.0
68-
self.done = True
6968

7069
def call(self, weight): # Mask is only updated every t_delta step, using collect_output
7170
if self.is_pretraining:

0 commit comments

Comments
 (0)