Problem
In EasySyntax (src/graphnet/models/easy_model.py), the self.log("train_loss", ...) call inside training_step hard-codes on_step=False, on_epoch=True:
https://github.com/graphnet-team/graphnet/blob/main/src/graphnet/models/easy_model.py#L246-L254
As a result, train_loss is only emitted once per epoch — there is no way to obtain a per-batch training-loss trace through the standard PyTorch Lightning logging path without subclassing or monkey-patching training_step. Users who want to monitor convergence within an epoch (long epochs, debugging instabilities, profiling LR schedules, etc.) currently have no supported way to do so.
The same hard-coding exists for val_loss and test_loss, but those are arguably less important since validation/test typically run only once per epoch.
Proposed fix
Add a single boolean constructor argument to EasySyntax:
log_train_loss_on_step: bool = False
When True, training_step additionally logs the per-batch loss under a separate key (train_loss_step) using on_step=True, on_epoch=False. The existing epoch-aggregated train_loss key is preserved unchanged, so:
- default behavior is identical to today (backwards-compatible),
EarlyStopping(monitor="val_loss") and the default ModelCheckpoint filename template {train_loss:.2f} keep working,
- the per-step trace is opt-in, avoiding accidental DDP
sync_dist overhead per batch.
Problem
In
EasySyntax(src/graphnet/models/easy_model.py), theself.log("train_loss", ...)call insidetraining_stephard-codeson_step=False, on_epoch=True:https://github.com/graphnet-team/graphnet/blob/main/src/graphnet/models/easy_model.py#L246-L254
As a result,
train_lossis only emitted once per epoch — there is no way to obtain a per-batch training-loss trace through the standard PyTorch Lightning logging path without subclassing or monkey-patchingtraining_step. Users who want to monitor convergence within an epoch (long epochs, debugging instabilities, profiling LR schedules, etc.) currently have no supported way to do so.The same hard-coding exists for
val_lossandtest_loss, but those are arguably less important since validation/test typically run only once per epoch.Proposed fix
Add a single boolean constructor argument to
EasySyntax:When
True,training_stepadditionally logs the per-batch loss under a separate key (train_loss_step) usingon_step=True, on_epoch=False. The existing epoch-aggregatedtrain_losskey is preserved unchanged, so:EarlyStopping(monitor="val_loss")and the defaultModelCheckpointfilename template{train_loss:.2f}keep working,sync_distoverhead per batch.