Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
078ecf9
T1: add PromptEHR source files, task stub, and examples
jalengg Mar 1, 2026
6a65142
T2: add PromptEHRGenerationMIMIC3 BaseTask with demographics
jalengg Mar 1, 2026
39ec171
T3+fix: Refactor PromptEHR to BaseModel; fix T2 early exit + T3 revie…
jalengg Mar 1, 2026
68f0ca3
T5: Add PromptEHR PyHealth 2.0 generation example
jalengg Mar 1, 2026
54a0836
T4: Add PromptEHR PyHealth 2.0 training example
jalengg Mar 1, 2026
9e9589a
T7: Update PromptEHR docstrings to Google/PyHealth style
jalengg Mar 1, 2026
ad0b27b
T8: Add PromptEHR integration tests (8 pass, 4 skip MIMIC-III)
jalengg Mar 1, 2026
a7be297
Add PromptEHR Colab notebook: demographic-conditioned synthetic EHR g…
jalengg Mar 2, 2026
b1cc36d
Fix: idempotent Drive mount in Colab notebook
jalengg Mar 4, 2026
5ab7596
Fix: guard scipy/mne-dependent task imports to fix Colab numpy 2.x ca…
jalengg Mar 4, 2026
4f7edb5
Feat: persist MIMIC-III files to Drive, skip re-upload on reconnect
jalengg Mar 4, 2026
394e128
Fix: guard PIL/mne-dependent dataset imports to fix Colab import cascade
jalengg Mar 4, 2026
8c176e9
Fix: force-reinstall PyHealth in setup cell; update preamble SHA
jalengg Mar 4, 2026
e77178a
Chore: update preamble SHA to 8c176e9c
jalengg Mar 4, 2026
6a4e1c8
Chore: switch preamble timestamp to UTC (no SHA lag)
jalengg Mar 4, 2026
6bb347a
Fix: guard optional-dep model imports to fix Colab sklearn/scipy cascade
jalengg Mar 4, 2026
2dad59c
Chore: update notebook timestamp 2026-03-04 08:21:17 UTC
jalengg Mar 4, 2026
517260f
Fix: guard EEGAbnormalTUAB/EEGEventsTUEV at source; rewrite preamble
jalengg Mar 4, 2026
c3618c7
Fix: install scipy>=1.14 AFTER PyHealth to prevent --force-reinstall …
jalengg Mar 4, 2026
2d4b5be
Fix: add Pillow>=10.4.0 to post-install step to fix mixed PIL state
jalengg Mar 4, 2026
b30c27b
Fix: guard ImageProcessor/TimeImageProcessor in processors/__init__.p…
jalengg Mar 4, 2026
2a50760
Chore: update notebook timestamp to 2026-03-04 09:55:13 (UTC)
jalengg Mar 4, 2026
5aa7800
Fix: Drive mount guard + makedirs ordering — files no longer re-uploa…
jalengg Mar 4, 2026
bf203a6
Fix: force-reinstall numpy+scipy post-PyHealth to clear mixed-version…
jalengg Mar 4, 2026
9ac2960
Chore: fix notebook timestamp to 2026-03-04 10:20:35 (UTC)
jalengg Mar 4, 2026
3a9c4cd
Fix: processors/__init__ __all__ + RuntimeError guard + numpy>=2.0.0
jalengg Mar 4, 2026
cbdd115
Chore: fix notebook timestamp to 2026-03-04 10:31:20 (UTC)
jalengg Mar 4, 2026
8872b7d
Fix: numpy version ceiling, force-reinstall cascade, Drive stale mount
jalengg Mar 4, 2026
364d6f6
Fix numpy mixed-version error: replace --force-reinstall with uninsta…
jalengg Mar 4, 2026
121999f
Remove step 2 dep upgrades causing Pillow mixed-version state
jalengg Mar 4, 2026
732d207
Fix Pillow mixed-version state: force-reinstall Pillow after PyHealth…
jalengg Mar 4, 2026
d05d445
Fix Colab PIL error: hide torchvision during BART import in PromptEHR
jalengg Mar 4, 2026
cd42f5b
Fix PyHealth 2.0 API: remove code_mapping kwarg, use unique_patient_ids
jalengg Mar 4, 2026
8c1120e
Remove icustays from MIMIC3Dataset defaults (same fix as HALO c52aa0b0)
jalengg Mar 4, 2026
cb0f6f9
Fix device mismatch in synthesize_dataset: use bart_model's device
jalengg Mar 4, 2026
f5a7d4c
Fix beam search crash in generate: set num_beams=1 explicitly
jalengg Mar 4, 2026
d591d0b
Fix decode_tokens: skip BOS/EOS/PAD instead of breaking on them
jalengg Mar 4, 2026
3a551c5
Cleanup: simplify PromptEHR notebook to match HALO structure
jalengg Mar 5, 2026
a967111
Cleanup: remove duplicate reference section and footer from notebook
jalengg Mar 5, 2026
7c1ac57
Fix: suppress wandb prompt and early_stopping warning during training
jalengg Mar 6, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/generate_synthetic_mimic3_promptehr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""PromptEHR: Synthetic MIMIC-III Patient Generation.

Load a trained PromptEHR checkpoint and generate synthetic patients.

Reference:
Wang et al. "PromptEHR: Conditional Electronic Healthcare Records
Generation with Prompt Learning." EMNLP 2023.
https://arxiv.org/abs/2211.01761
"""

import json

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import PromptEHR
from pyhealth.tasks import promptehr_generation_mimic3_fn

MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"
CHECKPOINT_PATH = "./save/promptehr/checkpoint.pt"
OUTPUT_PATH = "./save/promptehr/synthetic_patients.json"
NUM_SAMPLES = 10_000

# 1. Load dataset + apply task (needed for processor/vocab reconstruction)
dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["patients", "admissions", "diagnoses_icd"],
code_mapping={},
)
sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)

# 2. Load checkpoint
model = PromptEHR(dataset=sample_dataset)
model.load_model(CHECKPOINT_PATH)
print(f"Loaded checkpoint from {CHECKPOINT_PATH}")

# 3. Generate
print(f"Generating {NUM_SAMPLES} synthetic patients...")
synthetic = model.synthesize_dataset(num_samples=NUM_SAMPLES)
print(f"Generated {len(synthetic)} patients")

# 4. Save
with open(OUTPUT_PATH, "w") as f:
json.dump(synthetic, f, indent=2)
print(f"Saved to {OUTPUT_PATH}")

# Summary stats
avg_visits = sum(len(p["visits"]) for p in synthetic) / len(synthetic)
print(f"Average visits per patient: {avg_visits:.2f}")
192 changes: 192 additions & 0 deletions examples/promptehr_mimic3_colab.ipynb

Large diffs are not rendered by default.

47 changes: 47 additions & 0 deletions examples/promptehr_mimic3_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""PromptEHR: Training on MIMIC-III.

Train PromptEHR for synthetic EHR generation using PyHealth 2.0 API.

Reference:
Wang et al. "PromptEHR: Conditional Electronic Health Records Generation
with Prompt Learning." CHIL 2023.
"""

from pyhealth.datasets import MIMIC3Dataset, split_by_patient
from pyhealth.models import PromptEHR
from pyhealth.tasks import promptehr_generation_mimic3_fn

MIMIC3_ROOT = "/srv/local/data/physionet.org/files/mimiciii/1.4"

# 1. Load MIMIC-III
dataset = MIMIC3Dataset(
root=MIMIC3_ROOT,
tables=["patients", "admissions", "diagnoses_icd"],
code_mapping={},
)

# 2. Apply generation task
sample_dataset = dataset.set_task(promptehr_generation_mimic3_fn)
print(f"Patients: {len(sample_dataset)}")
sample_dataset.stat()

# 3. Split
train, val, test = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])

# 4. Initialize model
model = PromptEHR(
dataset=sample_dataset,
n_num_features=1,
cat_cardinalities=[2],
d_hidden=128,
prompt_length=1,
epochs=20,
batch_size=16,
lr=1e-5,
warmup_steps=1000,
save_dir="./save/promptehr/",
)

# 5. Train
model.train_model(train, val)
print("Training complete. Checkpoint saved to ./save/promptehr/")
25 changes: 20 additions & 5 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,16 @@ def __init__(self, *args, **kwargs):

from .base_dataset import BaseDataset
from .cardiology import CardiologyDataset
from .chestxray14 import ChestXray14Dataset
try:
from .chestxray14 import ChestXray14Dataset
except ImportError:
pass # PIL/torchvision unavailable
from .clinvar import ClinVarDataset
from .cosmic import COSMICDataset
from .covid19_cxr import COVID19CXRDataset
try:
from .covid19_cxr import COVID19CXRDataset
except ImportError:
pass # PIL/torchvision unavailable
from .dreamt import DREAMTDataset
from .ehrshot import EHRShotDataset
from .eicu import eICUDataset
Expand All @@ -63,7 +69,10 @@ def __init__(self, *args, **kwargs):
from .omop import OMOPDataset
from .sample_dataset import SampleBuilder, SampleDataset, create_sample_dataset
from .shhs import SHHSDataset
from .sleepedf import SleepEDFDataset
try:
from .sleepedf import SleepEDFDataset
except ImportError:
pass # mne unavailable
from .bmd_hs import BMDHSDataset
from .support2 import Support2Dataset
from .tcga_prad import TCGAPRADDataset
Expand All @@ -76,8 +85,14 @@ def __init__(self, *args, **kwargs):
split_by_visit,
split_by_visit_conformal,
)
from .tuab import TUABDataset
from .tuev import TUEVDataset
try:
from .tuab import TUABDataset
except ImportError:
pass # mne unavailable; TUABDataset not registered
try:
from .tuev import TUEVDataset
except ImportError:
pass # mne unavailable; TUEVDataset not registered
from .utils import (
collate_fn_dict,
collate_fn_dict_with_padding,
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/datasets/mimic3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
if config_path is None:
logger.info("No config path provided, using default config")
config_path = Path(__file__).parent / "configs" / "mimic3.yaml"
default_tables = ["patients", "admissions", "icustays"]
default_tables = ["patients", "admissions"]
tables = default_tables + tables
if "prescriptions" in tables:
warnings.warn(
Expand Down
5 changes: 4 additions & 1 deletion pyhealth/datasets/tuab.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from typing import Optional
from .base_dataset import BaseDataset
from pyhealth.tasks import EEGAbnormalTUAB
try:
from pyhealth.tasks import EEGAbnormalTUAB
except ImportError:
EEGAbnormalTUAB = None # mne unavailable; TUABDataset.default_task will raise if called

logger = logging.getLogger(__name__)

Expand Down
5 changes: 4 additions & 1 deletion pyhealth/datasets/tuev.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@

from typing import Optional
from .base_dataset import BaseDataset
from pyhealth.tasks import EEGEventsTUEV
try:
from pyhealth.tasks import EEGEventsTUEV
except ImportError:
EEGEventsTUEV = None # mne unavailable; TUEVDataset.default_task will raise if called

logger = logging.getLogger(__name__)

Expand Down
76 changes: 56 additions & 20 deletions pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .adacare import AdaCare, AdaCareLayer
from .agent import Agent, AgentLayer
from .base_model import BaseModel
from .biot import BIOT
from .cnn import CNN, CNNLayer
try:
from .biot import BIOT
except ImportError:
pass # einops unavailable
try:
from .cnn import CNN, CNNLayer
except ImportError:
pass # PIL/torchvision unavailable
from .concare import ConCare, ConCareLayer
from .contrawr import ContraWR, ResBlock2D
from .deepr import Deepr, DeeprLayer
Expand All @@ -12,33 +18,63 @@
from .logistic_regression import LogisticRegression
from .gan import GAN
from .gnn import GAT, GCN
from .graph_torchvision_model import Graph_TorchvisionModel
from .grasp import GRASP, GRASPLayer
try:
from .graph_torchvision_model import Graph_TorchvisionModel
except ImportError:
pass # torchvision unavailable
try:
from .grasp import GRASP, GRASPLayer
except ImportError:
pass # sklearn unavailable
from .medlink import MedLink
from .micron import MICRON, MICRONLayer
from .mlp import MLP
from .molerec import MoleRec, MoleRecLayer
try:
from .molerec import MoleRec, MoleRecLayer
except ImportError:
pass # rdkit unavailable
from .promptehr import PromptEHR
from .retain import RETAIN, RETAINLayer
from .rnn import MultimodalRNN, RNN, RNNLayer
from .safedrug import SafeDrug, SafeDrugLayer
try:
from .safedrug import SafeDrug, SafeDrugLayer
except ImportError:
pass # rdkit unavailable
from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer
from .stagenet import StageNet, StageNetLayer
from .stagenet_mha import StageAttentionNet, StageNetAttentionLayer
from .tcn import TCN, TCNLayer
from .tfm_tokenizer import (
TFMTokenizer,
TFM_VQVAE2_deep,
TFM_TOKEN_Classifier,
get_tfm_tokenizer_2x2x8,
get_tfm_token_classifier_64x4,
load_embedding_weights,
)
from .torchvision_model import TorchvisionModel
try:
from .tfm_tokenizer import (
TFMTokenizer,
TFM_VQVAE2_deep,
TFM_TOKEN_Classifier,
get_tfm_tokenizer_2x2x8,
get_tfm_token_classifier_64x4,
load_embedding_weights,
)
except ImportError:
pass # einops unavailable
try:
from .torchvision_model import TorchvisionModel
except ImportError:
pass # torchvision unavailable
from .transformer import Transformer, TransformerLayer
from .transformers_model import TransformersModel
try:
from .transformers_model import TransformersModel
except ImportError:
pass # transformers unavailable
from .ehrmamba import EHRMamba, MambaBlock
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
from .text_embedding import TextEmbedding
from .sdoh import SdohClassifier
from .medlink import MedLink
try:
from .vision_embedding import VisionEmbeddingModel
except ImportError:
pass # PIL/torchvision unavailable
try:
from .text_embedding import TextEmbedding
except ImportError:
pass # transformers unavailable
try:
from .sdoh import SdohClassifier
except ImportError:
pass # transformers/peft unavailable
41 changes: 41 additions & 0 deletions pyhealth/models/promptehr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""PromptEHR: Prompt-based BART model for synthetic EHR generation.

This module provides a demographic-conditioned sequence-to-sequence model
for generating realistic synthetic electronic health records.

Main components:
- PromptEHR: Main model class (inherits from BaseModel)
- ConditionalPromptEncoder: Demographic conditioning with reparameterization
- PromptBartEncoder: Modified BART encoder with prompt injection
- PromptBartDecoder: Modified BART decoder with prompt injection
- VisitStructureSampler: Utility for structure-constrained generation
- Generation functions: sample_demographics, parse_sequence_to_visits, etc.
"""

from .model import PromptEHR
from .conditional_prompt import ConditionalPromptEncoder
from .bart_encoder import PromptBartEncoder
from .bart_decoder import PromptBartDecoder
from .visit_sampler import VisitStructureSampler
from .generation import (
DemographicSampler,
sample_demographics,
decode_patient_demographics,
parse_sequence_to_visits,
generate_patient_sequence_conditional,
generate_patient_with_structure_constraints
)

__all__ = [
"PromptEHR",
"ConditionalPromptEncoder",
"PromptBartEncoder",
"PromptBartDecoder",
"VisitStructureSampler",
"DemographicSampler",
"sample_demographics",
"decode_patient_demographics",
"parse_sequence_to_visits",
"generate_patient_sequence_conditional",
"generate_patient_with_structure_constraints",
]
Loading