Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class MultitalkerTranscriptionConfig:

# Required configs
diar_model: Optional[str] = None # Path to a .nemo file
diar_pretrained_name: Optional[str] = None # Name of a pretrained model
max_num_of_spks: Optional[int] = 4 # maximum number of speakers
parallel_speaker_strategy: bool = True # whether to use parallel speaker strategy
masked_asr: bool = True # whether to use masked ASR
Expand Down Expand Up @@ -73,7 +72,6 @@ class MultitalkerTranscriptionConfig:

# ASR Configs
asr_model: Optional[str] = None
device: str = 'cuda'
audio_file: Optional[str] = None
manifest_file: Optional[str] = None
att_context_size: Optional[List[int]] = field(default_factory=lambda: [70, 13])
Expand Down Expand Up @@ -214,8 +212,8 @@ def main(cfg: MultitalkerTranscriptionConfig) -> Union[MultitalkerTranscriptionC
if cfg.random_seed:
pl.seed_everything(cfg.random_seed)

if cfg.diar_model is None and cfg.diar_pretrained_name is None:
raise ValueError("Both cfg.diar_model and cfg.pretrained_name cannot be None!")
if cfg.diar_model is None:
raise ValueError("cfg.diar_model cannot be None!")
if cfg.audio_file is None and cfg.manifest_file is None:
raise ValueError("Both cfg.audio_file and cfg.manifest_file cannot be None!")

Expand Down Expand Up @@ -246,8 +244,7 @@ def main(cfg: MultitalkerTranscriptionConfig) -> Union[MultitalkerTranscriptionC
elif cfg.diar_model.endswith(".nemo"):
diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.diar_model, map_location=map_location)
else:
raise ValueError("cfg.diar_model must end with.ckpt or.nemo!")

diar_model = SortformerEncLabelModel.from_pretrained(model_name=cfg.diar_model, map_location=map_location)
# Model setup for inference
trainer = pl.Trainer(devices=device, accelerator=accelerator)
diar_model.set_trainer(trainer)
Expand Down Expand Up @@ -297,7 +294,7 @@ def main(cfg: MultitalkerTranscriptionConfig) -> Union[MultitalkerTranscriptionC
# Initialize to avoid "possibly used before assignment" error
multispk_asr_streamer = None

asr_model = asr_model.to(cfg.device)
asr_model = asr_model.to(map_location)
asr_model.eval()

# chunk_size is set automatically for models trained for streaming.
Expand Down