Med-R1 is a reinforcement learning-enhanced vision-language model (VLM) designed for generalizable medical reasoning. Built on Qwen2-VL, Med-R1 uses Group Relative Policy Optimization (GRPO) to support 8 diverse imaging modalities and 5 core diagnostic tasks, achieving high performance with parameter efficiency.
> Qwen2.5 checkpoint and training code are now publicly available on Hugging Face!Med-R1 explores how reinforcement learning can improve medical reasoning in VLMs. Unlike traditional supervised fine-tuning (SFT), which may overfit to specific tasks, Med-R1 leverages reward-guided optimization to promote robust, diverse, and interpretable reasoning paths.
Med-R1 is optimized with GRPO, producing stable training behavior across imaging types and diagnostic tasks.
conda create -n med-r1 python=3.11
conda activate med-r1
bash setup.shNote
If you encounter issues during setup, please ensure your environment aligns with ./src/requirements.txt.
Qwen2-VLQwen2.5-VL
We provide cross-modality checkpoints, each trained on a specific imaging type:
- CT
- MRI
- X-Ray
- Fundus (FP)
- Dermoscopy (Der)
- Microscopy (Micro)
- Optical Coherence Tomography (OCT)
- Ultrasound (US)
We also release cross-task checkpoints, each focusing on a key diagnostic function:
- Anatomy Identification (AI)
- Disease Diagnosis (DD)
- Lesion Grading (LG)
- Modality Recognition (MR)
- Biological Attribute Analysis (OBA)
All input images should be resized to 384Γ384 resolution. Below is an example of the expected input JSON format:
[
{
"image": "Images/Chest CT Scan/test/adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib/000139 (9).png",
"problem": "What imaging technique is employed for obtaining this image? A)Mammogram, B)PET, C)CT, D)Fluoroscopy",
"solution": "<answer> C </answer>"
},
...
]torchrun --nproc_per_node=2 \
--nnodes=1 \
--node_rank=0 \
--master_addr="127.0.0.1" \
--master_port=12345 \
src/open_r1/grpo_vqa_nothink.py \
--output_dir ./output/Modality_CT \
--model_name_or_path ./checkpoints/Qwen2.5-VL-3B-Instruct \
--dataset_name ./data/VQA/CT_384 \
--max_prompt_length 1024 \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 2 \
--logging_steps 1 \
--bf16 \
--report_to wandb \
--gradient_checkpointing false \
--attn_implementation flash_attention_2 \
--max_pixels 401408 \
--num_train_epochs 2 \
--run_name Qwen2.5-VL-3B-GRPO-CT \
--save_steps 200 \
--save_only_model true \
--num_generations 4from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
MODEL_PATH = "..."
model = Qwen2VLForConditionalGeneration.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)Med-R1 generates chain-of-thought (CoT) responses for medical visual queries:
from qwen_vl_utils import process_vision_info
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
data = json.load(f)
QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> and final choice (A, B, C, D ...) in <answer> </answer> tags."
messages = []
for i in data:
message = [{
"role": "user",
"content": [
{"type": "image", "image": f"file://{i['image']}"},
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=i['problem'])}
]
}]
messages.append(message)
for i in tqdm(range(0, len(messages), BSZ)):
batch = messages[i:i + BSZ]
text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch]
image_inputs, video_inputs = process_vision_info(batch)
inputs = processor(text=text, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
inputs = inputs.to("cuda")
outputs = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
trimmed = [out[len(inp):] for inp, out in zip(inputs.input_ids, outputs)]
decoded = processor.batch_decode(trimmed, skip_special_tokens=True)
all_outputs.extend(decoded)Evaluation across modalities and tasks:
We thank the authors of OmniMedVQA and R1-V for their open-source contributions.
π R1-V GitHub
π OmniMedVQA GitHub
@article{lai2025med,
title={Med-R1: Reinforcement Learning for Generalizable Medical Reasoning in Vision-Language Models},
author={Lai, Yuxiang and Zhong, Jike and Li, Ming and Zhao, Shitian and Yang, Xiaofeng},
journal={arXiv preprint arXiv:2503.13939},
year={2025}
}



