Skip to content

Commit e06a3d7

Browse files
authored
Add Sortformer CUDA export and Linux/Windows CUDA CI coverage (pytorch#17865)
## Summary This PR adds CUDA coverage for Sortformer in both Linux and Windows CI, and updates the Sortformer example/export path so CUDA artifacts are exportable and runnable end-to-end. ## What Changed - Added Sortformer to CUDA export/e2e matrices in: - `.github/workflows/cuda.yml` (Linux CUDA) - `.github/workflows/cuda-windows.yml` (Windows CUDA runtime, Linux export) - Extended CI export/test scripts for Sortformer: - `.ci/scripts/export_model_artifact.sh` - Added `nvidia/diar_streaming_sortformer_4spk-v2` support - Added Sortformer-specific export path - Enforced non-quantized Sortformer export - `.ci/scripts/test_model_e2e.sh` - Added Sortformer model routing, test audio download, and runner invocation - `.ci/scripts/test_model_e2e_windows.ps1` - Added Sortformer runner path/args and expected-output validation - Enabled Sortformer CUDA build targets: - `examples/models/sortformer/CMakePresets.json` - Added `sortformer-cuda` configure/build/workflow presets - `Makefile` - Added `sortformer-cuda` target and help text - Updated Sortformer runner to accept CUDA named-data blob: - `examples/models/sortformer/main.cpp` - Added `--data_path` - `examples/models/sortformer/sortformer_runner.h/.cpp` - Added constructor overload/path handling for optional `.ptd` - Updated Sortformer exporter for CUDA backends: - `examples/models/sortformer/export_sortformer.py` - Added backend choices: `cuda`, `cuda-windows` - Added CUDA/CUDA-Windows lowering path - Writes external tensor data via `write_tensor_data_to_file(output_dir)` - Verifies `aoti_cuda_blob.ptd` exists in output dir - Added explicit print for blob write location ## Validation - `python -m py_compile examples/models/sortformer/export_sortformer.py` - CI coverage is now wired for: - Linux CUDA export + e2e Sortformer - Windows CUDA e2e Sortformer (using exported artifact)
1 parent 4c20ef1 commit e06a3d7

12 files changed

Lines changed: 245 additions & 17 deletions

File tree

.ci/scripts/export_model_artifact.sh

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Arguments:
2222
- mistralai/Voxtral-Mini-4B-Realtime-2602
2323
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2424
- google/gemma-3-4b-it
25+
- nvidia/diar_streaming_sortformer_4spk-v2
2526
- nvidia/parakeet-tdt
2627
2728
quant_name Quantization type (optional, default: non-quantized)
@@ -45,6 +46,7 @@ Examples:
4546
export_model_artifact.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "quantized-int4-metal"
4647
export_model_artifact.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "non-quantized" "." "vr-streaming"
4748
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
49+
export_model_artifact.sh cuda-windows "nvidia/diar_streaming_sortformer_4spk-v2" "non-quantized" "./output"
4850
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
4951
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
5052
export_model_artifact.sh xnnpack "nvidia/parakeet-tdt" "quantized-8da4w" "./output"
@@ -157,6 +159,14 @@ case "$HF_MODEL" in
157159
PREPROCESSOR_FEATURE_SIZE=""
158160
PREPROCESSOR_OUTPUT=""
159161
;;
162+
nvidia/diar_streaming_sortformer_4spk-v2)
163+
MODEL_NAME="sortformer"
164+
TASK=""
165+
MAX_SEQ_LEN=""
166+
EXTRA_PIP=""
167+
PREPROCESSOR_FEATURE_SIZE=""
168+
PREPROCESSOR_OUTPUT=""
169+
;;
160170
mistralai/Voxtral-Mini-4B-Realtime-2602)
161171
MODEL_NAME="voxtral_realtime"
162172
TASK=""
@@ -167,7 +177,7 @@ case "$HF_MODEL" in
167177
;;
168178
*)
169179
echo "Error: Unsupported model '$HF_MODEL'"
170-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
180+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt"
171181
exit 1
172182
;;
173183
esac
@@ -247,6 +257,42 @@ if [ "$MODEL_NAME" = "parakeet" ]; then
247257
exit 0
248258
fi
249259

260+
# Sortformer uses a custom export script
261+
if [ "$MODEL_NAME" = "sortformer" ]; then
262+
if [ "$QUANT_NAME" != "non-quantized" ]; then
263+
echo "Error: Sortformer currently supports only non-quantized export"
264+
exit 1
265+
fi
266+
267+
pip install -r examples/models/sortformer/install_requirements.txt
268+
269+
SORTFORMER_BACKEND="$DEVICE"
270+
if [ "$DEVICE" = "cuda-windows" ]; then
271+
SORTFORMER_BACKEND="cuda-windows"
272+
elif [ "$DEVICE" = "cuda" ]; then
273+
SORTFORMER_BACKEND="cuda"
274+
elif [ "$DEVICE" = "xnnpack" ]; then
275+
SORTFORMER_BACKEND="xnnpack"
276+
else
277+
SORTFORMER_BACKEND="portable"
278+
fi
279+
280+
python -m executorch.examples.models.sortformer.export_sortformer \
281+
--hf-model "${HF_MODEL}" \
282+
--backend "${SORTFORMER_BACKEND}" \
283+
--output-dir "${OUTPUT_DIR}"
284+
285+
test -f "${OUTPUT_DIR}/sortformer.pte"
286+
mv "${OUTPUT_DIR}/sortformer.pte" "${OUTPUT_DIR}/model.pte"
287+
# CUDA saves named data to separate .ptd file, XNNPACK/portable do not.
288+
if [ "$DEVICE" = "cuda" ] || [ "$DEVICE" = "cuda-windows" ]; then
289+
test -f "${OUTPUT_DIR}/aoti_cuda_blob.ptd"
290+
fi
291+
ls -al "${OUTPUT_DIR}"
292+
echo "::endgroup::"
293+
exit 0
294+
fi
295+
250296
# Voxtral Realtime uses a custom export script
251297
if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
252298
pip install safetensors huggingface_hub

.ci/scripts/test_model_e2e.sh

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Arguments:
1919
hf_model HuggingFace model ID (required)
2020
Supported models:
2121
- mistralai/Voxtral-Mini-3B-2507
22+
- nvidia/diar_streaming_sortformer_4spk-v2
2223
- openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo})
2324
- google/gemma-3-4b-it
2425
- Qwen/Qwen3-0.6B
@@ -44,6 +45,7 @@ Arguments:
4445
Examples:
4546
test_model_e2e.sh metal "openai/whisper-small" "non-quantized"
4647
test_model_e2e.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output"
48+
test_model_e2e.sh cuda "nvidia/diar_streaming_sortformer_4spk-v2" "non-quantized" "./model_output"
4749
test_model_e2e.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./model_output"
4850
test_model_e2e.sh xnnpack "nvidia/parakeet-tdt" "quantized-8da4w" "./model_output"
4951
test_model_e2e.sh metal "mistralai/Voxtral-Mini-4B-Realtime-2602" "non-quantized" "." "vr-streaming"
@@ -176,6 +178,18 @@ case "$HF_MODEL" in
176178
AUDIO_FILE="test_audio.wav"
177179
IMAGE_PATH=""
178180
;;
181+
nvidia/diar_streaming_sortformer_4spk-v2)
182+
MODEL_NAME="sortformer"
183+
RUNNER_TARGET="sortformer_runner"
184+
RUNNER_PATH="sortformer"
185+
EXPECTED_OUTPUT="Speaker 1"
186+
PREPROCESSOR=""
187+
TOKENIZER_URL=""
188+
TOKENIZER_FILE=""
189+
AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
190+
AUDIO_FILE="poem.wav"
191+
IMAGE_PATH=""
192+
;;
179193
mistralai/Voxtral-Mini-4B-Realtime-2602)
180194
MODEL_NAME="voxtral_realtime"
181195
RUNNER_TARGET="voxtral_realtime_runner"
@@ -190,7 +204,7 @@ case "$HF_MODEL" in
190204
;;
191205
*)
192206
echo "Error: Unsupported model '$HF_MODEL'"
193-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
207+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt"
194208
exit 1
195209
;;
196210
esac
@@ -203,8 +217,8 @@ echo "::endgroup::"
203217
echo "::group::Prepare $MODEL_NAME Artifacts"
204218

205219

206-
# Download tokenizer files (skip for parakeet and voxtral_realtime which bundle tokenizer in export)
207-
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ]; then
220+
# Download tokenizer files (skip for models that bundle tokenizer in export or do not use one)
221+
if [ "$MODEL_NAME" != "parakeet" ] && [ "$MODEL_NAME" != "voxtral_realtime" ] && [ "$MODEL_NAME" != "sortformer" ]; then
208222
if [ "$TOKENIZER_FILE" != "" ]; then
209223
curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE
210224
else
@@ -296,6 +310,12 @@ EOF
296310
RUNNER_ARGS="$RUNNER_ARGS --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd"
297311
fi
298312
;;
313+
sortformer)
314+
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --audio_path ${MODEL_DIR}/$AUDIO_FILE"
315+
if [ "$DEVICE" = "cuda" ]; then
316+
RUNNER_ARGS="$RUNNER_ARGS --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd"
317+
fi
318+
;;
299319
voxtral_realtime)
300320
RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0"
301321
# Add CUDA data path if present

.ci/scripts/test_model_e2e_windows.ps1

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ switch ($HfModel) {
6464
$audioUrl = "https://dldata-public.s3.us-east-2.amazonaws.com/2086-149220-0033.wav"
6565
$audioFile = "test_audio.wav"
6666
}
67+
"nvidia/diar_streaming_sortformer_4spk-v2" {
68+
$runnerTarget = "sortformer_runner"
69+
$runnerPath = "sortformer"
70+
$runnerPreset = "sortformer-cuda"
71+
$expectedOutput = "Speaker 1"
72+
$preprocessor = ""
73+
$tokenizerUrl = ""
74+
$tokenizerFile = ""
75+
$audioUrl = "https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav"
76+
$audioFile = "poem.wav"
77+
}
6778
"mistralai/Voxtral-Mini-4B-Realtime-2602" {
6879
$runnerTarget = "voxtral_realtime_runner"
6980
$runnerPath = "voxtral_realtime"
@@ -76,7 +87,7 @@ switch ($HfModel) {
7687
$audioFile = "poem.wav"
7788
}
7889
default {
79-
throw "Unsupported model '$HfModel'. Supported: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/parakeet-tdt"
90+
throw "Unsupported model '$HfModel'. Supported: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt"
8091
}
8192
}
8293

@@ -182,6 +193,13 @@ try {
182193
"--data_path", $cudaBlob
183194
)
184195
}
196+
"nvidia/diar_streaming_sortformer_4spk-v2" {
197+
$runnerArgs = @(
198+
"--model_path", $modelPte,
199+
"--audio_path", (Join-Path -Path $resolvedModelDir -ChildPath $audioFile),
200+
"--data_path", $cudaBlob
201+
)
202+
}
185203
"mistralai/Voxtral-Mini-4B-Realtime-2602" {
186204
$runnerArgs += @(
187205
"--temperature", "0",

.github/workflows/cuda-windows.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ jobs:
4141
- model_repo: "nvidia"
4242
model_name: "parakeet-tdt"
4343
quant: "quantized-int4-weight-only"
44+
- model_repo: "nvidia"
45+
model_name: "diar_streaming_sortformer_4spk-v2"
46+
quant: "non-quantized"
4447
- model_repo: "mistralai"
4548
model_name: "Voxtral-Mini-4B-Realtime-2602"
4649
quant: "quantized-int4-tile-packed"
@@ -113,6 +116,9 @@ jobs:
113116
- model_repo: "nvidia"
114117
model_name: "parakeet-tdt"
115118
quant: "quantized-int4-weight-only"
119+
- model_repo: "nvidia"
120+
model_name: "diar_streaming_sortformer_4spk-v2"
121+
quant: "non-quantized"
116122
- model_repo: "mistralai"
117123
model_name: "Voxtral-Mini-4B-Realtime-2602"
118124
quant: "quantized-int4-tile-packed"

.github/workflows/cuda.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ jobs:
139139
name: "Voxtral-Mini-3B-2507"
140140
- repo: "mistralai"
141141
name: "Voxtral-Mini-4B-Realtime-2602"
142+
- repo: "nvidia"
143+
name: "diar_streaming_sortformer_4spk-v2"
142144
- repo: "openai"
143145
name: "whisper-small"
144146
- repo: "openai"
@@ -168,6 +170,15 @@ jobs:
168170
repo: "mistralai"
169171
name: "Voxtral-Mini-4B-Realtime-2602"
170172
quant: "quantized-int4-weight-only"
173+
# Sortformer currently supports only non-quantized export
174+
- model:
175+
repo: "nvidia"
176+
name: "diar_streaming_sortformer_4spk-v2"
177+
quant: "quantized-int4-tile-packed"
178+
- model:
179+
repo: "nvidia"
180+
name: "diar_streaming_sortformer_4spk-v2"
181+
quant: "quantized-int4-weight-only"
171182
with:
172183
timeout: 90
173184
secrets-env: EXECUTORCH_HF_TOKEN
@@ -214,6 +225,8 @@ jobs:
214225
name: "Voxtral-Mini-3B-2507"
215226
- repo: "mistralai"
216227
name: "Voxtral-Mini-4B-Realtime-2602"
228+
- repo: "nvidia"
229+
name: "diar_streaming_sortformer_4spk-v2"
217230
- repo: "openai"
218231
name: "whisper-small"
219232
- repo: "openai"
@@ -241,6 +254,15 @@ jobs:
241254
repo: "mistralai"
242255
name: "Voxtral-Mini-4B-Realtime-2602"
243256
quant: "quantized-int4-weight-only"
257+
# Sortformer currently supports only non-quantized export
258+
- model:
259+
repo: "nvidia"
260+
name: "diar_streaming_sortformer_4spk-v2"
261+
quant: "quantized-int4-tile-packed"
262+
- model:
263+
repo: "nvidia"
264+
name: "diar_streaming_sortformer_4spk-v2"
265+
quant: "quantized-int4-weight-only"
244266
with:
245267
timeout: 90
246268
runner: linux.g5.4xlarge.nvidia.gpu

Makefile

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal)
1919
# - whisper: Speech recognition model (CPU, CUDA, Metal)
2020
# - parakeet: Speech recognition model (CPU, CUDA, Metal)
21-
# - sortformer: Speaker diarization model (CPU)
21+
# - sortformer: Speaker diarization model (CPU, CUDA)
2222
# - silero_vad: Voice activity detection model (CPU)
2323
# - llama: Text generation model (CPU)
2424
# - llava: Vision + language model (CPU)
@@ -91,7 +91,7 @@
9191
#
9292
# ==============================================================================
9393

94-
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
94+
.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help
9595

9696
help:
9797
@echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make <target>\`. Available targets:"
@@ -109,6 +109,7 @@ help:
109109
@echo " parakeet-cuda-debug - Build Parakeet runner with CUDA backend (debug mode)"
110110
@echo " parakeet-cpu - Build Parakeet runner with CPU backend"
111111
@echo " parakeet-metal - Build Parakeet runner with Metal backend (macOS only)"
112+
@echo " sortformer-cuda - Build Sortformer runner with CUDA backend"
112113
@echo " sortformer-cpu - Build Sortformer runner with CPU backend"
113114
@echo " silero-vad-cpu - Build Silero VAD runner with CPU backend"
114115
@echo " llama-cuda - Build Llama runner with CUDA backend"
@@ -218,6 +219,15 @@ parakeet-metal:
218219
@echo "✓ Build complete!"
219220
@echo " Binary: cmake-out/examples/models/parakeet/parakeet_runner"
220221

222+
sortformer-cuda:
223+
@echo "==> Building and installing ExecuTorch with CUDA..."
224+
cmake --workflow --preset llm-release-cuda
225+
@echo "==> Building Sortformer runner with CUDA..."
226+
cd examples/models/sortformer && cmake --workflow --preset sortformer-cuda
227+
@echo ""
228+
@echo "✓ Build complete!"
229+
@echo " Binary: cmake-out/examples/models/sortformer/sortformer_runner"
230+
221231
sortformer-cpu:
222232
@echo "==> Building and installing ExecuTorch..."
223233
cmake --workflow --preset llm-release

examples/models/sortformer/CMakePresets.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,19 @@
1515
"name": "sortformer-cpu",
1616
"displayName": "Sortformer runner (CPU)",
1717
"inherits": ["sortformer-base"]
18+
},
19+
{
20+
"name": "sortformer-cuda",
21+
"displayName": "Sortformer runner (CUDA)",
22+
"inherits": ["sortformer-base"],
23+
"cacheVariables": {
24+
"EXECUTORCH_BUILD_CUDA": "ON"
25+
},
26+
"condition": {
27+
"type": "inList",
28+
"string": "${hostSystemName}",
29+
"list": ["Linux", "Windows"]
30+
}
1831
}
1932
],
2033
"buildPresets": [
@@ -23,6 +36,12 @@
2336
"displayName": "Build Sortformer runner (CPU)",
2437
"configurePreset": "sortformer-cpu",
2538
"targets": ["sortformer_runner"]
39+
},
40+
{
41+
"name": "sortformer-cuda",
42+
"displayName": "Build Sortformer runner (CUDA)",
43+
"configurePreset": "sortformer-cuda",
44+
"targets": ["sortformer_runner"]
2645
}
2746
],
2847
"workflowPresets": [
@@ -39,6 +58,20 @@
3958
"name": "sortformer-cpu"
4059
}
4160
]
61+
},
62+
{
63+
"name": "sortformer-cuda",
64+
"displayName": "Configure and build Sortformer runner (CUDA)",
65+
"steps": [
66+
{
67+
"type": "configure",
68+
"name": "sortformer-cuda"
69+
},
70+
{
71+
"type": "build",
72+
"name": "sortformer-cuda"
73+
}
74+
]
4275
}
4376
]
4477
}

0 commit comments

Comments
 (0)