Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ A unified, extensible framework for text classification with categorical variabl

## 🚀 Features

- **Mixed input support**: Handle text data alongside categorical variables seamlessly.
- **Complex input support**: Handle text data alongside categorical variables seamlessly.
- **Unified yet highly customizable**:
- Use any tokenizer from HuggingFace or the original fastText's ngram tokenizer.
- Manipulate the components (`TextEmbedder`, `CategoricalVariableNet`, `ClassificationHead`) to easily create custom architectures - including **self-attention**. All of them are `torch.nn.Module` !
- The `TextClassificationModel` class combines these components and can be extended for custom behavior.
- **Multiclass / multilabel classification support**: Support for both multiclass (only one label is true) and multi-label (several labels can be true) classification tasks.
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
- **Easy experimentation**: Simple API for training, evaluating, and predicting with minimal code:
- The `torchTextClassifiers` wrapper class orchestrates the tokenizer and the model for you
Expand Down
3 changes: 2 additions & 1 deletion notebooks/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@
"outputs": [],
"source": [
"# test the TextEmbedder: it takes as input a tensor of token ids and outputs a tensor of embeddings\n",
"\n",
"text_embedder_output = text_embedder(input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"])\n",
"\n",
"print(\"TextEmbedder input: \", text_embedder_input.input_ids)\n",
"print(\"TextEmbedder input: \", batch[\"input_ids\"])\n",
"print(\"TextEmbedder output shape: \", text_embedder_output.shape)"
]
},
Expand Down
350 changes: 350 additions & 0 deletions notebooks/multilabel_classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,350 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0",
"metadata": {},
"source": [
"# Multilabel classification"
]
},
{
"cell_type": "markdown",
"id": "1",
"metadata": {},
"source": [
"In **multilabel classification**, each instance can be assigned multiple labels simultaneously. This is different from multiclass classification, where each instance is assigned to one and only one class from a set of classes.\n",
"\n",
"This notebook shows how to use torchTextClassifiers to perform multilabel classification."
]
},
{
"cell_type": "markdown",
"id": "2",
"metadata": {},
"source": [
"## Ragged-lists approach"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"from torchTextClassifiers import ModelConfig, TrainingConfig, torchTextClassifiers\n",
"from torchTextClassifiers.dataset import TextClassificationDataset\n",
"from torchTextClassifiers.model import TextClassificationModel, TextClassificationModule\n",
"from torchTextClassifiers.model.components import (\n",
" AttentionConfig,\n",
" CategoricalVariableNet,\n",
" ClassificationHead,\n",
" TextEmbedder,\n",
" TextEmbedderConfig,\n",
")\n",
"from torchTextClassifiers.tokenizers import HuggingFaceTokenizer\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"id": "4",
"metadata": {},
"source": [
"Let's use fake data.\n",
"\n",
"Look at `labels`: it is a list of lists, where each inner list contains the labels for the corresponding instance.\n",
"\n",
"We're indeed in a multilabel classification setting, where each instance can have multiple labels."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5",
"metadata": {},
"outputs": [],
"source": [
"sample_text_data = [\n",
" \"This is a positive example\",\n",
" \"This is a negative example\",\n",
" \"Another positive case\",\n",
" \"Another negative case\",\n",
" \"Good example here\",\n",
" \"Bad example here\",\n",
"]\n",
"\n",
"labels = [[0, 1, 5], [0, 4], [1, 5], [0, 1, 4], [1, 5], [0]]"
]
},
{
"cell_type": "markdown",
"id": "6",
"metadata": {},
"source": [
"Note that `labels` is not a nice object to manipulate: each inner list has different lengths. You can not convert it to a tensor or a numpy array directly.\n",
"\n",
"This is called a *jagged array* or *ragged array*.\n",
"\n",
"Yet, you do not need to change anything: torchTextClassifiers can handle this kind of data directly."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7",
"metadata": {},
"outputs": [],
"source": [
"labels = np.array(labels) # This does not work !"
]
},
{
"cell_type": "markdown",
"id": "8",
"metadata": {},
"source": [
"Let's import a pre-trained tokenizer from HuggingFace."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = HuggingFaceTokenizer.load_from_pretrained(\n",
" \"google-bert/bert-base-uncased\", output_dim=126\n",
")"
]
},
{
"cell_type": "markdown",
"id": "10",
"metadata": {},
"source": [
"And create our input numpy array."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": [
"X = np.array(\n",
" sample_text_data\n",
")\n",
"\n",
"print(X.shape)\n",
"\n",
"Y = labels # only for the sake of clarity, but it remains a ragged array here"
]
},
{
"cell_type": "markdown",
"id": "12",
"metadata": {},
"source": [
"We initialize a very simple model, no categorical features, no attention, just text input and multilabel output.\n",
"\n",
"In this setting, we advise to use `torch.nn.BCEWithLogitsLoss()` as loss function in the training config. \n",
"\n",
"Each label is treated as a separate (but not independent, because we output the joint prediction vector) binary classification problem (where we try to estimate the probability of inclusion), whereas in the default setting (multiclass classification) the model uses `torch.nn.CrossEntropyLoss()`, that implies a *competition* among classes.\n",
"\n",
"Note that we won't enforce this change of loss and if you do not specify it, the default loss (CrossEntropyLoss) will be used."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13",
"metadata": {},
"outputs": [],
"source": [
"embedding_dim = 96\n",
"n_layers = 2\n",
"n_head = 4\n",
"n_kv_head = n_head\n",
"sequence_len = tokenizer.output_dim\n",
"num_classes = max(max(label_list) for label_list in labels) + 1\n",
"\n",
"model_config = ModelConfig(\n",
" embedding_dim=embedding_dim,\n",
" num_classes=num_classes,\n",
")\n",
"\n",
"training_config = TrainingConfig(\n",
" lr=1e-3,\n",
" batch_size=4,\n",
" num_epochs=1,\n",
" loss=torch.nn.BCEWithLogitsLoss(), # change the loss here\n",
")"
]
},
{
"cell_type": "markdown",
"id": "14",
"metadata": {},
"source": [
"Here, do not forget to set `ragged_multilabel=True` !"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"ttc = torchTextClassifiers(\n",
" tokenizer=tokenizer,\n",
" model_config=model_config,\n",
" ragged_multilabel=True, # This is key !\n",
")"
]
},
{
"cell_type": "markdown",
"id": "16",
"metadata": {},
"source": [
"And you can train !"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {},
"outputs": [],
"source": [
"ttc.train(\n",
" X_train=X,\n",
" y_train=Y,\n",
" training_config=training_config,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "18",
"metadata": {},
"source": [
"What happens behind the hood, is that we efficiently convert your ragged lists of labels into a binary matrix, where each row corresponds to an instance and each column to a label. A value of 1 indicates the presence of a label for an instance, while 0 indicates its absence: **it is a one-hot version** of your ragged lists.\n",
"\n",
"You can have a look [here](../torchTextClassifiers/dataset/dataset.py#L85)."
]
},
{
"cell_type": "markdown",
"id": "19",
"metadata": {},
"source": [
"## One-hot / multidimensional output approach"
]
},
{
"cell_type": "markdown",
"id": "20",
"metadata": {},
"source": [
"You can also choose to directly provide a one-hot / multidimensional array as labels.\n",
"\n",
"For each sample, you have a vector of size equal to the number of labels, with 1s and 0s indicating the presence or absence of each label - or float values between 0 and 1, indicating the ground truth probability of each label.\n",
"\n",
"You do not have ragged lists anymore: **set `ragged_multilabel=False`** in the ``ttc`` initialization (it is very important, otherwise it will interpret it as a bag of labels as previously ! - we will throw a warning if we detect that your labels are one-hot encoded while you set `ragged_multilabel=True`, but we won't enforce anything).\n",
"\n",
"Also, convert your labels to a numpy array - it is possible now !"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21",
"metadata": {},
"outputs": [],
"source": [
"# We put 1s here, but it could be any float value (probabilities...)\n",
"labels = [[1., 1., 0., 0., 0., 1.],\n",
" [1., 0., 0., 0., 1., 0.],\n",
" [0., 1., 0., 0., 0., 1.],\n",
" [1., 1., 0., 0., 1., 0.],\n",
" [0., 1., 0., 0., 0., 1.],\n",
" [1., 0., 0., 0., 1., 0.]]\n",
"Y = np.array(labels)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22",
"metadata": {},
"outputs": [],
"source": [
"ttc = torchTextClassifiers(\n",
" tokenizer=tokenizer,\n",
" model_config=model_config,\n",
") # We removed the ragged_multilabel flag here, it is False by default"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23",
"metadata": {},
"outputs": [],
"source": [
"ttc.train(\n",
" X_train=X,\n",
" y_train=Y,\n",
" training_config=training_config,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "24",
"metadata": {},
"source": [
"As discussed, you can also put probabilities in `labels`. \n",
"\n",
"In this case, once again, you can use:\n",
"\n",
"- `torch.nn.BCEWithLogitsLoss()` as loss function in the training config, if you are in a multilabel setting.\n",
"- `torch.nn.CrossEntropyLoss()` as loss function in the training config, if you are in a *soft* multiclass setting (i.e. each instance has only one label, but you provide probabilities instead of class indices). Normally, your ground truth probabilities should sum to 1 for each instance in this case.\n",
"\n",
"We won't enforce anything that PyTorch does not enforce, so make sure to choose the right loss function for your task."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading