Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ A unified, extensible framework for text classification with categorical variabl
- **Unified yet highly customizable**:
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
- Text embedding is split into two composable stages: **`TokenEmbedder`** (token → per-token vectors, with optional self-attention) and **`SentenceEmbedder`** (aggregation: mean / first / last / label attention). Combine them with `CategoricalVariableNet` and `ClassificationHead` — all are `torch.nn.Module`.
- The `TextClassificationModel` class assembles these components and can be extended for custom behavior.
- **Two architecture paths**: use `ModelConfig` + the `torchTextClassifiers` constructor for the standard `TextClassificationModel` (zero boilerplate), or build any `nn.Module` you like and pass it to `torchTextClassifiers.from_model()` for full control. The `contrib` sub-package ships ready-made custom architectures (e.g. `MultiLevelTextClassificationModel` for multi-task classification) as reference implementations.
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
Expand Down Expand Up @@ -56,6 +56,7 @@ See the [examples/](examples/) directory for:
- Mixed features (text + categorical)
- Advanced training configurations
- Prediction and explainability
- [Multi-level classification](examples/multilevel_example.py) — custom architecture via `from_model` and `contrib`

## 📄 License

Expand Down
92 changes: 92 additions & 0 deletions docs/source/architecture/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,98 @@ predictions = classifier.predict(new_texts)
- Don't need custom architecture
- Want simplicity over control

## Two Architecture Paths

torchTextClassifiers offers two ways to build a classifier, covering different
levels of customisation:

### Path 1 — Standard architecture (ModelConfig)

The `torchTextClassifiers` constructor accepts a `ModelConfig` and builds a
`TextClassificationModel` for you automatically. This covers the vast majority
of use cases: single-task binary, multi-class, or multi-label classification,
with or without categorical variables, with or without self-attention and label
attention.

```python
from torchTextClassifiers import torchTextClassifiers, ModelConfig

classifier = torchTextClassifiers(
tokenizer=tokenizer,
model_config=ModelConfig(embedding_dim=128, num_classes=5),
)
classifier.train(texts, labels, training_config)
predictions = classifier.predict(new_texts)
```

You never instantiate `TextClassificationModel` directly; `ModelConfig` is the
only knob you need.

### Path 2 — Custom architecture (from_model)

When `TextClassificationModel` cannot express what you need — multiple
classification heads, shared encoders across tasks, or any other topology —
build your own `nn.Module` and wrap it with `torchTextClassifiers.from_model`.
The wrapper then provides the same `predict` / `save` / `load` interface around
your model.

```python
import torch.nn as nn
from torchTextClassifiers import torchTextClassifiers

class MyModel(nn.Module):
num_classes = 3
categorical_variable_net = None # or a CategoricalVariableNet instance

def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs):
...
return logits # (batch, num_classes) — raw logits, not softmaxed

classifier = torchTextClassifiers.from_model(
tokenizer=tokenizer,
pytorch_model=MyModel(),
)
```

**Required interface for custom models:**

| Requirement | Details |
|---|---|
| `forward(input_ids, attention_mask, categorical_vars=None, **kwargs)` | Exact positional names; extra kwargs are ignored |
| Returns raw logits | `torch.Tensor` of shape `(batch, num_classes)`, or `list[torch.Tensor]` for multi-task |
| `num_classes` attribute | `int` for single-task; `list[int]` for multi-task |
| `categorical_variable_net` attribute | A `CategoricalVariableNet` instance, or `None` |

### contrib — reference custom architectures

The `torchTextClassifiers.contrib` sub-package ships example architectures that
follow the `from_model` interface and can be used directly or as starting points:

| Class | Purpose |
|---|---|
| `MultiLevelTextClassificationModel` | Multi-task classifier: one shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task |
| `MultiLevelCrossEntropyLoss` | Weighted cross-entropy averaged across tasks |

```python
from torchTextClassifiers.contrib import (
MultiLevelTextClassificationModel,
MultiLevelCrossEntropyLoss,
)

model = MultiLevelTextClassificationModel(
token_embedder=token_embedder,
sentence_embedders=[se_level1, se_level2, se_level3],
classification_heads=[head1, head2, head3],
categorical_variable_net=cat_net,
)
classifier = torchTextClassifiers.from_model(tokenizer=tokenizer, pytorch_model=model)
```

See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py)
for a complete working script.

---

## For Advanced Users

### Direct PyTorch Usage
Expand Down
215 changes: 215 additions & 0 deletions docs/source/tutorials/custom_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# Custom Architectures with from_model

**Difficulty:** Advanced | **Time:** 30 minutes

## When to use this

The standard `torchTextClassifiers` constructor + `ModelConfig` covers most
single-task classification needs. Use `from_model` when you need something the
standard architecture cannot express:

- **Multiple classification heads** (multi-task / hierarchical labels)
- **Shared encoders** across several outputs
- **Custom combination logic** between text and categorical embeddings
- **Any other topology** that does not fit a single linear pipeline

---

## Required interface

Your custom model must satisfy three contracts so that the wrapper's `predict`,
`save`, and `load` methods work correctly.

### 1. `forward` signature

```python
def forward(
self,
input_ids: torch.Tensor, # (batch, seq_len) — Long
attention_mask: torch.Tensor, # (batch, seq_len) — int
categorical_vars: torch.Tensor, # (batch, n_cats) — Long, may be None
**kwargs, # ignored by the wrapper
) -> torch.Tensor | list[torch.Tensor]:
...
```

- The argument **names must match exactly** — the wrapper calls the model with
keyword arguments from the dataloader collate function.
- The return value must be **raw logits** (not softmaxed).
- Single task → `torch.Tensor` of shape `(batch, num_classes)`
- Multi-task → `list[torch.Tensor]`, one tensor per task

### 2. `num_classes` attribute

```python
model.num_classes # int (single task)
model.num_classes # list[int] (multi-task — one entry per task head)
```

### 3. `categorical_variable_net` attribute

```python
model.categorical_variable_net # CategoricalVariableNet | None
```

Set this to `None` if your model does not use categorical features. When it is
not `None` the wrapper reads
`categorical_variable_net.categorical_vocabulary_sizes` to configure data
encoding.

---

## Minimal example — single-task custom model

```python
import torch
import torch.nn as nn
from torchTextClassifiers import torchTextClassifiers
from torchTextClassifiers.model.components import TokenEmbedder, TokenEmbedderConfig
from torchTextClassifiers.tokenizers import WordPieceTokenizer

class MyClassifier(nn.Module):
def __init__(self, vocab_size: int, num_classes: int):
super().__init__()
self.token_embedder = TokenEmbedder(TokenEmbedderConfig(
vocab_size=vocab_size, embedding_dim=64, padding_idx=0,
))
self.pool = lambda x, mask: (x * mask.unsqueeze(-1)).sum(1) / mask.sum(1, keepdim=True)
self.head = nn.Linear(64, num_classes)

# Required attributes
self.num_classes = num_classes
self.categorical_variable_net = None # no categorical features

def forward(self, input_ids, attention_mask, categorical_vars=None, **kwargs):
out = self.token_embedder(input_ids, attention_mask)
sentence = self.pool(out["token_embeddings"], attention_mask.float())
return self.head(sentence) # (batch, num_classes) — raw logits

tokenizer = WordPieceTokenizer(vocab_size=5000)
tokenizer.train(texts)

model = MyClassifier(vocab_size=tokenizer.vocab_size, num_classes=3)

classifier = torchTextClassifiers.from_model(
tokenizer=tokenizer,
pytorch_model=model,
)
classifier.train(texts, labels, training_config)
predictions = classifier.predict(new_texts)
```

---

## Multi-task example — contrib architecture

For multi-task classification the `contrib` sub-package provides ready-made
classes that follow the interface above.

```python
from torchTextClassifiers import torchTextClassifiers, TrainingConfig
from torchTextClassifiers.contrib import (
MultiLevelTextClassificationModel,
MultiLevelCrossEntropyLoss,
)
from torchTextClassifiers.model.components import (
CategoricalVariableNet,
ClassificationHead,
LabelAttentionConfig,
SentenceEmbedder, SentenceEmbedderConfig,
TokenEmbedder, TokenEmbedderConfig,
)
from torchTextClassifiers.value_encoder import ValueEncoder

# Assume tokenizer, value_encoder, and model_config are already built.
# value_encoder.num_classes is a list[int] — one count per task level.

token_embedder = TokenEmbedder(TokenEmbedderConfig(
vocab_size=tokenizer.vocab_size,
embedding_dim=64,
padding_idx=tokenizer.padding_idx,
))
cat_net = CategoricalVariableNet(
categorical_vocabulary_sizes=value_encoder.vocabulary_sizes,
categorical_embedding_dims=8,
text_embedding_dim=64,
)

sentence_embedders = []
classification_heads = []
for n_cls in value_encoder.num_classes:
sentence_embedders.append(SentenceEmbedder(SentenceEmbedderConfig(
aggregation_method=None,
label_attention_config=LabelAttentionConfig(n_head=2, num_classes=n_cls, embedding_dim=64),
)))
classification_heads.append(ClassificationHead(input_dim=64 + cat_net.output_dim, num_classes=1))

model = MultiLevelTextClassificationModel(
token_embedder=token_embedder,
sentence_embedders=sentence_embedders,
classification_heads=classification_heads,
categorical_variable_net=cat_net,
)

classifier = torchTextClassifiers.from_model(
tokenizer=tokenizer,
pytorch_model=model,
value_encoder=value_encoder,
)

training_config = TrainingConfig(
num_epochs=10,
batch_size=32,
lr=1e-3,
raw_categorical_inputs=True,
loss=MultiLevelCrossEntropyLoss(num_classes=list(value_encoder.num_classes)),
)
classifier.train(X_train, y_train, training_config)
predictions = classifier.predict(X_test)
```

`predictions` is a dict with one key per task level.

See [examples/multilevel_example.py](https://github.com/InseeFrLab/torchTextClassifiers/blob/main/examples/multilevel_example.py)
for the full runnable script.

---

## contrib reference

| Class | Description |
|---|---|
| `MultiLevelTextClassificationModel` | Shared `TokenEmbedder`, one `SentenceEmbedder` + `ClassificationHead` per task |
| `MultiLevelCrossEntropyLoss` | Per-task cross-entropy, optionally weighted by `num_classes` |

```python
from torchTextClassifiers.contrib import (
MultiLevelTextClassificationModel,
MultiLevelCrossEntropyLoss,
)
```

These classes are reference implementations — use them directly or as a
starting point for your own architecture.

---

## Saving and loading

`save` and `load` work the same way regardless of which path was used. Custom
models are serialised as a pickle of the model structure plus a separate
state-dict file; the `_custom_model` flag in the checkpoint tells `load` which
strategy to use.

```python
classifier.save("my_classifier/")
loaded = torchTextClassifiers.load("my_classifier/")
```

---

## Next steps

- **Architecture overview**: {doc}`../architecture/overview` — component reference and design philosophy
- **API reference**: {doc}`../api/wrapper` — full `torchTextClassifiers` API
- **contrib source**: `torchTextClassifiers/contrib/multilevel.py`
19 changes: 19 additions & 0 deletions docs/source/tutorials/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ multiclass_classification
mixed_features
explainability
multilabel_classification
custom_model
```

## Overview
Expand Down Expand Up @@ -116,6 +117,21 @@ Assign multiple labels to each text sample for complex classification scenarios.
**Difficulty:** Advanced | **Time:** 30 minutes
:::

:::{grid-item-card} {fas}`puzzle-piece` Custom Architectures
:link: custom_model
:link-type: doc

Plug any PyTorch model into the torchTextClassifiers wrapper via `from_model`.

**What you'll learn:**
- When to go beyond `TextClassificationModel`
- The required `forward` / `num_classes` / `categorical_variable_net` interface
- Using `contrib` classes as reference implementations
- Multi-task classification with `MultiLevelTextClassificationModel`

**Difficulty:** Advanced | **Time:** 30 minutes
:::

::::

## Learning Path
Expand All @@ -130,13 +146,16 @@ graph LR
C --> F[Multilabel Classification]
D --> E[Explainability]
F --> E
D --> G[Custom Architectures]
F --> G

style A fill:#e3f2fd
style B fill:#bbdefb
style C fill:#90caf9
style D fill:#64b5f6
style E fill:#1976d2
style F fill:#42a5f5
style G fill:#0d47a1
```

1. **Start with**: {doc}`../getting_started/quickstart` - Get familiar with the basics
Expand Down
Loading