Simple Diffusion Language Modeling
dLLM is a library that unifies the training and evaluation of diffusion language models, bringing transparency and reproducibility to the entire development pipeline:
-
dLLM provides scalable training pipelines (based on
transformersTrainer), with support for LoRA, DeepSpeed, FSDP and beyond. -
dLLM provides unified evaluation pipelines (based on
lm-evaluation-harness) that abstracts away inference details and making customization simple. -
Built on these components, dLLM provide the minimal pretraining / finetuning / evaluation recipes for open-weight models (e.g., LLaDA and Dream), and implementations of training algorithms (e.g., MDLM (masked diffusion), BD3LM (block diffusion), Edit Flows and so on).
[2025/12] π€Tiny-A2D: We released a collection of SOTA small (0.5B/0.6B) diffusion models adapted from AR models, with fully open recipes for converting ANY AR model (e.g., Qwen, LLaMA, and GPT-2) into a diffusion model. See examples/a2d for training / inference / evaluation instructions.
[2025/11] π€BERT-Chat: We released a collection of BERTs finetuned to chat with diffusion, with open recipes for turning ANY BERT encoder (e.g., BERT, RoBERTa, ModernBERT) into a diffusion model. See examples/bert for training / inference / evaluation instructions.
examples/llada: Pretraining, finetuning and evaluating LLaDA LLaDA / LLaDA-MoE.examples/dream: Pretraining, finetuning and evaluating Dream Dream.examples/a2d: Finetuning any autoregressive model to generate text with masked diffusion / block diffusion.examples/bert: Finetuning any BERT to be lightweight Chatbots.examples/editflow: Educational reference for training Edit Flows models, demonstrating how to extend existing DLLMs (e.g., LLaDA, Dream, BERT-Chat) with edit operationsβinsertion, deletion, and substitutionβand how to pretrain or finetune Edit Flows models from scratch on public data.- More upcoming.
# create and activate conda environment
conda create -n dllm python=3.10 -y
conda activate dllm
# install pytorch with CUDA 12.4 (other pytorch/cuda versions should also work)
conda install cuda=12.4 -c nvidia
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 \
--index-url https://download.pytorch.org/whl/cu124
# install dllm package
pip install -e .# initialize `lm-evaluation-harness` submodule
git submodule update --init --recursive
# install submodule in editable mode with IFEval & Math dependencies
pip install -e "lm-evaluation-harness[ifeval,math]"For Slurm users, update scripts/train.slurm.sh for your cluster:
- #SBATCH --partition=mllm_safety # Note: adjust this for your cluster
- #SBATCH --quotatype=spot # Note: adjust this for your cluster
+ #SBATCH --partition=YOUR_PARTITION
+ #SBATCH --quotatype=YOUR_QUOTATYPENext, create a directory for your job logs:
mkdir logsThis folder will store the log files generated by your sbatch jobs.
# modules for training / sampling
dllm
βββ core # Core reusable modules shared across `dllm/pipelines`
β βββ samplers
β βββ schedulers
β βββ trainers
βββ data
βββ pipelines # Application-specific training & inference pipelines
| βββ bert
β βββ dream
β βββ editflow
β βββ llada
β βββ models # Model architecture and configs
β βββ sampler.py # Inference module
β βββ trainer.py # Training module
β βββ eval.py # Evaluation module
βββ tools
βββ utils
# entry points for training / sampling
examples
βββ bert
βββ dream
βββ editflow
βββ llada
βββ chat.py # Interactive inference example
βββ sample.py # Inference example
βββ pt.py # Pretraining example
βββ README.md # Documentation (you are here)
βββ sft.py # Supervised finetuning example
βββ eval.sh # Evalution script
A typical training entry script looks like (for example, examples/llada/sft.py) looks like this:
import transformers
import dllm
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ----- Model ------------------------------------------------------------------
model = dllm.utils.get_model(model_args=model_args)
# ----- Tokenizer --------------------------------------------------------------
tokenizer = dllm.utils.get_tokenizer(model_args=model_args)
# ----- Dataset ----------------------------------------------------------------
dataset = "..."
# ----- Training --------------------------------------------------------------
trainer = dllm.core.trainers.MDLMTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
label_pad_token_id=tokenizer.pad_token_id,
),
)
trainer.train()You can launch training job locally with accelerate, or submit it to a Slurm cluster using sbatch.
# Run locally (ZeRO-2 on 8 GPUs with 4bit quantization and LoRA)
accelerate launch \
--config_file scripts/accelerate_configs/zero2.yaml \
examples/llada/sft.py \
--num_train_epochs 4 \
--load_in_4bit True --lora True# Submit to a Slurm cluster (FSDP on 1 node, 8 GPUs)
sbatch --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4
# Submit to a Slurm cluster (FSDP on 2 nodes, 16 GPUs)
sbatch --nodes=2 --gres=gpu:8 scripts/train.slurm.sh \
--accelerate_config "fsdp" \
--script_path "examples/llada/sft.py" \
--num_train_epochs 4See Features for specific training recipes.
- Use a subset of data:
--dataset_args "allenai/tulu-3-sft-mixture[train:10000,test:1000]" - Concatenate datasets:
--dataset_args "allenai/tulu-3-sft-mixture+HuggingFaceTB/smoltalk" - Train with LoRA and 4bit quantization:
--load_in_4bit True --lora True - Train with different distributed training methods:
--accelerate_config "ddp,zero-{1,2,3},fsdp" - Load pretraining dataset in streaming mode:
--streaming True - Preprocesss SFT dataset before training (e.g., LLaDA):
# Preprocess SFT data python dllm/tools/preprocess_sft_dataset.py \ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \ --dataset_args "allenai/tulu-3-sft-mixture" \ --output_dir "data/sft/llada/tulu-3-sft-mixture" \ --num_proc 64 # SFT with preprocessed data accelerate launch \ --config_file scripts/accelerate_configs/fsdp.yaml \ examples/llada/sft.py \ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ --dataset_args "data/sft/llada/tulu-3-sft-mixture" \ --load_preprocessed_data True \ ...
We provide unified samplers that abstracts away inference details.
A typical inference entry script (for example, examples/llada/sample.py) looks like this:
import dllm
model = dllm.utils.get_model(model_args=script_args).eval()
tokenizer = dllm.utils.get_tokenizer(model_args=script_args)
sampler = dllm.core.samplers.MDLMSampler(model=model, tokenizer=tokenizer)
messages = [
[{"role": "user", "content": "Lily runs 12 km/h for 4 hours. How far in 8 hours?"}],
[{"role": "user", "content": "Please write an educational python function."}],
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
)
outputs = sampler.sample(inputs, return_dict=True)
sequences = dllm.utils.decode_trim(tokenizer, outputs.sequences.tolist(), inputs)You can also try interactive chat script (for example, examples/llada/chat.py) for visualized multi-turn dialogue:
python -u examples/llada/chat.py --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct"Read (optional) Evaluation setup before running evaluation.
For example, to evaluate LLaDA-8B-Instruct on MMLU_Pro, run:
accelerate launch --num_processes 4 \
dllm/pipelines/llada/eval.py \
--tasks "mmlu_pro" \
--model "llada" \
--apply_chat_template \
--num_fewshot 0 \
--model_args "pretrained=GSAI-ML/LLaDA-8B-Instruct,is_check_greedy=False,mc_num=1,max_new_tokens=256,steps=256,block_size=256,cfg=0.0"We also provide scripts to automatically evaluate LLaDA, Dream, and BERT-Chat on all benchmarks.
For example, you can launch examples/llada/eval.sh directly using the following commands:
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Instruct" --instruct True
bash examples/llada/eval.sh --model_name_or_path "GSAI-ML/LLaDA-8B-Base" --instruct False@misc{dllm,
author = {Zhanhui Zhou and Lingjie Chen and Hanghang Tong and Dawn Song},
title = {dLLM: Simple Diffusion Language Modeling},
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/ZHZisZZ/dllm}},
}

