Skip to content

Unified embedding subpackage + unimodal reuse + unified mode for Transformer/Mamba/Jamba#881

Closed
Rian354 wants to merge 2 commits intosunlabuiuc:masterfrom
Multimodal-PyHealth:ra/aligned-processors
Closed

Unified embedding subpackage + unimodal reuse + unified mode for Transformer/Mamba/Jamba#881
Rian354 wants to merge 2 commits intosunlabuiuc:masterfrom
Multimodal-PyHealth:ra/aligned-processors

Conversation

@Rian354
Copy link
Contributor

@Rian354 Rian354 commented Mar 3, 2026

Summary

  • Introduces BaseEmbeddingModel ABC, all embedding models share a common embedding_dim property and forward contract
  • Moves all four embedding models into pyhealth/models/embedding/ subpackage (base, vanilla, vision, text, unified)
  • UnifiedMultimodalEmbeddingModel IMAGE encoder delegates to PatchEmbedding from VisionEmbeddingModel (+ mean pool) instead of an inline helper
  • New: field_embeddings parameter lets callers inject pre-built EmbeddingModel, VisionEmbeddingModel, or TextEmbeddingModel, reusing trained weights without rebuilding from scratch
  • New: unified_embedding parameter on Transformer, EHRMamba, and JambaEHR enables unified multi-modal mode, a single temporally-aligned backbone over all fields instead of per-field encoders

What changed

Old path New path
models/embedding.py models/embedding/vanilla.py - EmbeddingModel
models/vision_embedding.py models/embedding/vision.py - VisionEmbeddingModel
models/text_embedding.py models/embedding/text.py - TextEmbeddingModel,TextEmbedding alias kept
models/unified_embedding.py models/embedding/unified.py
(new) models/embedding/base.py - BaseEmbeddingModel ABC
(new) models/embedding/__init__.py

field_embeddings in UnifiedMultimodalEmbeddingModel

Pre-built type What is extracted Dim-mismatch handling
EmbeddingModel embedding_layers[field] (Embedding/Linear) nn.Linear bridge
VisionEmbeddingModel embedding_layers[field] + _MeanPool nn.Linear bridge
TextEmbeddingModel transformer (BERT) + fc (projection) chained nn.Linear

unified_embedding in downstream models

Model Before After (unified mode)
Transformer one TransformerLayer per field, fc(n*E) single TransformerLayer, fc(E)
EHRMamba one MambaBlock stack per field, fc(n*E) single MambaBlock stack, fc(E)
JambaEHR one JambaLayer per field, fc(n*E) single JambaLayer, fc(E)

Backward compatibility

All names previously importable from pyhealth.models are preserved. TextEmbedding alias maintained. Per-field mode is the default (no unified_embedding -> existing behaviour).

Rian354 added 2 commits March 2, 2026 22:23
Introduces a shared BaseEmbeddingModel ABC (embedding_dim property +
forward) and reorganises the four embedding models into a unified
pyhealth/models/embedding/ package:

  base.py    - BaseEmbeddingModel abstract base class
  vanilla.py - EmbeddingModel (codes, sequences, timeseries)
  vision.py  - VisionEmbeddingModel (patch/CNN/ResNet)
  text.py    - TextEmbeddingModel (renamed from TextEmbedding,
               TextEmbedding kept as backward compatability alias)
  unified.py - UnifiedMultimodalEmbeddingModel, IMAGE encoding now
               delegates to PatchEmbedding from vision.py + mean pool
               instead of an inline _build_image_encoder helper

All existing public exports in pyhealth/models/__init__.py are unchanged, w/ test imports updated to new paths.
…tream models

UnifiedMultimodalEmbeddingModel now takes prebuilt unimodal embedding
models via the new field_embeddings parameter, pulling their trained
encoder weights instead of building from scratch:

  - EmbeddingModel  -> embedding_layers[field] (nn.Embedding / nn.Linear)
  - VisionEmbeddingModel -> embedding_layers[field] backbone + mean pool
  - TextEmbeddingModel   -> transformer (BERT) + fc (projection)

  Dims mismatches are handled automatically with an nn.Linear bridge.

Transformer, EHRMamba, and JambaEHR have an optional
unified_embedding parameter (UnifiedMultimodalEmbeddingModel). When
provided the model switches to unified mode:

  - All temporal fields are jointly embedded and time-sorted by
    UnifiedMultimodalEmbeddingModel -> single (B, S_total, E) sequence
  - A single backbone (TransformerLayer / MambaBlock stack / JambaLayer)
    processes the interleaved sequence instead of one backbone per field
  - The fc head takes embedding_dim inputs (not n_fields * embedding_dim)
  - Per-field mode is unchanged; backward compat guaranteed

_forward_unified and _build_unified_inputs helpers are added to each
downstream model. 7 new tests cover field_embeddings reuse, projection
bridging, and unified-mode forward + backward through all three models.
@Rian354 Rian354 closed this Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant