1- # Copyright 2024-2025 NXP
1+ # Copyright 2024-2026 NXP
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
@@ -145,34 +145,11 @@ def get_model(
145145 cifar_net .load_state_dict (torch .load (state_dict_file , weights_only = True ))
146146
147147 if train :
148- # Train the model.
149- criterion = nn .CrossEntropyLoss ()
150- optimizer = optim .SGD (cifar_net .parameters (), lr = 0.0005 , momentum = 0.6 )
151- train_loader = get_train_loader (batch_size )
152-
153- for epoch in range (num_epochs ):
154- running_loss = 0.0
155- for i , data in enumerate (train_loader , 0 ):
156- # get the inputs; data is a list of [inputs, labels]
157- inputs , labels = data
158-
159- # zero the parameter gradients
160- optimizer .zero_grad ()
161-
162- # forward + backward + optimize
163- outputs = cifar_net (inputs )
164- loss = criterion (outputs , labels )
165- loss .backward ()
166- optimizer .step ()
167-
168- # print statistics
169- running_loss += loss .item ()
170- if i % 2000 == 1999 : # print every 2000 mini-batches
171- print (f"[{ epoch + 1 } , { i + 1 :5d} ] loss: { running_loss / 2000 :.3f} " )
172- running_loss = 0.0
173-
174- logger .info ("Finished training." )
175- if state_dict_file is not None and train :
148+ cifar_net = train_cifarnet_model (
149+ cifar_net = cifar_net , batch_size = batch_size , num_epochs = num_epochs
150+ )
151+
152+ if state_dict_file is not None :
176153 logger .info (f"Saving the trained weights in `{ state_dict_file } `." )
177154 torch .save (cifar_net .state_dict (), state_dict_file )
178155
@@ -189,6 +166,40 @@ def get_cifarnet_calibration_data(num_images: int = 100) -> tuple[torch.Tensor]:
189166 return (tensor ,)
190167
191168
169+ def train_cifarnet_model (
170+ cifar_net : nn .Module | torch .fx .GraphModule ,
171+ batch_size : int = 1 ,
172+ num_epochs : int = 1 ,
173+ ) -> nn .Module :
174+ criterion = nn .CrossEntropyLoss ()
175+ optimizer = optim .SGD (cifar_net .parameters (), lr = 0.0001 , momentum = 0.6 )
176+ train_loader = get_train_loader (batch_size )
177+
178+ for epoch in range (num_epochs ):
179+ running_loss = 0.0
180+ for i , data in enumerate (train_loader , 0 ):
181+ # get the inputs; data is a list of [inputs, labels]
182+ inputs , labels = data
183+
184+ # zero the parameter gradients
185+ optimizer .zero_grad ()
186+
187+ # forward + backward + optimize
188+ outputs = cifar_net (inputs )
189+ loss = criterion (outputs , labels )
190+ loss .backward ()
191+ optimizer .step ()
192+
193+ # print statistics
194+ running_loss += loss .item ()
195+ if i % 2000 == 1999 : # print every 2000 mini-batches
196+ print (f"[{ epoch + 1 } , { i + 1 :5d} ] loss: { running_loss / 2000 :.3f} " )
197+ running_loss = 0.0
198+
199+ logger .info ("Finished training." )
200+ return cifar_net
201+
202+
192203def test_cifarnet_model (cifar_net : nn .Module , batch_size : int = 1 ) -> float :
193204 """Test the CifarNet model on the CifarNet10 testing dataset and return the accuracy.
194205
0 commit comments