Skip to content

StandardModel: per-step training loss logging is not configurable #882

@sevmag

Description

@sevmag

Problem

In EasySyntax (src/graphnet/models/easy_model.py), the self.log("train_loss", ...) call inside training_step hard-codes on_step=False, on_epoch=True:

https://github.com/graphnet-team/graphnet/blob/main/src/graphnet/models/easy_model.py#L246-L254

As a result, train_loss is only emitted once per epoch — there is no way to obtain a per-batch training-loss trace through the standard PyTorch Lightning logging path without subclassing or monkey-patching training_step. Users who want to monitor convergence within an epoch (long epochs, debugging instabilities, profiling LR schedules, etc.) currently have no supported way to do so.

The same hard-coding exists for val_loss and test_loss, but those are arguably less important since validation/test typically run only once per epoch.

Proposed fix

Add a single boolean constructor argument to EasySyntax:

log_train_loss_on_step: bool = False

When True, training_step additionally logs the per-batch loss under a separate key (train_loss_step) using on_step=True, on_epoch=False. The existing epoch-aggregated train_loss key is preserved unchanged, so:

  • default behavior is identical to today (backwards-compatible),
  • EarlyStopping(monitor="val_loss") and the default ModelCheckpoint filename template {train_loss:.2f} keep working,
  • the per-step trace is opt-in, avoiding accidental DDP sync_dist overhead per batch.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions