Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3c37a01
added mvcl_training_sleepedf_task
heyan3 Apr 5, 2026
10f4c6d
task outputs xt,dx,xf
heyan3 Apr 6, 2026
679b4ba
Set input_schema and output_schema and do not redefine it
coderookie1994 Apr 9, 2026
7999c3f
Merge pull request #1 from coderookie1994/mvcl_task
coderookie1994 Apr 9, 2026
f939d3a
Merge branch 'master' of github.com:coderookie1994/PyHealth into cont…
coderookie1994 Apr 18, 2026
202311f
Init commit for the model implementation
coderookie1994 Apr 11, 2026
ce5ff2e
added function to load external data in task and update finetune model
heyan3 Apr 12, 2026
12b2c2d
Encoder encodes xt, xd, xf and doesn't augment z_k
coderookie1994 Apr 16, 2026
f411e23
added symmetry selfloss;updated the finetuning classifier
heyan3 Apr 16, 2026
24ed61f
fixed and improved task to match papers
heyan3 Apr 17, 2026
df04ed9
improved augmentation method, noise + frequency
heyan3 Apr 18, 2026
9b9a8fc
Added type hints and fixed the return type for the 'forward' method
coderookie1994 Apr 18, 2026
8035333
Merge pull request #5 from coderookie1994/wip/mvcl_model_impl
coderookie1994 Apr 18, 2026
6411297
Fixed data loss bug and updated the loss function.
coderookie1994 Apr 18, 2026
37027ab
1. H dimentsion/permute fix(critial); 2. add positional encoding; 3. …
heyan3 Apr 19, 2026
ebb0787
small updated in task; plus unit test/exmaples/api docs
heyan3 Apr 20, 2026
535bf2a
Merge pull request #6 from coderookie1994/CL/update_task
coderookie1994 Apr 20, 2026
dcf32ae
Merge pull request #7 from coderookie1994/potential_fixes_update
coderookie1994 Apr 20, 2026
8ae76ac
remove loop; added docstrings
heyan3 Apr 20, 2026
f2ef01b
Merge pull request #8 from coderookie1994/CL/taskUpdate
coderookie1994 Apr 20, 2026
f0b71e1
Generic MVCL model
coderookie1994 Apr 21, 2026
aae6aa5
attempt at making the tests run faster
coderookie1994 Apr 21, 2026
d242387
Added tests and updated docstrings
coderookie1994 Apr 22, 2026
3510816
Updated the notebook
coderookie1994 Apr 22, 2026
fffcf23
Merge pull request #9 from coderookie1994/generic_mvcl_model
coderookie1994 Apr 22, 2026
578d1ad
Updated RST files
coderookie1994 Apr 22, 2026
9e02d37
Merge branch 'master' into contrastive_learning
coderookie1994 Apr 22, 2026
486e89d
Captured output from the example notebook
coderookie1994 Apr 22, 2026
9065290
fixed feedbacks on gradescope
heyan3 May 6, 2026
5f9f923
Merge pull request #10 from coderookie1994/cl_feedback_fixes
coderookie1994 May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,4 +205,5 @@ API Reference
models/pyhealth.models.TextEmbedding
models/pyhealth.models.BIOT
models/pyhealth.models.unified_multimodal_embedding_docs
models/pyhealth.models.MVCL
models/pyhealth.models.califorest
7 changes: 7 additions & 0 deletions docs/api/models/pyhealth.models.MVCL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.models.MultiViewContrastiveModel
===================================

.. autoclass:: pyhealth.models.MultiViewContrastiveModel
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ Available Tasks
Readmission Prediction <tasks/pyhealth.tasks.readmission_prediction>
Sleep Staging <tasks/pyhealth.tasks.sleep_staging>
Sleep Staging (SleepEDF) <tasks/pyhealth.tasks.SleepStagingSleepEDF>
MVCL Training (SleepEDF EEG) <tasks/pyhealth.tasks.MVCLTrainingSleepEEG>
Temple University EEG Tasks <tasks/pyhealth.tasks.temple_university_EEG_tasks>
Sleep Staging v2 <tasks/pyhealth.tasks.sleep_staging_v2>
Benchmark EHRShot <tasks/pyhealth.tasks.benchmark_ehrshot>
Expand Down
7 changes: 7 additions & 0 deletions docs/api/tasks/pyhealth.tasks.MVCLTrainingSleepEEG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pyhealth.tasks.MVCLTrainingSleepEEG
===================================

.. autoclass:: pyhealth.tasks.MVCLTrainingSleepEEG
:members:
:undoc-members:
:show-inheritance:
556 changes: 556 additions & 0 deletions examples/mvcl_training_sleepedf.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyhealth/datasets/sleepedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
dataset_name: Optional[str] = None,
config_path: Optional[str] = None,
subset: Optional[str] = "cassette",
dev: bool = False,
) -> None:
subset = (subset or "cassette").lower()
if subset not in {"cassette", "telemetry"}:
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(
tables=default_tables,
dataset_name=dataset_name or "sleepedf",
config_path=config_path,
dev=dev,
)

def prepare_metadata_cassette(self, root: str) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@
from .sdoh import SdohClassifier
from .medlink import MedLink
from .unified_embedding import UnifiedMultimodalEmbeddingModel, SinusoidalTimeEmbedding
from .califorest import CaliForest
from .multi_view_contrastive_time_series_model import MultiViewContrastiveTimeSeriesModel
from .mvcl_model import MultiViewContrastiveModel
from .califorest import CaliForest
302 changes: 302 additions & 0 deletions pyhealth/models/multi_view_contrastive_time_series_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
"""Multi-view contrastive time-series model for PyHealth datasets."""

import math
from typing import Any, Tuple, cast

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


class PositionalEncoding(nn.Module):
def __init__(self, hidden_dim, dropout=0.1, max_len=1024):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)

pe = torch.zeros(max_len, hidden_dim)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, hidden_dim, 2) *
(-math.log(10000.0) / hidden_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # Shape: [1, max_len, hidden_dim]
self.register_buffer('pe', pe)

def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)

class MultiViewContrastiveTimeSeriesModel(BaseModel):
"""Multi-view contrastive model for time-series tensors.

This model follows the multi-view contrastive learning setup used for
time-domain, derivative, and frequency-domain views. Each view is projected,
encoded with a Transformer encoder, fused with cross-view attention, and
used for either contrastive pretraining or downstream classification.

Args:
dataset (SampleDataset): Dataset with ``xt``, ``xd``, and ``xf`` tensor
inputs and one output label.
training_stage (str): Training stage, either ``"pretrain"`` for
contrastive representation learning or ``"finetune"`` for
classification. Default is ``"pretrain"``.
num_classes (int): Number of classes used by the classification head.
Default is 3.
**kwargs: Additional keyword arguments kept for PyHealth model API
compatibility.

Attributes:
hidden_dim: Hidden dimension used by projections, encoders, and fusion.
lambda_cl: Weight for the contrastive penalty during finetuning.
tau: Temperature used by the NT-Xent contrastive loss.

Examples:
>>> from pyhealth.models import MultiViewContrastiveTimeSeriesModel
>>> model = MultiViewContrastiveTimeSeriesModel(
... dataset=sample_dataset,
... training_stage="pretrain",
... num_classes=5,
... )
>>> output = model(xt=xt, xd=xd, xf=xf)
>>> sorted(output.keys())
['loss', 'z_d', 'z_f', 'z_t']
"""

def __init__(
self,
dataset: SampleDataset,
training_stage: str = "pretrain",
num_classes: int = 3,
**kwargs: Any
):
super().__init__(dataset=dataset)
self.hidden_dim = 128
self.training_stage = training_stage
self.lambda_cl = 0.1
self.tau = 0.07
self.num_classes = num_classes

self.temporal_projection = nn.Linear(1, self.hidden_dim)
self.derivative_projection = nn.Linear(1, self.hidden_dim)
self.frequency_projection = nn.Linear(1, self.hidden_dim)

self.pos_encoder = PositionalEncoding(self.hidden_dim, dropout=0.1)

def make_encoder() -> nn.TransformerEncoder:
encoder_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_dim, nhead=4, batch_first=True, dropout=0.2
)
return nn.TransformerEncoder(encoder_layer, num_layers=3)

self.encoder_t: nn.TransformerEncoder = make_encoder()
self.encoder_d: nn.TransformerEncoder = make_encoder()
self.encoder_f: nn.TransformerEncoder = make_encoder()

self.fusion_mha: nn.MultiheadAttention = nn.MultiheadAttention(
embed_dim=self.hidden_dim, num_heads=4, batch_first=True
)
self.fusion_layer_norm: nn.LayerNorm = nn.LayerNorm(self.hidden_dim)

def projector() -> nn.Sequential:
return nn.Sequential(
nn.Linear(self.hidden_dim * 2, self.hidden_dim),
nn.LayerNorm(self.hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(self.hidden_dim, self.hidden_dim),
)

self.temporal_feature_projector: nn.Sequential = projector()
self.derivative_feature_projector: nn.Sequential = projector()
self.frequency_feature_projector: nn.Sequential = projector()

self.classifier_mha = nn.MultiheadAttention(
embed_dim=self.hidden_dim, num_heads=1, batch_first=True
)
self.classifier = nn.Linear(self.hidden_dim * 3, self.num_classes)

def augment_time(self, x: torch.Tensor, std: float = 0.1) -> torch.Tensor:
"""Time-domain jitter augmentation"""
noise = torch.randn_like(x) * std
return x + noise

def augment_freq(self, sample: torch.Tensor, perturb_ratio: float = 0.05) -> torch.Tensor:
"""Frequency-domain augmentation (remove and add frequencies)"""
removed_frequency = self.remove_frequency(sample, perturb_ratio)
added_frequency = self.add_frequency(sample, perturb_ratio)
return removed_frequency + added_frequency

def remove_frequency(self, x: torch.Tensor, perturb_ratio: float = 0.0) -> torch.Tensor:
mask = torch.rand(x.shape, device=x.device) > perturb_ratio
return x * mask

def add_frequency(self, x: torch.Tensor, perturb_ratio: float = 0.0) -> torch.Tensor:
mask = torch.rand(x.shape, device=x.device) > (1 - perturb_ratio)
max_amplitude = x.max()
random_amplitude = torch.rand(mask.shape, device=x.device) * (max_amplitude * 0.1)
perturbation = mask * random_amplitude
return x + perturbation

def ntxent_loss(self, zis: torch.Tensor, zjs: torch.Tensor, tau: float) -> torch.Tensor:
"""2N x 2N NTXentLoss aligned with the TFC implementation."""
batch_size = zis.size(0)

# Normalize the representations
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)

# Concatenate into 2N
representations = torch.cat([zjs, zis], dim=0) # [2N, hidden_dim]

# Compute 2Nx2N cosine similarity matrix
similarity_matrix = torch.mm(representations, representations.T)

# Extract the positive pairs (offset by batch_size)
l_pos = torch.diag(similarity_matrix, batch_size)
r_pos = torch.diag(similarity_matrix, -batch_size)
positives = torch.cat([l_pos, r_pos]).view(2 * batch_size, 1)

# Create a mask to remove self-similarity (the diagonal)
mask = (~torch.eye(2 * batch_size, 2 * batch_size, dtype=torch.bool, device=zis.device))

# Extract negatives (everything except the diagonal)
negatives = similarity_matrix[mask].view(2 * batch_size, -1)

# Concatenate logits: [positives, negatives]
logits = torch.cat((positives, negatives), dim=1)
logits /= tau

# The positive sample is always at index 0 for each row
labels = torch.zeros(2 * batch_size, dtype=torch.long, device=zis.device)

# PyTorch CrossEntropy applies the log-softmax calculation
loss = F.cross_entropy(logits, labels, reduction="sum")

return loss / (2 * batch_size)

def _forward_features(self, x_t: torch.Tensor, x_d: torch.Tensor, x_f: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
x_t = self.temporal_projection(x_t)
x_d = self.derivative_projection(x_d)
x_f = self.frequency_projection(x_f)

x_t = self.pos_encoder(x_t)
x_d = self.pos_encoder(x_d)
x_f = self.pos_encoder(x_f)

h_t = self.encoder_t(x_t)
h_d = self.encoder_d(x_d)
h_f = self.encoder_f(x_f)

batch_size, seq_length, _ = h_t.shape
view_sequence = torch.stack([h_t, h_d, h_f], dim=2)
flattened_views = (
view_sequence
.permute(0, 2, 1, 3)
.contiguous()
.view(batch_size * 3, seq_length, self.hidden_dim)
)

attention_output, _ = self.fusion_mha(
flattened_views, flattened_views, flattened_views
)
fused_views = self.fusion_layer_norm(attention_output + flattened_views)

fused_views = fused_views.view(batch_size, 3, seq_length, self.hidden_dim).permute(0, 2, 1, 3)
h_t_star, h_d_star, h_f_star = (
fused_views[:, :, 0, :],
fused_views[:, :, 1, :],
fused_views[:, :, 2, :],
)

# Pool across sequence length and concatenate with pre-interaction features
h_t_pool = torch.cat([h_t.mean(dim=1), h_t_star.mean(dim=1)], dim=-1)
h_d_pool = torch.cat([h_d.mean(dim=1), h_d_star.mean(dim=1)], dim=-1)
h_f_pool = torch.cat([h_f.mean(dim=1), h_f_star.mean(dim=1)], dim=-1)

z_t = self.temporal_feature_projector(h_t_pool)
z_d = self.derivative_feature_projector(h_d_pool)
z_f = self.frequency_feature_projector(h_f_pool)

return z_t, z_d, z_f

def forward(self, **kwargs) -> dict[str, torch.Tensor]:
temporal_tensor = self._prepare_tensor(kwargs.get("xt")) # [N, L, 1]
derivative_tensor = self._prepare_tensor(kwargs.get("xd")) # [N, L, 1]
frequency_tensor = self._prepare_tensor(kwargs.get("xf")) # [N, L, 1]

if self.training_stage == "pretrain":
x_t_aug = self.augment_time(temporal_tensor)
x_d_aug = self.augment_time(derivative_tensor)
x_f_aug = self.augment_freq(frequency_tensor)

z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor)
z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug)

loss = self.ntxent_loss(z_t, z_t_aug, self.tau) + \
self.ntxent_loss(z_d, z_d_aug, self.tau) + \
self.ntxent_loss(z_f, z_f_aug, self.tau)

# Dict strictly containing torch.Tensor
return {
"loss": loss,
"z_t": z_t,
"z_d": z_d,
"z_f": z_f
}

elif self.training_stage == "finetune":
z_t, z_d, z_f = self._forward_features(temporal_tensor, derivative_tensor, frequency_tensor)

# Cross-view attention for classification
stacked_emb = torch.stack([z_t, z_d, z_f], dim=1) # [batch_size, 3, hidden_dim]
attention_output, _ = self.classifier_mha(stacked_emb, stacked_emb, stacked_emb)
emb = attention_output + stacked_emb # Residual connection

z_combined = emb.reshape(emb.size(0), -1) # Flatten to [batch_size, 3 * hidden_dim]
logits = self.classifier(z_combined)

# Use PyHealth's automatic label parsing
label_key = self.label_keys[0]
y_true = cast(torch.Tensor, kwargs[label_key])
y_true = y_true.to(logits.device)

# Use PyHealth's automatic loss function mapping
criterion = self.get_loss_function()
# Cross entropy expects raw logits, not argmax class indices.
loss_ce = criterion(logits, y_true)

# Contrastive penalty during finetuning
x_t_aug = self.augment_time(temporal_tensor)
x_d_aug = self.augment_time(derivative_tensor)
x_f_aug = self.augment_freq(frequency_tensor)

z_t_aug, z_d_aug, z_f_aug = self._forward_features(x_t_aug, x_d_aug, x_f_aug)
loss_cl = self.ntxent_loss(z_t, z_t_aug, self.tau) + \
self.ntxent_loss(z_d, z_d_aug, self.tau) + \
self.ntxent_loss(z_f, z_f_aug, self.tau)

total_loss = (self.lambda_cl * loss_cl) + loss_ce

# Return PyHealth's expected dictionary schema
return {
"loss": total_loss,
"logit": logits,
"y_prob": self.prepare_y_prob(logits), # Autocast to prob
"y_true": y_true
}
return {}

def _prepare_tensor(self, x) -> torch.Tensor:
"""Converts lists to batched tensors, enforces float32, and moves to device."""
if isinstance(x, list):
if isinstance(x[0], torch.Tensor):
x = torch.stack(x)
else:
x = torch.from_numpy(np.stack(x))

# Enforce standard float precision and push to GPU
return x.float().to(self.device)
Loading