To be able to compare aggregators to simple baselines, it would be great to have a torchjd.scalarization package.
This package would provide all sorts of Scalarizers, to combine multiple losses into a single scalar loss.
Proposed usage example:
from torchjd.scalarization import Mean
...
scalarizer = Mean()
losses = criterion(output, target)
loss = scalarizer(losses)
loss.backward()
...
To implement this, we would need:
- A new public
scalarization package in torchjd
- An abstract base class
Scalarizer
- To start, a few trivial scalarizers, e.g.
Mean (aka Equal Weights, or EW), Sum, Constant (aka Linear Scalarization or LS), Random (aka RLW).
- In future pull requests, we could add some interesting scalarization methods 1 by 1. I'm thinking of STCH, maybe FAMO, but I'll make a dedicated issue to track those.
A few questions:
- Is it fine to have some name in common between scalarizers and aggregators (e.g. torchjd.scalarization.Mean and torchjd.aggregation.Mean)? Maybe we wanna name them
MeanScalarizer, SumScalarizer, ConstantScalarizer, STCHScalarizer, etc. But at the same time the package should be responsible for indicating this. And people can always import with from torchjd.scalarization import Mean as MeanScalarizer. So I'm not sure.
- Should
Scalarizer inherit from nn.Module (like Aggregator and Weighting do)? It makes it much easier to add hooks, but I'm not sure if hooks will even be needed for scalarizers, and nn.Module may lead to some typing issues. Another advantage of nn.Module is that it can have trainable params, which I think would be necessary for GradNorm.
- There is a slightly more general concept than scalarization, which is to group some losses. For example, you could start with 128 losses, and group them into 32 group of 4 losses, and average each group, ending up with 32 losses. We could thus have a
Combiner that could take a loss tensor and return another loss tensor, not necessarily scalar. In my opinion, we should develop Scalarizer without thinking too much about this, and one day maybe Scalarizer will be a special case of Combiner.
- Should the input to the
__call__ method of a Scalarizer necessarily be a vector (tensor of ndim=1) or could it be any tensor shape (scalar, vector, matrix, etc.)? I think it would be better if it could be any shape, to be coherent with the interface of autojac.backward that works on any shape of tensor. Also, some methods may make use of the shape of the loss tensor.
To be able to compare aggregators to simple baselines, it would be great to have a
torchjd.scalarizationpackage.This package would provide all sorts of
Scalarizers, to combine multiple losses into a single scalar loss.Proposed usage example:
To implement this, we would need:
scalarizationpackage intorchjdScalarizerMean(aka Equal Weights, or EW),Sum,Constant(aka Linear Scalarization or LS),Random(aka RLW).A few questions:
MeanScalarizer,SumScalarizer,ConstantScalarizer,STCHScalarizer, etc. But at the same time the package should be responsible for indicating this. And people can always import withfrom torchjd.scalarization import Mean as MeanScalarizer. So I'm not sure.Scalarizerinherit fromnn.Module(likeAggregatorandWeightingdo)? It makes it much easier to add hooks, but I'm not sure if hooks will even be needed for scalarizers, andnn.Modulemay lead to some typing issues. Another advantage ofnn.Moduleis that it can have trainable params, which I think would be necessary forGradNorm.Combinerthat could take a loss tensor and return another loss tensor, not necessarily scalar. In my opinion, we should developScalarizerwithout thinking too much about this, and one day maybeScalarizerwill be a special case ofCombiner.__call__method of aScalarizernecessarily be a vector (tensor of ndim=1) or could it be any tensor shape (scalar, vector, matrix, etc.)? I think it would be better if it could be any shape, to be coherent with the interface ofautojac.backwardthat works on any shape of tensor. Also, some methods may make use of the shape of the loss tensor.