Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions src/graphnet/models/easy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,25 @@ def __init__(
scheduler_class: Optional[type] = None,
scheduler_kwargs: Optional[Dict] = None,
scheduler_config: Optional[Dict] = None,
log_on_epoch: bool = True,
log_on_step: bool = False,
) -> None:
"""Construct `StandardModel`."""
"""Construct `StandardModel`.

Args:
tasks: Task(s) appended as the head(s) of the model, defining
the prediction target(s) and loss(es).
optimizer_class: Optimizer class used during training.
optimizer_kwargs: Keyword arguments passed to `optimizer_class`.
scheduler_class: Learning-rate scheduler class. If `None`, no
scheduler is used.
scheduler_kwargs: Keyword arguments passed to `scheduler_class`.
scheduler_config: Additional configuration for how the scheduler
is invoked by PyTorch Lightning (e.g. `interval`, `frequency`).
log_on_epoch: If `True`, logs the training loss on epoch end.
log_on_step: If `True`, logs the training loss on step end.
per-batch training loss under `train_loss_step`.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)

Expand All @@ -52,6 +69,8 @@ def __init__(
self._scheduler_class = scheduler_class
self._scheduler_kwargs = scheduler_kwargs or dict()
self._scheduler_config = scheduler_config or dict()
self._log_on_step = log_on_step
self._log_on_epoch = log_on_epoch

self.validate_tasks()

Expand Down Expand Up @@ -243,13 +262,14 @@ def training_step(
if isinstance(train_batch, Data):
train_batch = [train_batch]
loss = self.shared_step(train_batch, batch_idx)
batch_size = self._get_batch_size(train_batch)
self.log(
"train_loss",
loss,
batch_size=self._get_batch_size(train_batch),
batch_size=batch_size,
prog_bar=True,
on_epoch=True,
on_step=False,
on_epoch=self._log_on_epoch,
on_step=self._log_on_step,
sync_dist=True,
)

Expand All @@ -269,8 +289,8 @@ def validation_step(
loss,
batch_size=self._get_batch_size(val_batch),
prog_bar=True,
on_epoch=True,
on_step=False,
on_epoch=self._log_on_epoch,
on_step=self._log_on_step,
sync_dist=True,
)
return loss
Expand Down
Loading