Skip to content

feat(medcat-plugins) trainable linker for the embedding linker#392

Open
adam-sutton-1992 wants to merge 7 commits intomainfrom
feat(EmbeddingLinker)_trainable_linker
Open

feat(medcat-plugins) trainable linker for the embedding linker#392
adam-sutton-1992 wants to merge 7 commits intomainfrom
feat(EmbeddingLinker)_trainable_linker

Conversation

@adam-sutton-1992
Copy link
Copy Markdown
Contributor

@adam-sutton-1992 adam-sutton-1992 commented Apr 1, 2026

Hihi,

Here's the PR for the trainable embedding linker as foretold by our ancestors.

I'll list the main parts of it here, then the performances.

  1. Creates a new module trainable_embedding_linker that inherits from the embedding_linker. The only functionality in trainable version is that which is only used by it.
  2. Split out the modelling functionality from the embedding_linker into the transformer_context_model module. This has two classes within it: ContextModel which handles all embedding tasks (embed_cuis...). ModelForEmbeddingLinking inherits from the transformers nn.module and is where the model logic is. It's largely a wrapper around the language model with a few choices of how to embed.
  3. Added additional variables to the config these explain themselves, all of these are to do with training and effectively are trade offs between time / compute / performance.

Onto performances:

This is trained / tested with three datasets that all contain SNOMED CT labels. The SNOMED CT Entity Benchmark, Distemist, and COMETA. Due to this requiring training I have a train/test split of 80/20. With the implementation of cat.train at the time of making this I kind of built the training loop outside of the medcat library.

Effectively it looks like this:

cat.trainer.train_supervised_raw(train_projects, test_size=0, nepochs=1)
cat._pipeline._components[-1].train_on_batch()
cat._pipeline._components[-1].refresh_structure()
cat._pipeline._components[-1].create_embeddings()
get_stats(cat=cat, data=test_projects, use_project_filters=False)
for i in range(7):
    cat.trainer.train_supervised_raw(train_projects, test_size=0, nepochs=1)
    cat._pipeline._components[-1].train_on_batch()
    cat._pipeline._components[-1].create_embeddings()
    get_stats(cat=cat, data=test_projects, use_project_filters=False)

Here's the baseline performance without training (which would be comparing to a normal embedding_linker):

Epoch: 0, Prec: 0.081820475543968, Rec: 0.3308877119673485, F1: 0.13119870535198244

Here's the best performance with a few hyperparams I've found to be optimal:

Epoch: 0, Prec: 0.11907464089601173, Rec: 0.5048605035481676, F1: 0.19269978201382124

These hyper-params I mention are, with a bit of commentary:

cat.config.components.linking.embedding_model_name = "abhinand/MedEmbed-small-v0.1" (we default to "sentence-transformers/all-MiniLM-L6-v2" because it's kind of a standard, and it's small (6 layers). But MedEmbed in my experience is the best embedding I've found for medical stuff)
context_window_size = 35 (with more context performance should increase, if you use the mention_mask, the gains past 14 tokens is very marginal)
cat._pipeline._components[-1].context_model.model.unfreeze_top_n_lm_layers(4) (As you increase the number of layers the performance will increase, at the cost of storing more gradients and computational complexity. You will also transform the embedding space, requiring more data. I haven't gone beyond unfreezing four layers, because that's a redundant task at this point.)
lr=1e-4, weight_decay=0.01 (These are hardcoded and it's TODO: it shouldn't be. Oops. If you're only training the linear projection layer you can set it as high as 1e-3. When you start affecting transformer layers it significantly impacts performance for the worse).
cat._pipeline._components[-1].cnf_l.train_on_names = True (You can train on cuis or all potential names. A full CDB is about 3 mil names, or 600k CUIs. Names performs slightly better, CUIs performs quite a bit faster. It's a one time cost for training a model, so something to consider`

Here's some earlier iterations performances:

Without mention_attention functionality:
Epoch: 0, Prec: 0.10870886017906067, Rec: 0.4608991494532199, F1: 0.17592386464826354
It performed best with smaller context windows (size 10).

An earlier experiment: Unfreezing layers (these were done without the COMETA dataset, so just Distemist and the Snomed_CT Benchmark):

ALL LAYERS FROZEN
Epoch: 0, Prec: 0.10861838458277627, Rec: 0.40456670917592763, F1: 0.17125743885040035
UNFREEZING TOP 1 LAYERS
Epoch: 0, Prec: 0.11208035088291403, Rec: 0.4174662941819507, F1: 0.1767163258223325
UNFREEZING TOP 2 LAYERS
Epoch: 0, Prec: 0.11370157819225252, Rec: 0.4235394145511964, F1: 0.17927559702835402
UNFREEZING TOP 4 LAYERS
Epoch: 0, Prec: 0.11482145768791782, Rec: 0.4276691364022835, F1: 0.18103758547997326
TODO:

  1. Test saving and loading trained models perform the same as trained and tested from live.
  2. Fix giving LR and weight decay config parameters.
  3. Test on another embedding model (i.e. "sentence-transformers/all-MiniLM-L6-v2") - to check performances

@adam-sutton-1992 adam-sutton-1992 requested a review from mart-r April 1, 2026 20:13
Copy link
Copy Markdown
Collaborator

@mart-r mart-r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must say, some of the details her go over my head.

But I think overall it looks really good!

I did leave a few comments / questions. So feel free to address what's relevant.

names_scores.shape[1],
)
continue
similarity = cui_scores[i, best_idx].item()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely this isn't right with respect to the condition before?
You're checking that the best_idx is within the bounds of dimension 2 of names_scores but then taking it from cui_scores instead. And the previous implementation was using names_scores here. And so are the other conditions.

Perhaps this is a deliberate change in which one is used, but I'm just not sure.

logger.warning(
"Attemping to train an embedding linker. This is not required."
"Attemping to train a static embedding linker. "
"This is not possible / required."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add something along the lines "Use 'trainable_embedding_linker' if you wish to fine-tune / train the model" for clarity.

self.cnf_l.embed_per_n_batches > 0
and self.number_of_batches > self.cnf_l.embed_per_n_batches
):
print(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps use a logger instead.

def serialise_to(self, folder_path: str) -> None:
# Ensure final partial batch is not dropped before saving model state.
logger.info("Flushing final training batch before saving model.")
logger.info("This is grandfathered in from trainer.py restraints.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is the main workhorse in terms of working around the restrictions before #374.

Perhaps we can add a check for medcat version and if it's greater or equal to 2.7 then avoid the train here and the log message?
Plus, we'd need to implement the change in the train method as well - i.e if it's 2.7 or later then just train on batch upon last entity in document.

But then again doing it after every document might be problematic if the documents are short (i.e not many entities) and/or the last batch in it has very few entities. So I'm open to leaving it as is if you prefer.

# cui_names = [
# max(self.cdb.cui2info[cui]["names"], key=len) for cui in self._cui_keys
# ]
cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now CDB.get_name is used, whereas before the longest name was used.
Is that a deliberate change in behaviour?

EXAMPLE_MODEL_PACK_ZIP = os.path.join(RESOURCES_PATH, "mct2_model_pack.zip")
UNPACKED_EXAMPLE_MODEL_PACK_PATH = os.path.join(
RESOURCES_PATH, "mct2_model_pack")
UNPACKED_EXAMPLE_MODEL_PACK_PATH = os.path.join(RESOURCES_PATH, "mct2_model_pack")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to this file don't seem to have changed any logic. The new formatting may be better, but is it worth muddying what's being changed?

Just a comment - we can keep it as is. But in general, I like to avoid changes that aren't necessary to make things work unless they're clearly fixing something.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants