Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cebra/distributions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def search(self, continuous, discrete=None):
Samples from the continuous index
discrete:
Optionally matching samples from the discrete index,
used to pre-select matching indices.
used to preselect matching indices.
"""
if continuous.shape[1] != self.continuous.shape[1]:
raise ValueError(f"Shape of continuous index does not match along "
Expand Down
186 changes: 34 additions & 152 deletions cebra/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,59 +260,61 @@ def num_trainable_parameters(self) -> int:
param.numel() for param in self.parameters() if param.requires_grad)


@register("offset10-model")
class Offset10Model(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field."""
@parametrize("offset{n_offset}-model",
n_offset=(5, 10, 15, 18, 20, 31, 36, 40, 50))
class OffsetNModel(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a `n_offset` sample receptive field.

def __init__(self, num_neurons, num_units, num_output, normalize=True):
n_offset: The size of the receptive field.
"""

def __init__(self,
num_neurons,
num_units,
num_output,
n_offset,
normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)

self.n_offset = n_offset

def _compute_num_layers():
"""Compute the number of layers to add on top of the first and last conv layers."""
return (self.n_offset - 4) // 2 + self.n_offset % 2

last_layer_kernel = 3 if (self.n_offset % 2) == 0 else 2
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, num_layers=3),
nn.Conv1d(num_units, num_output, 3),
*self._make_layers(num_units, num_layers=_compute_num_layers()),
nn.Conv1d(num_units, num_output, last_layer_kernel),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See :py:meth:`~.Model.get_offset`"""
return cebra.data.Offset(5, 5)
return cebra.data.Offset(self.n_offset // 2,
self.n_offset // 2 + self.n_offset % 2)


@register("offset10-model-mse")
class Offset10ModelMSE(Offset10Model):
class Offset10ModelMSE(OffsetNModel):
"""Symmetric model with 10 sample receptive field, without normalization.

Suitable for use with InfoNCE metrics for Euclidean space.
"""

def __init__(self, num_neurons, num_units, num_output, normalize=False):
super().__init__(num_neurons, num_units, num_output, normalize)


@register("offset5-model")
class Offset5Model(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 5 sample receptive field and output normalization."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
nn.Conv1d(num_units, num_output, 2),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See :py:meth:`~.Model.get_offset`"""
return cebra.data.Offset(2, 3)
super().__init__(num_neurons,
num_units,
num_output,
n_offset=10,
normalize=normalize)


@register("offset1-model-mse")
Expand Down Expand Up @@ -666,30 +668,6 @@ def get_offset(self) -> cebra.data.datatypes.Offset:
return cebra.data.Offset(0, 1)


@register("offset36-model")
class Offset36(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, num_layers=16),
nn.Conv1d(num_units, num_output, 3),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
return cebra.data.Offset(18, 18)


@_register_conditionally("offset36-model-dropout")
class Offset36Dropout(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field.
Expand Down Expand Up @@ -721,7 +699,7 @@ def __init__(self,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
"""See :py:meth:`~.Model.get_offset`"""
return cebra.data.Offset(18, 18)


Expand Down Expand Up @@ -763,108 +741,12 @@ def __init__(self,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
"""See :py:meth:`~.Model.get_offset`"""
return cebra.data.Offset(18, 18)


@register("offset40-model")
class Offset40(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 40 samples receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, 18),
nn.Conv1d(num_units, num_output, 3),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
return cebra.data.Offset(20, 20)


@register("offset50-model")
class Offset50(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a sample receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, 23),
nn.Conv1d(num_units, num_output, 3),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
return cebra.data.Offset(25, 25)


@register("offset15-model")
class Offset15Model(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 15 sample receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, num_layers=6),
nn.Conv1d(num_units, num_output, 2),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
return cebra.data.Offset(7, 8)


@register("offset20-model")
class Offset20Model(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 15 sample receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=True):
if num_units < 1:
raise ValueError(
f"Hidden dimension needs to be at least 1, but got {num_units}."
)
super().__init__(
nn.Conv1d(num_neurons, num_units, 2),
nn.GELU(),
*self._make_layers(num_units, num_layers=8),
nn.Conv1d(num_units, num_output, 3),
num_input=num_neurons,
num_output=num_output,
normalize=normalize,
)

def get_offset(self) -> cebra.data.datatypes.Offset:
"""See `:py:meth:Model.get_offset`"""
return cebra.data.Offset(10, 10)


@register("offset10-model-mse-tanh")
class Offset10Model(_OffsetModel, ConvolutionalModelMixin):
class Offset10ModelMSETanh(_OffsetModel, ConvolutionalModelMixin):
"""CEBRA model with a 10 sample receptive field."""

def __init__(self, num_neurons, num_units, num_output, normalize=False):
Expand Down
8 changes: 4 additions & 4 deletions cebra/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _register(cls):

def parametrize(pattern: str,
*,
kwargs: List[Dict[str, Any]] = [],
kwargs: List[Dict[str, Any]] = None,
**all_kwargs):
"""Decorator to add parametrizations of a new class to the registry.

Expand Down Expand Up @@ -221,8 +221,8 @@ def _create_class(cls, **default_kwargs):
class _ParametrizedClass(cls):

def __init__(self, *args, **kwargs):
default_kwargs.update(kwargs)
super().__init__(*args, **default_kwargs)
merged_kwargs = {**default_kwargs, **kwargs}
super().__init__(*args, **merged_kwargs)

# Make the class pickleable by copying metadata from the base class
# and registering it in the module namespace
Expand All @@ -239,7 +239,7 @@ def __init__(self, *args, **kwargs):
setattr(parent_module, unique_name, _ParametrizedClass)

def _parametrize(cls):
for _default_kwargs in kwargs:
for _default_kwargs in (kwargs or []):
_create_class(cls, **_default_kwargs)
if len(all_kwargs) > 0:
for _default_kwargs in _product_dict(all_kwargs):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av

.. testoutput::

['offset10-model', 'offset10-model-mse', 'offset5-model', 'offset1-model-mse']
['offset5-model', 'offset10-model', 'offset15-model', 'offset18-model']

Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter.

Expand Down
32 changes: 32 additions & 0 deletions tests/_reference_implementations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Reference implementations for testing consistency and backward compatibility.

This package contains reference implementations of previously deprecated or
parametrized model components, used for testing consistency and backward compatibility
in the test suite.
"""

from .deprecated_transforms import (
cebra_transform_deprecated,
multiobjective_transform_deprecated,
)
from .reference_offset_models import (
Offset5ModelReference,
Offset10ModelReference,
Offset15ModelReference,
Offset20ModelReference,
Offset36Reference,
Offset40Reference,
Offset50Reference,
)

__all__ = [
"cebra_transform_deprecated",
"multiobjective_transform_deprecated",
"Offset5ModelReference",
"Offset10ModelReference",
"Offset15ModelReference",
"Offset20ModelReference",
"Offset36Reference",
"Offset40Reference",
"Offset50Reference",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import cebra.models


#NOTE: Deprecated: transform is now handled in the solver but the original
# method is kept here for testing.
#NOTE(celia): Deprecated: transform is now handled in the solver but the original
# method is kept here for testing consistency.
def cebra_transform_deprecated(cebra_model,
X: Union[npt.NDArray, torch.Tensor],
session_id: Optional[int] = None) -> npt.NDArray:
Expand Down Expand Up @@ -72,9 +72,9 @@ def cebra_transform_deprecated(cebra_model,
return output


# NOTE: Deprecated: batched transform can now be performed (more memory efficient)
# NOTE(celia): Deprecated: batched transform can now be performed (more memory efficient)
# using the transform method of the model, and handling padding is implemented
# directly in the base Solver. This method is kept for testing purposes.
# directly in the base Solver. This method is kept for testing consistency.
@torch.no_grad()
def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver",
inputs: torch.Tensor) -> torch.Tensor:
Expand All @@ -90,7 +90,7 @@ def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver",

warnings.warn(
"The method is deprecated "
"but kept for testing puroposes."
"but kept for testing purposes."
"We recommend using `transform` instead.",
DeprecationWarning,
stacklevel=2)
Expand Down
Loading
Loading