Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added
- LiGR transformer layers from "From Features to Transformers: Redefining Ranking for Scalable Impact" ([#295](https://github.com/MobileTeleSystems/RecTools/pull/295))

## [0.16.0] - 27.07.2025

### Added
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ faster than ever before.
- In [HSTU tutorial](examples/tutorials/transformers_HSTU_tutorial.ipynb) we show that original metrics reported for HSTU on public Movielens datasets may actually be **underestimated**
- Configurable, customizable, callback-friendly, checkpoints-included, logs-out-of-the-box, custom-validation-ready, multi-gpu-compatible! See [Transformers Advanced Training User Guide](examples/tutorials/transformers_advanced_training_guide.ipynb) and [Transformers Customization Guide](examples/tutorials/transformers_customization_guide.ipynb)


## ✨ Highlights: RecTools framework at ACM RecSys'25 ✨

**RecTools implementations are featured in ACM RecSys'25: ["eSASRec: Enhancing Transformer-based Recommendations in a Modular Fashion"](https://www.arxiv.org/abs/2508.06450):**
- The article presents a systematic benchmark of Transformer modifications using RecTools models. It offers a detailed evaluation of training objectives, Transformer architectures, loss functions, and negative sampling strategies in realistic, production-like settings
- We introduce a new SOTA baseline, **eSASRec**, which combines SASRec’s training objective with LiGR Transformer layers and Sampled Softmax loss, forming a simple yet powerful recipe
- **eSASRec** shows 23% boost over SOTA models, such as ActionPiece, on academic benchmarks
- [LiGR](https://arxiv.org/pdf/2502.03417) Transformer layers used in **eSASRec** are now in RecTools

Plase note that we always compare the quality of our implementations to academic papers results. [Public benchmarks for transformer models SASRec and BERT4Rec](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) show that RecTools implementations achieve highest scores on multiple datasets compared to other published results.


Expand Down Expand Up @@ -107,7 +116,7 @@ The table below lists recommender models that are available in RecTools.
| Model | Type | Description (🎏 for user/item features, 🔆 for warm inference, ❄️ for cold inference support) | Tutorials & Benchmarks |
|---------------------|----|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|
| HSTU | Neural Network | `rectools.models.HSTUModel` - Sequential model with unidirectional pointwise aggregated attention mechanism, incorporating relative attention bias from positional and temporal information, introduced in ["Actions speak louder then words..."](https://arxiv.org/pdf/2402.17152), combined with "Shifted Sequence" training objective as in original public benchmarks<br>🎏 | 📓 [HSTU Theory & Practice](examples/tutorials/transformers_HSTU_tutorial.ipynb) <br> 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 🚀 [Top performance on public datasets](examples/tutorials/transformers_HSTU_tutorial.ipynb)
| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
| SASRec | Neural Network | `rectools.models.SASRecModel` - Transformer-based sequential model with unidirectional attention mechanism and "Shifted Sequence" training objective. <br> For eSASRec variant specify `rectools.models.nn.transformers.ligr.LiGRLayers` for `transformer_layers_type` and `sampled_softmax` for `loss` <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
| BERT4Rec | Neural Network | `rectools.models.BERT4RecModel` - Transformer-based sequential model with bidirectional attention mechanism and "MLM" (masked item) training objective <br>🎏 | 📕 [Transformers Theory & Practice](examples/tutorials/transformers_tutorial.ipynb)<br> 📗 [Advanced training guide](examples/tutorials/transformers_advanced_training_guide.ipynb) <br> 📘 [Customization guide](examples/tutorials/transformers_customization_guide.ipynb) <br> 🚀 [Top performance on public benchmarks](https://github.com/blondered/bert4rec_repro?tab=readme-ov-file#rectools-transformers-benchmark-results) |
| [implicit](https://github.com/benfred/implicit) ALS Wrapper | Matrix Factorization | `rectools.models.ImplicitALSWrapperModel` - Alternating Least Squares Matrix Factorizattion algorithm for implicit feedback. <br>🎏 | 📙 [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Implicit-ALS)<br> 🚀 [50% boost to metrics with user & item features](examples/5_benchmark_iALS_with_features.ipynb) |
| [implicit](https://github.com/benfred/implicit) BPR-MF Wrapper | Matrix Factorization | `rectools.models.ImplicitBPRWrapperModel` - Bayesian Personalized Ranking Matrix Factorization algorithm. | 📙 [Theory & Practice](https://rectools.readthedocs.io/en/latest/examples/tutorials/baselines_extended_tutorial.html#Bayesian-Personalized-Ranking-Matrix-Factorization-(BPR-MF)) |
Expand Down
177 changes: 177 additions & 0 deletions rectools/models/nn/transformers/ligr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import typing as tp

import torch
from torch import nn

from rectools.models.nn.transformers.net_blocks import TransformerLayersBase

from .net_blocks import init_feed_forward


class LiGRLayer(nn.Module):
"""
Transformer Layer as described in "From Features to Transformers:
Redefining Ranking for Scalable Impact" https://arxiv.org/pdf/2502.03417

Parameters
----------
n_factors: int
Latent embeddings size.
n_heads: int
Number of attention heads.
dropout_rate: float
Probability of a hidden unit to be zeroed.
ff_factors_multiplier: int, default 4
Feed-forward layers latent embedding size multiplier.
bias_in_ff: bool, default ``False``
Add bias in Linear layers of Feed Forward
ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
Activation function to use.
"""

def __init__(
self,
n_factors: int,
n_heads: int,
dropout_rate: float,
ff_factors_multiplier: int = 4,
bias_in_ff: bool = False,
ff_activation: str = "swiglu",
):
super().__init__()
self.multi_head_attn = nn.MultiheadAttention(n_factors, n_heads, dropout_rate, batch_first=True)
self.layer_norm_1 = nn.LayerNorm(n_factors)
self.dropout_1 = nn.Dropout(dropout_rate)
self.layer_norm_2 = nn.LayerNorm(n_factors)
self.feed_forward = init_feed_forward(n_factors, ff_factors_multiplier, dropout_rate, ff_activation, bias_in_ff)
self.dropout_2 = nn.Dropout(dropout_rate)

self.gating_linear_1 = nn.Linear(n_factors, n_factors)
self.gating_linear_2 = nn.Linear(n_factors, n_factors)

def forward(
self,
seqs: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
) -> torch.Tensor:
"""
Forward pass through transformer block.

Parameters
----------
seqs: torch.Tensor
User sequences of item embeddings.
attn_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.


Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
mha_input = self.layer_norm_1(seqs)
mha_output, _ = self.multi_head_attn(
mha_input,
mha_input,
mha_input,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)
gated_skip = torch.nn.functional.sigmoid(self.gating_linear_1(seqs))
seqs = seqs + torch.mul(gated_skip, self.dropout_1(mha_output))

ff_input = self.layer_norm_2(seqs)
ff_output = self.feed_forward(ff_input)
gated_skip = torch.nn.functional.sigmoid(self.gating_linear_2(seqs))
seqs = seqs + torch.mul(gated_skip, self.dropout_2(ff_output))
return seqs


class LiGRLayers(TransformerLayersBase):
"""
LiGR Transformer blocks.

Parameters
----------
n_blocks: int
Number of transformer blocks.
n_factors: int
Latent embeddings size.
n_heads: int
Number of attention heads.
dropout_rate: float
Probability of a hidden unit to be zeroed.
ff_factors_multiplier: int, default 4
Feed-forward layers latent embedding size multiplier. Pass in ``transformer_layers_kwargs`` to override.
ff_activation: {"swiglu", "relu", "gelu"}, default "swiglu"
Activation function to use. Pass in ``transformer_layers_kwargs`` to override.
bias_in_ff: bool, default ``False``
Add bias in Linear layers of Feed Forward. Pass in ``transformer_layers_kwargs`` to override.
"""

def __init__(
self,
n_blocks: int,
n_factors: int,
n_heads: int,
dropout_rate: float,
ff_factors_multiplier: int = 4,
ff_activation: str = "swiglu",
bias_in_ff: bool = False,
):
super().__init__()
self.n_blocks = n_blocks
self.n_factors = n_factors
self.n_heads = n_heads
self.dropout_rate = dropout_rate
self.ff_factors_multiplier = ff_factors_multiplier
self.ff_activation = ff_activation
self.bias_in_ff = bias_in_ff
self.transformer_blocks = nn.ModuleList([self._init_transformer_block() for _ in range(self.n_blocks)])

def _init_transformer_block(self) -> nn.Module:
return LiGRLayer(
self.n_factors,
self.n_heads,
self.dropout_rate,
self.ff_factors_multiplier,
bias_in_ff=self.bias_in_ff,
ff_activation=self.ff_activation,
)

def forward(
self,
seqs: torch.Tensor,
timeline_mask: torch.Tensor,
attn_mask: tp.Optional[torch.Tensor],
key_padding_mask: tp.Optional[torch.Tensor],
**kwargs: tp.Any,
) -> torch.Tensor:
"""
Forward pass through transformer blocks.

Parameters
----------
seqs: torch.Tensor
User sequences of item embeddings.
timeline_mask: torch.Tensor
Mask indicating padding elements.
attn_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `attn_mask`.
key_padding_mask: torch.Tensor, optional
Optional mask to use in forward pass of multi-head attention as `key_padding_mask`.


Returns
-------
torch.Tensor
User sequences passed through transformer layers.
"""
for block_idx in range(self.n_blocks):
seqs = self.transformer_blocks[block_idx](seqs, attn_mask, key_padding_mask)
return seqs
96 changes: 93 additions & 3 deletions rectools/models/nn/transformers/net_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ class PointWiseFeedForward(nn.Module):
Probability of a hidden unit to be zeroed.
activation: torch.nn.Module
Activation function module.
bias: bool, default ``True``
If ``True``, add bias to linear layers.
"""

def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module) -> None:
def __init__(
self, n_factors: int, n_factors_ff: int, dropout_rate: float, activation: torch.nn.Module, bias: bool = True
) -> None:
super().__init__()
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff)
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias)
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
self.ff_activation = activation
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors)
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias)

def forward(self, seqs: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -61,6 +65,92 @@ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
return fin


class SwigluFeedForward(nn.Module):
"""
Feed-Forward network to introduce nonlinearity into the transformer model.
This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036,
LiGR https://arxiv.org/pdf/2502.03417

Parameters
----------
n_factors : int
Latent embeddings size.
n_factors_ff : int
How many hidden units to use in the network.
dropout_rate : float
Probability of a hidden unit to be zeroed.
bias: bool, default ``True``
If ``True``, add bias to linear layers.
"""

def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None:
super().__init__()
self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias)
self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
self.ff_activation = torch.nn.SiLU()
self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias)
self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias)

def forward(self, seqs: torch.Tensor) -> torch.Tensor:
"""
Forward pass.

Parameters
----------
seqs : torch.Tensor
User sequences of item embeddings.

Returns
-------
torch.Tensor
User sequence that passed through all layers.
"""
output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs)
fin = self.ff_linear_2(self.ff_dropout_1(output))
return fin


def init_feed_forward(
n_factors: int, ff_factors_multiplier: int, dropout_rate: float, ff_activation: str, bias: bool = True
) -> nn.Module:
"""
Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu".

Parameters
----------
n_factors : int
Latent embeddings size.
ff_factors_multiplier : int
How many hidden units to use in the network.
dropout_rate : float
Probability of a hidden unit to be zeroed.
ff_activation : {"swiglu", "relu", "gelu"}
Activation function to use.
bias: bool, default ``True``
If ``True``, add bias to linear layers.

Returns
-------
nn.Module
Feed-Forward network.
"""
if ff_activation == "swiglu":
return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias)
if ff_activation == "gelu":
return PointWiseFeedForward(
n_factors, n_factors * ff_factors_multiplier, dropout_rate, activation=torch.nn.GELU(), bias=bias
)
if ff_activation == "relu":
return PointWiseFeedForward(
n_factors,
n_factors * ff_factors_multiplier,
dropout_rate,
activation=torch.nn.ReLU(),
bias=bias,
)
raise ValueError(f"Unsupported ff_activation: {ff_activation}")


class TransformerLayersBase(nn.Module):
"""Base class for transformer layers."""

Expand Down
Loading