Unified embedding subpackage + unimodal reuse + unified mode for Transformer/Mamba/Jamba#881
Closed
Rian354 wants to merge 2 commits intosunlabuiuc:masterfrom
Closed
Unified embedding subpackage + unimodal reuse + unified mode for Transformer/Mamba/Jamba#881Rian354 wants to merge 2 commits intosunlabuiuc:masterfrom
Rian354 wants to merge 2 commits intosunlabuiuc:masterfrom
Conversation
Introduces a shared BaseEmbeddingModel ABC (embedding_dim property +
forward) and reorganises the four embedding models into a unified
pyhealth/models/embedding/ package:
base.py - BaseEmbeddingModel abstract base class
vanilla.py - EmbeddingModel (codes, sequences, timeseries)
vision.py - VisionEmbeddingModel (patch/CNN/ResNet)
text.py - TextEmbeddingModel (renamed from TextEmbedding,
TextEmbedding kept as backward compatability alias)
unified.py - UnifiedMultimodalEmbeddingModel, IMAGE encoding now
delegates to PatchEmbedding from vision.py + mean pool
instead of an inline _build_image_encoder helper
All existing public exports in pyhealth/models/__init__.py are unchanged, w/ test imports updated to new paths.
…tream models
UnifiedMultimodalEmbeddingModel now takes prebuilt unimodal embedding
models via the new field_embeddings parameter, pulling their trained
encoder weights instead of building from scratch:
- EmbeddingModel -> embedding_layers[field] (nn.Embedding / nn.Linear)
- VisionEmbeddingModel -> embedding_layers[field] backbone + mean pool
- TextEmbeddingModel -> transformer (BERT) + fc (projection)
Dims mismatches are handled automatically with an nn.Linear bridge.
Transformer, EHRMamba, and JambaEHR have an optional
unified_embedding parameter (UnifiedMultimodalEmbeddingModel). When
provided the model switches to unified mode:
- All temporal fields are jointly embedded and time-sorted by
UnifiedMultimodalEmbeddingModel -> single (B, S_total, E) sequence
- A single backbone (TransformerLayer / MambaBlock stack / JambaLayer)
processes the interleaved sequence instead of one backbone per field
- The fc head takes embedding_dim inputs (not n_fields * embedding_dim)
- Per-field mode is unchanged; backward compat guaranteed
_forward_unified and _build_unified_inputs helpers are added to each
downstream model. 7 new tests cover field_embeddings reuse, projection
bridging, and unified-mode forward + backward through all three models.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
BaseEmbeddingModelABC, all embedding models share a commonembedding_dimproperty andforwardcontractpyhealth/models/embedding/subpackage (base,vanilla,vision,text,unified)UnifiedMultimodalEmbeddingModelIMAGE encoder delegates toPatchEmbeddingfromVisionEmbeddingModel(+ mean pool) instead of an inline helperfield_embeddingsparameter lets callers inject pre-builtEmbeddingModel,VisionEmbeddingModel, orTextEmbeddingModel, reusing trained weights without rebuilding from scratchunified_embeddingparameter onTransformer,EHRMamba, andJambaEHRenables unified multi-modal mode, a single temporally-aligned backbone over all fields instead of per-field encodersWhat changed
models/embedding.pymodels/embedding/vanilla.py-EmbeddingModelmodels/vision_embedding.pymodels/embedding/vision.py-VisionEmbeddingModelmodels/text_embedding.pymodels/embedding/text.py-TextEmbeddingModel,TextEmbeddingalias keptmodels/unified_embedding.pymodels/embedding/unified.pymodels/embedding/base.py-BaseEmbeddingModelABCmodels/embedding/__init__.pyfield_embeddingsinUnifiedMultimodalEmbeddingModelEmbeddingModelembedding_layers[field](Embedding/Linear)nn.LinearbridgeVisionEmbeddingModelembedding_layers[field]+_MeanPoolnn.LinearbridgeTextEmbeddingModeltransformer(BERT) +fc(projection)nn.Linearunified_embeddingin downstream modelsTransformerTransformerLayerper field,fc(n*E)TransformerLayer,fc(E)EHRMambaMambaBlockstack per field,fc(n*E)MambaBlockstack,fc(E)JambaEHRJambaLayerper field,fc(n*E)JambaLayer,fc(E)Backward compatibility
All names previously importable from
pyhealth.modelsare preserved.TextEmbeddingalias maintained. Per-field mode is the default (nounified_embedding-> existing behaviour).