Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions src/pquant/core/activations_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import keras
import torch
import torch.nn as nn
from hgq.quantizer import Quantizer
from keras import ops
from keras.ops import convert_to_tensor, maximum, minimum, tanh
Expand Down Expand Up @@ -136,6 +138,75 @@ def call(self, x):
return x


class QuantizedPooling(nn.Module):

def __init__(self, config, layer):
super().__init__()
self.f = torch.tensor(config.quantization_parameters.default_fractional_bits)
self.i = torch.tensor(config.quantization_parameters.default_integer_bits)
self.overflow = "SAT_SYM" if config.quantization_parameters.use_symmetric_quantization else "SAT"
self.config = config
self.hgq_heterogeneous = config.quantization_parameters.hgq_heterogeneous
self.is_pretraining = True
self.use_high_granularity_quantization = config.quantization_parameters.use_high_granularity_quantization
self.pooling = layer
self.hgq_gamma = config.quantization_parameters.hgq_gamma

def build(self, input_shape):
if self.use_high_granularity_quantization:
if self.hgq_heterogeneous:
self.hgq = Quantizer(
k0=1.0,
i0=self.i,
f0=self.f,
round_mode="RND",
overflow_mode=self.overflow,
q_type="kif",
homogeneous_axis=(0,),
)

else:
self.hgq = Quantizer(
k0=1.0,
i0=self.i,
f0=self.f,
round_mode="RND",
overflow_mode=self.overflow,
q_type="kif",
heterogeneous_axis=(),
)
self.hgq.build(input_shape)
else:
self.quantizer = get_fixed_quantizer(round_mode="RND", overflow_mode=self.overflow)

def set_activation_bits(self, i, f):
self.i = torch.tensor(i)
self.f = torch.tensor(f)

def post_pre_train_function(self):
self.is_pretraining = False

def hgq_loss(self):
if self.is_pretraining:
return 0.0
return (
torch.sum(self.hgq.quantizer.i) + torch.sum(self.hgq.quantizer.f)
) * self.config.quantization_parameters.hgq_gamma

def quantize(self, x):
if not hasattr(self, "hgq") or not hasattr(self, "quantizer"):
self.build(x.shape)
if self.use_high_granularity_quantization:
x = self.hgq(x)
else:
x = self.quantizer(x, k=torch.tensor(1.0), i=self.i, f=self.f, training=True)
return x

def forward(self, x):
x = self.pooling(x)
return self.quantize(x)


def hard_sigmoid(x):
"""Computes hard_sigmoid function that saturates between 0 and 1."""
x = 0.5 * x + 0.5
Expand Down
55 changes: 55 additions & 0 deletions src/pquant/core/backend_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod


class BackendInterface(ABC):
@abstractmethod
def add_default_layer_quantization_pruning_to_config(self, model, config):
pass

@abstractmethod
def iterative_train(self, model, config, train_func, valid_func, **kwargs):
pass

@abstractmethod
def remove_pruning_from_model(self, model, config):
pass

@abstractmethod
def add_compression_layers(self, model, config, input_shape=None):
pass

@abstractmethod
def post_epoch_functions(self, model, epoch, total_epochs, **kwargs):
pass

@abstractmethod
def post_pretrain_functions(self, model, config):
pass

@abstractmethod
def pre_epoch_functions(self, model, epoch, total_epochs):
pass

@abstractmethod
def pre_finetune_functions(self, model):
pass

@abstractmethod
def save_weights_functions(self, model):
pass

@abstractmethod
def get_layer_keep_ratio(self, model):
pass

@abstractmethod
def get_model_losses(self, model, losses):
pass

def call_post_round_functions(self, model, rewind, rounds, r):
if rewind == "round":
self.rewind_weights_functions(model)
elif rewind == "post-ticket-search" and r == rounds - 1:
self.rewind_weights_functions(model)
else:
self.post_round_functions(model)
Loading