Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
from .deepr import Deepr, DeeprLayer
from .embedding import EmbeddingModel
from .embedding import (
BaseEmbeddingModel,
EmbeddingModel,
VisionEmbeddingModel,
TextEmbeddingModel,
TextEmbedding, # backward compat alias
UnifiedMultimodalEmbeddingModel,
SinusoidalTimeEmbedding,
PatchEmbedding,
Permute,
init_embedding_with_pretrained,
)
from .gamenet import GAMENet, GAMENetLayer
from .jamba_ehr import JambaEHR, JambaLayer
from .logistic_regression import LogisticRegression
Expand Down Expand Up @@ -38,8 +49,4 @@
from .transformers_model import TransformersModel
from .ehrmamba import EHRMamba, MambaBlock
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
from .text_embedding import TextEmbedding
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
99 changes: 88 additions & 11 deletions pyhealth/models/ehrmamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel
from pyhealth.models.embedding import EmbeddingModel
from pyhealth.models.embedding.unified import UnifiedMultimodalEmbeddingModel
from pyhealth.models.utils import get_last_visit
from pyhealth.processors import (
MultiHotProcessor,
Expand Down Expand Up @@ -111,13 +112,20 @@ class EHRMamba(BaseModel):
Electronic Health Records (arxiv 2405.14567). Uses Mamba (SSM) for linear
complexity in sequence length; supports long EHR sequences.

When ``unified_embedding`` is supplied the model switches to **unified
mode**: all temporal fields are jointly embedded and time-sorted by
:class:`UnifiedMultimodalEmbeddingModel`, then processed by a *single*
stack of :class:`MambaBlock` layers rather than one stack per field.

Args:
dataset: SampleDataset for token/embedding setup.
embedding_dim: Embedding and hidden dimension. Default 128.
num_layers: Number of Mamba blocks. Default 2.
state_size: SSM state size per channel. Default 16.
conv_kernel: Causal conv kernel size in block. Default 4.
dropout: Dropout before classification head. Default 0.1.
unified_embedding: Optional pre-built UnifiedMultimodalEmbeddingModel.
When provided, enables unified multi-modal mode.
"""

def __init__(
Expand All @@ -128,26 +136,26 @@ def __init__(
state_size: int = 16,
conv_kernel: int = 4,
dropout: float = 0.1,
unified_embedding: Optional[UnifiedMultimodalEmbeddingModel] = None,
):
super().__init__(dataset=dataset)
self.embedding_dim = embedding_dim
self.num_layers = num_layers
self.state_size = state_size
self.conv_kernel = conv_kernel
self.dropout_rate = dropout
self._use_unified = unified_embedding is not None

assert len(self.label_keys) == 1, "EHRMamba supports single label key only"
self.label_key = self.label_keys[0]
self.mode = self.dataset.output_schema[self.label_key]

self.embedding_model = EmbeddingModel(dataset, embedding_dim)
self.feature_processors = {
k: self.dataset.input_processors[k] for k in self.feature_keys
}
output_size = self.get_output_size()
self.dropout = nn.Dropout(dropout)

self.blocks = nn.ModuleDict()
for feature_key in self.feature_keys:
self.blocks[feature_key] = nn.ModuleList(
if self._use_unified:
self.embedding_model = unified_embedding
self._unified_blocks = nn.ModuleList(
[
MambaBlock(
d_model=embedding_dim,
Expand All @@ -157,10 +165,76 @@ def __init__(
for _ in range(num_layers)
]
)

output_size = self.get_output_size()
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(len(self.feature_keys) * embedding_dim, output_size)
self.fc = nn.Linear(embedding_dim, output_size)
else:
self.embedding_model = EmbeddingModel(dataset, embedding_dim)
self.feature_processors = {
k: self.dataset.input_processors[k] for k in self.feature_keys
}
self.blocks = nn.ModuleDict()
for feature_key in self.feature_keys:
self.blocks[feature_key] = nn.ModuleList(
[
MambaBlock(
d_model=embedding_dim,
state_size=state_size,
conv_kernel=conv_kernel,
)
for _ in range(num_layers)
]
)
self.fc = nn.Linear(len(self.feature_keys) * embedding_dim, output_size)

def _build_unified_inputs(
self, kwargs: Dict[str, Any]
) -> Dict[str, Dict[str, torch.Tensor]]:
"""Build the inputs dict required by UnifiedMultimodalEmbeddingModel."""
inputs: Dict[str, Dict[str, torch.Tensor]] = {}
for field_name in self.feature_keys:
feature = kwargs[field_name]
if isinstance(feature, torch.Tensor):
feature = (feature,)
schema = self.dataset.input_processors[field_name].schema()
field_dict: Dict[str, torch.Tensor] = {}
if "value" in schema:
field_dict["value"] = feature[schema.index("value")].to(self.device)
if "time" in schema:
field_dict["time"] = feature[schema.index("time")].to(self.device)
if "mask" in schema:
field_dict["mask"] = feature[schema.index("mask")].to(self.device)
inputs[field_name] = field_dict
return inputs

def _forward_unified(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward pass in unified-embedding mode.

Calls UnifiedMultimodalEmbeddingModel to produce a single
temporally-sorted event sequence, then encodes it with one shared
MambaBlock stack and pools to the last valid event.
"""
inputs = self._build_unified_inputs(kwargs)
out = self.embedding_model(inputs)
x = out["sequence"] # (B, S_total, E)
mask = out["mask"].bool() # (B, S_total)

for blk in self._unified_blocks:
x = blk(x)

last_h = get_last_visit(x, mask)
logits = self.fc(self.dropout(last_h))
y_prob = self.prepare_y_prob(logits)
results: Dict[str, torch.Tensor] = {
"loss": torch.tensor(0.0), # placeholder, overwritten below
"y_prob": y_prob,
"logit": logits,
}
if self.label_key in kwargs:
y_true = kwargs[self.label_key].to(self.device)
results["loss"] = self.get_loss_function()(logits, y_true)
results["y_true"] = y_true
if kwargs.get("embed", False):
results["embed"] = last_h
return results

@staticmethod
def _split_temporal(feature: Any) -> Tuple[Optional[torch.Tensor], Any]:
Expand Down Expand Up @@ -211,6 +285,9 @@ def _pool_embedding(x: torch.Tensor) -> torch.Tensor:
return x

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
if self._use_unified:
return self._forward_unified(**kwargs)

patient_emb = []
embedding_inputs: Dict[str, torch.Tensor] = {}
masks: Dict[str, torch.Tensor] = {}
Expand Down
37 changes: 37 additions & 0 deletions pyhealth/models/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Embedding models for PyHealth multimodal pipelines.

All embedding models share the :class:`BaseEmbeddingModel` interface:
they expose an ``embedding_dim`` property and a ``forward`` method that
transforms processor output tensors into dense vector embeddings.

Available models:

- :class:`EmbeddingModel` — generic encoder for codes, sequences, timeseries
- :class:`VisionEmbeddingModel` — ViT-style patch encoder for medical images (Josh)
- :class:`TextEmbeddingModel` — BERT-based encoder for clinical text (Rian)
- :class:`UnifiedMultimodalEmbeddingModel` — temporally-aligned multi-modal encoder

Helper utilities:

- :class:`SinusoidalTimeEmbedding` — continuous time positional encoding
- :func:`init_embedding_with_pretrained` — load GloVe-style pretrained vectors
"""

from .base import BaseEmbeddingModel
from .vanilla import EmbeddingModel, init_embedding_with_pretrained
from .vision import VisionEmbeddingModel, PatchEmbedding, Permute
from .text import TextEmbeddingModel, TextEmbedding
from .unified import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding

__all__ = [
"BaseEmbeddingModel",
"EmbeddingModel",
"VisionEmbeddingModel",
"PatchEmbedding",
"Permute",
"TextEmbeddingModel",
"TextEmbedding",
"UnifiedMultimodalEmbeddingModel",
"SinusoidalTimeEmbedding",
"init_embedding_with_pretrained",
]
30 changes: 30 additions & 0 deletions pyhealth/models/embedding/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from abc import ABC, abstractmethod


class BaseEmbeddingModel(ABC):
"""Abstract base class for all embedding models in PyHealth.

All embedding models share a common contract:

- They expose an ``embedding_dim`` property indicating the output vector dimension.
- Their ``forward`` method accepts processor output tensors and returns
vector embeddings.

Concrete subclasses:

- :class:`EmbeddingModel` – generic encoder for codes, sequences, timeseries
- :class:`VisionEmbeddingModel` – patch-based encoder for medical images (Josh)
- :class:`TextEmbeddingModel` – BERT-based encoder for clinical text (Rian)
- :class:`UnifiedMultimodalEmbeddingModel` – temporally-aligned multi-modal encoder
"""

@property
@abstractmethod
def embedding_dim(self) -> int:
"""Output embedding dimension shared across all modalities."""
...

@abstractmethod
def forward(self, *args, **kwargs):
"""Transform processor outputs into embeddings."""
...
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Text embedding module for multimodal PyHealth pipelines.

Author: Rian

This module provides a Transformer-based text encoder for clinical/medical text.
It is designed to integrate with PyHealth's multimodal fusion architecture.

Expand All @@ -14,8 +16,8 @@
- torch

Example:
>>> from pyhealth.models.text_embedding import TextEmbedding
>>> encoder = TextEmbedding(embedding_dim=256)
>>> from pyhealth.models.embedding import TextEmbeddingModel
>>> encoder = TextEmbeddingModel(embedding_dim=256)
>>> embeddings, mask = encoder(["Patient has fever.", "Follow-up."])
>>> embeddings.shape # [2, T, 256]
"""
Expand All @@ -28,11 +30,13 @@
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

from .base import BaseEmbeddingModel


logger = logging.getLogger(__name__)


class TextEmbedding(nn.Module):
class TextEmbeddingModel(nn.Module, BaseEmbeddingModel):
"""Encodes clinical text into embeddings for multimodal fusion.

This module wraps a pretrained Hugging Face transformer (default:
Expand All @@ -55,7 +59,7 @@ class TextEmbedding(nn.Module):

Example: A 300-token note with chunk_size=128 becomes 3 chunks:
Chunk 1: [CLS] + tokens[0:126] + [SEP] = 128 tokens
Chunk 2: [CLS] + tokens[126:252] + [SEP] = 128 tokens
Chunk 2: [CLS] + tokens[126:252] + [SEP] = 128 tokens
Chunk 3: [CLS] + tokens[252:300] + [SEP] = 50 tokens

Pooling Modes:
Expand Down Expand Up @@ -117,7 +121,7 @@ class TextEmbedding(nn.Module):
Example:
Basic usage with default parameters:

>>> encoder = TextEmbedding(embedding_dim=256)
>>> encoder = TextEmbeddingModel(embedding_dim=256)
>>> texts = ["Patient presents with chest pain.", "Routine checkup."]
>>> embeddings, mask = encoder(texts)
>>> embeddings.shape
Expand All @@ -127,14 +131,14 @@ class TextEmbedding(nn.Module):

Using chunk-level pooling for efficiency:

>>> encoder = TextEmbedding(pooling="cls", embedding_dim=128)
>>> encoder = TextEmbeddingModel(pooling="cls", embedding_dim=128)
>>> long_note = "..." * 1000 # Very long clinical note
>>> emb, mask = encoder([long_note])
>>> emb.shape # [1, num_chunks, 128] instead of [1, thousands, 128]

Backward-compatible single tensor return:

>>> encoder = TextEmbedding(return_mask=False)
>>> encoder = TextEmbeddingModel(return_mask=False)
>>> embeddings = encoder(["Test"]) # Just tensor, no tuple
"""

Expand All @@ -159,7 +163,7 @@ def __init__(
"""
super().__init__()
self.model_name = model_name
self.embedding_dim = embedding_dim
self._embedding_dim = embedding_dim
self.chunk_size = chunk_size
self.max_chunks = max_chunks
self.pooling = pooling
Expand All @@ -184,6 +188,10 @@ def __init__(
# This aligns text embeddings with other modalities in a shared E' space
self.fc = nn.Linear(self.transformer.config.hidden_size, embedding_dim)

@property
def embedding_dim(self) -> int:
return self._embedding_dim

def _chunk_and_encode(
self, text: str, device: torch.device
) -> torch.Tensor:
Expand Down Expand Up @@ -232,11 +240,6 @@ def _chunk_and_encode(
chunks = [[self.tokenizer.cls_token_id, self.tokenizer.sep_token_id]]

# Step 3: Apply max_chunks limit (performance guardrail)
# Rationale: Clinical notes can be 10K+ tokens. Without a cap:
# - Memory usage explodes (each chunk needs transformer forward pass)
# - Silent OOMs in production environments
# - Inference time becomes unpredictable
# We warn rather than silently truncate so users can adjust.
if self.max_chunks is not None and len(chunks) > self.max_chunks:
original_chunks = len(chunks)
chunks = chunks[: self.max_chunks]
Expand Down Expand Up @@ -319,18 +322,6 @@ def forward(

If return_mask=False (backward compatibility):
torch.Tensor: Just the embeddings tensor [B, T, E']

Note:
The return_mask parameter exists for backward compatibility.
New code should use the default return_mask=True to get masks
needed for downstream attention layers.

Example:
>>> encoder = TextEmbedding(embedding_dim=128)
>>> emb, mask = encoder(["Hello world", "A longer text here"])
>>> emb.shape # [2, T, 128] where T is max tokens
>>> mask.shape # [2, T]
>>> mask[0].sum() # Number of valid tokens in first sample
"""
# Normalize single string to list
if isinstance(text, str):
Expand All @@ -357,7 +348,7 @@ def forward(

# Pad embedding tensor with zeros
if pad_len > 0:
padding = torch.zeros(pad_len, self.embedding_dim, device=device)
padding = torch.zeros(pad_len, self._embedding_dim, device=device)
e = torch.cat([e, padding], dim=0)
padded.append(e)

Expand All @@ -377,3 +368,7 @@ def forward(
return embeddings, mask
else:
return embeddings


# Alias for backward compatibility
TextEmbedding = TextEmbeddingModel
Loading