Skip to content

Commit c8a87e4

Browse files
committed
NXP backend: Add QAT training to AOT examples
1 parent 8503477 commit c8a87e4

File tree

2 files changed

+73
-41
lines changed

2 files changed

+73
-41
lines changed

examples/nxp/aot_neutron_compile.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -39,9 +39,18 @@
3939
to_edge_transform_and_lower,
4040
)
4141
from executorch.extension.export_util import save_pte_program
42+
from torch.ao.quantization import (
43+
move_exported_model_to_eval,
44+
move_exported_model_to_train,
45+
)
4246
from torch.export import export
47+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_qat_pt2e
4348

44-
from .experimental.cifar_net.cifar_net import CifarNet, test_cifarnet_model
49+
from .experimental.cifar_net.cifar_net import (
50+
CifarNet,
51+
test_cifarnet_model,
52+
train_cifarnet_model,
53+
)
4554
from .models.mobilenet_v2 import MobilenetV2
4655

4756
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -154,7 +163,7 @@ def get_model_and_inputs_from_name(model_name: str):
154163
action="store_true",
155164
required=False,
156165
default=False,
157-
help="Use QAT mode for quantization (does not include QAT training)",
166+
help="Use QAT mode for quantization (performs two QAT training epochs)",
158167
)
159168
parser.add_argument(
160169
"-s",
@@ -220,15 +229,27 @@ def get_model_and_inputs_from_name(model_name: str):
220229

221230
# 3. Quantize if required
222231
if args.quantize:
223-
if calibration_inputs is None:
224-
logging.warning(
225-
"No calibration inputs available, using the example inputs instead"
226-
)
227-
calibration_inputs = example_inputs
228-
quantizer = NeutronQuantizer(neutron_target_spec, args.use_qat)
229-
module = calibrate_and_quantize(
230-
module, calibration_inputs, quantizer, is_qat=args.use_qat
231-
)
232+
quantizer = NeutronQuantizer(neutron_target_spec, is_qat=args.use_qat)
233+
if args.use_qat:
234+
match args.model_name:
235+
case "cifar10":
236+
print("Starting two epochs of QAT training with CifarNet model...")
237+
module = prepare_qat_pt2e(module, quantizer)
238+
module = move_exported_model_to_train(module)
239+
module = train_cifarnet_model(module, num_epochs=2)
240+
module = move_exported_model_to_eval(module)
241+
module = convert_pt2e(module)
242+
case _:
243+
raise ValueError(
244+
f"QAT training is not supported for model '{args.model_name}'"
245+
)
246+
else:
247+
if calibration_inputs is None:
248+
logging.warning(
249+
"No calibration inputs available, using the example inputs instead"
250+
)
251+
calibration_inputs = example_inputs
252+
module = calibrate_and_quantize(module, calibration_inputs, quantizer)
232253

233254
if args.so_library is not None:
234255
logging.debug(f"Loading libraries: {args.so_library}")

examples/nxp/experimental/cifar_net/cifar_net.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024-2025 NXP
1+
# Copyright 2024-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -145,34 +145,11 @@ def get_model(
145145
cifar_net.load_state_dict(torch.load(state_dict_file, weights_only=True))
146146

147147
if train:
148-
# Train the model.
149-
criterion = nn.CrossEntropyLoss()
150-
optimizer = optim.SGD(cifar_net.parameters(), lr=0.0005, momentum=0.6)
151-
train_loader = get_train_loader(batch_size)
152-
153-
for epoch in range(num_epochs):
154-
running_loss = 0.0
155-
for i, data in enumerate(train_loader, 0):
156-
# get the inputs; data is a list of [inputs, labels]
157-
inputs, labels = data
158-
159-
# zero the parameter gradients
160-
optimizer.zero_grad()
161-
162-
# forward + backward + optimize
163-
outputs = cifar_net(inputs)
164-
loss = criterion(outputs, labels)
165-
loss.backward()
166-
optimizer.step()
167-
168-
# print statistics
169-
running_loss += loss.item()
170-
if i % 2000 == 1999: # print every 2000 mini-batches
171-
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
172-
running_loss = 0.0
173-
174-
logger.info("Finished training.")
175-
if state_dict_file is not None and train:
148+
cifar_net = train_cifarnet_model(
149+
cifar_net=cifar_net, batch_size=batch_size, num_epochs=num_epochs
150+
)
151+
152+
if state_dict_file is not None:
176153
logger.info(f"Saving the trained weights in `{state_dict_file}`.")
177154
torch.save(cifar_net.state_dict(), state_dict_file)
178155

@@ -189,6 +166,40 @@ def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]:
189166
return (tensor,)
190167

191168

169+
def train_cifarnet_model(
170+
cifar_net: nn.Module | torch.fx.GraphModule,
171+
batch_size: int = 1,
172+
num_epochs: int = 1,
173+
) -> nn.Module:
174+
criterion = nn.CrossEntropyLoss()
175+
optimizer = optim.SGD(cifar_net.parameters(), lr=0.0001, momentum=0.6)
176+
train_loader = get_train_loader(batch_size)
177+
178+
for epoch in range(num_epochs):
179+
running_loss = 0.0
180+
for i, data in enumerate(train_loader, 0):
181+
# get the inputs; data is a list of [inputs, labels]
182+
inputs, labels = data
183+
184+
# zero the parameter gradients
185+
optimizer.zero_grad()
186+
187+
# forward + backward + optimize
188+
outputs = cifar_net(inputs)
189+
loss = criterion(outputs, labels)
190+
loss.backward()
191+
optimizer.step()
192+
193+
# print statistics
194+
running_loss += loss.item()
195+
if i % 2000 == 1999: # print every 2000 mini-batches
196+
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
197+
running_loss = 0.0
198+
199+
logger.info("Finished training.")
200+
return cifar_net
201+
202+
192203
def test_cifarnet_model(cifar_net: nn.Module, batch_size: int = 1) -> float:
193204
"""Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy.
194205

0 commit comments

Comments
 (0)