Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions tools/convert/seko_talk_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
Model Merge and Multi-Precision Conversion Script

This script supports four conversion modes:
1. 'both' (default): Convert both R2V model and audio adapter
2. 'r2v': Only convert R2V model (R2V + distill via LoRA)
1. 'both' (default): Convert both R2V model and audio adapter (R2V + distill via LoRA)
2. 'r2v': Only convert R2V model (single model → FP32 → BF16/FP8, no distill merge)
3. 'audio': Only convert audio adapter
4. 'merged': Only merge R2V + distill (no precision conversion)

Pipeline:
- R2V model: R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8
- R2V (both/merged): R2V + distill via LoRA → merged.safetensors (FP32) → BF16/FP8
- R2V (r2v only): R2V model → merged.safetensors (FP32) → BF16/FP8
- Audio adapter: (optional: + LoRA) → audio_adapter.pt → BF16 → FP8

Usage Examples:
Expand All @@ -19,11 +20,10 @@
--audio_adapter /path/to/audio_adapter.pt \
--output_dir /data/output

# Only convert R2V model
# Only convert R2V model (no distill merge)
python tools/convert/seko_talk_converter.py \
--mode r2v \
--r2v_model /path/to/model.pt \
--distill_model /path/to/model_ema.pt \
--output_dir /data/output

# Only convert audio adapter
Expand Down Expand Up @@ -64,6 +64,7 @@
"""

import argparse
import shutil
import subprocess
import sys
from pathlib import Path
Expand Down Expand Up @@ -277,9 +278,12 @@ def step5_convert_audio_adapter_to_fp8(output_dir: Path):

def validate_args(args):
"""验证参数并返回路径对象"""
if args.mode in ["both", "r2v", "merged"]:
if args.mode in ["both", "merged"]:
if not args.r2v_model or not args.distill_model:
raise ValueError("--r2v_model and --distill_model are required for 'both', 'r2v', and 'merged' modes")
raise ValueError("--r2v_model and --distill_model are required for 'both' and 'merged' modes")
if args.mode == "r2v":
if not args.r2v_model:
raise ValueError("--r2v_model is required for 'r2v' mode")

if args.mode in ["both", "audio"]:
if not args.audio_adapter:
Expand Down Expand Up @@ -337,8 +341,8 @@ def main():
default="both",
help="Conversion mode: 'both' (default), 'r2v' (only R2V model), 'audio' (only audio adapter), or 'merged' (only merge R2V+distill)",
)
parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both' and 'r2v' modes]")
parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'r2v' modes]")
parser.add_argument("--r2v_model", type=str, help="Path to R2V model (.pt) [required for 'both', 'r2v', 'merged' modes]")
parser.add_argument("--distill_model", type=str, help="Path to distillation model (.pt) [required for 'both' and 'merged' modes]")
parser.add_argument("--audio_adapter", type=str, help="Path to audio adapter (.pt) [required for 'both' and 'audio' modes]")
parser.add_argument("--audio_lora", type=str, help="Path to audio LoRA (.pt/.safetensors) [optional, for merging with audio adapter]")
parser.add_argument("--audio_lora_alpha", type=float, default=8.0, help="Alpha for audio LoRA merge (default: 8.0)")
Expand Down Expand Up @@ -368,15 +372,22 @@ def main():
paths["r2v"] = backward_convert_model(paths["r2v"], output_dir, "model")
if paths["distill"]:
paths["distill"] = backward_convert_model(paths["distill"], output_dir, "model_ema")

if args.mode in ["both", "audio"] and paths["audio"]:
logger.info("\n>>> BACKWARD CONVERSION: Audio Adapter")
paths["audio"] = backward_convert_model(paths["audio"], output_dir, "audio_adapter")

# Process R2V model
if args.mode in ["both", "r2v", "merged"]:
logger.info("\n>>> Processing R2V MODEL")
merged_path = step1_merge_via_lora(paths["r2v"], paths["distill"], output_dir, args.lora_alpha, temp_dir)
if args.mode == "r2v":
merged_path = output_dir / "merged.safetensors"
if paths["r2v"].suffix in [".pt", ".pth"]:
checkpoint_to_safetensors(paths["r2v"], merged_path, "STEP 1: Convert R2V model to safetensors (FP32)")
else:
shutil.copy(paths["r2v"], merged_path)
logger.info(f" ✓ Copied R2V model to: {merged_path}")
Comment on lines +384 to +388
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic implicitly assumes that any file for r2v_model that doesn't have a .pt or .pth extension is a .safetensors file. This is not robust and could lead to issues if an unsupported file format is provided. It's better to explicitly check for the .safetensors extension and raise an error for any other format. This will make the script more robust and provide clearer error messages to the user. Additionally, it's good practice to avoid copying a file onto itself if the source and destination paths are the same.

Suggested change
if paths["r2v"].suffix in [".pt", ".pth"]:
checkpoint_to_safetensors(paths["r2v"], merged_path, "STEP 1: Convert R2V model to safetensors (FP32)")
else:
shutil.copy(paths["r2v"], merged_path)
logger.info(f" ✓ Copied R2V model to: {merged_path}")
if paths["r2v"].suffix in [".pt", ".pth"]:
checkpoint_to_safetensors(paths["r2v"], merged_path, "STEP 1: Convert R2V model to safetensors (FP32)")
elif paths["r2v"].suffix == ".safetensors":
if paths["r2v"] != merged_path:
shutil.copy(paths["r2v"], merged_path)
logger.info(f" ✓ Copied R2V model to: {merged_path}")
else:
logger.info(f" ✓ R2V model is already at the target location: {merged_path}")
else:
raise ValueError(f"Unsupported file format for r2v_model: {paths['r2v'].suffix}. Must be .pt, .pth, or .safetensors")

else:
merged_path = step1_merge_via_lora(paths["r2v"], paths["distill"], output_dir, args.lora_alpha, temp_dir)

if args.mode != "merged":
step2_convert_merged_to_bf16(merged_path, output_dir)
Expand Down