-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Codex/create training and inference scripts #2791
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ZZUZSL1024
wants to merge
6
commits into
modelscope:main
Choose a base branch
from
ZZUZSL1024:codex/create-training-and-inference-scripts
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ec9a603
Update funasr_wss_server.py
ZZUZSL1024 0f5b9cc
Update funasr_wss_client.py
ZZUZSL1024 38cff38
Add files via upload
ZZUZSL1024 d4b9d9d
Update funasr_wss_server.py
ZZUZSL1024 655e654
Update funasr_wss_client.py
ZZUZSL1024 0c96d6c
Add LoRA Paraformer README
ZZUZSL1024 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
112 changes: 112 additions & 0 deletions
112
examples/industrial_data_pretraining/paraformer/README_LoRA_zh.md
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| # Paraformer LoRA 微调说明 | ||
|
|
||
| 本文档说明如何在 FunASR 中使用 LoRA 微调 Paraformer,并提供训练、推理与 CER 评测的完整示例。 | ||
|
|
||
| ## 1. 前置准备 | ||
|
|
||
| 1. 已准备好符合 FunASR 要求的 `train.jsonl` 与 `val.jsonl`。 | ||
| 2. 进入仓库根目录(示例路径): | ||
|
|
||
| ```bash | ||
| cd /workspace/FunASR | ||
| ``` | ||
|
|
||
| ## 2. 训练配置 | ||
|
|
||
| LoRA 配置文件: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/conf/paraformer_lora.yaml | ||
| ``` | ||
|
|
||
| 关键字段说明: | ||
| - `model`: 基座模型名称或本地模型路径。 | ||
| - `lora_only`: 是否只训练 LoRA 参数。 | ||
| - `lora_bias`: LoRA 偏置训练策略(`none`/`all`/`lora_only`)。 | ||
| - `encoder_conf.lora_*` / `decoder_conf.lora_*`: LoRA 参数(rank/alpha/dropout)。 | ||
| - `train_data_set_list`/`valid_data_set_list`: 训练/验证集 jsonl。 | ||
|
|
||
| 如需覆盖配置,请通过命令行 `++key=value` 传参。 | ||
|
|
||
| ## 3. 训练脚本 | ||
|
|
||
| 脚本: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/lora_finetune.sh | ||
| ``` | ||
|
|
||
| 你只需要确认脚本中的数据路径: | ||
|
|
||
| ```bash | ||
| data_dir="${workspace}/data/list" | ||
| train_data="${data_dir}/train.jsonl" | ||
| val_data="${data_dir}/val.jsonl" | ||
| ``` | ||
|
|
||
| 运行: | ||
|
|
||
| ```bash | ||
| bash examples/industrial_data_pretraining/paraformer/lora_finetune.sh | ||
| ``` | ||
|
|
||
| 训练日志与模型输出将保存在: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/outputs_lora | ||
| ``` | ||
|
|
||
| ## 4. 推理脚本 | ||
|
|
||
| 推理脚本会读取 jsonl 输入并生成 `text.hyp` / `text.ref`: | ||
|
|
||
| - Python 脚本:`examples/industrial_data_pretraining/paraformer/lora_infer.py` | ||
| - Shell 封装:`examples/industrial_data_pretraining/paraformer/lora_infer.sh` | ||
|
|
||
| 修改 `lora_infer.sh` 中路径后运行: | ||
|
|
||
| ```bash | ||
| bash examples/industrial_data_pretraining/paraformer/lora_infer.sh | ||
| ``` | ||
|
|
||
| 输出目录默认: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/outputs_lora/infer | ||
| ``` | ||
|
|
||
| ## 5. CER 评测 | ||
|
|
||
| 评测脚本: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/lora_cer.sh | ||
| ``` | ||
|
|
||
| 运行: | ||
|
|
||
| ```bash | ||
| bash examples/industrial_data_pretraining/paraformer/lora_cer.sh | ||
| ``` | ||
|
|
||
| 结果会输出 CER 统计到: | ||
|
|
||
| ``` | ||
| examples/industrial_data_pretraining/paraformer/outputs_lora/infer/text.cer | ||
| ``` | ||
|
|
||
| ## 6. 常见问题 | ||
|
|
||
| 1. **训练不收敛或效果差** | ||
| - 尝试调整 `lora_rank`、`lora_alpha`、`lora_dropout`。 | ||
| - 调整 `optim_conf.lr` 与 `train_conf.max_epoch`。 | ||
|
|
||
| 2. **推理报错找不到配置** | ||
| - 确保训练输出目录中存在 `config.yaml`,并在推理脚本中设置正确的 `config_path` 和 `config_name`。 | ||
|
|
||
| 3. **多卡训练** | ||
| - 设置 `CUDA_VISIBLE_DEVICES`,脚本会自动计算 `gpu_num`。 | ||
|
|
||
| --- | ||
|
|
||
| 如需进一步定制,可直接在 `paraformer_lora.yaml` 中修改配置或在命令行传参覆盖。 |
51 changes: 51 additions & 0 deletions
51
examples/industrial_data_pretraining/paraformer/conf/paraformer_lora.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| # LoRA finetune config for Paraformer | ||
| # You can override data paths and hyper-parameters by command-line ++key=value. | ||
|
|
||
| # model hub name or local model dir | ||
| model: iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch | ||
| model_revision: master | ||
|
|
||
| # LoRA settings | ||
| lora_only: true | ||
| lora_bias: none | ||
|
|
||
| encoder_conf: | ||
| lora_list: ["q", "k", "v", "o"] | ||
| lora_rank: 8 | ||
| lora_alpha: 16 | ||
| lora_dropout: 0.05 | ||
|
|
||
| decoder_conf: | ||
| lora_list: ["q", "k", "v", "o"] | ||
| lora_rank: 8 | ||
| lora_alpha: 16 | ||
| lora_dropout: 0.05 | ||
|
|
||
| # dataset | ||
| train_data_set_list: data/list/train.jsonl | ||
| valid_data_set_list: data/list/val.jsonl | ||
|
|
||
| dataset: AudioDataset | ||
| dataset_conf: | ||
| index_ds: IndexDSJsonl | ||
| data_split_num: 1 | ||
| batch_sampler: BatchSampler | ||
| batch_size: 6000 | ||
| sort_size: 1024 | ||
| batch_type: token | ||
| num_workers: 4 | ||
|
|
||
| # training | ||
| train_conf: | ||
| max_epoch: 30 | ||
| log_interval: 10 | ||
| resume: true | ||
| validate_interval: 2000 | ||
| save_checkpoint_interval: 2000 | ||
| keep_nbest_models: 10 | ||
| avg_nbest_model: 5 | ||
| use_deepspeed: false | ||
|
|
||
| optim: adam | ||
| optim_conf: | ||
| lr: 0.0001 |
23 changes: 23 additions & 0 deletions
23
examples/industrial_data_pretraining/paraformer/lora_cer.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,23 @@ | ||
| #!/usr/bin/env bash | ||
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | ||
| # MIT License (https://opensource.org/licenses/MIT) | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| workspace=$(pwd) | ||
|
|
||
| infer_dir="${workspace}/examples/industrial_data_pretraining/paraformer/outputs_lora/infer" | ||
| ref_file="${infer_dir}/text.ref" | ||
| hyp_file="${infer_dir}/text.hyp" | ||
| cer_file="${infer_dir}/text.cer" | ||
|
|
||
| python -m funasr.metrics.wer \ | ||
| ++ref_file="${ref_file}" \ | ||
| ++hyp_file="${hyp_file}" \ | ||
| ++cer_file="${cer_file}" \ | ||
| ++cn_postprocess=false | ||
|
|
||
| # Show final CER summary | ||
| if [ -f "${cer_file}" ]; then | ||
| tail -n 3 "${cer_file}" | ||
| fi |
46 changes: 46 additions & 0 deletions
46
examples/industrial_data_pretraining/paraformer/lora_finetune.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| #!/usr/bin/env bash | ||
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | ||
| # MIT License (https://opensource.org/licenses/MIT) | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| workspace=$(pwd) | ||
|
|
||
| # which gpu to train or finetune | ||
| export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} | ||
| gpu_num=$(echo ${CUDA_VISIBLE_DEVICES} | awk -F "," '{print NF}') | ||
|
|
||
| # data dir, which contains train.jsonl/val.jsonl | ||
| # NOTE: update these paths to your dataset jsonl files. | ||
| data_dir="${workspace}/data/list" | ||
| train_data="${data_dir}/train.jsonl" | ||
| val_data="${data_dir}/val.jsonl" | ||
|
|
||
| # config | ||
| config_path="${workspace}/examples/industrial_data_pretraining/paraformer/conf" | ||
| config_name="paraformer_lora.yaml" | ||
|
|
||
| # exp output dir | ||
| output_dir="${workspace}/examples/industrial_data_pretraining/paraformer/outputs_lora" | ||
| log_file="${output_dir}/log.txt" | ||
|
|
||
| mkdir -p "${output_dir}" | ||
|
|
||
| DISTRIBUTED_ARGS=" | ||
| --nnodes ${WORLD_SIZE:-1} \ | ||
| --nproc_per_node ${gpu_num} \ | ||
| --node_rank ${RANK:-0} \ | ||
| --master_addr ${MASTER_ADDR:-127.0.0.1} \ | ||
| --master_port ${MASTER_PORT:-26669} | ||
| " | ||
|
|
||
| echo "log_file: ${log_file}" | ||
|
|
||
| torchrun ${DISTRIBUTED_ARGS} \ | ||
| funasr/bin/train_ds.py \ | ||
| --config-path "${config_path}" \ | ||
| --config-name "${config_name}" \ | ||
| ++train_data_set_list="${train_data}" \ | ||
| ++valid_data_set_list="${val_data}" \ | ||
| ++output_dir="${output_dir}" \ | ||
| &> "${log_file}" |
83 changes: 83 additions & 0 deletions
83
examples/industrial_data_pretraining/paraformer/lora_infer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| #!/usr/bin/env python3 | ||
| # -*- encoding: utf-8 -*- | ||
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
| from typing import List, Tuple | ||
|
|
||
| from omegaconf import OmegaConf | ||
|
|
||
| from funasr import AutoModel | ||
|
|
||
|
|
||
| def load_jsonl(jsonl_path: str) -> Tuple[List[str], List[str]]: | ||
| keys = [] | ||
| targets = [] | ||
| with open(jsonl_path, "r", encoding="utf-8") as f: | ||
| for line in f: | ||
| if not line.strip(): | ||
| continue | ||
| record = json.loads(line) | ||
| key = record.get("key") | ||
| if key is None and isinstance(record.get("source"), dict): | ||
| key = record["source"].get("key") | ||
| keys.append(key or "") | ||
| targets.append(record.get("target", "")) | ||
| return keys, targets | ||
|
|
||
|
|
||
| def build_model(args: argparse.Namespace): | ||
| kwargs = {} | ||
| if args.config_path and args.config_name: | ||
| cfg_path = os.path.join(args.config_path, args.config_name) | ||
| cfg = OmegaConf.load(cfg_path) | ||
| kwargs.update(OmegaConf.to_container(cfg, resolve=True)) | ||
| if args.model: | ||
| kwargs["model"] = args.model | ||
| if args.init_param: | ||
| kwargs["init_param"] = args.init_param | ||
| kwargs["device"] = args.device | ||
| if args.batch_size: | ||
| kwargs["batch_size"] = args.batch_size | ||
| return AutoModel(**kwargs) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--model", type=str, default=None, help="model name or model dir") | ||
| parser.add_argument("--config-path", type=str, default=None, help="config directory") | ||
| parser.add_argument("--config-name", type=str, default=None, help="config filename") | ||
| parser.add_argument("--init-param", type=str, default=None, help="model checkpoint path") | ||
| parser.add_argument("--input-jsonl", type=str, required=True, help="input jsonl with source/target") | ||
| parser.add_argument("--output-dir", type=str, required=True, help="output directory") | ||
| parser.add_argument("--device", type=str, default="cuda:0", help="cuda:0 or cpu") | ||
| parser.add_argument("--batch-size", type=int, default=1, help="batch size for inference") | ||
| args = parser.parse_args() | ||
|
|
||
| os.makedirs(args.output_dir, exist_ok=True) | ||
|
|
||
| keys, targets = load_jsonl(args.input_jsonl) | ||
|
|
||
| model = build_model(args) | ||
| results = model.generate(input=args.input_jsonl, batch_size=args.batch_size) | ||
|
|
||
| hyp_path = os.path.join(args.output_dir, "text.hyp") | ||
| ref_path = os.path.join(args.output_dir, "text.ref") | ||
|
|
||
| with open(hyp_path, "w", encoding="utf-8") as hyp_f, open( | ||
| ref_path, "w", encoding="utf-8" | ||
| ) as ref_f: | ||
| for idx, result in enumerate(results): | ||
| key = keys[idx] if idx < len(keys) else result.get("key", f"utt_{idx}") | ||
| hyp = result.get("text", "") | ||
| ref = targets[idx] if idx < len(targets) else "" | ||
| hyp_f.write(f"{key} {hyp}\n") | ||
| ref_f.write(f"{key} {ref}\n") | ||
|
|
||
| print(f"hyp saved to: {hyp_path}") | ||
| print(f"ref saved to: {ref_path}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
32 changes: 32 additions & 0 deletions
32
examples/industrial_data_pretraining/paraformer/lora_infer.sh
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| #!/usr/bin/env bash | ||
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | ||
| # MIT License (https://opensource.org/licenses/MIT) | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| workspace=$(pwd) | ||
|
|
||
| # model path and config (from training output) | ||
| model_dir="${workspace}/examples/industrial_data_pretraining/paraformer/outputs_lora" | ||
| init_param="${model_dir}/model.pt" | ||
| config_path="${model_dir}" | ||
| config_name="config.yaml" | ||
|
|
||
| # input jsonl (must contain source/target) | ||
| input_jsonl="${workspace}/data/list/val.jsonl" | ||
|
|
||
| # output directory | ||
| output_dir="${model_dir}/infer" | ||
|
|
||
| # device | ||
| device="cuda:0" | ||
|
|
||
| python ${workspace}/examples/industrial_data_pretraining/paraformer/lora_infer.py \ | ||
| --model "${model_dir}" \ | ||
| --config-path "${config_path}" \ | ||
| --config-name "${config_name}" \ | ||
| --init-param "${init_param}" \ | ||
| --input-jsonl "${input_jsonl}" \ | ||
| --output-dir "${output_dir}" \ | ||
| --device "${device}" \ | ||
| --batch-size 1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在
load_jsonl函数中,key的提取逻辑可以更清晰和健壮。如果key未找到,当前会默认为空字符串,这可能不利于后续记录的追踪。为了与输出逻辑(第72行)保持一致,建议在key缺失时生成一个唯一的标识符。