diff --git a/.gitignore b/.gitignore index 1ff2a92cac64..4bc9a2cacb4f 100644 --- a/.gitignore +++ b/.gitignore @@ -179,3 +179,4 @@ examples/neural_graphs/*.yml .hydra/ nemo_experiments/ +balu_codes/test_experiments/ \ No newline at end of file diff --git a/balu_codes/configs/c1.yaml b/balu_codes/configs/c1.yaml new file mode 100644 index 000000000000..827ee6269274 --- /dev/null +++ b/balu_codes/configs/c1.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: au_pdec_uman_stok + project: NEMO_TEST + create_wandb_logger: false + log_model: false +use_video_modality: false +use_pretrained_dec: true +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.5 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 768 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 512 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 128 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: false + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c10.yaml b/balu_codes/configs/c10.yaml new file mode 100644 index 000000000000..38db1de93c87 --- /dev/null +++ b/balu_codes/configs/c10.yaml @@ -0,0 +1,289 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: au_ndec_uman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +label_pred_head: + keep: false + num_classes: 44 +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.5 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_test_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c11.yaml b/balu_codes/configs/c11.yaml new file mode 100644 index 000000000000..c548ec98e27e --- /dev/null +++ b/balu_codes/configs/c11.yaml @@ -0,0 +1,289 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: au_ndec_lman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +label_pred_head: + keep: true + num_classes: 44 +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.5 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_test_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c2.yaml b/balu_codes/configs/c2.yaml new file mode 100644 index 000000000000..6a91158d2c2c --- /dev/null +++ b/balu_codes/configs/c2.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: true +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: au_ndec_lman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: false +use_pretrained_dec: false +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.6 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.6 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.6 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 512 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c3.yaml b/balu_codes/configs/c3.yaml new file mode 100644 index 000000000000..0cfb64391b68 --- /dev/null +++ b/balu_codes/configs/c3.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: true +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: au_ndec_uman_stok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: false +use_pretrained_dec: false +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.7 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 768 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 512 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 128 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c3_au_with_same_av_arch.yaml b/balu_codes/configs/c3_au_with_same_av_arch.yaml new file mode 100644 index 000000000000..7fdecab3df03 --- /dev/null +++ b/balu_codes/configs/c3_au_with_same_av_arch.yaml @@ -0,0 +1,309 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large # CHANGE, BPE: is a must since, it is used to load audio encoder. +labelled_manifest: true # CHANGE +exp_dir: /tmp/bld56_dataset_v1/tmp/ # CHANGE +wandb: + run_name: "au_ndec_lman_ntok_NArch_0.5" # CHANGE + project: "NEMO_TEST" # CHANGE + create_wandb_logger: true # CHANGE + log_model: False # CHANGE + +use_video_modality: true # CHANGE +use_pretrained_dec: false # CHANGE +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_train.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: true # CHANGE + sample_rate: 16000 + batch_size: 96 # CHANGE + shuffle: true + num_workers: 8 # CHANGE + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 # CHANGE + min_duration: 0.1 + is_tarred: false # CHANGE + tarred_audio_filepaths: null # CHANGE + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: true # CHANGE + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: true # CHANGE + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false + +# NEW TOKENIZER +tokenizer: # CHANGE + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab + + +# OLD TOKENIZER +# tokenizer: # CHANGE # CHANGE THE NUM CLASSES TO 128 in DEC +# dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ +# type: bpe +# model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model +# vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt +# spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab + +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + +av_encoder: # CHANGE + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 + +v_model: # CHANGE + feat_dim: 512 + + +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 # CHANGE to 356 for new tok, else 128. + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false # CHANGE + +adapters: # CHANGE + linear_adapter: + keep: true + name: "AV_v1" #@param {type:"string"} + dim: 64 #@param {type:"integer"} + activation: "swish" #@param {type:"string"} + norm_position: "pre" #@param ["pre", "post"] + dropout: 0.1 #@param {type:"number"} + multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + rel_position_multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + + + + + +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c4.yaml b/balu_codes/configs/c4.yaml new file mode 100644 index 000000000000..f82a69fd51a9 --- /dev/null +++ b/balu_codes/configs/c4.yaml @@ -0,0 +1,309 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large # CHANGE, BPE: is a must since, it is used to load audio encoder. +labelled_manifest: false # CHANGE +exp_dir: /tmp/bld56_dataset_v1/tmp/ # CHANGE +wandb: + run_name: "av_ndec_uman_stok" # CHANGE + project: "NEMO_TEST" # CHANGE + create_wandb_logger: true # CHANGE + log_model: False # CHANGE + +use_video_modality: true # CHANGE +use_pretrained_dec: false # CHANGE +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_train_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 96 # CHANGE + shuffle: true + num_workers: 8 # CHANGE + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 # CHANGE + min_duration: 0.1 + is_tarred: false # CHANGE + tarred_audio_filepaths: null # CHANGE + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false + +# NEW TOKENIZER +# tokenizer: # CHANGE +# dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ +# type: bpe +# model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model +# vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt +# spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab + + +# OLD TOKENIZER +tokenizer: # CHANGE # CHANGE THE NUM CLASSES TO 128 in DEC + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab + +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + +av_encoder: # CHANGE + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 + +v_model: # CHANGE + feat_dim: 512 + + +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 128 # CHANGE to 356 for new tok, else 128. + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false # CHANGE + +adapters: # CHANGE + linear_adapter: + keep: true + name: "AV_v1" #@param {type:"string"} + dim: 64 #@param {type:"integer"} + activation: "swish" #@param {type:"string"} + norm_position: "pre" #@param ["pre", "post"] + dropout: 0.1 #@param {type:"number"} + multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + rel_position_multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + + + + + +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c5.yaml b/balu_codes/configs/c5.yaml new file mode 100644 index 000000000000..ace661ec92dd --- /dev/null +++ b/balu_codes/configs/c5.yaml @@ -0,0 +1,289 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: av_ndec_lman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +label_pred_head: + keep: true + num_classes: 44 +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.5 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c6.yaml b/balu_codes/configs/c6.yaml new file mode 100644 index 000000000000..1f0fa02fa62d --- /dev/null +++ b/balu_codes/configs/c6.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: av_ndec_uman_stok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.7 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 768 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 768 + num_classes: 128 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c7.yaml b/balu_codes/configs/c7.yaml new file mode 100644 index 000000000000..5590558bf4a5 --- /dev/null +++ b/balu_codes/configs/c7.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: pre_av_ndec_uman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/pretraining_train_manifest.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.0 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/pretraining_eval_manifest.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.0 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/pretraining_eval_manifest.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: true + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.0 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c8.yaml b/balu_codes/configs/c8.yaml new file mode 100644 index 000000000000..bee33a706210 --- /dev/null +++ b/balu_codes/configs/c8.yaml @@ -0,0 +1,286 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: true +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: av_ndec_lman_stok_fullau + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.7 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.7 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 768 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 768 + num_classes: 128 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: false + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/configs/c9.yaml b/balu_codes/configs/c9.yaml new file mode 100644 index 000000000000..87f1a1697796 --- /dev/null +++ b/balu_codes/configs/c9.yaml @@ -0,0 +1,289 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large +labelled_manifest: false +exp_dir: /tmp/bld56_dataset_v1/tmp/ +wandb: + run_name: av_ndec_uman_ntok + project: NEMO_TEST + create_wandb_logger: true + log_model: false +use_video_modality: true +use_pretrained_dec: false +label_pred_head: + keep: false + num_classes: 44 +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: true + num_workers: 10 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.5 + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json + video_frame_rate: 5 + get_vid_feats: true + get_zero_vid_feats: false + sample_rate: 16000 + batch_size: 96 + shuffle: false + num_workers: 10 + pin_memory: true + override_snr_ratio: 0.5 + use_start_end_token: false +tokenizer: + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 +av_encoder: + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 +v_model: + feat_dim: 768 +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 356 + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false +adapters: + linear_adapter: + keep: true + name: AV_v1 + dim: 64 + activation: swish + norm_position: pre + dropout: 0.1 + multi_head_attention_adapter: + keep: false + rel_position_multi_head_attention_adapter: + keep: false +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/ctc_model_QuartzNet15x5Base copy.yaml b/balu_codes/ctc_model_QuartzNet15x5Base copy.yaml new file mode 100644 index 000000000000..364e6a337e23 --- /dev/null +++ b/balu_codes/ctc_model_QuartzNet15x5Base copy.yaml @@ -0,0 +1,149 @@ +preprocessor: + cls: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + params: + normalize: per_feature + window_size: 0.02 + sample_rate: 16000 + window_stride: 0.01 + window: hann + features: 64 + n_fft: 512 + frame_splicing: 1 + dither: 1.0e-05 + stft_conv: false +spec_augment: + cls: nemo.collections.asr.modules.SpectrogramAugmentation + params: + rect_freq: 50 + rect_masks: 5 + rect_time: 120 + +a_model_name: QuartzNet15x5Base-En # CHANGE +sample_rate: 16000 +labels: + - ' ' + - a + - b + - c + - d + - e + - f + - g + - h + - i + - j + - k + - l + - m + - n + - o + - p + - q + - r + - s + - t + - u + - v + - w + - x + - y + - z + - '''' +train_ds: + manifest_filepath: /disk1/it1/annotations/manifest_train.json # CHANGE + video_frame_rate: 5 # CHANGE + # - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/tarred_audio_manifest.json #TBD # CHANGE + sample_rate: 16000 + batch_size: 1 # CHANGE + shuffle: true + num_workers: 0 # CHANGE + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 # CHANGE + min_duration: 0.1 + is_tarred: false + tarred_audio_filepaths: null # CHANGE + shuffle_n: 2048 + bucketing_strategy: synced_randomized + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 + +validation_ds: + manifest_filepath: /disk1/it1/annotations/manifest_train.json # CHANGE + video_frame_rate: 5 # CHANGE + # - /manifests/librispeech/librivox-dev-other.json #TBD # CHANGE + sample_rate: 16000 + batch_size: 1 + shuffle: false + num_workers: 0 + pin_memory: true + use_start_end_token: false +test_ds: + manifest_filepath: /disk1/it1/annotations/manifest_train.json + # - /manifests/librispeech/librivox-dev-other.json #TBD + sample_rate: 16000 + batch_size: 1 + shuffle: false + num_workers: 0 + pin_memory: true + use_start_end_token: false + +av_encoder: # CHANGE + d_model: 512 + nhead: 4 + num_layers: 2 + dropout: 0.1 + +v_model: # CHANGE + feat_dim: 512 + +decoder: + cls: nemo.collections.asr.modules.ConvASRDecoder + params: + feat_in: 512 + num_classes: 28 + vocabulary: + - ' ' + - a + - b + - c + - d + - e + - f + - g + - h + - i + - j + - k + - l + - m + - n + - o + - p + - q + - r + - s + - t + - u + - v + - w + - x + - y + - z + - '''' +optim: + name: novograd + lr: 0.01 + betas: + - 0.8 + - 0.5 + weight_decay: 0.001 +target: nemo.collections.asr.models.av_ctc_bpe_models.AV_EncDecCTCModel diff --git a/balu_codes/infer_av_asr.py b/balu_codes/infer_av_asr.py new file mode 100644 index 000000000000..7e95b165f40c --- /dev/null +++ b/balu_codes/infer_av_asr.py @@ -0,0 +1,70 @@ +import os +import sys +sys.path.insert(0, os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource/')) +import nemo.collections.asr as nemo_asr +from omegaconf import OmegaConf +import torch +import json + +# Function to load the model from a .nemo file +def load_model(nemo_file_path): + model = nemo_asr.models.AV_EncDecCTCModelBPE.restore_from(nemo_file_path) + model.eval() + return model + +# Function to perform inference on a single sample +def infer_single_sample(model, sample): + # Prepare input data + audio_file = sample['audio_filepath'] + video_file = sample['video_filepath'] + feature_file = sample['feature_file'] + duration = sample['duration'] + + # Perform inference + transcription = model.transcribe( + audio=[audio_file], + return_hypotheses = True, + override_duration = duration, + ) + + return transcription[0] + +# Function to run inference on a manifest file +def run_inference(manifest_file_path, nemo_file_path, output_file_path): + # Load the model + model = load_model(nemo_file_path) + + # Read the manifest file + with open(manifest_file_path, 'r') as f: + manifest_data = [json.loads(line.strip()) for line in f] + + # Run inference on each sample in the manifest + results = [] + for sample in manifest_data: + transcription = infer_single_sample(model, sample) + result = { + 'audio_filepath': sample['audio_filepath'], + 'video_filepath': sample['video_filepath'], + 'feature_file': sample['feature_file'], + 'duration': sample['duration'], + 'transcription': transcription + } + results.append(result) + + # Save the results to the output file + with open(output_file_path, 'w') as f: + for result in results: + f.write(json.dumps(result) + '\n') + + print(f"Inference completed. Results saved to {output_file_path}") + +# Main function +def main(): + manifest_file_path = '/tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json' # Path to your input manifest file + nemo_file_path = '/tmp/bld56_dataset_v1/tmp/av_ndec_lman_ntok_0.5/2024-08-16_11-16-34/checkpoints/av_ndec_lman_ntok_0.5.nemo' # Path to your trained .nemo file + output_file_path = 'temp.json' # Path to save the inference results + + run_inference(manifest_file_path, nemo_file_path, output_file_path) + +if __name__ == "__main__": + main() diff --git a/balu_codes/infer_test.ipynb b/balu_codes/infer_test.ipynb new file mode 100644 index 000000000000..253e57438f4e --- /dev/null +++ b/balu_codes/infer_test.ipynb @@ -0,0 +1,919 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## For valdation with given model path" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.insert(0, os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource/'))\n", + "import nemo.core as nemo_core\n", + "from nemo.core import adapter_mixins\n", + "from nemo.utils import exp_manager\n", + "import nemo.collections.asr as nemo_asr\n", + "import nemo\n", + "import json\n", + "from omegaconf import OmegaConf, open_dict\n", + "import torch\n", + "from pytorch_lightning import Trainer\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from torchmetrics.text import WordErrorRate\n", + "import warnings\n", + "import argparse" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def load_and_configure_model(config_file_path):\n", + " conf = OmegaConf.load(config_file_path)\n", + " overrides = OmegaConf.from_cli()\n", + " updated_conf = OmegaConf.merge(conf, overrides)\n", + " OmegaConf.set_struct(updated_conf, True)\n", + " model = nemo_asr.models.AV_EncDecCTCModelBPE(updated_conf)\n", + "\n", + " model.setup_training_data(model.cfg.train_ds)\n", + " return model, conf\n", + "\n", + "# Function to freeze and unfreeze model parameters based on adapters\n", + "def manage_model_adapters(model, conf):\n", + " # Freeze the entire model\n", + " model.freeze()\n", + " \n", + " # Determine which modules to train based on configuration\n", + " if model.cfg.use_video_modality:\n", + " modules_to_train = [\n", + " model.a_linear, model.v_linear, model.av_encoder, model.av_enocder_layer, \n", + " model.a_modal_embs, model.v_modal_embs, model.decoder, model.a_pos_enc, model.v_pos_enc\n", + " ]\n", + " elif not model.cfg.use_video_modality and model.cfg.use_pretrained_dec:\n", + " modules_to_train = [model.a_model.decoder]\n", + " else: # not model.cfg.use_video_modality and not model.cfg.use_pretrained_dec\n", + " modules_to_train = [model.decoder]\n", + " \n", + " # Set the selected modules to training mode and enable gradients\n", + " for module in modules_to_train:\n", + " module.train()\n", + " for param in module.parameters():\n", + " param.requires_grad = True\n", + "\n", + " # Handle adapter configurations if needed\n", + " if conf.adapters.linear_adapter.keep:\n", + " model.a_model.freeze()\n", + " model.a_model.set_enabled_adapters(enabled=False)\n", + " model.a_model.set_enabled_adapters(name=conf.adapters.linear_adapter.name, enabled=True)\n", + " model.a_model.unfreeze_enabled_adapters()\n", + " else:\n", + " model.a_model.unfreeze()\n", + "\n", + "# Function to set up the trainer\n", + "def setup_trainer():\n", + " torch.set_float32_matmul_precision('high')\n", + " accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", + " trainer = Trainer(\n", + " devices=1, accelerator=accelerator, \n", + " # strategy=\"ddp_find_unused_parameters_true\",\n", + " # strategy=\"ddp_notebook\",\n", + " max_epochs=100,\n", + " enable_checkpointing=False, logger=False,\n", + " log_every_n_steps=5, check_val_every_n_epoch=1,\n", + " )\n", + " return trainer\n", + "\n", + "# Function to set up experiment manager\n", + "def setup_exp_manager(trainer, model):\n", + " os.environ.pop('NEMO_EXPM_VERSION', None)\n", + "\n", + " exp_config = exp_manager.ExpManagerConfig(\n", + " exp_dir=model.cfg.exp_dir,\n", + " name=f'{model.cfg.wandb.run_name}',\n", + " checkpoint_callback_params=exp_manager.CallbackParams(\n", + " monitor=\"val_u_wer\",\n", + " mode=\"min\",\n", + " always_save_nemo=True,\n", + " save_best_model=True,\n", + " ),\n", + " create_wandb_logger=model.cfg.wandb.create_wandb_logger,\n", + " wandb_logger_kwargs=OmegaConf.create({\"project\": f\"{model.cfg.wandb.project}\", \"name\": f\"{model.cfg.wandb.run_name}_{model.cfg.train_ds.override_snr_ratio}\", \"log_model\": model.cfg.wandb.log_model}),\n", + " )\n", + "\n", + " exp_config = OmegaConf.structured(exp_config)\n", + " logdir = exp_manager.exp_manager(trainer, exp_config)\n", + " if model.cfg.wandb.create_wandb_logger:\n", + " trainer.loggers[1].log_hyperparams(OmegaConf.to_container(model.cfg)) # wandb logger\n", + " # log the manifest file to wandb server\n", + " trainer.loggers[1].experiment.log_artifact(f\"{model.cfg.train_ds.manifest_filepath}\")\n", + " trainer.loggers[1].experiment.log_artifact(f\"{model.cfg.validation_ds.manifest_filepath}\")\n", + " \n", + " return logdir\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-25 15:44:38 mixins:172] Tokenizer SentencePieceTokenizer initialized with 128 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-25 15:44:38 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/tarred_audio_manifest.json\n", + " sample_rate: 16000\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " is_tarred: true\n", + " tarred_audio_filepaths:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/audio__OP_0..8191_CL_.tar\n", + " shuffle_n: 2048\n", + " bucketing_strategy: synced_randomized\n", + " bucketing_batch_size:\n", + " - 34\n", + " - 30\n", + " - 26\n", + " - 22\n", + " - 18\n", + " - 16\n", + " - 12\n", + " - 8\n", + " \n", + "[NeMo W 2024-08-25 15:44:38 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n", + "[NeMo W 2024-08-25 15:44:38 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-25 15:44:38 features:305] PADDING: 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-25 15:44:39 save_restore_connector:263] Model EncDecCTCModelBPE was successfully restored from /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "EncDecCTCModelBPE(\n", + " (preprocessor): AudioToMelSpectrogramPreprocessor(\n", + " (featurizer): FilterbankFeatures()\n", + " )\n", + " (encoder): ConformerEncoder(\n", + " (pre_encode): ConvSubsampling(\n", + " (out): Linear(in_features=10240, out_features=512, bias=True)\n", + " (conv): Sequential(\n", + " (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (3): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (pos_enc): RelPositionalEncoding(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (layers): ModuleList(\n", + " (0-17): 18 x ConformerLayer(\n", + " (norm_feed_forward1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (feed_forward1): ConformerFeedForward(\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " (norm_conv): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (conv): ConformerConvolution(\n", + " (pointwise_conv1): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): CausalConv1D(512, 512, kernel_size=(31,), stride=(1,), groups=512)\n", + " (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (activation): Swish()\n", + " (pointwise_conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (norm_self_att): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (self_attn): RelPositionMultiHeadAttention(\n", + " (linear_q): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_k): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_v): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_out): Linear(in_features=512, out_features=512, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear_pos): Linear(in_features=512, out_features=512, bias=False)\n", + " )\n", + " (norm_feed_forward2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (feed_forward2): ConformerFeedForward(\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (norm_out): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " )\n", + " (decoder): ConvASRDecoder(\n", + " (decoder_layers): Sequential(\n", + " (0): Conv1d(512, 129, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (loss): CTCLoss()\n", + " (spec_augmentation): SpectrogramAugmentation(\n", + " (spec_augment): SpecAugment()\n", + " )\n", + " (wer): WER()\n", + ")\n" + ] + } + ], + "source": [ + "# Main function to execute the workflow\n", + "# def main(config_file_path, args):\n", + "# config_file_path = '/home/bld56/gsoc/nemo/NeMo-opensource/balu_codes/configs/c1.yaml'\n", + "# model, conf = load_and_configure_model(config_file_path)\n", + "# ckpt_path = f\"/tmp/bld56_dataset_v1/saved_models/pre_av_ndec_uman_ntok--val_u_wer=0.0809-epoch=11.ckpt\"\n", + "ckpt_path = f\"/home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\"\n", + "model = nemo_asr.models.AV_EncDecCTCModelBPE.restore_from(ckpt_path, override_config_path=None) \n", + "model.cfg.train_ds.manifest_filepath = '/tmp/bld56_dataset_v1/it2/annotations/manifest_train_no_label.json'\n", + "model.cfg.validation_ds.manifest_filepath = '/tmp/bld56_dataset_v1/it2/annotations/manifest_eval_no_label.json'\n", + "model.cfg.test_ds.manifest_filepath = '/tmp/bld56_dataset_v1/it2/annotations/manifest_test_no_label.json'\n", + "print(model)\n", + "# model.cfg.wandb.run_name += 'pre+'\n", + "# manage_model_adapters(model, conf)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-25 15:44:39 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/bld56/.miniconda3/envs/nemo/lib/python3.10/sit ...\n", + " \n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "trainer = setup_trainer()\n", + "model.set_trainer(trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-25 15:46:29 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/bld56/.miniconda3/envs/nemo/lib/python3.10/sit ...\n", + " \n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", + "[NeMo W 2024-08-25 15:46:29 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:105: Total length of `list` across ranks is zero. Please make sure this was your intention.\n", + " \n" + ] + }, + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.validate(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## From Aug 16 Weekly meet to develop to transcribe fucniton" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "sys.path.insert(0, os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource/'))\n", + "import nemo.collections.asr as nemo_asr\n", + "import json\n", + "import nemo.collections.asr.data.av_to_text" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Function to load the model from a .nemo file\n", + "def load_model(nemo_file_path):\n", + " model = nemo_asr.models.AV_EncDecCTCModelBPE.restore_from(nemo_file_path)\n", + " model.eval()\n", + " return model\n", + "\n", + "# Function to perform inference on a single sample\n", + "def infer_single_sample(model, sample):\n", + " # Prepare input data\n", + " audio_file = sample['audio_filepath']\n", + " video_file = sample['video_filepath']\n", + " feature_file = sample['feature_file']\n", + " duration = sample['duration']\n", + " \n", + " # Perform inference\n", + " transcription = model.transcribe(\n", + " audio=[audio_file],\n", + " return_hypotheses = True,\n", + " override_duration = duration,\n", + " )\n", + " \n", + " return transcription[0]\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "# Load the tokenizer model from the specified path\n", + "def load_tokenizer(tokenizer_model_path):\n", + " tokenizer = spm.SentencePieceProcessor()\n", + " tokenizer.load(tokenizer_model_path)\n", + " return tokenizer\n", + "\n", + "# tokenizer_path = \"/home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model\"\n", + "# for i in range(self.tokenizer.vocab_size):\n", + "# piece = self.tokenizer.ids_to_tokens([i])\n", + "# piece = piece[0]\n", + "# vocabulary[piece] = i + 1\n", + "# tokenizer = load_tokenizer(tokenizer_path)\n", + "config = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-17 12:36:32 mixins:172] Tokenizer SentencePieceTokenizer initialized with 356 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-17 12:36:32 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_train.json\n", + " video_frame_rate: 5\n", + " get_vid_feats: true\n", + " get_zero_vid_feats: false\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: true\n", + " num_workers: 11\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " is_tarred: false\n", + " tarred_audio_filepaths: null\n", + " shuffle_n: 2048\n", + " bucketing_strategy: synced_randomized\n", + " override_snr_ratio: 0.7\n", + " bucketing_batch_size:\n", + " - 34\n", + " - 30\n", + " - 26\n", + " - 22\n", + " - 18\n", + " - 16\n", + " - 12\n", + " - 8\n", + " \n", + "[NeMo W 2024-08-17 12:36:32 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json\n", + " video_frame_rate: 5\n", + " get_vid_feats: true\n", + " get_zero_vid_feats: false\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 11\n", + " pin_memory: true\n", + " override_snr_ratio: 0.7\n", + " use_start_end_token: false\n", + " \n", + "[NeMo W 2024-08-17 12:36:32 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath: /tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json\n", + " video_frame_rate: 5\n", + " get_vid_feats: true\n", + " get_zero_vid_feats: false\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 11\n", + " pin_memory: true\n", + " override_snr_ratio: 0.7\n", + " use_start_end_token: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-17 12:36:32 cloud:58] Found existing object /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-17 12:36:32 cloud:64] Re-using file from: /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-08-17 12:36:32 common:815] Instantiating model from pre-trained checkpoint\n", + "Updated encoder _target_ model : nemo.collections.asr.modules.conformer_encoder.ConformerEncoderAdapter\n", + "[NeMo I 2024-08-17 12:36:32 cloud:58] Found existing object /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-17 12:36:32 cloud:64] Re-using file from: /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-08-17 12:36:32 common:815] Instantiating model from pre-trained checkpoint\n", + "[NeMo I 2024-08-17 12:36:33 mixins:172] Tokenizer SentencePieceTokenizer initialized with 128 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-17 12:36:33 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/tarred_audio_manifest.json\n", + " sample_rate: 16000\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " is_tarred: true\n", + " tarred_audio_filepaths:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/audio__OP_0..8191_CL_.tar\n", + " shuffle_n: 2048\n", + " bucketing_strategy: synced_randomized\n", + " bucketing_batch_size:\n", + " - 34\n", + " - 30\n", + " - 26\n", + " - 22\n", + " - 18\n", + " - 16\n", + " - 12\n", + " - 8\n", + " \n", + "[NeMo W 2024-08-17 12:36:33 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n", + "[NeMo W 2024-08-17 12:36:33 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-17 12:36:33 features:305] PADDING: 0\n", + "[NeMo I 2024-08-17 12:36:34 save_restore_connector:263] Model EncDecCTCModelBPE was successfully restored from /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-17 12:36:34 save_restore_connector:263] Model AV_EncDecCTCModelBPE was successfully restored from /tmp/bld56_dataset_v1/tmp/av_ndec_lman_ntok_0.5/2024-08-16_11-16-34/checkpoints/av_ndec_lman_ntok_0.5.nemo.\n" + ] + }, + { + "data": { + "text/plain": [ + "AV_EncDecCTCModelBPE(\n", + " (a_model): EncDecCTCModelBPE(\n", + " (preprocessor): AudioToMelSpectrogramPreprocessor(\n", + " (featurizer): FilterbankFeatures()\n", + " )\n", + " (encoder): ConformerEncoderAdapter(\n", + " (pre_encode): ConvSubsampling(\n", + " (out): Linear(in_features=10240, out_features=512, bias=True)\n", + " (conv): Sequential(\n", + " (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", + " (3): ReLU(inplace=True)\n", + " )\n", + " )\n", + " (pos_enc): RelPositionalEncoding(\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (layers): ModuleList(\n", + " (0-17): 18 x ConformerLayer(\n", + " (norm_feed_forward1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (feed_forward1): ConformerFeedForward(\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " (norm_conv): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (conv): ConformerConvolution(\n", + " (pointwise_conv1): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))\n", + " (depthwise_conv): CausalConv1D(512, 512, kernel_size=(31,), stride=(1,), groups=512)\n", + " (batch_norm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (activation): Swish()\n", + " (pointwise_conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))\n", + " )\n", + " (norm_self_att): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (self_attn): RelPositionMultiHeadAttention(\n", + " (linear_q): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_k): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_v): Linear(in_features=512, out_features=512, bias=True)\n", + " (linear_out): Linear(in_features=512, out_features=512, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear_pos): Linear(in_features=512, out_features=512, bias=False)\n", + " )\n", + " (norm_feed_forward2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (feed_forward2): ConformerFeedForward(\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (activation): Swish()\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (norm_out): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (adapter_layer): ModuleDict(\n", + " (AV_v1): LinearAdapter(\n", + " (module): Sequential(\n", + " (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (1): Linear(in_features=512, out_features=64, bias=False)\n", + " (2): SiLU(inplace=True)\n", + " (3): Linear(in_features=64, out_features=512, bias=False)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (decoder): ConvASRDecoder(\n", + " (decoder_layers): Sequential(\n", + " (0): Conv1d(512, 129, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (loss): CTCLoss()\n", + " (spec_augmentation): SpectrogramAugmentation(\n", + " (spec_augment): SpecAugment()\n", + " )\n", + " (wer): WER()\n", + " )\n", + " (a_linear): Linear(in_features=512, out_features=512, bias=True)\n", + " (v_linear): Linear(in_features=768, out_features=512, bias=True)\n", + " (av_enocder_layer): TransformerEncoderLayer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (av_encoder): TransformerEncoder(\n", + " (layers): ModuleList(\n", + " (0-3): 4 x TransformerEncoderLayer(\n", + " (self_attn): MultiheadAttention(\n", + " (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)\n", + " )\n", + " (linear1): Linear(in_features=512, out_features=2048, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (linear2): Linear(in_features=2048, out_features=512, bias=True)\n", + " (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", + " (dropout1): Dropout(p=0.1, inplace=False)\n", + " (dropout2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (a_modal_embs): Embedding(1, 512)\n", + " (v_modal_embs): Embedding(1, 512)\n", + " (a_pos_enc): Embedding(10000, 512)\n", + " (v_pos_enc): Embedding(10000, 512)\n", + " (decoder): ConvASRDecoder(\n", + " (decoder_layers): Sequential(\n", + " (0): Conv1d(512, 357, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (loss): CTCLoss()\n", + " (wer): AV_WER()\n", + ")" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "manifest_file_path = '/tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json' # Path to your input manifest file\n", + "nemo_file_path = '/tmp/bld56_dataset_v1/tmp/av_ndec_lman_ntok_0.5/2024-08-16_11-16-34/checkpoints/av_ndec_lman_ntok_0.5.nemo' # Path to your trained .nemo file\n", + "output_file_path = 'temp.json' # Path to save the inference results\n", + "model = load_model(nemo_file_path)\n", + "model.to('cpu')" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-17 12:36:34 collections:321] Dataset loaded with 2200 files totalling 6.11 hours\n", + "[NeMo I 2024-08-17 12:36:34 collections:323] 0 files were filtered totalling 0.00 hours\n" + ] + } + ], + "source": [ + "dataset = nemo.collections.asr.data.av_to_text.AVToBPEDataset(\n", + " manifest_filepath='/tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json',\n", + " tokenizer= model.tokenizer,\n", + " sample_rate= 16000,\n", + " int_values=config.get('int_values', False),\n", + " max_duration=config.get('max_duration', None),\n", + " min_duration=config.get('min_duration', None),\n", + " max_utts=config.get('max_utts', 0),\n", + " trim=config.get('trim_silence', False),\n", + " use_start_end_token=config.get('use_start_end_token', True),\n", + " return_sample_id=config.get('return_sample_id', False),\n", + " channel_selector=config.get('channel_selector', None),\n", + " video_frame_rate=config.get('video_frame_rate', 5),\n", + " get_vid_feats=config.get('get_vid_feats', True),\n", + " get_zero_vid_feats = config.get('get_zero_vid_feats', False),\n", + " override_snr_ratio = config.get('override_snr_ratio', None),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "batch_size = 1\n", + "dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "so regular and complete a part of normal everyday living that finding newspapers on the news then buying them morning and night was takenso regular and complete a part of normal everyday living that finding newspapers on the news then buying them morning and night was taken \n", + "\n", + "\n", + "++++++ re reguular and and com compleletee a p parart of of n normmalal eververyy dayay l liivving that that f findding n nwssppaapperss on on the neewssstandnds b buyying the them m mororninging and and n niightt+ wasas t taakinging+++rereggullar and and compleletee a a parart of of norormmal e eververy d dayay l liivving that that f findinging n nwssppaapperss on the the neewssstandnds b buyyinging themm m mornning and and+ n nighght+ wasas t t takk+ing+ ararararllararararararararlararararararararararararararararararararararlararararararararlararar\n" + ] + } + ], + "source": [ + "signal, signal_len, video_input_signal, transcript, transcript_len = dataloader.__iter__().__next__()\n", + "log_probs, encoded_len, predictions = model.forward(audio_input_signal=signal, audio_input_signal_length=signal_len, video_input_signal=video_input_signal)\n", + "loss_value = model.loss(\n", + " log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len\n", + " )\n", + "# print(transcript, predictions)\n", + "# tokenizer = load_tokenizer('/home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_toknizer/tokenizer.model')\n", + "# model.wer.decoding.decode_tokens_to_str(predictions[0].cpu().numpy().tolist())\n", + "# replace predictions[0] where 356 to 355\n", + "predictions[0][predictions[0] == 356] = 355\n", + "print(model.wer.decoding.decode_tokens_to_str(transcript[0].cpu().numpy().tolist()))\n", + "print('\\n')\n", + "print(model.wer.decoding.decode_tokens_to_str(predictions[0].cpu().numpy().tolist()))" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "++++++ re reguular and and com compleletee a p parart of of n normmalal eververyy dayay l liivving that that f findding n nwssppaapperss on on the neewssstandnds b buyying the them m mororninging and and n niightt+ wasas t taakinging+++rereggullar and and compleletee a a parart of of norormmal e eververy d dayay l liivving that that f findinging n nwssppaapperss on the the neewssstandnds b buyyinging themm m mornning and and+ n nighght+ wasas t t takk+ing+ ararararllararararararararlararararararararararararararararararararararlararararararararlararar\n" + ] + } + ], + "source": [ + "import re\n", + "temp_str = model.wer.decoding.decode_tokens_to_str(predictions[0].cpu().numpy().tolist())\n", + "r_tags = re.findall(r'', temp_str)\n", + "for tag in r_tags:\n", + " unlabelled_h = temp_str.replace(tag, '')\n", + "print(unlabelled_h)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# Function to run inference on a manifest file\n", + "def run_inference(manifest_file_path, nemo_file_path, output_file_path):\n", + " # Load the model\n", + " model = load_model(nemo_file_path)\n", + " \n", + " # Read the manifest file\n", + " with open(manifest_file_path, 'r') as f:\n", + " manifest_data = [json.loads(line.strip()) for line in f]\n", + " \n", + " # Run inference on each sample in the manifest\n", + " results = []\n", + " for sample in manifest_data:\n", + " transcription = infer_single_sample(model, sample)\n", + " result = {\n", + " 'audio_filepath': sample['audio_filepath'],\n", + " 'video_filepath': sample['video_filepath'],\n", + " 'feature_file': sample['feature_file'],\n", + " 'duration': sample['duration'],\n", + " 'transcription': transcription\n", + " }\n", + " results.append(result)\n", + " \n", + " # Save the results to the output file\n", + " with open(output_file_path, 'w') as f:\n", + " for result in results:\n", + " f.write(json.dumps(result) + '\\n')\n", + "\n", + " print(f\"Inference completed. Results saved to {output_file_path}\")\n", + "\n", + "# Main function\n", + "def main():\n", + " manifest_file_path = '/tmp/bld56_dataset_v1/it2/annotations/manifest_eval.json' # Path to your input manifest file\n", + " nemo_file_path = '/tmp/bld56_dataset_v1/tmp/av_ndec_lman_ntok_0.5/2024-08-16_11-16-34/checkpoints/av_ndec_lman_ntok_0.5.nemo' # Path to your trained .nemo file\n", + " output_file_path = 'temp.json' # Path to save the inference results\n", + " \n", + " run_inference(manifest_file_path, nemo_file_path, output_file_path)\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nemo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/balu_codes/model_config_from_transcribe_py.yaml b/balu_codes/model_config_from_transcribe_py.yaml new file mode 100644 index 000000000000..8e784965be43 --- /dev/null +++ b/balu_codes/model_config_from_transcribe_py.yaml @@ -0,0 +1,309 @@ +sample_rate: 16000 +log_prediction: false +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large # CHANGE, BPE: is a must since, it is used to load audio encoder. +labelled_manifest: False # CHANGE +exp_dir: /tmp/bld56_dataset_v1/tmp/ # CHANGE +wandb: + run_name: "snr_0.7_ada+df" # CHANGE + project: "NEMO_TEST" # CHANGE + create_wandb_logger: false # CHANGE + log_model: False # CHANGE + +use_video_modality: false # CHANGE +use_pretrained_dec: true # CHANGE +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_train_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 8 # CHANGE + shuffle: true + num_workers: 12 # CHANGE + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 # CHANGE + min_duration: 0.1 + is_tarred: false # CHANGE + tarred_audio_filepaths: null # CHANGE + shuffle_n: 2048 + bucketing_strategy: synced_randomized + override_snr_ratio: 0.6 # CHANGE if float, then coniders as snr, if None then goes by manifest. + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 32 + shuffle: false + num_workers: 12 + pin_memory: true + override_snr_ratio: 0.6 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 32 + shuffle: false + num_workers: 12 + pin_memory: true + override_snr_ratio: 0.6 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false + +# NEW TOKENIZER +# tokenizer: # CHANGE +# dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ +# type: bpe +# model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model +# vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt +# spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab + + +# OLD TOKENIZER +tokenizer: # CHANGE # CHANGE THE NUM CLASSES TO 128 in DEC + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab + +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + +av_encoder: # CHANGE + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 + +v_model: # CHANGE + feat_dim: 512 + + +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 128 # CHANGE to 356 for new tok, else 128. + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 2.0 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-06 +compute_eval_loss: false # CHANGE + +adapters: # CHANGE + linear_adapter: + keep: false + name: "AV_v1" #@param {type:"string"} + dim: 64 #@param {type:"integer"} + activation: "swish" #@param {type:"string"} + norm_position: "pre" #@param ["pre", "post"] + dropout: 0.1 #@param {type:"number"} + multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + rel_position_multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + + + + + +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/saving_with_comments.yaml b/balu_codes/saving_with_comments.yaml new file mode 100644 index 000000000000..b72034e77dd2 --- /dev/null +++ b/balu_codes/saving_with_comments.yaml @@ -0,0 +1,309 @@ +sample_rate: 16000 +log_prediction: true +ctc_reduction: mean_batch +skip_nan_grad: false +a_model_name: BPE:stt_en_conformer_ctc_large # CHANGE, BPE: is a must since, it is used to load audio encoder. +labelled_manifest: false # CHANGE +exp_dir: /tmp/bld56_dataset_v1/tmp/ # CHANGE +wandb: + run_name: "av_ndec_uman_stok" # CHANGE + project: "NEMO_TEST" # CHANGE + create_wandb_logger: true # CHANGE + log_model: False # CHANGE + +use_video_modality: true # CHANGE +use_pretrained_dec: false # CHANGE +train_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_train_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 64 # CHANGE + shuffle: true + num_workers: 8 # CHANGE + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 20.0 # CHANGE + min_duration: 0.1 + is_tarred: false # CHANGE + tarred_audio_filepaths: null # CHANGE + shuffle_n: 2048 + bucketing_strategy: synced_randomizedp + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + bucketing_batch_size: + - 34 + - 30 + - 26 + - 22 + - 18 + - 16 + - 12 + - 8 +validation_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 64 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false +test_ds: + manifest_filepath: /tmp/bld56_dataset_v1/it1_70/annotations/manifest_eval_no_label.json # CHANGE + video_frame_rate: 5 # CHANGE + get_vid_feats: true # CHANGE, always keep it to true. + get_zero_vid_feats: false # CHANGE + sample_rate: 16000 + batch_size: 64 + shuffle: false + num_workers: 8 + pin_memory: true + override_snr_ratio: 0.5 # CHANGE if float, then coniders as snr, if None then goes by manifest. + use_start_end_token: false + +# NEW TOKENIZER +# tokenizer: # CHANGE +# dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/ +# type: bpe +# model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.model +# vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/vocab.txt +# spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/final_tokenizer/tokenizer.vocab + + +# OLD TOKENIZER +tokenizer: # CHANGE # CHANGE THE NUM CLASSES TO 128 in DEC + dir: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/ + type: bpe + model_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.model + vocab_path: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/vocab.txt + spe_tokenizer_vocab: /home/bld56/gsoc/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/init_toknizer/tokenizer.vocab + +preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: 16000 + normalize: per_feature + window_size: 0.025 + window_stride: 0.01 + window: hann + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 1.0e-05 + pad_to: 0 + pad_value: 0.0 +spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 + time_masks: 10 + freq_width: 27 + time_width: 0.05 + +av_encoder: # CHANGE + d_model: 512 + nhead: 8 + num_layers: 4 + dropout: 0.1 + +v_model: # CHANGE + feat_dim: 512 + + +encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: 80 + feat_out: -1 + n_layers: 18 + d_model: 512 + subsampling: striding + subsampling_factor: 4 + subsampling_conv_channels: 512 + ff_expansion_factor: 4 + self_attention_model: rel_pos + n_heads: 8 + att_context_size: + - -1 + - -1 + xscaling: true + untie_biases: true + pos_emb_max_len: 5000 + conv_kernel_size: 31 + conv_norm_type: batch_norm + dropout: 0.1 + dropout_emb: 0.0 + dropout_att: 0.1 +decoder: + _target_: nemo.collections.asr.modules.ConvASRDecoder + feat_in: 512 + num_classes: 128 # CHANGE to 356 for new tok, else 128. + vocabulary: + - + - ▁ + - s + - t + - e + - d + - o + - ▁the + - a + - i + - ▁a + - u + - 'y' + - m + - l + - 'n' + - p + - re + - c + - h + - r + - ▁s + - g + - ▁to + - er + - ing + - f + - ▁and + - an + - ▁i + - k + - ▁that + - '''' + - ▁of + - ▁in + - w + - ▁p + - ed + - or + - al + - ar + - ▁f + - en + - in + - b + - ▁you + - ▁w + - ▁b + - le + - ll + - es + - ▁it + - ve + - ur + - ▁we + - ▁re + - ▁be + - ly + - ▁is + - ▁he + - ▁o + - ▁c + - it + - ▁n + - ▁on + - un + - ▁t + - 'on' + - se + - th + - ce + - ▁do + - ic + - ▁for + - ▁th + - ion + - ch + - ▁was + - ri + - ent + - ▁g + - ver + - ▁co + - li + - ▁ha + - ▁ma + - la + - ro + - v + - us + - ▁ca + - ▁di + - ▁this + - ra + - ▁st + - ▁e + - ▁not + - ▁so + - ▁de + - ▁have + - ter + - ir + - ▁go + - ation + - ▁with + - ate + - ▁me + - ▁mo + - ment + - ▁con + - ▁but + - vi + - ▁pro + - ▁ho + - j + - ▁com + - ight + - ▁know + - ▁what + - ect + - ▁ex + - ▁some + - ▁would + - ▁like + - x + - ▁his + - q + - z +optim: + name: adamw + lr: 0.2 + betas: + - 0.9 + - 0.98 + weight_decay: 0.001 + sched: + name: NoamAnnealing + d_model: 512 + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1.0e-07 +compute_eval_loss: false # CHANGE + +adapters: # CHANGE + linear_adapter: + keep: true + name: "AV_v1" #@param {type:"string"} + dim: 64 #@param {type:"integer"} + activation: "swish" #@param {type:"string"} + norm_position: "pre" #@param ["pre", "post"] + dropout: 0.1 #@param {type:"number"} + multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + rel_position_multi_head_attention_adapter: + keep: false # TODO @Balu: Needs deeper understanding of config. ref: tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb + + + + + +variational_noise: + start_step: 0 + std: 0.0 +target: nemo.collections.asr.models.ctc_bpe_models.AV_EncDecCTCModelBPE +nemo_version: 1.9.0rc0 diff --git a/balu_codes/testing_av_code.ipynb b/balu_codes/testing_av_code.ipynb new file mode 100644 index 000000000000..f37945b9f2cc --- /dev/null +++ b/balu_codes/testing_av_code.ipynb @@ -0,0 +1,703 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "sys.path.insert(0, os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource'))\n", + "import nemo\n", + "print(nemo)\n", + "import nemo.core as nemo_core\n", + "from nemo.core import adapter_mixins\n", + "from nemo.utils import exp_manager\n", + "import nemo.collections.asr as nemo_asr\n", + "import nemo\n", + "import json\n", + "from omegaconf import OmegaConf, open_dict\n", + "import torch\n", + "from pytorch_lightning import Trainer\n", + "from lightning.pytorch.loggers import WandbLogger\n", + "from torchmetrics.text import WordErrorRate\n", + "import warnings\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to load and configure the model\n", + "def load_and_configure_model(config_file_path):\n", + " conf = OmegaConf.load(config_file_path)\n", + " overrides = OmegaConf.from_cli()\n", + " updated_conf = OmegaConf.merge(conf, overrides)\n", + " OmegaConf.set_struct(updated_conf, True)\n", + " model = nemo_asr.models.AV_EncDecCTCModelBPE(updated_conf)\n", + "\n", + " model.setup_training_data(model.cfg.train_ds)\n", + " return model, conf\n", + "\n", + "# Function to freeze and unfreeze model parameters based on adapters\n", + "def manage_model_adapters(model, conf):\n", + " # Freeze the entire model\n", + " model.freeze()\n", + " \n", + " # Determine which modules to train based on configuration\n", + " if model.cfg.use_video_modality:\n", + " modules_to_train = [\n", + " model.a_linear, model.v_linear, model.av_encoder, model.av_enocder_layer, \n", + " model.a_modal_embs, model.v_modal_embs, model.decoder\n", + " ]\n", + " elif not model.cfg.use_video_modality and model.cfg.use_pretrained_dec:\n", + " modules_to_train = [model.a_model.decoder]\n", + " else: # not model.cfg.use_video_modality and not model.cfg.use_pretrained_dec\n", + " modules_to_train = [model.decoder]\n", + " \n", + " # Set the selected modules to training mode and enable gradients\n", + " for module in modules_to_train:\n", + " module.train()\n", + " for param in module.parameters():\n", + " param.requires_grad = True\n", + "\n", + " # Handle adapter configurations if needed\n", + " if conf.adapters.linear_adapter.keep:\n", + " model.a_model.freeze()\n", + " model.a_model.set_enabled_adapters(enabled=False)\n", + " model.a_model.set_enabled_adapters(name=conf.adapters.linear_adapter.name, enabled=True)\n", + " model.a_model.unfreeze_enabled_adapters()\n", + "\n", + "\n", + "# Function to set up the trainer\n", + "def setup_trainer():\n", + " accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", + " trainer = Trainer(\n", + " devices=1, accelerator=accelerator, max_epochs=100,\n", + " enable_checkpointing=False, logger=False,\n", + " log_every_n_steps=5, check_val_every_n_epoch=1\n", + " )\n", + " return trainer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Function to set up experiment manager\n", + "def setup_exp_manager(trainer, model):\n", + " os.environ.pop('NEMO_EXPM_VERSION', None)\n", + "\n", + " exp_config = exp_manager.ExpManagerConfig(\n", + " exp_dir=model.cfg.exp_dir,\n", + " name=f'{model.cfg.wandb.run_name}',\n", + " checkpoint_callback_params=exp_manager.CallbackParams(\n", + " monitor=\"val_u_wer\",\n", + " mode=\"min\",\n", + " always_save_nemo=True,\n", + " save_best_model=True,\n", + " ),\n", + " create_wandb_logger=model.cfg.wandb.create_wandb_logger,\n", + " wandb_logger_kwargs=OmegaConf.create({\"project\": f\"{model.cfg.wandb.project}\", \"name\": f\"{model.cfg.wandb.run_name}\", \"log_model\": model.cfg.wandb.log_model}),\n", + " )\n", + "\n", + " exp_config = OmegaConf.structured(exp_config)\n", + " logdir = exp_manager.exp_manager(trainer, exp_config)\n", + " if model.cfg.wandb.create_wandb_logger:\n", + " trainer.loggers[1].log_hyperparams(OmegaConf.to_container(model.cfg)) # wandb logger\n", + " return logdir\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "final_results = {}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-02 11:46:40 mixins:172] Tokenizer SentencePieceTokenizer initialized with 356 tokens\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-02 11:46:42 collections:321] Dataset loaded with 22247 files totalling 61.80 hours\n", + "[NeMo I 2024-08-02 11:46:42 collections:323] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-08-02 11:46:42 collections:321] Dataset loaded with 2447 files totalling 6.80 hours\n", + "[NeMo I 2024-08-02 11:46:42 collections:323] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-08-02 11:46:42 collections:321] Dataset loaded with 2447 files totalling 6.80 hours\n", + "[NeMo I 2024-08-02 11:46:42 collections:323] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-08-02 11:46:42 cloud:58] Found existing object /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-02 11:46:42 cloud:64] Re-using file from: /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-08-02 11:46:42 common:815] Instantiating model from pre-trained checkpoint\n", + "Updated encoder _target_ model : nemo.collections.asr.modules.conformer_encoder.ConformerEncoderAdapter\n", + "[NeMo I 2024-08-02 11:46:42 cloud:58] Found existing object /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-02 11:46:42 cloud:64] Re-using file from: /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-08-02 11:46:42 common:815] Instantiating model from pre-trained checkpoint\n", + "[NeMo I 2024-08-02 11:46:43 mixins:172] Tokenizer SentencePieceTokenizer initialized with 128 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-02 11:46:43 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/tarred_audio_manifest.json\n", + " sample_rate: 16000\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " is_tarred: true\n", + " tarred_audio_filepaths:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/audio__OP_0..8191_CL_.tar\n", + " shuffle_n: 2048\n", + " bucketing_strategy: synced_randomized\n", + " bucketing_batch_size:\n", + " - 34\n", + " - 30\n", + " - 26\n", + " - 22\n", + " - 18\n", + " - 16\n", + " - 12\n", + " - 8\n", + " \n", + "[NeMo W 2024-08-02 11:46:43 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n", + "[NeMo W 2024-08-02 11:46:43 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-02 11:46:43 features:305] PADDING: 0\n", + "[NeMo I 2024-08-02 11:46:46 save_restore_connector:263] Model EncDecCTCModelBPE was successfully restored from /home/bld56/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-08-02 11:46:48 collections:321] Dataset loaded with 22247 files totalling 61.80 hours\n", + "[NeMo I 2024-08-02 11:46:48 collections:323] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:719] Setting adapter 'AV_v1' status : Enabled = False\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:734] Setting adapter 'AV_v1' status : Enabled = True\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.0.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.1.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.2.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.3.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.4.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.5.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.6.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.7.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.8.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.9.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.10.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.11.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.12.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.13.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.14.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.15.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.16.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:405] Froze module encoder.layers.17.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-08-02 11:46:48 adapter_mixins:435] Unfrozen adapter : AV_v1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-02 11:46:48 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/bld56/.miniconda3/envs/nemo/lib/python3.10/sit ...\n", + " \n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-08-02 11:46:48 exp_manager:396] Experiments will be logged at /tmp/bld56_dataset_v1/tmp/au_ndec_lman_ntok_NArch_0.5/2024-08-02_11-46-48\n", + "[NeMo I 2024-08-02 11:46:48 exp_manager:856] TensorboardLogger has been set up\n", + "[NeMo I 2024-08-02 11:46:48 exp_manager:871] WandBLogger has been set up\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlakshmipathi-balaji\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.17.5" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /tmp/bld56_dataset_v1/tmp/wandb/run-20240802_114650-2024-08-02_11-46-48" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run au_ndec_lman_ntok_NArch_0.5 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/lakshmipathi-balaji/NEMO_TEST" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/lakshmipathi-balaji/NEMO_TEST/runs/2024-08-02_11-46-48" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# snr_list = [1,0.95,0.9,0.85,...0.5]\n", + "# snr_list = [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, 0.5]\n", + "snr_list = [0.5]\n", + "for snr in snr_list:\n", + " config_file_path = \"/home/bld56/gsoc/nemo/NeMo-opensource/balu_codes/configs/c3_au_with_same_av_arch.yaml\"\n", + " model, conf = load_and_configure_model(config_file_path)\n", + " manage_model_adapters(model, conf)\n", + "\n", + " trainer = setup_trainer()\n", + " model.set_trainer(trainer)\n", + " logdir = setup_exp_manager(trainer, model)\n", + " warnings.filterwarnings(\"ignore\", category=UserWarning, message=\"PySoundFile failed. Trying audioread instead.\")\n", + " warnings.filterwarnings(\"ignore\", category=FutureWarning, message=\"librosa.core.audio.__audioread_load\\n\\tDeprecated as of librosa version 0.10.0.\\n\\tIt will be removed in librosa version 1.0.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + " | Name | Type | Params | Mode \n", + "----------------------------------------------------------------------\n", + "0 | a_model | EncDecCTCModelBPE | 122 M | eval \n", + "1 | a_linear | Linear | 262 K | train\n", + "2 | v_linear | Linear | 262 K | train\n", + "3 | av_enocder_layer | TransformerEncoderLayer | 3.2 M | train\n", + "4 | av_encoder | TransformerEncoder | 12.6 M | train\n", + "5 | a_modal_embs | Embedding | 512 | train\n", + "6 | v_modal_embs | Embedding | 512 | train\n", + "7 | a_pos_enc | Embedding | 5.1 M | eval \n", + "8 | v_pos_enc | Embedding | 5.1 M | eval \n", + "9 | decoder | ConvASRDecoder | 183 K | train\n", + "10 | loss | CTCLoss | 0 | eval \n", + "11 | wer | AV_WER | 0 | eval \n", + "----------------------------------------------------------------------\n", + "17.7 M Trainable params\n", + "131 M Non-trainable params\n", + "149 M Total params\n", + "597.643 Total estimated model params size (MB)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.summarize()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-02 11:47:06 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/bld56/.miniconda3/envs/nemo/lib/python3.10/sit ...\n", + " \n", + "You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5235241cf934dc8b3dfdeda510d7a1e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Validate metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ global_step 0.0 │\n", + "│ val_acc 0.0 │\n", + "│ val_l_wer 1.984375 │\n", + "│ val_loss 1231.60107421875 │\n", + "│ val_u_wer 2.1525423526763916 │\n", + "│ val_wer 2.1525423526763916 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m global_step \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_l_wer \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.984375 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1231.60107421875 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_u_wer \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 2.1525423526763916 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_wer \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 2.1525423526763916 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[{'global_step': 0.0,\n", + " 'val_l_wer': 1.984375,\n", + " 'val_u_wer': 2.1525423526763916,\n", + " 'val_acc': 0.0,\n", + " 'val_loss': 1231.60107421875,\n", + " 'val_wer': 2.1525423526763916}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.validate(model, model.test_dataloader())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-08-01 13:37:16 nemo_logging:349] /home/bld56/.miniconda3/envs/nemo/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/bld56/.miniconda3/envs/nemo/lib/python3.10/sit ...\n", + " \n", + "You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "64457a2ff8264c45a48e8df4f89a2ff0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | | 0/? [00:00\" \n", + " will be used during training (effective maximum steps = 139100) - \n", + " Parameters : \n", + " (d_model: 512\n", + " warmup_steps: 2000\n", + " warmup_ratio: null\n", + " min_lr: 1.0e-06\n", + " max_steps: 139100\n", + " )\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params | Mode\n", + "-----------------------------------------------------\n", + "0 | a_model | EncDecCTCModelBPE | 122 M | eval\n", + "1 | decoder | ConvASRDecoder | 183 K | eval\n", + "2 | loss | CTCLoss | 0 | eval\n", + "3 | wer | AV_WER | 0 | eval\n", + "-----------------------------------------------------\n", + "1.2 M Trainable params\n", + "121 M Non-trainable params\n", + "122 M Total params\n", + "491.530 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "73fda7c1563740c6a0586b11d0af8638", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:543\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 542\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 543\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 544\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 545\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 44\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 47\u001b[0m _call_teardown_hook(trainer)\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:579\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 573\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 574\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 575\u001b[0m ckpt_path,\n\u001b[1;32m 576\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 577\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 578\u001b[0m )\n\u001b[0;32m--> 579\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 581\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 582\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:986\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 981\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 986\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 988\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 989\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 990\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 991\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1028\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining:\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1028\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1029\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[1;32m 1030\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfit_loop\u001b[38;5;241m.\u001b[39mrun()\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1057\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1054\u001b[0m call\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_start\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1056\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[0;32m-> 1057\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1059\u001b[0m call\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1061\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py:182\u001b[0m, in \u001b[0;36m_no_grad_context.._decorator\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 180\u001b[0m context_manager \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mno_grad\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m context_manager():\n\u001b[0;32m--> 182\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mloop_run\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:135\u001b[0m, in \u001b[0;36m_EvaluationLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mis_last_batch \u001b[38;5;241m=\u001b[39m data_fetcher\u001b[38;5;241m.\u001b[39mdone\n\u001b[1;32m 134\u001b[0m \u001b[38;5;66;03m# run step hooks\u001b[39;00m\n\u001b[0;32m--> 135\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_iter\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 136\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m:\n\u001b[1;32m 137\u001b[0m \u001b[38;5;66;03m# this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support\u001b[39;00m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py:396\u001b[0m, in \u001b[0;36m_EvaluationLoop._evaluation_step\u001b[0;34m(self, batch, batch_idx, dataloader_idx, dataloader_iter)\u001b[0m\n\u001b[1;32m 390\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 391\u001b[0m step_args \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 392\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_step_args_from_hook_kwargs(hook_kwargs, hook_name)\n\u001b[1;32m 393\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m using_dataloader_iter\n\u001b[1;32m 394\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m (dataloader_iter,)\n\u001b[1;32m 395\u001b[0m )\n\u001b[0;32m--> 396\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mstep_args\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 398\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n\u001b[1;32m 400\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m using_dataloader_iter:\n\u001b[1;32m 401\u001b[0m \u001b[38;5;66;03m# update the hook kwargs now that the step method might have consumed the iterator\u001b[39;00m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:311\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 310\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 311\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 314\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:411\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 409\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 410\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 411\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/gsoc/nemo/NeMo-opensource/nemo/collections/asr/models/av_ctc_models.py:704\u001b[0m, in \u001b[0;36mAV_EncDecCTCModel.validation_step\u001b[0;34m(self, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 703\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalidation_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch, batch_idx, dataloader_idx\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m):\n\u001b[0;32m--> 704\u001b[0m metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_pass\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdataloader_idx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 705\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mval_dataloaders) \u001b[38;5;241m==\u001b[39m \u001b[38;5;28mlist\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mval_dataloaders) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 706\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvalidation_step_outputs[dataloader_idx]\u001b[38;5;241m.\u001b[39mappend(metrics)\n", + "File \u001b[0;32m~/gsoc/nemo/NeMo-opensource/nemo/collections/asr/models/av_ctc_models.py:672\u001b[0m, in \u001b[0;36mAV_EncDecCTCModel.validation_pass\u001b[0;34m(self, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 665\u001b[0m \u001b[38;5;66;03m# if isinstance(batch, DALIOutputs) and batch.has_processed_signal:\u001b[39;00m\n\u001b[1;32m 666\u001b[0m \u001b[38;5;66;03m# log_probs, encoded_len, predictions = self.forward(\u001b[39;00m\n\u001b[1;32m 667\u001b[0m \u001b[38;5;66;03m# processed_signal=signal, processed_signal_length=signal_len\u001b[39;00m\n\u001b[1;32m 668\u001b[0m \u001b[38;5;66;03m# )\u001b[39;00m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;66;03m# else:\u001b[39;00m\n\u001b[1;32m 670\u001b[0m log_probs, encoded_len, predictions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(audio_input_signal\u001b[38;5;241m=\u001b[39msignal, audio_input_signal_length\u001b[38;5;241m=\u001b[39msignal_len, video_input_signal\u001b[38;5;241m=\u001b[39mvideo_input_signal)\n\u001b[0;32m--> 672\u001b[0m loss_value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 673\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_probs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_probs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtranscript\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoded_len\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_lengths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtranscript_len\u001b[49m\n\u001b[1;32m 674\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 675\u001b[0m loss_value, metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39madd_interctc_losses(\n\u001b[1;32m 676\u001b[0m loss_value, transcript, transcript_len, compute_wer\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, log_wer_num_denom\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, log_prefix\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mval_\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 677\u001b[0m )\n\u001b[1;32m 679\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwer\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 680\u001b[0m predictions\u001b[38;5;241m=\u001b[39mlog_probs, targets\u001b[38;5;241m=\u001b[39mtranscript, targets_lengths\u001b[38;5;241m=\u001b[39mtranscript_len, predictions_lengths\u001b[38;5;241m=\u001b[39mencoded_len,\n\u001b[1;32m 681\u001b[0m )\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", + "File \u001b[0;32m~/gsoc/nemo/NeMo-opensource/nemo/core/classes/common.py:1064\u001b[0m, in \u001b[0;36mtypecheck.__call__\u001b[0;34m(self, wrapped, instance, args, kwargs)\u001b[0m\n\u001b[1;32m 1061\u001b[0m instance\u001b[38;5;241m.\u001b[39m_validate_input_types(input_types\u001b[38;5;241m=\u001b[39minput_types, ignore_collections\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mignore_collections, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1063\u001b[0m \u001b[38;5;66;03m# Call the method - this can be forward, or any other callable method\u001b[39;00m\n\u001b[0;32m-> 1064\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mwrapped\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1066\u001b[0m instance\u001b[38;5;241m.\u001b[39m_attach_and_validate_output_types(\n\u001b[1;32m 1067\u001b[0m output_types\u001b[38;5;241m=\u001b[39moutput_types, ignore_collections\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mignore_collections, out_objects\u001b[38;5;241m=\u001b[39moutputs\n\u001b[1;32m 1068\u001b[0m )\n\u001b[1;32m 1070\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n", + "File \u001b[0;32m~/gsoc/nemo/NeMo-opensource/nemo/collections/asr/losses/ctc.py:77\u001b[0m, in \u001b[0;36mCTCLoss.forward\u001b[0;34m(self, log_probs, targets, input_lengths, target_lengths)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;66;03m# here we transpose because we expect [B, T, D] while PyTorch assumes [T, B, D]\u001b[39;00m\n\u001b[1;32m 76\u001b[0m log_probs \u001b[38;5;241m=\u001b[39m log_probs\u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m---> 77\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 78\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_probs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_probs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_lengths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_lengths\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtarget_lengths\u001b[49m\n\u001b[1;32m 79\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 80\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_apply_reduction:\n\u001b[1;32m 81\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreduce(loss, target_lengths)\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/torch/nn/modules/loss.py:1785\u001b[0m, in \u001b[0;36mCTCLoss.forward\u001b[0;34m(self, log_probs, targets, input_lengths, target_lengths)\u001b[0m\n\u001b[1;32m 1784\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_probs: Tensor, targets: Tensor, input_lengths: Tensor, target_lengths: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m-> 1785\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mctc_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlog_probs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_lengths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mblank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreduction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1786\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzero_infinity\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.miniconda3/envs/nemo/lib/python3.10/site-packages/torch/nn/functional.py:2687\u001b[0m, in \u001b[0;36mctc_loss\u001b[0;34m(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity)\u001b[0m\n\u001b[1;32m 2680\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_variadic(log_probs, targets, input_lengths, target_lengths):\n\u001b[1;32m 2681\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 2682\u001b[0m ctc_loss,\n\u001b[1;32m 2683\u001b[0m (log_probs, targets, input_lengths, target_lengths),\n\u001b[1;32m 2684\u001b[0m log_probs, targets, input_lengths, target_lengths,\n\u001b[1;32m 2685\u001b[0m blank\u001b[38;5;241m=\u001b[39mblank, reduction\u001b[38;5;241m=\u001b[39mreduction, zero_infinity\u001b[38;5;241m=\u001b[39mzero_infinity\n\u001b[1;32m 2686\u001b[0m )\n\u001b[0;32m-> 2687\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mctc_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2688\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_probs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtargets\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minput_lengths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget_lengths\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mblank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_Reduction\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_enum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreduction\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mzero_infinity\u001b[49m\n\u001b[1;32m 2689\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mRuntimeError\u001b[0m: blank must be in label range" + ] + } + ], + "source": [ + "trainer.fit(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model.summarize()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/balu_codes/train_av_asr.py b/balu_codes/train_av_asr.py new file mode 100644 index 000000000000..1270a2d5f7a0 --- /dev/null +++ b/balu_codes/train_av_asr.py @@ -0,0 +1,155 @@ +import os +import sys +sys.path.insert(0, os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource/')) +import nemo.core as nemo_core +from nemo.core import adapter_mixins +from nemo.utils import exp_manager +import nemo.collections.asr as nemo_asr +import nemo +import json +from omegaconf import OmegaConf, open_dict +import torch +from pytorch_lightning import Trainer +from lightning.pytorch.loggers import WandbLogger +from torchmetrics.text import WordErrorRate +import warnings +import argparse + +# Function to load and configure the model +def load_and_configure_model(config_file_path): + conf = OmegaConf.load(config_file_path) + overrides = OmegaConf.from_cli() + print(overrides) + updated_conf = OmegaConf.merge(conf, overrides) + OmegaConf.set_struct(updated_conf, True) + model = nemo_asr.models.AV_EncDecCTCModelBPE(updated_conf) + + model.setup_training_data(model.cfg.train_ds) + return model, conf + +# Function to freeze and unfreeze model parameters based on adapters +def manage_model_adapters(model, conf): + # Freeze the entire model + model.freeze() + + # Determine which modules to train based on configuration + if model.cfg.use_video_modality: + modules_to_train = [ + model.a_linear, model.v_linear, model.av_encoder, model.av_enocder_layer, + model.a_modal_embs, model.v_modal_embs, model.decoder, model.a_pos_enc, model.v_pos_enc + ] + elif not model.cfg.use_video_modality and model.cfg.use_pretrained_dec: + modules_to_train = [model.a_model.decoder] + else: # not model.cfg.use_video_modality and not model.cfg.use_pretrained_dec + modules_to_train = [model.decoder] + + # Set the selected modules to training mode and enable gradients + for module in modules_to_train: + module.train() + for param in module.parameters(): + param.requires_grad = True + + # Handle adapter configurations if needed + if conf.adapters.linear_adapter.keep: + model.a_model.freeze() + model.a_model.set_enabled_adapters(enabled=False) + model.a_model.set_enabled_adapters(name=conf.adapters.linear_adapter.name, enabled=True) + model.a_model.unfreeze_enabled_adapters() + else: + model.a_model.unfreeze() + +# Function to set up the trainer +def setup_trainer(): + torch.set_float32_matmul_precision('high') + accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' + trainer = Trainer( + devices=-1, accelerator=accelerator, strategy="ddp_find_unused_parameters_true", + max_epochs=100, + enable_checkpointing=False, logger=False, + log_every_n_steps=5, check_val_every_n_epoch=1, + ) + return trainer + +# Function to set up experiment manager +def setup_exp_manager(trainer, model): + os.environ.pop('NEMO_EXPM_VERSION', None) + + exp_config = exp_manager.ExpManagerConfig( + exp_dir=model.cfg.exp_dir, + name=f'{model.cfg.wandb.run_name}', + checkpoint_callback_params=exp_manager.CallbackParams( + monitor="val_u_wer", + mode="min", + always_save_nemo=True, + save_best_model=True, + ), + create_wandb_logger=model.cfg.wandb.create_wandb_logger, + wandb_logger_kwargs=OmegaConf.create({"project": f"{model.cfg.wandb.project}", "name": f"{model.cfg.wandb.run_name}_{model.cfg.train_ds.override_snr_ratio}", "log_model": model.cfg.wandb.log_model}), + ) + + exp_config = OmegaConf.structured(exp_config) + logdir = exp_manager.exp_manager(trainer, exp_config) + if model.cfg.wandb.create_wandb_logger: + trainer.loggers[1].log_hyperparams(OmegaConf.to_container(model.cfg)) # wandb logger + # log the manifest file to wandb server + trainer.loggers[1].experiment.log_artifact(f"{model.cfg.train_ds.manifest_filepath}") + trainer.loggers[1].experiment.log_artifact(f"{model.cfg.validation_ds.manifest_filepath}") + + return logdir + +def selective_load(model, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + state_dict = checkpoint['state_dict'] + + # Filter out unnecessary keys + model_state_dict = model.state_dict() + filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict and v.size() == model_state_dict[k].size()} + + # Update the existing model state dict with the filtered state dict from the checkpoint + model_state_dict.update(filtered_state_dict) + + # Load the updated state dict back into the model + model.load_state_dict(model_state_dict) + # print(f"Loaded keys from checkpoint: {filtered_state_dict.keys()}") + print(model) + return model + + +# Main function to execute the workflow +def main(config_file_path, args): + model, conf = load_and_configure_model(config_file_path) + if args.resume_pretrained: + ckpt_path = f"/tmp/bld56_dataset_v1/saved_models/pre_av_ndec_uman_ntok--val_u_wer=0.0809-epoch=11.ckpt" + # checkpoint = torch.load(ckpt_path) + # print(checkpoint['state_dict'].keys()) + # model.load_state_dict(checkpoint['state_dict']) + model = selective_load(model, ckpt_path) + model.cfg.wandb.run_name += 'pre+' + manage_model_adapters(model, conf) + + trainer = setup_trainer() + model.set_trainer(trainer) + logdir = setup_exp_manager(trainer, model) + # trainer.fit(model) + trainer.validate(model) + +if __name__ == "__main__": + # add config number args + parser = argparse.ArgumentParser(description='Train AV ASR model') + parser.add_argument('--config', type=int, default=5, help='Config number to use for training') + parser.add_argument('--snr', type=float, default=0.7, help='SNR ratio to use for training') + parser.add_argument('--gpus', type=int, default=1, help='Number of GPUs to use for training') + parser.add_argument('--resume_pretrained', type=bool, default=False, help='Resume training from pretrained model') + args = parser.parse_args() + config_file_path = f"/home/bld56/gsoc/nemo/NeMo-opensource/balu_codes/configs/c{args.config}.yaml" + # load yaml file + with open(config_file_path) as file: + config = OmegaConf.load(file) + config['train_ds']['override_snr_ratio'] = args.snr + config['validation_ds']['override_snr_ratio'] = args.snr + config['test_ds']['override_snr_ratio'] = args.snr + with open(config_file_path, 'w') as file: + OmegaConf.save(config, file) + warnings.filterwarnings("ignore", category=UserWarning, message="PySoundFile failed. Trying audioread instead.") + warnings.filterwarnings("ignore", category=FutureWarning, message="librosa.core.audio.__audioread_load\n\tDeprecated as of librosa version 0.10.0.\n\tIt will be removed in librosa version 1.0.") + main(config_file_path, args) diff --git a/balu_codes/train_av_asr.sh b/balu_codes/train_av_asr.sh new file mode 100644 index 000000000000..7fb09d2ab35f --- /dev/null +++ b/balu_codes/train_av_asr.sh @@ -0,0 +1,74 @@ +#!/bin/bash +#SBATCH -N 1 +#SBATCH -J "nemo_train" +#SBATCH -p gpu +#SBATCH -c 46 +#SBATCH -G 2 +#SBATCH --mem-per-cpu=3G +#SBATCH -o "/home/bld56/gsoc/nemo/NeMo-opensource/balu_codes/nemo_train/%j.log" +#SBATCH -w "gput067" +#SBATCH --time="2-00:00:00" +#SBATCH --mail-type=ALL +#SBATCH --mail-user="lakshmipathi.balaji@research.iiit.ac.in" + +# bash /home/bld56/gsoc/general/set_up_node.sh +# export PATH="/home/bld56/.miniconda3/bin:$PATH" +# export PATH="$HOME/tools:$PATH" + + +cd /home/bld56/gsoc/nemo/NeMo-opensource/balu_codes +# gput064 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 1 --snr 0.6 & +# sleep 10 +# CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 2 --snr 0.6 & +# sleep 10 +# CUDA_VISIBLE_DEVICES=2 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 5 --snr 0.6 & + +# gput067 +# sleep 10 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 6 --snr 0.7 & +# sleep 10 +# CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 1 --snr 0.6 & + + +# gput068 +# sleep 10 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 2 --snr 0.6 & +# sleep 100 +# CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 5 --snr 0.6 & + +# gput066 +# sleep 100 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 6 --snr 0.6 & + +# gput065 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 1 --snr 0.5 & +# sleep 100 +# CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 5 --snr 0.5 & + +# gput063 +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 6 --snr 0.5 & + +# PRETRAINING, gput063 +# bash /home/bld56/gsoc/general/set_up_node.sh + +# PRETRAINED USING +source activate /home/bld56/.miniconda3/envs/nemo +# bash /home/bld56/gsoc/general/set_up_node.sh + +# av_ndec_lman_ntok +# CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 5 --snr 0.5 --resume_pretrained True & + +# sleep 40 +# av_ndec_uman_ntok +# CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 9 --snr 0.5 --resume_pretrained True & + +# au_ndec_lman_ntok +CUDA_VISIBLE_DEVICES=0 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 10 --snr 0.5 --resume_pretrained True & + +sleep 40 +# au_ndec_uman_ntok +CUDA_VISIBLE_DEVICES=1 /home/bld56/.miniconda3/envs/nemo/bin/python train_av_asr.py --config 11 --snr 0.5 --resume_pretrained True & + + +wait \ No newline at end of file diff --git a/balu_codes/transcribe.py b/balu_codes/transcribe.py new file mode 100644 index 000000000000..0a8a1d1f6285 --- /dev/null +++ b/balu_codes/transcribe.py @@ -0,0 +1,12 @@ +# import nemo.collections.asr as nemo_asr +import sys +import os +sys.path.append(os.path.abspath('/home/bld56/gsoc/nemo/NeMo-opensource')) +import nemo.collections.asr as nemo_asr + +def load_model(model_name): + model = nemo_asr.models.ASRModel.from_pretrained(model_name) + return model +model = load_model("stt_en_conformer_ctc_large") +# model = load_model("QuartzNet15x5Base-En") +model.transcribe(["/disk1/it1/mixed_audios/009LTXtP4vE_c053b1_171114BCPC_SLASH_171114-BC-PC_DOT_mp3_00035.wav"]) \ No newline at end of file diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 7ad6560b4401..1eca0d771189 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -25,6 +25,7 @@ from torch.utils.data import ChainDataset from nemo.collections.asr.data import audio_to_text, audio_to_text_dali +from nemo.collections.asr.data import av_to_text from nemo.collections.asr.data.huggingface.hf_audio_to_text_dataset import ( get_hf_audio_to_text_bpe_dataset, get_hf_audio_to_text_char_dataset, @@ -704,6 +705,69 @@ def get_audio_to_text_char_dataset_from_config( dataset = get_char_dataset(config=config, augmentor=augmentor) return dataset +def get_av_char_dataset(config: dict, augmentor: Optional['AudioAugmentor'] = None) -> av_to_text.AVToCharDataset: + """ + Instantiates a Character Encoding based AVToCharDataset. + + Args: + config: Config of the AVToCharDataset. + augmentor: Optional AudioAugmentor object for augmentations on audio data. + + Returns: + An instance of AV. + """ + if 'labels' not in config: + logging.warning(f"dataset does not have explicitly defined labels") + + dataset = av_to_text.AVToCharDataset( + manifest_filepath=config['manifest_filepath'], + labels=config.get('labels', None), + sample_rate=config['sample_rate'], + int_values=config.get('int_values', False), + augmentor=augmentor, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + blank_index=config.get('blank_index', -1), + unk_index=config.get('unk_index', -1), + normalize=config.get('normalize_transcripts', False), + trim=config.get('trim_silence', False), + parser=config.get('parser', 'en'), + return_sample_id=config.get('return_sample_id', False), + channel_selector=config.get('channel_selector', None), + video_frame_rate=config.get('video_frame_rate', 3), + get_vid_feats=config.get('get_vid_feats', False), + get_zero_vid_feats = config.get('get_zero_vid_feats', False), + ) + return dataset + +def get_av_to_text_char_dataset_from_config( + config, local_rank: int, global_rank: int, world_size: int, preprocessor_cfg: Optional[DictConfig] = None +): + """ + Construct AV-To-Text Char dataset from a config. + Args: + config: dataset config + local_rank: model local rank + global_rank: model global rand + world_size: world size + preprocessor_cfg: preprocessor config, for DALI dataset + + Returns: + constructed dataset or None if dataset config is invalid or nothing to load + """ + + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) + else: + augmentor = None + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + dataset = get_av_char_dataset(config=config, augmentor=augmentor) + return dataset + def get_audio_to_text_bpe_dataset_from_config( config, @@ -841,6 +905,72 @@ def get_audio_to_text_bpe_dataset_from_config( dataset = get_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor) return dataset +def get_av_bpe_dataset( + config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None +) -> av_to_text.AVToBPEDataset: + """ + Instantiates a Byte Pair Encoding / Word Piece Encoding based AudioToBPEDataset. + + Args: + config: Config of the AVToBPEDataset. + tokenizer: An instance of a TokenizerSpec object. + augmentor: Optional AudioAugmentor object for augmentations on audio data. + + Returns: + An instance of AVToBPEDataset. + """ + dataset = av_to_text.AVToBPEDataset( + manifest_filepath=config['manifest_filepath'], + tokenizer=tokenizer, + sample_rate=config['sample_rate'], + int_values=config.get('int_values', False), + augmentor=augmentor, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + max_utts=config.get('max_utts', 0), + trim=config.get('trim_silence', False), + use_start_end_token=config.get('use_start_end_token', True), + return_sample_id=config.get('return_sample_id', False), + channel_selector=config.get('channel_selector', None), + video_frame_rate=config.get('video_frame_rate', 3), + get_vid_feats=config.get('get_vid_feats', False), + get_zero_vid_feats = config.get('get_zero_vid_feats', False), + override_snr_ratio = config.get('override_snr_ratio', None), + ) + return dataset + +def get_av_to_text_bpe_dataset_from_config( + config, + local_rank: int, + global_rank: int, + world_size: int, + tokenizer, + preprocessor_cfg: Optional[DictConfig] = None, +): + """ + Construct AV-To-Text BPE dataset from a config. + Args: + config: BPE dataset config + local_rank: model local rank + global_rank: model global rand + world_size: world size + tokenizer: BPE tokenizer + preprocessor_cfg: preprocessor config, for DALI BPE dataset + + Returns: + constructed dataset or None if dataset config is invalid or nothing to load + """ + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) + else: + augmentor = None + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + dataset = get_av_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor) + return dataset + class ASRPredictionWriter(BasePredictionWriter): def __init__(self, dataset, output_file: str): diff --git a/nemo/collections/asr/data/av_to_text.py b/nemo/collections/asr/data/av_to_text.py new file mode 100644 index 000000000000..5925e9a4810f --- /dev/null +++ b/nemo/collections/asr/data/av_to_text.py @@ -0,0 +1,773 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import io +import json +import math +import multiprocessing +import os +from collections.abc import Iterable as IterableABC +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import braceexpand +import numpy as np +import torch +import webdataset as wds +from torch.utils.data import ChainDataset +from tqdm import tqdm + +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import available_formats as valid_sf_formats +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +from nemo.collections.common import tokenizers +from nemo.collections.common.parts.preprocessing import collections, parsers +from nemo.core.classes import Dataset, IterableDataset +from nemo.core.neural_types import * +from nemo.utils import logging +from nemo.utils.data_utils import ( + DataStoreObject, + datastore_object_get, + datastore_path_to_webdataset_url, + is_datastore_cache_shared, + is_datastore_path, + is_tarred_path, +) +from nemo.utils.distributed import webdataset_split_by_workers +from nemo.utils.get_rank import is_global_rank_zero + +import numpy as np + +# FOR NOISE LOADING +from pydub import AudioSegment + +__all__ = [ + 'AVToCharDataset', + 'AVToBPEDataset', +] + +VALID_FILE_FORMATS = ';'.join( + ['wav', 'mp3', 'flac', 'opus'] + [fmt.lower() for fmt in valid_sf_formats.keys()]) + + +def _speech_collate_fn(batch, pad_id, get_vid_feats): + """collate batch of audio sig, audio len, video sig, tokens, tokens len + Args: + batch (Optional[FloatTensor], Optional[LongTensor], Optional[LongTensor], + LongTensor, LongTensor): A tuple of tuples of signal, signal lengths, + encoded tokens, and encoded tokens length. This collate func + assumes the signals are 1d torch tensors (i.e. mono audio). + """ + packed_batch = list(zip(*batch)) + if get_vid_feats: + if len(packed_batch) == 7: + _, audio_lengths, _, _, tokens_lengths, labels, sample_ids = packed_batch + elif len(packed_batch) == 6: + sample_ids = None + _, audio_lengths, _, _, tokens_lengths, labels = packed_batch + else: + raise ValueError(f"Expects 5 or 6 tensors in the batch!") + else: + if len(packed_batch) == 5: + sample_ids = None + _, audio_lengths, _, tokens_lengths, labels = packed_batch + elif len(packed_batch) == 4: + _, audio_lengths, _, tokens_lengths, sample_ids, labels = packed_batch + else: + raise ValueError(f"Expects 4 or 5 tensors in the batch!") + max_audio_len = 0 + has_audio = audio_lengths[0] is not None + if has_audio: + max_audio_len = max(audio_lengths).item() + max_tokens_len = max(tokens_lengths).item() + + audio_signal, tokens, video_feat_signal, labels = [], [], [], [] + for b in batch: + if len(b) == 7 and get_vid_feats: + sig, sig_len, video_feat, tokens_i, tokens_i_len, label, _ = b + elif len(b) == 6 and get_vid_feats: + sig, sig_len, video_feat, tokens_i, tokens_i_len, label = b + elif len(b) == 6 and not get_vid_feats: + sig, sig_len, tokens_i, tokens_i_len, label, _ = b + elif len(b) == 5 and not get_vid_feats: + sig, sig_len, tokens_i, tokens_i_len, label = b + if has_audio: + sig_len = sig_len.item() + if sig_len < max_audio_len: + pad = (0, max_audio_len - sig_len) + sig = torch.nn.functional.pad(sig, pad) + audio_signal.append(sig) + if get_vid_feats: + video_feat_signal.append(video_feat) + tokens_i_len = tokens_i_len.item() + if tokens_i_len < max_tokens_len: + pad = (0, max_tokens_len - tokens_i_len) + tokens_i = torch.nn.functional.pad(tokens_i, pad, value=pad_id) + tokens.append(tokens_i) + labels.append(label) + + if has_audio: + audio_signal = torch.stack(audio_signal) + audio_lengths = torch.stack(audio_lengths) + else: + audio_signal, audio_lengths = None, None + if get_vid_feats: + video_feat_signal = torch.stack(video_feat_signal) + tokens = torch.stack(tokens) + tokens_lengths = torch.stack(tokens_lengths) + labels = torch.stack(labels) + base_output = [audio_signal, audio_lengths, tokens, tokens_lengths, labels] + + if get_vid_feats: + base_output.insert(2, video_feat_signal) + + if sample_ids is not None: + sample_ids = torch.tensor(sample_ids, dtype=torch.int32) + base_output.append(sample_ids) + + return tuple(base_output) + +class ASR_AV_ManifestProcessor: + """ + Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A", "video_featpath": "/path/to/video_feat.npy"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + bos_id: Id of beginning of sequence symbol to append if not None. + eos_id: Id of end of sequence symbol to append if not None. + pad_id: Id of pad symbol. Defaults to 0. + """ + + def __init__( + self, + manifest_filepath: str, + parser: Union[str, Callable], + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: int = 0, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + index_by_file_id: bool = False, + ): + self.parser = parser + + self.collection = collections.ASR_AV_AudioText( + manifests_files=manifest_filepath, + parser=parser, + min_duration=min_duration, + max_duration=max_duration, + max_number=max_utts, + index_by_file_id=index_by_file_id, + ) + + self.eos_id = eos_id + self.bos_id = bos_id + self.pad_id = pad_id + + def process_text_by_id(self, index: int) -> Tuple[List[int], int]: + sample = self.collection[index] + return self.process_text_by_sample(sample) + + def process_text_by_file_id(self, file_id: str) -> Tuple[List[int], int]: + manifest_idx = self.collection.mapping[file_id][0] + sample = self.collection[manifest_idx] + return self.process_text_by_sample(sample) + + def process_text_by_sample(self, sample: collections.ASR_AV_AudioText.OUTPUT_TYPE) -> Tuple[List[int], int]: + t, tl = sample.text_tokens, len(sample.text_tokens) + + if self.bos_id is not None: + t = [self.bos_id] + t + tl += 1 + if self.eos_id is not None: + t = t + [self.eos_id] + tl += 1 + + return t, tl + + +def cache_datastore_manifests( + manifest_filepaths: Union[str, List[str]], + cache_audio: bool = False, + shared_cache: Optional[bool] = None, + num_workers: Optional[int] = None, + max_num_workers: int = 20, +): + """Cache manifests and audio from an object store. + It is assumed that remote manifests are using relative paths. + + Args: + manifest_filepaths: list of paths to manifest files (list of strings or a string with `,` as separator) + cache_audio: If True, audio from manifest will also be cached + shared_cache: Optional, True if cache is shared across all nodes + num_workers: Optional, number of workers to be used for download + max_num_workers: max number of workers to be used for download, used when setting num_workers automatically + """ + if isinstance(manifest_filepaths, str): + manifest_filepaths = manifest_filepaths.split(',') + + num_datastore_manifests = sum( + [is_datastore_path(f) for f in manifest_filepaths]) + + if num_datastore_manifests > 0: + # Local utility function + def cache_data(manifest_filepaths, cache_audio, num_workers, max_num_workers): + """Cache manifests and audio data from object store. + """ + # Determine the number of workers to use + if num_workers is None: + num_workers = os.cpu_count() - 1 + num_workers = min(num_workers, max_num_workers) + + # Process each manifest file + for manifest_file in manifest_filepaths: + # If manifest is on a data store, then cache it. + # Otherwise, nothing to do. + if is_datastore_path(manifest_file): + logging.info('Cache manifest file: %s', manifest_file) + cached_manifest_file = DataStoreObject(manifest_file).get() + logging.info('Cached at: %s', str(cached_manifest_file)) + + if cache_audio: + # Each audio file from manifest will be cached. + logging.info( + 'Cache audio from manifest file: %s', manifest_file) + # Assumes that manifest is using relative paths + manifest_dir = os.path.dirname(manifest_file) + # Prepare all store objects + audio_objects = [] + with open(cached_manifest_file, 'r') as f: + for line in f: + item = json.loads(line) + store_path = os.path.join( + manifest_dir, item['audio_filepath']) + audio_objects.append( + DataStoreObject(store_path=store_path)) + + if num_workers is not None and num_workers > 1: + logging.debug( + 'Using multiprocessing with num_workers: %d.', num_workers) + with multiprocessing.Pool(processes=num_workers) as p: + result = list( + tqdm(p.imap(datastore_object_get, audio_objects), total=len( + audio_objects)) + ) + else: + logging.debug('Using a single process.') + result = [] + for audio_object in tqdm(audio_objects): + result.append(audio_object.get() is not None) + + if not all(result): + raise RuntimeError( + 'Some files not downloaded successfully') + logging.info('Caching complete') + + else: + # Nothing to do here + logging.debug( + 'Manifest is not on a data store: %s', manifest_file) + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + logging.debug( + 'Distributed environment is available and initialized.') + + # Handle distributed environment + if shared_cache is None: + shared_cache = is_datastore_cache_shared() + + if shared_cache: + logging.debug( + 'Cache is shared among nodes, cache data on global rank zero.') + is_rank_zero = is_global_rank_zero() + else: + logging.debug( + 'Cache is not shared among nodes, cache data on local rank zero.') + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + is_rank_zero = local_rank == 0 + + if is_rank_zero: + logging.info('Cache data from %s rank 0', + 'global' if shared_cache else 'local') + cache_data( + manifest_filepaths=manifest_filepaths, + cache_audio=cache_audio, + num_workers=num_workers, + max_num_workers=max_num_workers, + ) + logging.debug('Reached barrier') + torch.distributed.barrier() + + elif is_global_rank_zero(): + # Handle non-distributed environment, e.g., if running on a single GPU + logging.warning( + 'Torch distributed is not initialized and caching may be prone to data race conditions. ' + 'Now caching data from global rank 0. If there are other ranks and they pass this ' + 'before rank 0, errors might result.' + ) + cache_data( + manifest_filepaths=manifest_filepaths, + cache_audio=cache_audio, + num_workers=num_workers, + max_num_workers=max_num_workers, + ) + else: + raise RuntimeError( + 'Torch distributed is not initialized and caching on nodes other than global rank zero is disabled ' + 'to avoid race condition between different ranks. To ensure distributed environment is ' + 'initialized, please update data config to use `defer_setup = True`.' + ) + + +class _AVTextDataset(Dataset): + """ + Dataset that loads tensors via a json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor object used to augment loaded + audio + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include in dataset + max_utts: Limit number of utterances + trim: whether or not to trim silence. Defaults to False + bos_id: Id of beginning of sequence symbol to append if not None + eos_id: Id of end of sequence symbol to append if not None + pad_id: Id of pad symbol. Defaults to 0 + return_sample_id (bool): whether to return the sample_id as a part of each sample + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + video_frame_rate (int): Frame rate of video, used to calculate duration of video + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + return { + 'audio_signal': [NeuralType(('B', 'T'), AudioSignal())], + 'a_sig_length': [NeuralType(tuple('B'), LengthsType())], + 'video_input_signal': [NeuralType(('B', 'T', 'D'), ChannelType(), optional=True)], + 'transcripts': [NeuralType(('B', 'T'), LabelsType())], + 'transcript_length': [NeuralType(tuple('B'), LengthsType())], + 'sample_id': [NeuralType(tuple('B'), LengthsType(), optional=True)], + } + + def __init__( + self, + manifest_filepath: str, + parser: Union[str, Callable], + sample_rate: int, + int_values: bool = False, + augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None, + max_duration: Optional[int] = None, + min_duration: Optional[int] = None, + max_utts: int = 0, + trim: bool = False, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + return_sample_id: bool = False, + channel_selector: Optional[ChannelSelectorType] = None, + video_frame_rate: int = 5, + get_vid_feats: bool = True, + get_zero_vid_feats: bool = False, + override_snr_ratio: Optional[float] = None, + ): + if type(manifest_filepath) == str: + manifest_filepath = manifest_filepath.split(",") + + # If necessary, cache manifests and audio from object store + # TODO: @Balu, include cache_video + cache_datastore_manifests( + manifest_filepaths=manifest_filepath, cache_audio=True) + + self.manifest_processor = ASR_AV_ManifestProcessor( + manifest_filepath=manifest_filepath, + parser=parser, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + ) + self.featurizer = WaveformFeaturizer( + sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) + self.trim = trim + self.return_sample_id = return_sample_id + self.channel_selector = channel_selector + self.video_frame_rate = video_frame_rate + self.get_vid_feats = get_vid_feats + self.get_zero_vid_feats = get_zero_vid_feats + self.override_snr_ratio = override_snr_ratio + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] + + def __getitem__(self, index): + if isinstance(index, IterableABC): + return [self._process_sample(_index) for _index in index] + else: + return self._process_sample(index) + + def calculate_rms(self, audio): + """Calculate the RMS (root mean square) level of an audio signal.""" + return torch.sqrt(torch.mean(audio ** 2)) + + def adjust_volume(self, audio, target_rms): + """Adjust the audio's volume to a target RMS level.""" + current_rms = self.calculate_rms(audio) + return audio * (target_rms / (current_rms + 1e-9)) # Avoid division by zero + + def _mix_audios(self, noisy_audio_feats, clean_audio_feats, snr, target_sr=16000): + if self.override_snr_ratio is not None: + snr = self.override_snr_ratio + rms1 = self.calculate_rms(clean_audio_feats) + rms2 = self.calculate_rms(noisy_audio_feats) + mean_rms = (rms1 + rms2) / 2 + + noisy_audio_feats = self.adjust_volume(noisy_audio_feats, mean_rms) + clean_audio_feats = self.adjust_volume(clean_audio_feats, mean_rms) + + assert len(clean_audio_feats) >= 10*target_sr, f"Audio length is too short: {len(clean_audio_feats)}" + + if len(noisy_audio_feats) < len(clean_audio_feats): + noisy_audio_feats = torch.nn.functional.pad(noisy_audio_feats, (0, len(clean_audio_feats) - len(noisy_audio_feats))) + + min_len = min(10*target_sr, len(clean_audio_feats)) + noisy_audio_feats = noisy_audio_feats[:min_len] + clean_audio_feats = clean_audio_feats[:min_len] + + mixed_audio = snr * clean_audio_feats + (1 - snr) * noisy_audio_feats + + return mixed_audio + + + def _process_sample(self, index): + sample = self.manifest_processor.collection[index] + offset = sample.offset + + if offset is None: + offset = 0 + + clean_audio_features = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + orig_sr=sample.orig_sr, + channel_selector=self.channel_selector, + ) + if self.override_snr_ratio != 0.0: + audio = AudioSegment.from_file(sample.video_file, format="mp4") + samples_pydub = np.array(audio.get_array_of_samples(), dtype=np.float32) + noise_features = torch.tensor(samples_pydub, dtype=torch.float32) + noise_features = noise_features / (2**(8 * audio.sample_width) / 2) + mixed_features = self._mix_audios(noise_features, clean_audio_features, snr = sample.snr) + else: + mixed_features = clean_audio_features + f, fl = mixed_features, torch.tensor(mixed_features.shape[0]).long() + + # TODO: @Balu, saving audio temporarily + # os.makedirs(f"/tmp/bld56_dataset_v1/audioset/temp_sample_check/snr_{self.override_snr_ratio}", exist_ok=True) + # save_audio_path = f"/tmp/bld56_dataset_v1/audioset/temp_sample_check/snr_{self.override_snr_ratio}/{sample.video_file.split('/')[-1].split('.')[0]}_{sample.audio_file.split('/')[-1].split('.')[0]}.wav" + # import torchaudio + # torchaudio.save(save_audio_path, f.unsqueeze(0), 16000) + + if self.get_vid_feats: + if not self.get_zero_vid_feats: + # check if file exists + assert os.path.exists( + sample.video_featfile), f"Video feature file {sample.video_featfile} does not exist" + vf = np.load(sample.video_featfile) + # uniformly sample self.video_frame_rate frames from video at shape 0. + assert vf.shape[0] == self.video_frame_rate*sample.duration, f"Video feature file {sample.video_featfile} has {vf.shape[0]} frame_feats, expected {self.video_frame_rate}" + vf = torch.from_numpy(vf) + # make it torch float + vf = vf.float() + else: + vf = torch.zeros( + self.video_frame_rate*sample.duration, 768).float() + + t, tl = self.manifest_processor.process_text_by_sample(sample=sample) + + output = [f, fl, torch.tensor(t).long(), torch.tensor(tl).long(), torch.tensor(sample.label).long()] + + if self.get_vid_feats: + output.insert(2, vf) + + if self.return_sample_id: + output.append(index) + + output = tuple(output) + + return output + + def __len__(self): + # return 5 + # return 100 + return len(self.manifest_processor.collection) + + def _collate_fn(self, batch): + return _speech_collate_fn(batch, pad_id=self.manifest_processor.pad_id, get_vid_feats=self.get_vid_feats) + + +class AVToCharDataset(_AVTextDataset): + """ + Dataset that loads tensors via a json file containing paths to audio + files, transcripts, and durations (in seconds). Each new line is a + different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": + "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the + transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + Args: + manifest_filepath: Path to manifest json as described above. Can + be comma-separated paths. + labels: String containing all the possible characters to map to + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor + object used to augment loaded audio + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include + in dataset + max_utts: Limit number of utterances + blank_index: blank character index, default = -1 + unk_index: unk_character index, default = -1 + normalize: whether to normalize transcript text (default): True + bos_id: Id of beginning of sequence symbol to append if not None + eos_id: Id of end of sequence symbol to append if not None + return_sample_id (bool): whether to return the sample_id as a part of each sample + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + video_frame_rate (int): Frame rate of video, used to calculate duration of video + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'hidden_states': NeuralType(('B', 'T', 'D'), ImageFeatureValue(), optional=True), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__( + self, + manifest_filepath: str, + labels: Union[str, List[str]], + sample_rate: int, + int_values: bool = False, + augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_utts: int = 0, + blank_index: int = -1, + unk_index: int = -1, + normalize: bool = True, + trim: bool = False, + bos_id: Optional[int] = None, + eos_id: Optional[int] = None, + pad_id: int = 0, + parser: Union[str, Callable] = 'en', + return_sample_id: bool = False, + channel_selector: Optional[ChannelSelectorType] = None, + video_frame_rate: int = 3, + get_vid_feats: bool = True, + get_zero_vid_feats: bool = False, + override_snr_ratio: Optional[float] = None, + ): + self.labels = labels + + parser = parsers.make_parser( + labels=labels, name=parser, unk_id=unk_index, blank_id=blank_index, do_normalize=normalize + ) + + super().__init__( + manifest_filepath=manifest_filepath, + parser=parser, + sample_rate=sample_rate, + int_values=int_values, + augmentor=augmentor, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + trim=trim, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + return_sample_id=return_sample_id, + channel_selector=channel_selector, + video_frame_rate=video_frame_rate, + get_vid_feats=get_vid_feats, + get_zero_vid_feats=get_zero_vid_feats, + override_snr_ratio=override_snr_ratio, + ) + + +class AVToBPEDataset(_AVTextDataset): + """ + Dataset that loads tensors via a json file containing paths to audio + files, transcripts, and durations (in seconds). Each new line is a + different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": + "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the + transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + + In practice, the dataset and manifest used for character encoding and byte pair encoding + are exactly the same. The only difference lies in how the dataset tokenizes the text in + the manifest. + + Args: + manifest_filepath: Path to manifest json as described above. Can + be comma-separated paths. + tokenizer: A subclass of the Tokenizer wrapper found in the common collection, + nemo.collections.common.tokenizers.TokenizerSpec. ASR Models support a subset of + all available tokenizers. + sample_rate (int): Sample rate to resample loaded audio to + int_values (bool): If true, load samples as 32-bit integers. Defauts to False. + augmentor (nemo.collections.asr.parts.perturb.AudioAugmentor): An AudioAugmentor + object used to augment loaded audio + max_duration: If audio exceeds this length, do not include in dataset + min_duration: If audio is less than this length, do not include + in dataset + max_utts: Limit number of utterances + trim: Whether to trim silence segments + use_start_end_token: Boolean which dictates whether to add [BOS] and [EOS] + tokens to beginning and ending of speech respectively. + return_sample_id (bool): whether to return the sample_id as a part of each sample + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + video_frame_rate (int): Frame rate of video, used to calculate duration of video + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports. + """ + if self.get_vid_feats: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'hidden_states': NeuralType(('B', 'T', 'D'), ImageFeatureValue(), optional=True), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + else: + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'transcripts': NeuralType(('B', 'T'), LabelsType()), + 'transcript_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + + def __init__( + self, + manifest_filepath: str, + tokenizer: 'nemo.collections.common.tokenizers.TokenizerSpec', + sample_rate: int, + int_values: bool = False, + augmentor: 'nemo.collections.asr.parts.perturb.AudioAugmentor' = None, + max_duration: Optional[int] = None, + min_duration: Optional[int] = None, + max_utts: int = 0, + trim: bool = False, + use_start_end_token: bool = True, + return_sample_id: bool = False, + channel_selector: Optional[ChannelSelectorType] = None, + video_frame_rate: int = 3, + get_vid_feats: bool = True, + get_zero_vid_feats: bool = False, + override_snr_ratio: Optional[float] = None, + ): + if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: + bos_id = tokenizer.bos_id + else: + bos_id = None + + if use_start_end_token and hasattr(tokenizer, "eos_id") and tokenizer.eos_id > 0: + eos_id = tokenizer.eos_id + else: + eos_id = None + + if hasattr(tokenizer, "pad_id") and tokenizer.pad_id > 0: + pad_id = tokenizer.pad_id + else: + pad_id = 0 + + class TokenizerWrapper: + def __init__(self, tokenizer): + if isinstance(tokenizer, tokenizers.aggregate_tokenizer.AggregateTokenizer): + self.is_aggregate = True + else: + self.is_aggregate = False + self._tokenizer = tokenizer + + def __call__(self, *args): + if isinstance(args[0], List) and self.is_aggregate: + t = [] + for span in args[0]: + t.extend(self._tokenizer.text_to_ids( + span['str'], span['lang'])) + return t + + t = self._tokenizer.text_to_ids(*args) + return t + + super().__init__( + manifest_filepath=manifest_filepath, + parser=TokenizerWrapper(tokenizer), + sample_rate=sample_rate, + int_values=int_values, + augmentor=augmentor, + max_duration=max_duration, + min_duration=min_duration, + max_utts=max_utts, + bos_id=bos_id, + eos_id=eos_id, + pad_id=pad_id, + trim=trim, + return_sample_id=return_sample_id, + channel_selector=channel_selector, + video_frame_rate=video_frame_rate, + get_vid_feats=get_vid_feats, + get_zero_vid_feats=get_zero_vid_feats, + override_snr_ratio=override_snr_ratio, + ) diff --git a/nemo/collections/asr/metrics/av_wer.py b/nemo/collections/asr/metrics/av_wer.py new file mode 100644 index 000000000000..26ed7a092f64 --- /dev/null +++ b/nemo/collections/asr/metrics/av_wer.py @@ -0,0 +1,263 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import editdistance +import jiwer +import torch +from torchmetrics import Metric + +from nemo.collections.asr.parts.submodules.ctc_decoding import AbstractCTCDecoding +from nemo.collections.asr.parts.submodules.multitask_decoding import AbstractMultiTaskDecoding +from nemo.collections.asr.parts.submodules.rnnt_decoding import AbstractRNNTDecoding +from nemo.utils import logging + +import regex as re + +__all__ = ['AV_WER'] + +def move_dimension_to_the_front(tensor, dim_index): + all_dims = list(range(tensor.ndim)) + return tensor.permute(*([dim_index] + all_dims[:dim_index] + all_dims[dim_index + 1 :])) + + +class AV_WER(Metric): + """ + This metric computes numerator and denominator for Overall Word Error Rate (WER) between prediction and reference + texts. When doing distributed training/evaluation the result of ``res=WER(predictions, predictions_lengths, targets, target_lengths)`` + calls will be all-reduced between all workers using SUM operations. Here ``res`` contains three numbers + ``res=[wer, total_levenstein_distance, total_number_of_words]``. + + This also has options to compute WER with tags, without tags and accuracy for tag prediction too. + TODO @Balu: Can also integrate spans of the tag predicted. + + If used with PytorchLightning LightningModule, include wer_numerator and wer_denominators inside validation_step + results. Then aggregate (sum) then at the end of validation epoch to correctly compute validation WER. + + Example: + def validation_step(self, batch, batch_idx): + ... + wer_num, wer_denom = self.__wer(predictions, predictions_len, transcript, transcript_len) + self.val_outputs = {'val_loss': loss_value, 'val_wer_num': wer_num, 'val_wer_denom': wer_denom} + return self.val_outputs + + def on_validation_epoch_end(self): + ... + wer_num = torch.stack([x['val_wer_num'] for x in self.val_outputs]).sum() + wer_denom = torch.stack([x['val_wer_denom'] for x in self.val_outputs]).sum() + tensorboard_logs = {'validation_loss': val_loss_mean, 'validation_avg_wer': wer_num / wer_denom} + self.val_outputs.clear() # free memory + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + Args: + decoding: An instance of CTCDecoding or RNNTDecoding. + use_cer: Whether to use Character Error Rate instead of Word Error Rate. + log_prediction: Whether to log a single decoded sample per call. + batch_dim_index: Index corresponding to batch dimension. (For RNNT.) + dist_dync_on_step: Whether to perform reduction on forward pass of metric. + labelled_manifest: Whether the manifest has labels or not. + + Returns: + res: a tuple of 3 zero dimensional float32 ``torch.Tensor` objects: a WER score, a sum of Levenstein's + distances for all prediction - reference pairs, total number of words in all references. + """ + + full_state_update: bool = True + + def __init__( + self, + decoding: Union[AbstractCTCDecoding, AbstractRNNTDecoding, AbstractMultiTaskDecoding], + use_cer=False, + log_prediction=True, + fold_consecutive=True, + batch_dim_index=0, + dist_sync_on_step=False, + labelled_manifest=False, + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.decoding = decoding + self.use_cer = use_cer + self.log_prediction = log_prediction + self.fold_consecutive = fold_consecutive + self.batch_dim_index = batch_dim_index + + self.has_spl_tokens = False + self.decode = None + if isinstance(self.decoding, AbstractRNNTDecoding): + self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids, targets: self.decoding.rnnt_decoder_predictions_tensor( + encoder_output=predictions, encoded_lengths=predictions_lengths + ) + elif isinstance(self.decoding, AbstractCTCDecoding): + self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids, targets: self.decoding.ctc_decoder_predictions_tensor( + decoder_outputs=predictions, + decoder_lengths=predictions_lengths, + fold_consecutive=self.fold_consecutive, + ) + elif isinstance(self.decoding, AbstractMultiTaskDecoding): + self.has_spl_tokens = True + self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids, targets: self.decoding.decode_predictions_tensor( + encoder_hidden_states=predictions, + encoder_input_mask=predictions_mask, + decoder_input_ids=input_ids, + return_hypotheses=False, + ) + else: + raise TypeError(f"WER metric does not support decoding of type {type(self.decoding)}") + + self.add_state("scores_labelled", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words_labelled", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("scores_unlabelled", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("words_unlabelled", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("correct_label_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + + self.labelled_manifest = labelled_manifest + + def get_words_and_scores(self, hypotheses: List[str], references: List[str], labelled_data: str): + words = 0 + scores = 0 + + for h, r in zip(hypotheses, references): + if self.use_cer: + h_list = list(h) + r_list = list(r) + else: + h_list = h.split() + r_list = r.split() + words += len(r_list) + # Compute Levenstein's distance + scores += editdistance.eval(h_list, r_list) + + if labelled_data: + self.scores_labelled = torch.tensor(scores, device=self.scores_labelled.device, dtype=self.scores_labelled.dtype) + self.words_labelled = torch.tensor(words, device=self.words_labelled.device, dtype=self.words_labelled.dtype) + else: + self.scores_unlabelled = torch.tensor(scores, device=self.scores_unlabelled.device, dtype=self.scores_unlabelled.dtype) + self.words_unlabelled = torch.tensor(words, device=self.words_unlabelled.device, dtype=self.words_unlabelled.dtype) + + def seperate_labels_from_labelled_data(self, hypotheses: List[str], references: List[str]): + # labels are in the text of form ...text..., note it is not only and but can be any number of tag marked by <> + unlabelled_hypotheses = [] + unlabelled_references = [] + labels_hypotheses = [] + labels_references = [] + correct_label_count = 0 + + for h, r in zip(hypotheses, references): + # identify the tags with <> + # h_tags = [h[i:j+1] for i in range(len(h)) for j in range(i, len(h)) if h[i] == '<' and h[j] == '>'] + # r_tags = [r[i:j+1] for i in range(len(r)) for j in range(i, len(r)) if r[i] == '<' and r[j] == '>'] + r_tags = re.findall(r'', r) + h_tags = re.findall(r'', h) + # assert len(r_tags) == 2, f"Reference tags are not 2, they are {r_tags} for {r}" # Note we are only considering for single label. + # Above assert doesnt apply when ps audio is of 15 seconds but words are in only first 4 seconds and noise occurs from 8 to 10 secs. + # Replace all tags in the hypothesis and reference + unlabelled_h = h + unlabelled_r = r + for tag in r_tags: + unlabelled_h = unlabelled_h.replace(tag, '') + unlabelled_r = unlabelled_r.replace(tag, '') + + unlabelled_hypotheses.append(unlabelled_h) + unlabelled_references.append(unlabelled_r) + labels_hypotheses.append(h_tags) + # FOR IT1 + # if len(r_tags) == 2: + # labels_references.append(r_tags[0]) + # else: + # labels_references.append([]) + # if len(h_tags) == 2 and len(r_tags) == 2 and h_tags[0] == r_tags[0]: + # correct_label_count += 1 + + # FOR IT2 + if len(r_tags) == 1: + labels_references.append(r_tags[0]) + else: + labels_references.append([]) + if len(h_tags) == 1 and len(r_tags) == 1 and h_tags[0] == r_tags[0]: + correct_label_count += 1 + + return unlabelled_hypotheses, unlabelled_references, labels_hypotheses, labels_references, correct_label_count + + + def update( + self, + predictions: torch.Tensor, + predictions_lengths: torch.Tensor, + targets: torch.Tensor, + targets_lengths: torch.Tensor, + predictions_mask: Optional[torch.Tensor] = None, + input_ids: Optional[torch.Tensor] = None, + ): + """ + Updates metric state. + Args: + predictions: an integer torch.Tensor of shape ``[Batch, Time, {Vocabulary}]`` (if ``batch_dim_index == 0``) or + ``[Time, Batch]`` (if ``batch_dim_index == 1``) + prediction_lengths: an integer torch.Tensor of shape ``[Batch]`` + targets: an integer torch.Tensor of shape ``[Batch, Time]`` (if ``batch_dim_index == 0``) or + ``[Time, Batch]`` (if ``batch_dim_index == 1``) + target_lengths: an integer torch. + predictions_lengths: an integer torch.Tensor of shape ``[Batch]`` + """ + references = [] + with torch.no_grad(): + tgt_lenths_cpu_tensor = targets_lengths.long().cpu() + targets_cpu_tensor = targets.long().cpu() + # check batch_dim_index is first dim + if self.batch_dim_index != 0: + targets_cpu_tensor = move_dimension_to_the_front(targets_cpu_tensor, self.batch_dim_index) + # iterate over batch + for ind in range(targets_cpu_tensor.shape[0]): + tgt_len = tgt_lenths_cpu_tensor[ind].item() + target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist() + reference = self.decoding.decode_tokens_to_str(target) + references.append(reference) + hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets) + + if self.has_spl_tokens: + hypotheses = [self.decoding.strip_special_tokens(hyp) for hyp in hypotheses] + references = [self.decoding.strip_special_tokens(ref) for ref in references] + + if self.log_prediction: + logging.info(f"\n") + logging.info(f"reference:{references[0]}") + logging.info(f"predicted:{hypotheses[0]}") + logging.info(f"\n") + + unlabelled_hypotheses, unlabelled_references, labels_hypotheses, labels_references, correct_label_count = self.seperate_labels_from_labelled_data(hypotheses, references) + self.get_words_and_scores(unlabelled_hypotheses, unlabelled_references, labelled_data=False) + self.get_words_and_scores(hypotheses, references, labelled_data=True) + self.correct_label_count = torch.tensor(correct_label_count, device=self.correct_label_count.device, dtype=self.correct_label_count.dtype) + self.num_samples = torch.tensor(len(references), device=self.num_samples.device, dtype=self.num_samples.dtype) + + + def compute(self): + if self.labelled_manifest: + scores_labelled = self.scores_labelled.detach().float() + words_labelled = self.words_labelled.detach().float() + labelled_wer = scores_labelled / words_labelled + else: + scores_labelled = None + words_labelled = None + labelled_wer = None + scores_unlabelled = self.scores_unlabelled.detach().float() + words_unlabelled = self.words_unlabelled.detach().float() + unlabelled_wer = scores_unlabelled / words_unlabelled + correct_label_count = self.correct_label_count.detach().float() + num_samples = self.num_samples.detach().float() + + return labelled_wer, unlabelled_wer, correct_label_count/num_samples, scores_unlabelled, words_unlabelled diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 23c759afc80d..4a8772de2c05 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -23,6 +23,8 @@ from nemo.collections.asr.models.clustering_diarizer import ClusteringDiarizer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.av_ctc_models import AV_EncDecCTCModel +from nemo.collections.asr.models.av_ctc_bpe_models import AV_EncDecCTCModelBPE from nemo.collections.asr.models.enhancement_models import ( EncMaskDecAudioToAudioModel, PredictiveAudioToAudioModel, diff --git a/nemo/collections/asr/models/av_ctc_bpe_models.py b/nemo/collections/asr/models/av_ctc_bpe_models.py new file mode 100644 index 000000000000..cb2b59949e75 --- /dev/null +++ b/nemo/collections/asr/models/av_ctc_bpe_models.py @@ -0,0 +1,349 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import os +from typing import Dict, List, Optional, Union + +import torch +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.audio_to_text import _AudioTextDataset +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.av_wer import AV_WER +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.av_ctc_models import AV_EncDecCTCModel +from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.parts.mixins import ASRBPEMixin +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig +from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes.common import PretrainedModelInfo +from nemo.utils import logging, model_utils + +__all__ = ['AV_EncDecCTCModelBPE'] + + +class AV_EncDecCTCModelBPE(AV_EncDecCTCModel, ASRBPEMixin): + """Encoder decoder CTC-based models with Byte Pair Encoding.""" + + def __init__(self, cfg: DictConfig, trainer=None): + # Convert to Hydra 1.0 compatible DictConfig + cfg = model_utils.convert_model_config_to_dict_config(cfg) + cfg = model_utils.maybe_update_config_version(cfg) + + if 'tokenizer' not in cfg: + raise ValueError("`cfg` must have `tokenizer` config to create a tokenizer !") + + # Setup the tokenizer + self._setup_tokenizer(cfg.tokenizer) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + self.labelled_manifest = cfg.labelled_manifest + # Set the new vocabulary + with open_dict(cfg): + # sidestepping the potential overlapping tokens issue in aggregate tokenizers + if self.tokenizer_type == "agg": + cfg.decoder.vocabulary = ListConfig(vocabulary) + else: + cfg.decoder.vocabulary = ListConfig(list(vocabulary.keys())) + + # Override number of classes if placeholder provided + num_classes = cfg.decoder["num_classes"] + + if num_classes < 1: + logging.info( + "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format( + num_classes, len(vocabulary) + ) + ) + cfg.decoder["num_classes"] = len(vocabulary) + + super().__init__(cfg=cfg, trainer=trainer) + + + # Setup decoding objects + decoding_cfg = self.cfg.get('decoding', None) + + # In case decoding config not found, use default config + if decoding_cfg is None: + decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig) + with open_dict(self.cfg): + self.cfg.decoding = decoding_cfg + + self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) + + # Setup metric with decoding strategy + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + labelled_manifest=self.labelled_manifest, + ) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + + dataset = audio_to_text_dataset.get_av_to_text_bpe_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + tokenizer=self.tokenizer, + preprocessor_cfg=self.cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + + shuffle = config['shuffle'] + if isinstance(dataset, torch.utils.data.IterableDataset): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + batch_sampler = None + if config.get('use_semi_sorted_batching', False): + if not isinstance(dataset, _AudioTextDataset): + raise RuntimeError( + "Semi Sorted Batch sampler can be used with AudioToCharDataset or AudioToBPEDataset " + f"but found dataset of type {type(dataset)}" + ) + # set batch_size and batch_sampler to None to disable automatic batching + batch_sampler = get_semi_sorted_batch_sampler(self, dataset, config) + config['batch_size'] = None + config['drop_last'] = False + shuffle = False + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + sampler=batch_sampler, + batch_sampler=None, + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + num_workers: (int) number of workers. Depends of the batch_size and machine. \ + 0 - only the main process will load batches, 1 - one worker (not main process) + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.preprocessor._sample_rate, + 'batch_size': batch_size, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'channel_selector': config.get('channel_selector', None), + 'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False), + } + + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + def change_vocabulary( + self, + new_tokenizer_dir: Union[str, DictConfig], + new_tokenizer_type: str, + decoding_cfg: Optional[DictConfig] = None, + ): + """ + Changes vocabulary of the tokenizer used during CTC decoding process. + Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + Args: + new_tokenizer_dir: Directory path to tokenizer or a config for a new tokenizer (if the tokenizer type is `agg`) + new_tokenizer_type: Either `agg`, `bpe` or `wpe`. `bpe` is used for SentencePiece tokenizers, + whereas `wpe` is used for `BertTokenizer`. + new_tokenizer_cfg: A config for the new tokenizer. if provided, pre-empts the dir and type + + Returns: None + + """ + if isinstance(new_tokenizer_dir, DictConfig): + if new_tokenizer_type == 'agg': + new_tokenizer_cfg = new_tokenizer_dir + else: + raise ValueError( + f'New tokenizer dir should be a string unless the tokenizer is `agg`, but this tokenizer type is: {new_tokenizer_type}' + ) + else: + new_tokenizer_cfg = None + + if new_tokenizer_cfg is not None: + tokenizer_cfg = new_tokenizer_cfg + else: + if not os.path.isdir(new_tokenizer_dir): + raise NotADirectoryError( + f'New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}' + f"New tokenizer dir must be non-empty path to a directory. But I got: {new_tokenizer_dir}" + ) + + if new_tokenizer_type.lower() not in ('bpe', 'wpe'): + raise ValueError(f'New tokenizer type must be either `bpe` or `wpe`') + + tokenizer_cfg = OmegaConf.create({'dir': new_tokenizer_dir, 'type': new_tokenizer_type}) + + # Setup the tokenizer + self._setup_tokenizer(tokenizer_cfg) + + # Initialize a dummy vocabulary + vocabulary = self.tokenizer.tokenizer.get_vocab() + + # Set the new vocabulary + decoder_config = copy.deepcopy(self.decoder.to_config_dict()) + # sidestepping the potential overlapping tokens issue in aggregate tokenizers + if self.tokenizer_type == "agg": + decoder_config.vocabulary = ListConfig(vocabulary) + else: + decoder_config.vocabulary = ListConfig(list(vocabulary.keys())) + + decoder_num_classes = decoder_config['num_classes'] + + # Override number of classes if placeholder provided + logging.info( + "\nReplacing old number of classes ({}) with new number of classes - {}".format( + decoder_num_classes, len(vocabulary) + ) + ) + + decoder_config['num_classes'] = len(vocabulary) + + del self.decoder + self.decoder = EncDecCTCModelBPE.from_config_dict(decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self.ce_loss = torch.nn.CrossEntropyLoss() + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer) + + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self._cfg.get('use_cer', False), + log_prediction=self._cfg.get("log_prediction", False), + dist_sync_on_step=True, + labelled_manifest=self.labelled_manifest, + ) + + # Update config + with open_dict(self.cfg.decoder): + self._cfg.decoder = decoder_config + + with open_dict(self.cfg.decoding): + self._cfg.decoding = decoding_cfg + + logging.info(f"Changed tokenizer to {self.decoder.vocabulary} vocabulary.") + + def change_decoding_strategy(self, decoding_cfg: DictConfig): + """ + Changes decoding strategy used during CTC decoding process. + + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + """ + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCBPEDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = CTCBPEDecoding(decoding_cfg=decoding_cfg, tokenizer=self.tokenizer,) + + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + labelled_manifest=self.labelled_manifest, + ) + + self.decoder.temperature = decoding_cfg.get('temperature', 1.0) + + # Update config + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + return results diff --git a/nemo/collections/asr/models/av_ctc_models.py b/nemo/collections/asr/models/av_ctc_models.py new file mode 100644 index 000000000000..521b2252924b --- /dev/null +++ b/nemo/collections/asr/models/av_ctc_models.py @@ -0,0 +1,897 @@ +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +import os +import tempfile +from math import ceil +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning import Trainer +from tqdm.auto import tqdm + +from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data.av_to_text import _AVTextDataset +# from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs +# from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset +from nemo.collections.asr.models.ctc_models import EncDecCTCModel +from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE +from nemo.collections.asr.losses.ctc import CTCLoss +from nemo.collections.asr.metrics.av_wer import AV_WER +from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel +from nemo.collections.asr.parts.mixins import ASRModuleMixin, ASRTranscriptionMixin, InterCTCMixin, TranscribeConfig +from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType, TranscriptionReturnType +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler +from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType +# from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.parts.preprocessing.parsers import make_parser +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from nemo.core.classes.mixins import AccessMixin +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, LogprobsType, NeuralType, SpectrogramType, ImageFeatureValue +from nemo.utils import logging + +#ADAPTERS +from nemo.core import adapter_mixins +from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import MultiHeadAttentionAdapterConfig +from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import RelPositionMultiHeadAttentionAdapterConfig +__all__ = ['AV_EncDecCTCModel'] + + +class AV_EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin, InterCTCMixin, ASRTranscriptionMixin): + """Base class for encoder decoder CTC-based models.""" + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable + # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 + self.world_size = 1 + if trainer is not None: + self.world_size = trainer.world_size + + super().__init__(cfg=cfg, trainer=trainer) + if "BPE:" in cfg.a_model_name: + a_model_cfg = EncDecCTCModelBPE.from_pretrained(cfg.a_model_name[4:], return_config=True) + a_model_cfg = self.update_model_config_to_support_adapter(a_model_cfg) # for adapters + self.a_model = EncDecCTCModelBPE.from_pretrained(cfg.a_model_name[4:], override_config_path=a_model_cfg) + else: + a_model_cfg = EncDecCTCModel.from_pretrained(cfg.a_model_name, return_config=True) + a_model_cfg = self.update_model_config_to_support_adapter(a_model_cfg) + self.a_model = EncDecCTCModel.from_pretrained(cfg.a_model_name, override_config_path=a_model_cfg) + + self.labelled_manifest = cfg.labelled_manifest + + + if cfg.adapters.linear_adapter.keep: + linear_adapter_cfg = LinearAdapterConfig( + in_features=self.a_model.encoder.d_model, + dim = cfg.adapters.linear_adapter.dim, + activation=cfg.adapters.linear_adapter.activation, + norm_position=cfg.adapters.linear_adapter.norm_position, + dropout=cfg.adapters.linear_adapter.dropout, + ) + linear_adapter_name = cfg.adapters.linear_adapter.name + self.a_model.add_adapter(name=linear_adapter_name, cfg=linear_adapter_cfg) + with open_dict(self._cfg): + if "feat_in" not in self._cfg.decoder or ( + not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out') + ): + self._cfg.decoder.feat_in = self.encoder._feat_out + if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in: + raise ValueError("param feat_in of the decoder's config is not set!") + + if self.cfg.decoder.num_classes < 1 and self.cfg.decoder.vocabulary is not None: + logging.info( + "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format( + self.cfg.decoder.num_classes, len(self.cfg.decoder.vocabulary) + ) + ) + cfg.decoder["num_classes"] = len(self.cfg.decoder.vocabulary) + assert not (self.cfg.use_pretrained_dec and self.cfg.use_video_modality), "Pretrained decoder is not supported for video modality" + + # initialize a transformer encoder and decoder + if cfg.use_video_modality: + self.a_linear = torch.nn.Linear(in_features = self.a_model.encoder._feat_out, out_features = self.cfg.av_encoder.d_model) + self.v_linear = torch.nn.Linear(in_features = self.cfg.v_model.feat_dim, out_features = self.cfg.av_encoder.d_model) + self.av_enocder_layer = torch.nn.TransformerEncoderLayer(d_model = self.cfg.av_encoder.d_model, nhead = self.cfg.av_encoder.nhead, dropout = self.cfg.av_encoder.dropout, batch_first=True) + self.av_encoder = torch.nn.TransformerEncoder(self.av_enocder_layer, num_layers = self.cfg.av_encoder.num_layers) + if cfg.label_pred_head.keep: + self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.cfg.av_encoder.d_model)) + self.label_predictor = torch.nn.Linear(self.cfg.av_encoder.d_model, self.cfg.label_pred_head.num_classes) + self.ce_loss = torch.nn.CrossEntropyLoss() + # Modality embeddings + self.a_modal_embs = torch.nn.Embedding(1, self.cfg.av_encoder.d_model) + self.v_modal_embs = torch.nn.Embedding(1, self.cfg.av_encoder.d_model) + + # Trainable positional encodings + self.a_pos_enc = torch.nn.Embedding(10000, self.cfg.av_encoder.d_model) + self.v_pos_enc = torch.nn.Embedding(10000, self.cfg.av_encoder.d_model) + + + self.decoder = EncDecCTCModel.from_config_dict(self._cfg.decoder) + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + + # Setup decoding objects + decoding_cfg = self.cfg.get('decoding', None) + + # In case decoding config not found, use default config + if decoding_cfg is None: + decoding_cfg = OmegaConf.structured(CTCDecodingConfig) + with open_dict(self.cfg): + self.cfg.decoding = decoding_cfg + + self.decoding = CTCDecoding(self.cfg.decoding, vocabulary=OmegaConf.to_container(self.decoder.vocabulary)) + + # Setup metric with decoding strategy + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + labelled_manifest=self.labelled_manifest + ) + + # Setup optional Optimization flags + self.setup_optimization_flags() + + # setting up interCTC loss (from InterCTCMixin) + self.setup_interctc(decoder_name='decoder', loss_name='loss', wer_name='wer') + + def update_model_config_to_support_adapter(self, model_cfg): + with open_dict(model_cfg): + adapter_metadata = adapter_mixins.get_registered_adapter(model_cfg.encoder._target_) + if adapter_metadata is not None: + model_cfg.encoder._target_ = adapter_metadata.adapter_class_path + + print("Updated encoder _target_ model :", model_cfg.encoder._target_) + return model_cfg + + def transcribe( + self, + audio: Union[str, List[str], torch.Tensor, np.ndarray], + batch_size: int = 4, + return_hypotheses: bool = False, + num_workers: int = 0, + channel_selector: Optional[ChannelSelectorType] = None, + augmentor: DictConfig = None, + verbose: bool = True, + override_config: Optional[TranscribeConfig] = None, + ) -> TranscriptionReturnType: + """ + If modify this function, please remember update transcribe_partial_audio() in + nemo/collections/asr/parts/utils/trancribe_utils.py + + Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. + + Args: + audio: (a single or list) of paths to audio files or a np.ndarray audio array. \ + Recommended length per file is between 5 and 25 seconds. \ + But it is possible to pass a few hours long file if enough GPU memory is available. + batch_size: (int) batch size to use during inference. + Bigger will result in better throughput performance but would use more memory. + return_hypotheses: (bool) Either return hypotheses or text + With hypotheses can do some postprocessing like getting timestamp or rescoring + num_workers: (int) number of workers for DataLoader + channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. + augmentor: (DictConfig): Augment audio samples during transcription if augmentor is applied. + verbose: (bool) whether to display tqdm progress bar + override_config: (Optional[TranscribeConfig]) override transcription config pre-defined by the user. + **Note**: All other arguments in the function will be ignored if override_config is passed. + You should call this argument as `model.transcribe(audio, override_config=TranscribeConfig(...))`. + + Returns: + A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files + """ + return super().transcribe( + audio=audio, + batch_size=batch_size, + return_hypotheses=return_hypotheses, + num_workers=num_workers, + channel_selector=channel_selector, + augmentor=augmentor, + verbose=verbose, + override_config=override_config, + ) + + def change_vocabulary(self, new_vocabulary: List[str], decoding_cfg: Optional[DictConfig] = None): + """ + Changes vocabulary used during CTC decoding process. Use this method when fine-tuning on from pre-trained model. + This method changes only decoder and leaves encoder and pre-processing modules unchanged. For example, you would + use it if you want to use pretrained encoder when fine-tuning on a data in another language, or when you'd need + model to learn capitalization, punctuation and/or special characters. + + If new_vocabulary == self.decoder.vocabulary then nothing will be changed. + + Args: + + new_vocabulary: list with new vocabulary. Must contain at least 2 elements. Typically, \ + this is target alphabet. + + Returns: None + + """ + if self.decoder.vocabulary == new_vocabulary: + logging.warning(f"Old {self.decoder.vocabulary} and new {new_vocabulary} match. Not changing anything.") + else: + if new_vocabulary is None or len(new_vocabulary) == 0: + raise ValueError(f'New vocabulary must be non-empty list of chars. But I got: {new_vocabulary}') + decoder_config = self.decoder.to_config_dict() + new_decoder_config = copy.deepcopy(decoder_config) + new_decoder_config['vocabulary'] = new_vocabulary + new_decoder_config['num_classes'] = len(new_vocabulary) + + del self.decoder + self.decoder = EncDecCTCModel.from_config_dict(new_decoder_config) + del self.loss + self.loss = CTCLoss( + num_classes=self.decoder.num_classes_with_blank - 1, + zero_infinity=True, + reduction=self._cfg.get("ctc_reduction", "mean_batch"), + ) + self.ce_loss = torch.nn.CrossEntropyLoss() + + if decoding_cfg is None: + # Assume same decoding config as before + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = CTCDecoding( + decoding_cfg=decoding_cfg, vocabulary=OmegaConf.to_container(self.decoder.vocabulary) + ) + + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self._cfg.get('use_cer', False), + dist_sync_on_step=True, + log_prediction=self._cfg.get("log_prediction", False), + labelled_manifest=self.labelled_manifest + ) + + # Update config + with open_dict(self.cfg.decoder): + self._cfg.decoder = new_decoder_config + + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + ds_keys = ['train_ds', 'validation_ds', 'test_ds'] + for key in ds_keys: + if key in self.cfg: + with open_dict(self.cfg[key]): + self.cfg[key]['labels'] = OmegaConf.create(new_vocabulary) + + logging.info(f"Changed decoder to output to {self.decoder.vocabulary} vocabulary.") + + def change_decoding_strategy(self, decoding_cfg: DictConfig): + """ + Changes decoding strategy used during CTC decoding process. + + Args: + decoding_cfg: A config for the decoder, which is optional. If the decoding type + needs to be changed (from say Greedy to Beam decoding etc), the config can be passed here. + """ + if decoding_cfg is None: + # Assume same decoding config as before + logging.info("No `decoding_cfg` passed when changing decoding strategy, using internal config") + decoding_cfg = self.cfg.decoding + + # Assert the decoding config with all hyper parameters + decoding_cls = OmegaConf.structured(CTCDecodingConfig) + decoding_cls = OmegaConf.create(OmegaConf.to_container(decoding_cls)) + decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) + + self.decoding = CTCDecoding( + decoding_cfg=decoding_cfg, vocabulary=OmegaConf.to_container(self.decoder.vocabulary) + ) + + self.wer = AV_WER( + decoding=self.decoding, + use_cer=self.wer.use_cer, + log_prediction=self.wer.log_prediction, + dist_sync_on_step=True, + labelled_manifest=self.labelled_manifest + ) + + self.decoder.temperature = decoding_cfg.get('temperature', 1.0) + + # Update config + with open_dict(self.cfg.decoding): + self.cfg.decoding = decoding_cfg + + logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}") + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + # Automatically inject args from model config to dataloader config + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels') + + dataset = audio_to_text_dataset.get_av_to_text_char_dataset_from_config( + config=config, + local_rank=self.local_rank, + global_rank=self.global_rank, + world_size=self.world_size, + preprocessor_cfg=self._cfg.get("preprocessor", None), + ) + + if dataset is None: + return None + + shuffle = config['shuffle'] + if isinstance(dataset, torch.utils.data.IterableDataset): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + batch_sampler = None + if config.get('use_semi_sorted_batching', False): # This is in usable format even for our + if not isinstance(dataset, _AVTextDataset): + raise RuntimeError( + "Semi Sorted Batch sampler can be used with AudioToCharDataset or AudioToBPEDataset " + f"but found dataset of type {type(dataset)}" + ) + # set batch_size and batch_sampler to None to disable automatic batching + batch_sampler = get_semi_sorted_batch_sampler(self, dataset, config) + config['batch_size'] = None + config['drop_last'] = False + shuffle = False + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + sampler=batch_sampler, + batch_sampler=None, + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the training data loader via a Dict-like object. + + Args: + train_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in train_data_config: + train_data_config['shuffle'] = True + + # preserve config + self._update_dataset_config(dataset_name='train', config=train_data_config) + + self._train_dl = self._setup_dataloader_from_config(config=train_data_config) + + # Need to set this because if using an IterableDataset, the length of the dataloader is the total number + # of samples rather than the number of batches, and this messes up the tqdm progress bar. + # So we set the number of steps manually (to the correct number) to fix this. + if ( + self._train_dl is not None + and hasattr(self._train_dl, 'dataset') + and isinstance(self._train_dl.dataset, torch.utils.data.IterableDataset) + ): + # We also need to check if limit_train_batches is already set. + # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, + # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): + self._trainer.limit_train_batches = int( + self._trainer.limit_train_batches + * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) + ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) + + def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the validation data loader via a Dict-like object. + + Args: + val_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in val_data_config: + val_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + """ + Sets up the test data loader via a Dict-like object. + + Args: + test_data_config: A config that contains the information regarding construction + of an ASR Training dataset. + + Supported Datasets: + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.AudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToCharDataset` + - :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset` + - :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset` + """ + if 'shuffle' not in test_data_config: + test_data_config['shuffle'] = False + + # preserve config + self._update_dataset_config(dataset_name='test', config=test_data_config) + + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.a_model.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.a_model.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "audio_input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "audio_input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "video_input_signal": NeuralType(('B', 'T', 'D'), ImageFeatureValue(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "sample_id": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + return { + "outputs": NeuralType(('B', 'T', 'D'), LogprobsType()), + "label_log_probs": NeuralType(('B', 'C'), LogprobsType(), optional=True), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "greedy_predictions": NeuralType(('B', 'T'), LabelsType()), + } + + @typecheck() + def forward( + self, audio_input_signal=None, audio_input_signal_length=None, video_input_signal= None, processed_signal=None, processed_signal_length=None + ): + """ + Forward pass of the model. + + Args: + audio_input: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + audio_input_signal_length: Vector of length B, that contains the individual lengths of the audio + sequences. + processed_signal: Tensor that represents a batch of processed audio signals, + of shape (B, D, T) that has undergone processing via some DALI preprocessor. + processed_signal_length: Vector of length B, that contains the individual lengths of the + processed audio sequences. + + Returns: + A tuple of 3 elements - + 1) The log probabilities tensor of shape [B, T, D]. + 2) The lengths of the acoustic sequence after propagation through the encoder, of shape [B]. + 3) The greedy token predictions of the model of shape [B, T] (via argmax) + """ + has_input_signal = audio_input_signal is not None and audio_input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``audio_input`` and ``audio_input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.a_model.preprocessor( + input_signal=audio_input_signal, length=audio_input_signal_length, + ) + + if self.a_model.spec_augmentation is not None and self.training: + processed_signal = self.a_model.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoder_output = self.a_model.encoder(audio_signal=processed_signal, length=processed_signal_length) + encoded = encoder_output[0] + encoded_len = encoder_output[1] + if self.cfg.use_video_modality and not self.cfg.use_pretrained_dec: + # B,C,T -> B,T,C + encoded = encoded.permute(0, 2, 1) + a_encoded = self.a_linear(encoded) + v_encoded = self.v_linear(video_input_signal) + + # Add modality embeddings + B, T, C = a_encoded.size() + B, F, D = v_encoded.size() + assert C == D, "The audio and video features must have the same dimensionality" + + # Expand modality embeddings to match the dimensions of a_encoded and v_encoded + a_modal_emb_expanded = self.a_modal_embs.weight.expand(B, T, -1) # Shape: (B, T, feat_in) + v_modal_emb_expanded = self.v_modal_embs.weight.expand(B, F, -1) # Shape: (B, F, feat_in) + + a_encoded = a_encoded + a_modal_emb_expanded + v_encoded = v_encoded + v_modal_emb_expanded + + # Add positional encodings + a_pos_enc = self.a_pos_enc(torch.arange(T, device=a_encoded.device)).unsqueeze(0).expand(B, -1, -1) + v_pos_enc = self.v_pos_enc(torch.arange(F, device=v_encoded.device)).unsqueeze(0).expand(B, -1, -1) + + a_encoded = a_encoded + a_pos_enc + v_encoded = v_encoded + v_pos_enc + + # Concat and pass them through the transformer encoder + av_encoded = torch.cat((a_encoded, v_encoded), dim=1) + if self.cfg.label_pred_head.keep: + cls_token = self.cls_token.expand(encoded.size(0), -1, -1) # Expanding to batch size + av_encoded = torch.cat((cls_token, av_encoded), dim=1) # Concatenating classifier token + av_encoded = self.av_encoder(av_encoded) + + if self.cfg.label_pred_head.keep: + # remove the v_encoded tokens + av_encoded = av_encoded[:, :T, :] + + # B,T,C -> B,C,T + av_encoded = av_encoded.permute(0, 2, 1) + + if self.cfg.label_pred_head.keep: + # Predicting labels using the classifier token's output + cls_output = av_encoded[:, :, 0] # First token after encoding is the classifier token + label_log_probs = self.label_predictor(cls_output) + else: + label_log_probs = None + + log_probs = self.decoder(encoder_output=av_encoded) + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + elif (not self.cfg.use_video_modality) and (not self.cfg.use_pretrained_dec): + log_probs = self.decoder(encoder_output=encoded) + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + label_log_probs = None + elif (not self.cfg.use_video_modality) and self.cfg.use_pretrained_dec: + log_probs = self.a_model.decoder(encoder_output=encoded) + greedy_predictions = log_probs.argmax(dim=-1, keepdim=False) + label_log_probs = None + elif self.cfg.use_video_modality and self.cfg.use_pretrained_dec: + raise ValueError("Pretrained decoder is not supported for video modality") + + + return ( + log_probs, + label_log_probs, + encoded_len, + greedy_predictions, + ) + + # PTL-specific methods + def training_step(self, batch, batch_nb): + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + if self.is_interctc_enabled(): + AccessMixin.set_access_enabled(access_enabled=True, guid=self.model_guid) + + signal, signal_len, video_input_signal, transcript, transcript_len, label = batch + # if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + # log_probs, encoded_len, predictions = self.forward( + # processed_signal=signal, processed_signal_length=signal_len + # ) + # else: + log_probs, label_log_probs, encoded_len, predictions = self.forward(audio_input_signal=signal, audio_input_signal_length=signal_len, video_input_signal=video_input_signal) + + if hasattr(self, '_trainer') and self._trainer is not None: + log_every_n_steps = self._trainer.log_every_n_steps + else: + log_every_n_steps = 1 + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + if self.cfg.label_pred_head.keep: + label_loss_value = self.ce_loss(label_log_probs, label) + # Add auxiliary losses, if registered + loss_value = self.add_auxiliary_losses(loss_value+label_loss_value) + else: + label_loss_value = 0.0 + loss_value = self.add_auxiliary_losses(loss_value) + + # only computing WER when requested in the logs (same as done for final-layer WER below) + loss_value, tensorboard_logs = self.add_interctc_losses( + loss_value, transcript, transcript_len, compute_wer=((batch_nb + 1) % log_every_n_steps == 0) + ) + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + if self.cfg.label_pred_head.keep: + tensorboard_logs.update({'train_label_loss': label_loss_value}) + + tensorboard_logs.update( + { + 'train_loss': loss_value, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': torch.tensor(self.trainer.global_step, dtype=torch.float32), + } + ) + + if (batch_nb + 1) % log_every_n_steps == 0: + self.wer.update( + predictions=log_probs, + targets=transcript, + targets_lengths=transcript_len, + predictions_lengths=encoded_len, + ) + # wer, _, _ = self.wer.compute() + labelled_wer, unlabelled_wer, acc, scores_unlabelled, words_unlabelled = self.wer.compute() + self.wer.reset() + # tensorboard_logs.update({'training_batch_l_wer': labelled_wer, + # 'training_batch_u_wer': unlabelled_wer, + # 'training_batch_l_acc': acc, + # }) + if labelled_wer is not None: + tensorboard_logs.update({'train_l_wer': labelled_wer}) + self.log('train_l_wer', labelled_wer, on_step=True, on_epoch=False) + if unlabelled_wer is not None: + tensorboard_logs.update({'train_u_wer': unlabelled_wer}) + self.log('train_u_wer', unlabelled_wer, on_step=True, on_epoch=False) + if acc is not None: + tensorboard_logs.update({'train_acc': acc}) + self.log('train_acc', acc, on_step=True, on_epoch=False) + + return {'loss': loss_value+label_loss_value, 'log': tensorboard_logs} + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + signal, signal_len, video_input_signal, transcript, transcript_len, label, sample_id = batch + # if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + # log_probs, encoded_len, predictions = self.forward( + # processed_signal=signal, processed_signal_length=signal_len + # ) + # else: + log_probs, label_loss_value, encoded_len, predictions = self.forward(audio_input_signal=signal, audio_input_signal_length=signal_len, video_input_signal=video_input_signal) + + transcribed_texts, _ = self.wer.decoding.ctc_decoder_predictions_tensor( + decoder_outputs=log_probs, decoder_lengths=encoded_len, return_hypotheses=False, + ) + + sample_id = sample_id.cpu().detach().numpy() + return list(zip(sample_id, transcribed_texts)) + + def validation_pass(self, batch, batch_idx, dataloader_idx=0): + if self.is_interctc_enabled(): + AccessMixin.set_access_enabled(access_enabled=True, guid=self.model_guid) + + signal, signal_len, video_input_signal, transcript, transcript_len, label = batch + # if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + # log_probs, encoded_len, predictions = self.forward( + # processed_signal=signal, processed_signal_length=signal_len + # ) + # else: + log_probs, label_log_probs, encoded_len, predictions = self.forward(audio_input_signal=signal, audio_input_signal_length=signal_len, video_input_signal=video_input_signal) + + loss_value = self.loss( + log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len + ) + if self.cfg.label_pred_head.keep: + label_loss_value = self.ce_loss(label_log_probs, label) + else: + label_loss_value = 0.0 + loss_value, metrics = self.add_interctc_losses( + loss_value, transcript, transcript_len, compute_wer=True, log_wer_num_denom=True, log_prefix="val_", + ) + + self.wer.update( + predictions=log_probs, targets=transcript, targets_lengths=transcript_len, predictions_lengths=encoded_len, + ) + # wer, wer_num, wer_denom = self.wer.compute() + labelled_wer, unlabelled_wer, acc, scores_unlabelled, words_unlabelled = self.wer.compute() + self.wer.reset() + metrics.update({'val_loss': loss_value, 'val_label_loss': label_loss_value, + 'val_labelled_wer': labelled_wer, 'val_unlabelled_wer': unlabelled_wer, 'val_acc': acc, 'val_wer_num': scores_unlabelled, 'val_wer_denom': words_unlabelled}) + + # self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + if labelled_wer is not None: + self.log('val_l_wer', labelled_wer, on_epoch=True, sync_dist=True) + if unlabelled_wer is not None: + self.log('val_u_wer', unlabelled_wer, on_epoch=True, sync_dist=True) + if acc is not None: + self.log('val_acc', acc, on_epoch=True, sync_dist=True) + self.log('val_loss', loss_value, sync_dist=True) + if self.cfg.label_pred_head.keep: + self.log('val_label_loss', label_loss_value, sync_dist=True) + + + # Reset access registry + if AccessMixin.is_access_enabled(self.model_guid): + AccessMixin.reset_registry(self) + + return metrics + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + metrics = self.validation_pass(batch, batch_idx, dataloader_idx) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(metrics) + else: + self.validation_step_outputs.append(metrics) + return metrics + + def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + metrics = super().multi_validation_epoch_end(outputs, dataloader_idx) + self.finalize_interctc_metrics(metrics, outputs, prefix="val_") + return metrics + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + metrics = super().multi_test_epoch_end(outputs, dataloader_idx) + self.finalize_interctc_metrics(metrics, outputs, prefix="test_") + return metrics + + def test_step(self, batch, batch_idx, dataloader_idx=0): + logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) + test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} + if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(test_logs) + else: + self.test_step_outputs.append(test_logs) + return test_logs + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + + """ Transcription related methods """ + + def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): + super()._transcribe_on_begin(audio, trcfg) + + # Freeze the encoder and decoure_exder modules + self.encoder.freeze() + self.decoder.freeze() + + def _transcribe_on_end(self, trcfg: TranscribeConfig): + super()._transcribe_on_end(trcfg) + + # Unfreeze the encoder and decoder modules + self.encoder.unfreeze() + self.decoder.unfreeze() + + def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): + logits, logits_len, greedy_predictions = self.forward(audio_input=batch[0], audio_input_signal_length=batch[1], video_input_signal=batch[2]) + output = dict(logits=logits, logits_len=logits_len) + del greedy_predictions + return output + + def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> GenericTranscriptionType: + logits = outputs.pop('logits') + logits_len = outputs.pop('logits_len') + + current_hypotheses, all_hyp = self.decoding.ctc_decoder_predictions_tensor( + logits, decoder_lengths=logits_len, return_hypotheses=trcfg.return_hypotheses, + ) + if trcfg.return_hypotheses: + if logits.is_cuda: + # See comment in + # ctc_greedy_decoding.py::GreedyCTCInfer::forward() to + # understand this idiom. + logits_cpu = torch.empty(logits.shape, dtype=logits.dtype, device=torch.device("cpu"), pin_memory=True) + logits_cpu.copy_(logits, non_blocking=True) + else: + logits_cpu = logits + logits_len = logits_len.cpu() + # dump log probs per file + for idx in range(logits_cpu.shape[0]): + current_hypotheses[idx].y_sequence = logits_cpu[idx][: logits_len[idx]] + if current_hypotheses[idx].alignments is None: + current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence + del logits_cpu + + # cleanup memory + del logits, logits_len + + hypotheses = [] + if all_hyp is None: + hypotheses += current_hypotheses + else: + hypotheses += all_hyp + + return hypotheses + + def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLoader': + """ + Setup function for a temporary data loader which wraps the provided audio file. + + Args: + config: A python dictionary which contains the following keys: + paths2audio_files: (a list) of paths to audio files. The files should be relatively short fragments. \ + Recommended length per file is between 5 and 25 seconds. + batch_size: (int) batch size to use during inference. \ + Bigger will result in better throughput performance but would use more memory. + temp_dir: (str) A temporary directory where the audio manifest is temporarily + stored. + num_workers: (int) number of workers. Depends of the batch_size and machine. \ + 0 - only the main process will load batches, 1 - one worker (not main process) + + Returns: + A pytorch DataLoader for the given audio file(s). + """ + if 'manifest_filepath' in config: + manifest_filepath = config['manifest_filepath'] + batch_size = config['batch_size'] + else: + manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') + batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + + dl_config = { + 'manifest_filepath': manifest_filepath, + 'sample_rate': self.a_model.preprocessor._sample_rate, + 'labels': OmegaConf.to_container(self.decoder.vocabulary), + 'batch_size': batch_size, + 'trim_silence': False, + 'shuffle': False, + 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), + 'pin_memory': True, + 'channel_selector': config.get('channel_selector', None), + } + if config.get("augmentor"): + dl_config['augmentor'] = config.get("augmentor") + + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) + return temporary_datalayer + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + return results + + @property + def wer(self): + return self._wer + + @wer.setter + def wer(self, wer): + self._wer = wer diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 4df02b1177cd..3887a92da648 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -65,7 +65,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): not self._cfg.decoder.feat_in and hasattr(self.encoder, '_feat_out') ): self._cfg.decoder.feat_in = self.encoder._feat_out - if "feat_in" not in self._cfg.decoder or not self._cfg.decoder.feat_in: raise ValueError("param feat_in of the decoder's config is not set!") if self.cfg.decoder.num_classes < 1 and self.cfg.decoder.vocabulary is not None: diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 66def034400f..3b5f65228e8b 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -163,7 +163,8 @@ def __init__( # elif hasattr(parser, "lang") and parser.lang is not None: # text_tokens = parser(text, parser.lang) else: - raise ValueError("lang required in manifest when using aggregate tokenizers") + raise ValueError( + "lang required in manifest when using aggregate tokenizers") else: text_tokens = parser(text) else: @@ -176,7 +177,8 @@ def __init__( total_duration += duration - data.append(output_type(id_, audio_file, duration, text_tokens, offset, text, speaker, orig_sr, lang)) + data.append(output_type(id_, audio_file, duration, + text_tokens, offset, text, speaker, orig_sr, lang)) if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(audio_file)) if file_id not in self.mapping: @@ -189,12 +191,143 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) - logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) - logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + logging.info("Dataset loaded with %d files totalling %.2f hours", len( + data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", + num_filtered, duration_filtered / 3600) + + super().__init__(data) + + +class AVText(_Collection): + """List of audio-transcript text correspondence with preprocessing.""" + + AV_OUTPUT_TYPE = collections.namedtuple( + typename='AVTextEntity', + field_names='id audio_file video_file video_featfile duration text_tokens snr offset text_raw speaker orig_sr lang label', + ) + + def __init__( + self, + ids: List[int], + audio_files: List[str], + video_files: List[str], + video_featfiles: List[str], + durations: List[float], + snr_ratios: List[float], + texts: List[str], + offsets: List[str], + speakers: List[Optional[int]], + orig_sampling_rates: List[Optional[int]], + token_labels: List[Optional[int]], + langs: List[Optional[str]], + labels: List[Optional[str]], + parser: parsers.CharParser, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """Instantiates audio-text manifest with filters and preprocessing. + + Args: + ids: List of examples positions. + audio_files: List of audio files. + video_files: List of video files. + video_featfiles: List of video feature files. + durations: List of float durations. + texts: List of raw text transcripts. + snr_ratios: List of signal-to-noise ratios. + offsets: List of duration offsets or None. + speakers: List of optional speakers ids. + orig_sampling_rates: List of original sampling rates of audio files. + langs: List of language ids, one for eadh sample, or None. + parser: Instance of `CharParser` to convert string to tokens. + min_duration: Minimum duration to keep entry with (default: None). + max_duration: Maximum duration to keep entry with (default: None). + max_number: Maximum number of samples to collect. + do_sort_by_duration: True if sort samples list by duration. Not compatible with index_by_file_id. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data. + """ + + output_type = self.AV_OUTPUT_TYPE + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for id_, audio_file, video_file, video_featfile, duration, offset, text, snr_ratio, speaker, orig_sr, token_labels, lang, label in zip( + ids, audio_files, video_files, video_featfiles, durations, offsets, texts, snr_ratios, speakers, orig_sampling_rates, token_labels, langs, labels + ): + # Duration filters. + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if token_labels is not None: + text_tokens = token_labels + else: + if text != '': + if hasattr(parser, "is_aggregate") and parser.is_aggregate and isinstance(text, str): + if lang is not None: + text_tokens = parser(text, lang) + # for future use if want to add language bypass to audio_to_text classes + # elif hasattr(parser, "lang") and parser.lang is not None: + # text_tokens = parser(text, parser.lang) + else: + raise ValueError( + "lang required in manifest when using aggregate tokenizers") + else: + text_tokens = parser(text) + else: + text_tokens = [] + + if text_tokens is None: + duration_filtered += duration + num_filtered += 1 + continue + + if label is not None: # + # replace <,>,N and then convert to int + label = int(label.replace('<', '').replace('>', '').replace('N', '')) + label = label - 1 # 0-indexed + + total_duration += duration + + data.append(output_type(id_, audio_file, video_file, video_featfile, duration, + text_tokens, snr_ratio, offset, text, speaker, orig_sr, lang, label)) + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len( + data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", + num_filtered, duration_filtered / 3600) super().__init__(data) @@ -271,7 +404,8 @@ def __init__( if lang is not None: text_tokens = parser(text, lang) else: - raise ValueError("lang required in manifest when using aggregate tokenizers") + raise ValueError( + "lang required in manifest when using aggregate tokenizers") else: text_tokens = parser(text) else: @@ -284,7 +418,8 @@ def __init__( total_duration += duration - data.append(output_type(id_, video_file, duration, text_tokens, offset, text, speaker, orig_sr, lang)) + data.append(output_type(id_, video_file, duration, + text_tokens, offset, text, speaker, orig_sr, lang)) if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(video_file)) if file_id not in self.mapping: @@ -297,12 +432,15 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) - logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) - logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + logging.info("Dataset loaded with %d files totalling %.2f hours", len( + data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", + num_filtered, duration_filtered / 3600) super().__init__(data) @@ -343,6 +481,50 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): ) +class ASR_AV_AudioText(AVText): + """`AudioText` collector from asr structured json files.""" + + def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + """Parse lists of audio files, durations and transcripts texts. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AVText` constructor. + **kwargs: Kwargs to pass to `AVText` constructor. + """ + + ids, audio_files, video_files, durations, texts, offsets, video_featfiles, snr_ratios, labels = ( + [], + [], + [], + [], + [], + [], + [], + [], + [], + ) + speakers, orig_srs, token_labels, langs = [], [], [], [] + for item in manifest.av_item_iter(manifests_files): + ids.append(item['id']) + audio_files.append(item['audio_file']) + video_featfiles.append(item['feature_file']) + durations.append(item['duration']) + texts.append(item['text']) + video_files.append(item['video_file']) + snr_ratios.append(item['snr_ratio']) + offsets.append(item['offset']) + speakers.append(item['speaker']) + orig_srs.append(item['orig_sr']) + token_labels.append(item['token_labels']) + langs.append(item['lang']) + labels.append(item['label']) + super().__init__( + ids, audio_files, video_files, video_featfiles, durations, snr_ratios, texts, offsets, speakers, orig_srs, token_labels, langs, labels, *args, **kwargs + ) + + class ASRVideoText(VideoText): """`VideoText` collector from cv structured json files.""" @@ -382,7 +564,8 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): class SpeechLabel(_Collection): """List of audio-label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='SpeechLabelEntity', field_names='audio_file duration label offset',) + OUTPUT_TYPE = collections.namedtuple( + typename='SpeechLabelEntity', field_names='audio_file duration label offset',) def __init__( self, @@ -438,14 +621,18 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) - logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") - logging.info(f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") + logging.info( + f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") + logging.info( + f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") self.uniq_labels = sorted(set(map(lambda x: x.label, data))) - logging.info("# {} files loaded accounting to # {} labels".format(len(data), len(self.uniq_labels))) + logging.info("# {} files loaded accounting to # {} labels".format( + len(data), len(self.uniq_labels))) super().__init__(data) @@ -502,12 +689,15 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'audio_filepath' in item: item['audio_file'] = item.pop('audio_filepath') else: - raise ValueError(f"Manifest file has invalid json line structure: {line} without proper audio file key.") - item['audio_file'] = manifest.get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + raise ValueError( + f"Manifest file has invalid json line structure: {line} without proper audio file key.") + item['audio_file'] = manifest.get_full_path( + audio_file=item['audio_file'], manifest_file=manifest_file) # Duration. if 'duration' not in item: - raise ValueError(f"Manifest file has invalid json line structure: {line} without proper duration key.") + raise ValueError( + f"Manifest file has invalid json line structure: {line} without proper duration key.") # Label. if 'command' in item: @@ -517,7 +707,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'label' in item: pass else: - raise ValueError(f"Manifest file has invalid json line structure: {line} without proper label key.") + raise ValueError( + f"Manifest file has invalid json line structure: {line} without proper label key.") item = dict( audio_file=item['audio_file'], @@ -532,7 +723,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: class FeatureSequenceLabel(_Collection): """List of feature sequence of label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='FeatureSequenceLabelEntity', field_names='feature_file seq_label',) + OUTPUT_TYPE = collections.namedtuple( + typename='FeatureSequenceLabelEntity', field_names='feature_file seq_label',) def __init__( self, @@ -562,7 +754,8 @@ def __init__( for feature_file, seq_label in zip(feature_files, seq_labels): - label_tokens, uniq_labels_in_seq = self.relative_speaker_parser(seq_label) + label_tokens, uniq_labels_in_seq = self.relative_speaker_parser( + seq_label) data.append(output_type(feature_file, label_tokens)) self.uniq_labels |= uniq_labels_in_seq @@ -579,7 +772,8 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info("# {} files loaded including # {} unique labels".format( + len(data), len(self.uniq_labels))) super().__init__(data) def relative_speaker_parser(self, seq_label): @@ -616,7 +810,6 @@ class ASRFeatureSequenceLabel(FeatureSequenceLabel): def __init__( self, manifests_files: Union[str, List[str]], max_number: Optional[int] = None, index_by_file_id: bool = False, ): - """Parse lists of feature files and sequences of labels. Args: @@ -655,7 +848,8 @@ def _parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: f"Manifest file has invalid json line " f"structure: {line} without proper seq_label key." ) - item = dict(feature_file=item['feature_file'], seq_label=item['seq_label'],) + item = dict(feature_file=item['feature_file'], + seq_label=item['seq_label'],) return item @@ -754,14 +948,16 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) logging.info( "Filtered duration for loading collection is %f.", duration_filtered, ) - logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + logging.info( + f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") super().__init__(data) @@ -821,12 +1017,15 @@ def __init__( for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): # Inference mode if self.pairwise_infer: - clus_speaker_digits = sorted(list(set([x[2] for x in clus_label_dict[item['uniq_id']]]))) + clus_speaker_digits = sorted( + list(set([x[2] for x in clus_label_dict[item['uniq_id']]]))) if item['rttm_file']: base_scale_index = max(self.emb_dict.keys()) _sess_spk_dict = self.emb_dict[base_scale_index][item['uniq_id']]['mapping'] - sess_spk_dict = {int(v.split('_')[-1]): k for k, v in _sess_spk_dict.items()} - rttm_speaker_digits = [int(v.split('_')[1]) for k, v in _sess_spk_dict.items()] + sess_spk_dict = { + int(v.split('_')[-1]): k for k, v in _sess_spk_dict.items()} + rttm_speaker_digits = [int(v.split('_')[1]) + for k, v in _sess_spk_dict.items()] if self.seq_eval_mode: clus_speaker_digits = rttm_speaker_digits else: @@ -838,14 +1037,17 @@ def __init__( rttm_labels = [] with open(item['rttm_file'], 'r') as f: for line in f.readlines(): - start, end, speaker = self.split_rttm_line(line, decimals=3) - rttm_labels.append('{} {} {}'.format(start, end, speaker)) + start, end, speaker = self.split_rttm_line( + line, decimals=3) + rttm_labels.append( + '{} {} {}'.format(start, end, speaker)) speaker_set = set() for rttm_line in rttm_labels: spk_str = rttm_line.split()[-1] speaker_set.add(spk_str) speaker_list = sorted(list(speaker_set)) - sess_spk_dict = {key: val for key, val in enumerate(speaker_list)} + sess_spk_dict = {key: val for key, + val in enumerate(speaker_list)} target_spks = tuple(sess_spk_dict.keys()) clus_speaker_digits = target_spks rttm_speaker_digits = target_spks @@ -853,7 +1055,8 @@ def __init__( if len(clus_speaker_digits) <= 2: spk_comb_list = [(0, 1)] else: - spk_comb_list = [x for x in combinations(clus_speaker_digits, 2)] + spk_comb_list = [x for x in combinations( + clus_speaker_digits, 2)] for target_spks in spk_comb_list: audio_files.append(item['audio_file']) @@ -923,9 +1126,11 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." ) item['audio_file'] = os.path.expanduser(item['audio_file']) - item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + item['uniq_id'] = os.path.splitext( + os.path.basename(item['audio_file']))[0] if 'duration' not in item: - raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") item = dict( audio_file=item['audio_file'], uniq_id=item['uniq_id'], @@ -940,7 +1145,8 @@ class Audio(_Collection): """Prepare a list of all audio items, filtered by duration. """ - OUTPUT_TYPE = collections.namedtuple(typename='Audio', field_names='audio_files duration offset text') + OUTPUT_TYPE = collections.namedtuple( + typename='Audio', field_names='audio_files duration offset text') def __init__( self, @@ -992,8 +1198,10 @@ def __init__( if do_sort_by_duration: data.sort(key=lambda entity: entity.duration) - logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) - logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + logging.info("Dataset loaded with %d files totalling %.2f hours", len( + data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", + num_filtered, duration_filtered / 3600) super().__init__(data) @@ -1033,7 +1241,8 @@ def __init__( offset_list.append(item['offset']) text_list.append(item['text']) - super().__init__(audio_files_list, duration_list, offset_list, text_list, *args, **kwargs) + super().__init__(audio_files_list, duration_list, + offset_list, text_list, *args, **kwargs) def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: """Parse a single line from a manifest file. @@ -1067,9 +1276,11 @@ def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): elif isinstance(item_key, list): audio_file += item_key else: - raise ValueError(f'Unexpected type {type(item_key)} of item for key {key}: {item_key}') + raise ValueError( + f'Unexpected type {type(item_key)} of item for key {key}: {item_key}') else: - raise ValueError(f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}') + raise ValueError( + f'Unexpected type {type(manifest_key)} of manifest_key: {manifest_key}') return audio_file @@ -1085,21 +1296,25 @@ def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): # Get full path to audio file(s) if isinstance(audio_file, str): # This dictionary entry points to a single file - audio_files[audio_key] = manifest.get_full_path(audio_file, manifest_file) + audio_files[audio_key] = manifest.get_full_path( + audio_file, manifest_file) elif isinstance(audio_file, Iterable): # This dictionary entry points to multiple files # Get the files and keep the list structure for this key - audio_files[audio_key] = [manifest.get_full_path(f, manifest_file) for f in audio_file] + audio_files[audio_key] = [manifest.get_full_path( + f, manifest_file) for f in audio_file] elif audio_file is None and audio_key.startswith('target'): # For inference, we don't need the target audio_files[audio_key] = None else: - raise ValueError(f'Unexpected type {type(audio_file)} of audio_file: {audio_file}') + raise ValueError( + f'Unexpected type {type(audio_file)} of audio_file: {audio_file}') item['audio_files'] = audio_files # Handle duration if 'duration' not in item: - raise ValueError(f'Duration not available in line: {line}. Manifest file: {manifest_file}') + raise ValueError( + f'Duration not available in line: {line}. Manifest file: {manifest_file}') # Handle offset if 'offset' not in item: @@ -1117,7 +1332,8 @@ def get_audio_file(item: Dict, manifest_key: Union[str, List[str]]): class FeatureLabel(_Collection): """List of feature sequence and their label correspondence with preprocessing.""" - OUTPUT_TYPE = collections.namedtuple(typename='FeatureLabelEntity', field_names='feature_file label duration',) + OUTPUT_TYPE = collections.namedtuple( + typename='FeatureLabelEntity', field_names='feature_file label duration',) def __init__( self, @@ -1172,13 +1388,17 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) - logging.info(f"Filtered duration for loading collection is {duration_filtered / 2600:.2f} hours.") - logging.info(f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info( + f"Filtered duration for loading collection is {duration_filtered / 2600:.2f} hours.") + logging.info( + f"Dataset loaded with {len(data)} items, total duration of {total_duration / 3600: .2f} hours.") + logging.info("# {} files loaded including # {} unique labels".format( + len(data), len(self.uniq_labels))) super().__init__(data) @@ -1194,7 +1414,6 @@ def __init__( *args, **kwargs, ): - """Parse lists of feature files and sequences of labels. Args: @@ -1236,15 +1455,18 @@ def _parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: raise ValueError( f"Manifest file has invalid json line " f"structure: {line} without proper 'feature_file' key." ) - item['feature_file'] = manifest.get_full_path(audio_file=item['feature_file'], manifest_file=manifest_file) + item['feature_file'] = manifest.get_full_path( + audio_file=item['feature_file'], manifest_file=manifest_file) # Label. if 'label' in item: item['label'] = item.pop('label') else: - raise ValueError(f"Manifest file has invalid json line structure: {line} without proper 'label' key.") + raise ValueError( + f"Manifest file has invalid json line structure: {line} without proper 'label' key.") - item = dict(feature_file=item['feature_file'], label=item['label'], duration=item['duration']) + item = dict(feature_file=item['feature_file'], + label=item['label'], duration=item['duration']) return item @@ -1332,7 +1554,8 @@ def __init__( if lang is not None: text_tokens = parser(text, lang) else: - raise ValueError("lang required in manifest when using aggregate tokenizers") + raise ValueError( + "lang required in manifest when using aggregate tokenizers") else: text_tokens = parser(text) else: @@ -1346,7 +1569,8 @@ def __init__( total_duration += duration data.append( - output_type(id_, feat_file, rttm_file, duration, text_tokens, offset, text, speaker, orig_sr, lang) + output_type(id_, feat_file, rttm_file, duration, + text_tokens, offset, text, speaker, orig_sr, lang) ) if index_by_file_id: file_id, _ = os.path.splitext(os.path.basename(feat_file)) @@ -1360,12 +1584,15 @@ def __init__( if do_sort_by_duration: if index_by_file_id: - logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + logging.warning( + "Tried to sort dataset by duration, but cannot since index_by_file_id is set.") else: data.sort(key=lambda entity: entity.duration) - logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) - logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + logging.info("Dataset loaded with %d files totalling %.2f hours", len( + data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", + num_filtered, duration_filtered / 3600) super().__init__(data) diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index 1d49bd7c7019..fba05937d7b5 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -188,6 +188,170 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: return item +def av_item_iter( + manifests_files: Union[str, List[str]], parse_func: Callable[[str, Optional[str]], Dict[str, Any]] = None +) -> Iterator[Dict[str, Any]]: + """Iterate through json lines of provided manifests. + + NeMo ASR pipelines often assume certain manifest files structure. In + particular, each manifest file should consist of line-per-sample files with + each line being correct json dict. Each such json dict should have a field + for audio file string, a field for duration float and a field for text + string. Offset also could be additional field and is set to None by + default. + + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + + parse_func: A callable function which accepts as input a single line + of a manifest and optionally the manifest file itself, + and parses it, returning a dictionary mapping from str -> Any. + + Yields: + Parsed key to value item dicts. + + Raises: + ValueError: If met invalid json line structure. + """ + + if isinstance(manifests_files, str): + manifests_files = [manifests_files] + + if parse_func is None: + parse_func = __av_parse_item + + errors = defaultdict(list) + k = -1 + logging.debug('Manifest files: %s', str(manifests_files)) + for manifest_file in manifests_files: + logging.debug('Using manifest file: %s', str(manifest_file)) + cached_manifest_file = DataStoreObject(manifest_file).get() + logging.debug('Cached at: %s', str(cached_manifest_file)) + with open(expanduser(cached_manifest_file), 'r') as f: + for line in f: + line = line.strip() + if not line: + continue + k += 1 + try: + item = parse_func(line, manifest_file) + except json.JSONDecodeError: + errors[str(manifest_file)].append(line) + continue + item['id'] = k + yield item + + if len(errors) > 0: + for filename, lines in errors.items(): + logging.error("=============================================") + logging.error(f"Failed to parse {len(lines)} lines from manifest file: {filename}") + for line in lines: + logging.error(f"-- Failed to parse line: `{line}`") + raise RuntimeError("Failed to parse some lines from manifest files. See logs for more details.") + + +def __av_parse_item(line: str, manifest_file: str) -> Dict[str, Any]: + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + + # Video File + if 'video_filename' in item: + item['video_file'] = item.pop('video_filename') + elif 'video_filepath' in item: + item['video_file'] = item.pop('video_filepath') + + if 'video_file' not in item and 'audio_file' not in item: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper audio/video file key." + ) + + # If the audio/video path is a relative path and does not exist, + # try to attach the parent directory of manifest to the audio path. + # Revert to the original path if the new path still doesn't exist. + # Assume that the audio path is like "wavs/xxxxxx.wav". + if 'audio_file' in item: + item['audio_file'] = get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + if 'video_file' in item: + item['video_file'] = get_full_path(audio_file=item['video_file'], manifest_file=manifest_file) + + # Duration. + if 'duration' not in item: + raise ValueError( + f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." + ) + + # Text. + if 'text' in item: + pass + elif 'text_filepath' in item: + with open(item.pop('text_filepath'), 'r') as f: + item['text'] = f.read().replace('\n', '') + elif 'normalized_text' in item: + item['text'] = item['normalized_text'] + else: + item['text'] = "" + + # Optional RTTM file + if 'rttm_file' in item: + pass + elif 'rttm_filename' in item: + item['rttm_file'] = item.pop('rttm_filename') + elif 'rttm_filepath' in item: + item['rttm_file'] = item.pop('rttm_filepath') + else: + item['rttm_file'] = None + if item['rttm_file'] is not None: + item['rttm_file'] = get_full_path(audio_file=item['rttm_file'], manifest_file=manifest_file) + + # Optional audio feature file + if 'feature_file' in item: + pass + elif 'feature_filename' in item: + item['feature_file'] = item.pop('feature_filename') + elif 'feature_filepath' in item: + item['feature_file'] = item.pop('feature_filepath') + else: + item['feature_file'] = None + if item['feature_file'] is not None: + item['feature_file'] = get_full_path(audio_file=item['feature_file'], manifest_file=manifest_file) + + # Optional snr_ratio + if 'snr' in item: + pass + elif 'snr_ratio' in item: + item['snr'] = item.pop('snr_ratio') + else: + item['snr'] = None + + if 'label' in item: + item['label'] = item.pop('label') + else: + item['label'] = None + + item = dict( + audio_file=item.get('audio_file', None), + video_file=item.get('video_file', None), + duration=item['duration'], + text=item['text'], + rttm_file=item['rttm_file'], + feature_file=item['feature_file'], + offset=item.get('offset', None), + speaker=item.get('speaker', None), + orig_sr=item.get('orig_sample_rate', None), + token_labels=item.get('token_labels', None), + lang=item.get('lang', None), + snr_ratio=item.get('snr', None), + label=item.get('label', None), + ) + return item + + def is_tarred_dataset(audio_file: str, manifest_file: Optional[str] = None) -> bool: if "/" in audio_file or manifest_file is None: # audio files in a tarred dataset don't have `/` in their paths diff --git a/scripts/tokenizers/sentencepiece_model_pb2.py b/scripts/tokenizers/sentencepiece_model_pb2.py new file mode 100644 index 000000000000..bd196c531a2c --- /dev/null +++ b/scripts/tokenizers/sentencepiece_model_pb2.py @@ -0,0 +1,764 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: sentencepiece_model.proto + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='sentencepiece_model.proto', + package='sentencepiece', + syntax='proto2', + serialized_options=b'H\003', + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\xa4\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05\x12\x16\n\tbos_piece\x18. \x01(\t:\x03\x12\x17\n\teos_piece\x18/ \x01(\t:\x04\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\x12\"\n\x18seed_sentencepieces_file\x18\x36 \x01(\t:\x00\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03' +) + + + +_TRAINERSPEC_MODELTYPE = _descriptor.EnumDescriptor( + name='ModelType', + full_name='sentencepiece.TrainerSpec.ModelType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='UNIGRAM', index=0, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='BPE', index=1, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='WORD', index=2, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='CHAR', index=3, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=1553, + serialized_end=1606, +) +_sym_db.RegisterEnumDescriptor(_TRAINERSPEC_MODELTYPE) + +_MODELPROTO_SENTENCEPIECE_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='sentencepiece.ModelProto.SentencePiece.Type', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='NORMAL', index=0, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='UNKNOWN', index=1, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='CONTROL', index=2, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='USER_DEFINED', index=3, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='BYTE', index=4, number=6, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='UNUSED', index=5, number=5, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=2359, + serialized_end=2443, +) +_sym_db.RegisterEnumDescriptor(_MODELPROTO_SENTENCEPIECE_TYPE) + + +_TRAINERSPEC = _descriptor.Descriptor( + name='TrainerSpec', + full_name='sentencepiece.TrainerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='input', full_name='sentencepiece.TrainerSpec.input', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='input_format', full_name='sentencepiece.TrainerSpec.input_format', index=1, + number=7, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='model_prefix', full_name='sentencepiece.TrainerSpec.model_prefix', index=2, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='model_type', full_name='sentencepiece.TrainerSpec.model_type', index=3, + number=3, type=14, cpp_type=8, label=1, + has_default_value=True, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='vocab_size', full_name='sentencepiece.TrainerSpec.vocab_size', index=4, + number=4, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=8000, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='accept_language', full_name='sentencepiece.TrainerSpec.accept_language', index=5, + number=5, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='self_test_sample_size', full_name='sentencepiece.TrainerSpec.self_test_sample_size', index=6, + number=6, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='enable_differential_privacy', full_name='sentencepiece.TrainerSpec.enable_differential_privacy', index=7, + number=50, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='differential_privacy_noise_level', full_name='sentencepiece.TrainerSpec.differential_privacy_noise_level', index=8, + number=51, type=2, cpp_type=6, label=1, + has_default_value=True, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='differential_privacy_clipping_threshold', full_name='sentencepiece.TrainerSpec.differential_privacy_clipping_threshold', index=9, + number=52, type=4, cpp_type=4, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='character_coverage', full_name='sentencepiece.TrainerSpec.character_coverage', index=10, + number=10, type=2, cpp_type=6, label=1, + has_default_value=True, default_value=float(0.9995), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='input_sentence_size', full_name='sentencepiece.TrainerSpec.input_sentence_size', index=11, + number=11, type=4, cpp_type=4, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='shuffle_input_sentence', full_name='sentencepiece.TrainerSpec.shuffle_input_sentence', index=12, + number=19, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='mining_sentence_size', full_name='sentencepiece.TrainerSpec.mining_sentence_size', index=13, + number=12, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\030\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='training_sentence_size', full_name='sentencepiece.TrainerSpec.training_sentence_size', index=14, + number=13, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=b'\030\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='seed_sentencepiece_size', full_name='sentencepiece.TrainerSpec.seed_sentencepiece_size', index=15, + number=14, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=1000000, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='shrinking_factor', full_name='sentencepiece.TrainerSpec.shrinking_factor', index=16, + number=15, type=2, cpp_type=6, label=1, + has_default_value=True, default_value=float(0.75), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_sentence_length', full_name='sentencepiece.TrainerSpec.max_sentence_length', index=17, + number=18, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=4192, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_threads', full_name='sentencepiece.TrainerSpec.num_threads', index=18, + number=16, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=16, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='num_sub_iterations', full_name='sentencepiece.TrainerSpec.num_sub_iterations', index=19, + number=17, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=2, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='max_sentencepiece_length', full_name='sentencepiece.TrainerSpec.max_sentencepiece_length', index=20, + number=20, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=16, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='split_by_unicode_script', full_name='sentencepiece.TrainerSpec.split_by_unicode_script', index=21, + number=21, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='split_by_number', full_name='sentencepiece.TrainerSpec.split_by_number', index=22, + number=23, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='split_by_whitespace', full_name='sentencepiece.TrainerSpec.split_by_whitespace', index=23, + number=22, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='treat_whitespace_as_suffix', full_name='sentencepiece.TrainerSpec.treat_whitespace_as_suffix', index=24, + number=24, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='allow_whitespace_only_pieces', full_name='sentencepiece.TrainerSpec.allow_whitespace_only_pieces', index=25, + number=26, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='split_digits', full_name='sentencepiece.TrainerSpec.split_digits', index=26, + number=25, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='pretokenization_delimiter', full_name='sentencepiece.TrainerSpec.pretokenization_delimiter', index=27, + number=53, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='control_symbols', full_name='sentencepiece.TrainerSpec.control_symbols', index=28, + number=30, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='user_defined_symbols', full_name='sentencepiece.TrainerSpec.user_defined_symbols', index=29, + number=31, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='required_chars', full_name='sentencepiece.TrainerSpec.required_chars', index=30, + number=36, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='byte_fallback', full_name='sentencepiece.TrainerSpec.byte_fallback', index=31, + number=35, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='vocabulary_output_piece_score', full_name='sentencepiece.TrainerSpec.vocabulary_output_piece_score', index=32, + number=32, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='hard_vocab_limit', full_name='sentencepiece.TrainerSpec.hard_vocab_limit', index=33, + number=33, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='use_all_vocab', full_name='sentencepiece.TrainerSpec.use_all_vocab', index=34, + number=34, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='unk_id', full_name='sentencepiece.TrainerSpec.unk_id', index=35, + number=40, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='bos_id', full_name='sentencepiece.TrainerSpec.bos_id', index=36, + number=41, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='eos_id', full_name='sentencepiece.TrainerSpec.eos_id', index=37, + number=42, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=2, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='pad_id', full_name='sentencepiece.TrainerSpec.pad_id', index=38, + number=43, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=-1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='unk_piece', full_name='sentencepiece.TrainerSpec.unk_piece', index=39, + number=45, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='bos_piece', full_name='sentencepiece.TrainerSpec.bos_piece', index=40, + number=46, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='eos_piece', full_name='sentencepiece.TrainerSpec.eos_piece', index=41, + number=47, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='pad_piece', full_name='sentencepiece.TrainerSpec.pad_piece', index=42, + number=48, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='unk_surface', full_name='sentencepiece.TrainerSpec.unk_surface', index=43, + number=44, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b" \342\201\207 ".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='train_extremely_large_corpus', full_name='sentencepiece.TrainerSpec.train_extremely_large_corpus', index=44, + number=49, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='seed_sentencepieces_file', full_name='sentencepiece.TrainerSpec.seed_sentencepieces_file', index=45, + number=54, type=9, cpp_type=9, label=1, + has_default_value=True, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _TRAINERSPEC_MODELTYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[(200, 536870912), ], + oneofs=[ + ], + serialized_start=45, + serialized_end=1617, +) + + +_NORMALIZERSPEC = _descriptor.Descriptor( + name='NormalizerSpec', + full_name='sentencepiece.NormalizerSpec', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='sentencepiece.NormalizerSpec.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='precompiled_charsmap', full_name='sentencepiece.NormalizerSpec.precompiled_charsmap', index=1, + number=2, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=b"", + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='add_dummy_prefix', full_name='sentencepiece.NormalizerSpec.add_dummy_prefix', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='remove_extra_whitespaces', full_name='sentencepiece.NormalizerSpec.remove_extra_whitespaces', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='escape_whitespaces', full_name='sentencepiece.NormalizerSpec.escape_whitespaces', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=True, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='normalization_rule_tsv', full_name='sentencepiece.NormalizerSpec.normalization_rule_tsv', index=5, + number=6, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[(200, 536870912), ], + oneofs=[ + ], + serialized_start=1620, + serialized_end=1829, +) + + +_SELFTESTDATA_SAMPLE = _descriptor.Descriptor( + name='Sample', + full_name='sentencepiece.SelfTestData.Sample', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='input', full_name='sentencepiece.SelfTestData.Sample.input', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='expected', full_name='sentencepiece.SelfTestData.Sample.expected', index=1, + number=2, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1900, + serialized_end=1941, +) + +_SELFTESTDATA = _descriptor.Descriptor( + name='SelfTestData', + full_name='sentencepiece.SelfTestData', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='samples', full_name='sentencepiece.SelfTestData.samples', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[_SELFTESTDATA_SAMPLE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[(200, 536870912), ], + oneofs=[ + ], + serialized_start=1831, + serialized_end=1952, +) + + +_MODELPROTO_SENTENCEPIECE = _descriptor.Descriptor( + name='SentencePiece', + full_name='sentencepiece.ModelProto.SentencePiece', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='piece', full_name='sentencepiece.ModelProto.SentencePiece.piece', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=b"".decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='score', full_name='sentencepiece.ModelProto.SentencePiece.score', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='type', full_name='sentencepiece.ModelProto.SentencePiece.type', index=2, + number=3, type=14, cpp_type=8, label=1, + has_default_value=True, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + _MODELPROTO_SENTENCEPIECE_TYPE, + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[(200, 536870912), ], + oneofs=[ + ], + serialized_start=2244, + serialized_end=2454, +) + +_MODELPROTO = _descriptor.Descriptor( + name='ModelProto', + full_name='sentencepiece.ModelProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='pieces', full_name='sentencepiece.ModelProto.pieces', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='trainer_spec', full_name='sentencepiece.ModelProto.trainer_spec', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='normalizer_spec', full_name='sentencepiece.ModelProto.normalizer_spec', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='self_test_data', full_name='sentencepiece.ModelProto.self_test_data', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='denormalizer_spec', full_name='sentencepiece.ModelProto.denormalizer_spec', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[_MODELPROTO_SENTENCEPIECE, ], + enum_types=[ + ], + serialized_options=None, + is_extendable=True, + syntax='proto2', + extension_ranges=[(200, 536870912), ], + oneofs=[ + ], + serialized_start=1955, + serialized_end=2465, +) + +_TRAINERSPEC.fields_by_name['model_type'].enum_type = _TRAINERSPEC_MODELTYPE +_TRAINERSPEC_MODELTYPE.containing_type = _TRAINERSPEC +_SELFTESTDATA_SAMPLE.containing_type = _SELFTESTDATA +_SELFTESTDATA.fields_by_name['samples'].message_type = _SELFTESTDATA_SAMPLE +_MODELPROTO_SENTENCEPIECE.fields_by_name['type'].enum_type = _MODELPROTO_SENTENCEPIECE_TYPE +_MODELPROTO_SENTENCEPIECE.containing_type = _MODELPROTO +_MODELPROTO_SENTENCEPIECE_TYPE.containing_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['pieces'].message_type = _MODELPROTO_SENTENCEPIECE +_MODELPROTO.fields_by_name['trainer_spec'].message_type = _TRAINERSPEC +_MODELPROTO.fields_by_name['normalizer_spec'].message_type = _NORMALIZERSPEC +_MODELPROTO.fields_by_name['self_test_data'].message_type = _SELFTESTDATA +_MODELPROTO.fields_by_name['denormalizer_spec'].message_type = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['TrainerSpec'] = _TRAINERSPEC +DESCRIPTOR.message_types_by_name['NormalizerSpec'] = _NORMALIZERSPEC +DESCRIPTOR.message_types_by_name['SelfTestData'] = _SELFTESTDATA +DESCRIPTOR.message_types_by_name['ModelProto'] = _MODELPROTO +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TrainerSpec = _reflection.GeneratedProtocolMessageType('TrainerSpec', (_message.Message,), { + 'DESCRIPTOR' : _TRAINERSPEC, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.TrainerSpec) + }) +_sym_db.RegisterMessage(TrainerSpec) + +NormalizerSpec = _reflection.GeneratedProtocolMessageType('NormalizerSpec', (_message.Message,), { + 'DESCRIPTOR' : _NORMALIZERSPEC, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.NormalizerSpec) + }) +_sym_db.RegisterMessage(NormalizerSpec) + +SelfTestData = _reflection.GeneratedProtocolMessageType('SelfTestData', (_message.Message,), { + + 'Sample' : _reflection.GeneratedProtocolMessageType('Sample', (_message.Message,), { + 'DESCRIPTOR' : _SELFTESTDATA_SAMPLE, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData.Sample) + }) + , + 'DESCRIPTOR' : _SELFTESTDATA, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.SelfTestData) + }) +_sym_db.RegisterMessage(SelfTestData) +_sym_db.RegisterMessage(SelfTestData.Sample) + +ModelProto = _reflection.GeneratedProtocolMessageType('ModelProto', (_message.Message,), { + + 'SentencePiece' : _reflection.GeneratedProtocolMessageType('SentencePiece', (_message.Message,), { + 'DESCRIPTOR' : _MODELPROTO_SENTENCEPIECE, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto.SentencePiece) + }) + , + 'DESCRIPTOR' : _MODELPROTO, + '__module__' : 'sentencepiece_model_pb2' + # @@protoc_insertion_point(class_scope:sentencepiece.ModelProto) + }) +_sym_db.RegisterMessage(ModelProto) +_sym_db.RegisterMessage(ModelProto.SentencePiece) + + +DESCRIPTOR._options = None +_TRAINERSPEC.fields_by_name['mining_sentence_size']._options = None +_TRAINERSPEC.fields_by_name['training_sentence_size']._options = None +# @@protoc_insertion_point(module_scope) diff --git a/tools/nemo_forced_aligner/align.py b/tools/nemo_forced_aligner/align.py index d298e8072d58..3f795606b4b7 100644 --- a/tools/nemo_forced_aligner/align.py +++ b/tools/nemo_forced_aligner/align.py @@ -148,7 +148,7 @@ class AlignmentConfig: simulate_cache_aware_streaming: Optional[bool] = False # Output file configs - save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm", "ass"]) + save_output_file_formats: List[str] = field(default_factory=lambda: ["ctm"]) ctm_file_config: CTMFileConfig = field(default_factory=lambda: CTMFileConfig()) ass_file_config: ASSFileConfig = field(default_factory=lambda: ASSFileConfig()) diff --git a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb index 94e2caa17a58..d3db783559a3 100644 --- a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb +++ b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb @@ -1,22 +1,23 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "code", + "execution_count": 1, "metadata": { "id": "EGV_ioUHqhun" }, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nRemember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\\nAlternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\\nthat you want to use the \"Run All Cells\" (or similar) option.\\n'" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "\"\"\"\n", "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", @@ -33,14 +34,14 @@ "\"\"\"\n", "\n", "# Install dependencies\n", - "!pip install wget\n", - "!apt-get install sox libsndfile1 ffmpeg libsox-fmt-mp3\n", - "!pip install text-unidecode\n", - "!pip install matplotlib>=3.3.2\n", + "# !pip install wget\n", + "# !apt-get install sox libsndfile1 ffmpeg libsox-fmt-mp3\n", + "# !pip install text-unidecode\n", + "# !pip install matplotlib>=3.3.2\n", "\n", "## Install NeMo\n", "BRANCH = 'main'\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "\"\"\"\n", "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n", @@ -48,9 +49,7 @@ "that you want to use the \"Run All Cells\" (or similar) option.\n", "\"\"\"\n", "# exit()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -75,9 +74,11 @@ }, { "cell_type": "code", + "execution_count": 2, "metadata": { "id": "1cjMaek4rY8-" }, + "outputs": [], "source": [ "import os\n", "import glob\n", @@ -86,15 +87,15 @@ "import wget\n", "import copy\n", "from omegaconf import OmegaConf, open_dict" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 3, "metadata": { "id": "8wqTRjpNruZD" }, + "outputs": [], "source": [ "data_dir = 'datasets/'\n", "\n", @@ -103,23 +104,21 @@ "\n", "if not os.path.exists(\"scripts\"):\n", " os.makedirs(\"scripts\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 4, "metadata": { "id": "TSTb6b5DriWG" }, + "outputs": [], "source": [ "import nemo\n", "import nemo.collections.asr as nemo_asr\n", "from nemo.collections.asr.metrics.wer import word_error_rate\n", "from nemo.utils import logging, exp_manager" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -138,6 +137,9 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "k40Q84TNnU6O" + }, "source": [ "## Hugging Face\n", "\n", @@ -156,59 +158,71 @@ "Code steps:\n", "- Now below, run `login()`\n", "- Paste your preserved HF TOKEN API KEY to the text box.\"" - ], - "metadata": { - "id": "k40Q84TNnU6O" - } + ] }, { "cell_type": "code", + "execution_count": 5, "metadata": { "id": "27h1i8qa7WFE" }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f4940f6a57be4278b1fc21d1266f7289", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import matplotlib.pyplot as plt\n", "\n", @@ -607,9 +845,7 @@ "plt.xlabel(\"# of occurrences\")\n", "plt.ylabel(\"# of tokens\")\n", "plt.xlim(0, MAX_COUNT);" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -622,9 +858,19 @@ }, { "cell_type": "code", + "execution_count": 24, "metadata": { "id": "9G6laS0ojV-B" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of tokens with <= 5 occurrences : 1041\n" + ] + } + ], "source": [ "UNCOMMON_TOKENS_COUNT = 5\n", "\n", @@ -635,9 +881,7 @@ " chars_with_infrequent_occurrence.update(set(token_list))\n", "\n", "print(f\"Number of tokens with <= {UNCOMMON_TOKENS_COUNT} occurrences : {len(chars_with_infrequent_occurrence)}\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -652,9 +896,20 @@ }, { "cell_type": "code", + "execution_count": 25, "metadata": { "id": "jnh_pnL2jWAY" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original train+dev+test vocab size : 1433\n", + "New train vocab size : 1255\n" + ] + } + ], "source": [ "all_tokens = set.union(train_dev_set, test_set)\n", "print(f\"Original train+dev+test vocab size : {len(all_tokens)}\")\n", @@ -662,9 +917,7 @@ "extra_kanji = set(test_oov)\n", "train_token_set = all_tokens - extra_kanji\n", "print(f\"New train vocab size : {len(train_token_set)}\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -692,38 +945,48 @@ }, { "cell_type": "code", + "execution_count": 26, "metadata": { - "id": "kaX9WzK15Q6t", - "cellView": "form" + "cellView": "form", + "id": "kaX9WzK15Q6t" }, + "outputs": [], "source": [ "#@title Dakuten normalization\n", "perform_dakuten_normalization = True #@param [\"True\", \"False\"] {type:\"raw\"}\n", "PERFORM_DAKUTEN_NORMALIZATION = bool(perform_dakuten_normalization)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 27, "metadata": { "id": "HiEZVEshOp-y" }, + "outputs": [], "source": [ "import unicodedata\n", "def process_dakuten(text):\n", " normalized_text = unicodedata.normalize('NFD', text)\n", " normalized_text = normalized_text.replace(\"\\u3099\", \"\").replace(\"\\u309A\", \"\")\n", " return normalized_text" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 28, "metadata": { "id": "pV4kOgpvjWGg" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After dakuten normalization, number of train tokens : 1210\n" + ] + } + ], "source": [ "if PERFORM_DAKUTEN_NORMALIZATION:\n", " normalized_train_token_set = set()\n", @@ -733,11 +996,8 @@ "\n", " print(f\"After dakuten normalization, number of train tokens : {len(normalized_train_token_set)}\")\n", "else:\n", - " normalized_train_token_set = train_token_set\n", - "" - ], - "execution_count": null, - "outputs": [] + " normalized_train_token_set = train_token_set\n" + ] }, { "cell_type": "markdown", @@ -754,9 +1014,11 @@ }, { "cell_type": "code", + "execution_count": 29, "metadata": { "id": "NN3asqvsrp_S" }, + "outputs": [], "source": [ "# Preprocessing steps\n", "import re\n", @@ -780,9 +1042,7 @@ " text = data['text']\n", " data['text'] = process_dakuten(text)\n", " return data" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -797,9 +1057,11 @@ }, { "cell_type": "code", + "execution_count": 30, "metadata": { "id": "mwNtHeHLjqJl" }, + "outputs": [], "source": [ "# Processing pipeline\n", "def apply_preprocessors(manifest, preprocessors):\n", @@ -809,15 +1071,15 @@ "\n", " print(\"Finished processing manifest !\")\n", " return manifest" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 31, "metadata": { "id": "xB06YHmDr-Ja" }, + "outputs": [], "source": [ "# List of pre-processing functions\n", "PREPROCESSORS = [\n", @@ -825,15 +1087,166 @@ " remove_extra_kanji,\n", " remove_dakuten,\n", "]" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 32, "metadata": { "id": "4lqUvpkrr7bQ" }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "85d3ee86b95547fc9781be780c913166", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Applying remove_special_characters: 0%| | 0/722 [00:00>\n", - "VOCAB_SIZE = len(train_dev_set) + 2" - ], - "execution_count": null, - "outputs": [] + "# VOCAB_SIZE = len(train_dev_set) + 2\n", + "VOCAB_SIZE = 128" + ] }, { "cell_type": "markdown", @@ -1489,34 +2045,467 @@ }, { "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "train_manifest_cleaned = \"/workspace/dataset/train_clean/manifest.json\"\n", + "dev_manifest_cleaned = \"/workspace/dataset/validation/manifest.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_dir = \"/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer\"" + ] + }, + { + "cell_type": "code", + "execution_count": 51, "metadata": { "id": "yT-SBPN2Ox6Y" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:root:Finished extracting manifest : /workspace/dataset/train_clean/manifest.json\n", + "INFO:root:Finished extracting manifest : /workspace/dataset/validation/manifest.json\n", + "INFO:root:Finished extracting all manifests ! Number of sentences : 236954\n", + "[NeMo I 2024-07-18 18:51:58 sentencepiece_tokenizer:316] Processing /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/text_corpus/document.txt and store at /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128\n", + "sentencepiece_trainer.cc(178) LOG(INFO) Running command: --input=/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/text_corpus/document.txt --model_prefix=/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer --vocab_size=128 --shuffle_input_sentence=true --hard_vocab_limit=false --model_type=bpe --character_coverage=1.0 --bos_id=-1 --eos_id=-1\n", + "sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : \n", + "trainer_spec {\n", + " input: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/text_corpus/document.txt\n", + " input_format: \n", + " model_prefix: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer\n", + " model_type: BPE\n", + " vocab_size: 128\n", + " self_test_sample_size: 0\n", + " character_coverage: 1\n", + " input_sentence_size: 0\n", + " shuffle_input_sentence: 1\n", + " seed_sentencepiece_size: 1000000\n", + " shrinking_factor: 0.75\n", + " max_sentence_length: 4192\n", + " num_threads: 16\n", + " num_sub_iterations: 2\n", + " max_sentencepiece_length: 16\n", + " split_by_unicode_script: 1\n", + " split_by_number: 1\n", + " split_by_whitespace: 1\n", + " split_digits: 0\n", + " pretokenization_delimiter: \n", + " treat_whitespace_as_suffix: 0\n", + " allow_whitespace_only_pieces: 0\n", + " required_chars: \n", + " byte_fallback: 0\n", + " vocabulary_output_piece_score: 1\n", + " train_extremely_large_corpus: 0\n", + " seed_sentencepieces_file: \n", + " hard_vocab_limit: 0\n", + " use_all_vocab: 0\n", + " unk_id: 0\n", + " bos_id: -1\n", + " eos_id: -1\n", + " pad_id: -1\n", + " unk_piece: \n", + " bos_piece: \n", + " eos_piece: \n", + " pad_piece: \n", + " unk_surface: ⁇ \n", + " enable_differential_privacy: 0\n", + " differential_privacy_noise_level: 0\n", + " differential_privacy_clipping_threshold: 0\n", + "}\n", + "normalizer_spec {\n", + " name: nmt_nfkc\n", + " add_dummy_prefix: 1\n", + " remove_extra_whitespaces: 1\n", + " escape_whitespaces: 1\n", + " normalization_rule_tsv: \n", + "}\n", + "denormalizer_spec {}\n", + "trainer_interface.cc(353) LOG(INFO) SentenceIterator is not specified. Using MultiFileSentenceIterator.\n", + "trainer_interface.cc(185) LOG(INFO) Loading corpus: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/text_corpus/document.txt\n", + "trainer_interface.cc(409) LOG(INFO) Loaded all 236954 sentences\n", + "trainer_interface.cc(425) LOG(INFO) Adding meta_piece: \n", + "trainer_interface.cc(430) LOG(INFO) Normalizing sentences...\n", + "trainer_interface.cc(539) LOG(INFO) all chars count=46492234\n", + "trainer_interface.cc(560) LOG(INFO) Alphabet size=38\n", + "trainer_interface.cc(561) LOG(INFO) Final character coverage=1\n", + "trainer_interface.cc(592) LOG(INFO) Done! preprocessed 236954 sentences.\n", + "trainer_interface.cc(598) LOG(INFO) Tokenizing input sentences with whitespace: 236954\n", + "trainer_interface.cc(609) LOG(INFO) Done! 48952\n", + "bpe_model_trainer.cc(159) LOG(INFO) Updating active symbols. max_freq=1625104 min_freq=1\n", + "bpe_model_trainer.cc(268) LOG(INFO) Added: freq=314950 size=20 all=1288 active=1250 piece=▁p\n", + "bpe_model_trainer.cc(268) LOG(INFO) Added: freq=185513 size=40 all=2095 active=2057 piece=▁you\n", + "bpe_model_trainer.cc(268) LOG(INFO) Added: freq=121485 size=60 all=3198 active=3160 piece=▁ha\n", + "bpe_model_trainer.cc(268) LOG(INFO) Added: freq=72557 size=80 all=4263 active=4225 piece=se\n", + "trainer_interface.cc(687) LOG(INFO) Saving model: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer.model\n", + "trainer_interface.cc(699) LOG(INFO) Saving vocabs: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer.vocab\n", + "Serialized tokenizer at location : /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128\n", + "INFO:root:Done!\n" + ] + } + ], "source": [ "!python scripts/process_asr_text_tokenizer.py \\\n", " --manifest=$train_manifest_cleaned,$dev_manifest_cleaned \\\n", - " --vocab_size=$VOCAB_SIZE \\\n", + " --vocab_size=128 \\\n", " --data_root=$tokenizer_dir \\\n", " --tokenizer=\"spe\" \\\n", " --spe_type=$TOKENIZER_TYPE \\\n", " --spe_character_coverage=1.0 \\\n", " --no_lower_case \\\n", " --log" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 52, "metadata": { "id": "G5TxLHtKPW4E" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenizer directory : /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v228/\n" + ] + } + ], "source": [ "TOKENIZER_DIR = f\"{tokenizer_dir}/tokenizer_spe_{TOKENIZER_TYPE}_v{VOCAB_SIZE}/\"\n", "print(\"Tokenizer directory :\", TOKENIZER_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokens : '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '' '+'\n", + " \"\" \"\" \n" + ] + } ], - "execution_count": null, - "outputs": [] + "source": [ + "# tokens is a list of , , ... tokens\n", + "tokens = [f\"\" for i in range(1, 228 + 1)] # Just added 228 tokens, you can add more\n", + "tokens_string = \"\"\"' '\"\"\".join(tokens)\n", + "tokens_string = f\"'{tokens_string}'\"\n", + "print(f\"Tokens : {tokens_string}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: Created token '' at ID 128\n", + "INFO: Created token '' at ID 129\n", + "INFO: Created token '' at ID 130\n", + "INFO: Created token '' at ID 131\n", + "INFO: Created token '' at ID 132\n", + "INFO: Created token '' at ID 133\n", + "INFO: Created token '' at ID 134\n", + "INFO: Created token '' at ID 135\n", + "INFO: Created token '' at ID 136\n", + "INFO: Created token '' at ID 137\n", + "INFO: Created token '' at ID 138\n", + "INFO: Created token '' at ID 139\n", + "INFO: Created token '' at ID 140\n", + "INFO: Created token '' at ID 141\n", + "INFO: Created token '' at ID 142\n", + "INFO: Created token '' at ID 143\n", + "INFO: Created token '' at ID 144\n", + "INFO: Created token '' at ID 145\n", + "INFO: Created token '' at ID 146\n", + "INFO: Created token '' at ID 147\n", + "INFO: Created token '' at ID 148\n", + "INFO: Created token '' at ID 149\n", + "INFO: Created token '' at ID 150\n", + "INFO: Created token '' at ID 151\n", + "INFO: Created token '' at ID 152\n", + "INFO: Created token '' at ID 153\n", + "INFO: Created token '' at ID 154\n", + "INFO: Created token '' at ID 155\n", + "INFO: Created token '' at ID 156\n", + "INFO: Created token '' at ID 157\n", + "INFO: Created token '' at ID 158\n", + "INFO: Created token '' at ID 159\n", + "INFO: Created token '' at ID 160\n", + "INFO: Created token '' at ID 161\n", + "INFO: Created token '' at ID 162\n", + "INFO: Created token '' at ID 163\n", + "INFO: Created token '' at ID 164\n", + "INFO: Created token '' at ID 165\n", + "INFO: Created token '' at ID 166\n", + "INFO: Created token '' at ID 167\n", + "INFO: Created token '' at ID 168\n", + "INFO: Created token '' at ID 169\n", + "INFO: Created token '' at ID 170\n", + "INFO: Created token '' at ID 171\n", + "INFO: Created token '' at ID 172\n", + "INFO: Created token '' at ID 173\n", + "INFO: Created token '' at ID 174\n", + "INFO: Created token '' at ID 175\n", + "INFO: Created token '' at ID 176\n", + "INFO: Created token '' at ID 177\n", + "INFO: Created token '' at ID 178\n", + "INFO: Created token '' at ID 179\n", + "INFO: Created token '' at ID 180\n", + "INFO: Created token '' at ID 181\n", + "INFO: Created token '' at ID 182\n", + "INFO: Created token '' at ID 183\n", + "INFO: Created token '' at ID 184\n", + "INFO: Created token '' at ID 185\n", + "INFO: Created token '' at ID 186\n", + "INFO: Created token '' at ID 187\n", + "INFO: Created token '' at ID 188\n", + "INFO: Created token '' at ID 189\n", + "INFO: Created token '' at ID 190\n", + "INFO: Created token '' at ID 191\n", + "INFO: Created token '' at ID 192\n", + "INFO: Created token '' at ID 193\n", + "INFO: Created token '' at ID 194\n", + "INFO: Created token '' at ID 195\n", + "INFO: Created token '' at ID 196\n", + "INFO: Created token '' at ID 197\n", + "INFO: Created token '' at ID 198\n", + "INFO: Created token '' at ID 199\n", + "INFO: Created token '' at ID 200\n", + "INFO: Created token '' at ID 201\n", + "INFO: Created token '' at ID 202\n", + "INFO: Created token '' at ID 203\n", + "INFO: Created token '' at ID 204\n", + "INFO: Created token '' at ID 205\n", + "INFO: Created token '' at ID 206\n", + "INFO: Created token '' at ID 207\n", + "INFO: Created token '' at ID 208\n", + "INFO: Created token '' at ID 209\n", + "INFO: Created token '' at ID 210\n", + "INFO: Created token '' at ID 211\n", + "INFO: Created token '' at ID 212\n", + "INFO: Created token '' at ID 213\n", + "INFO: Created token '' at ID 214\n", + "INFO: Created token '' at ID 215\n", + "INFO: Created token '' at ID 216\n", + "INFO: Created token '' at ID 217\n", + "INFO: Created token '' at ID 218\n", + "INFO: Created token '' at ID 219\n", + "INFO: Created token '' at ID 220\n", + "INFO: Created token '' at ID 221\n", + "INFO: Created token '' at ID 222\n", + "INFO: Created token '' at ID 223\n", + "INFO: Created token '' at ID 224\n", + "INFO: Created token '' at ID 225\n", + "INFO: Created token '' at ID 226\n", + "INFO: Created token '' at ID 227\n", + "INFO: Created token '' at ID 228\n", + "INFO: Created token '' at ID 229\n", + "INFO: Created token '' at ID 230\n", + "INFO: Created token '' at ID 231\n", + "INFO: Created token '' at ID 232\n", + "INFO: Created token '' at ID 233\n", + "INFO: Created token '' at ID 234\n", + "INFO: Created token '' at ID 235\n", + "INFO: Created token '' at ID 236\n", + "INFO: Created token '' at ID 237\n", + "INFO: Created token '' at ID 238\n", + "INFO: Created token '' at ID 239\n", + "INFO: Created token '' at ID 240\n", + "INFO: Created token '' at ID 241\n", + "INFO: Created token '' at ID 242\n", + "INFO: Created token '' at ID 243\n", + "INFO: Created token '' at ID 244\n", + "INFO: Created token '' at ID 245\n", + "INFO: Created token '' at ID 246\n", + "INFO: Created token '' at ID 247\n", + "INFO: Created token '' at ID 248\n", + "INFO: Created token '' at ID 249\n", + "INFO: Created token '' at ID 250\n", + "INFO: Created token '' at ID 251\n", + "INFO: Created token '' at ID 252\n", + "INFO: Created token '' at ID 253\n", + "INFO: Created token '' at ID 254\n", + "INFO: Created token '' at ID 255\n", + "INFO: Created token '' at ID 256\n", + "INFO: Created token '' at ID 257\n", + "INFO: Created token '' at ID 258\n", + "INFO: Created token '' at ID 259\n", + "INFO: Created token '' at ID 260\n", + "INFO: Created token '' at ID 261\n", + "INFO: Created token '' at ID 262\n", + "INFO: Created token '' at ID 263\n", + "INFO: Created token '' at ID 264\n", + "INFO: Created token '' at ID 265\n", + "INFO: Created token '' at ID 266\n", + "INFO: Created token '' at ID 267\n", + "INFO: Created token '' at ID 268\n", + "INFO: Created token '' at ID 269\n", + "INFO: Created token '' at ID 270\n", + "INFO: Created token '' at ID 271\n", + "INFO: Created token '' at ID 272\n", + "INFO: Created token '' at ID 273\n", + "INFO: Created token '' at ID 274\n", + "INFO: Created token '' at ID 275\n", + "INFO: Created token '' at ID 276\n", + "INFO: Created token '' at ID 277\n", + "INFO: Created token '' at ID 278\n", + "INFO: Created token '' at ID 279\n", + "INFO: Created token '' at ID 280\n", + "INFO: Created token '' at ID 281\n", + "INFO: Created token '' at ID 282\n", + "INFO: Created token '' at ID 283\n", + "INFO: Created token '' at ID 284\n", + "INFO: Created token '' at ID 285\n", + "INFO: Created token '' at ID 286\n", + "INFO: Created token '' at ID 287\n", + "INFO: Created token '' at ID 288\n", + "INFO: Created token '' at ID 289\n", + "INFO: Created token '' at ID 290\n", + "INFO: Created token '' at ID 291\n", + "INFO: Created token '' at ID 292\n", + "INFO: Created token '' at ID 293\n", + "INFO: Created token '' at ID 294\n", + "INFO: Created token '' at ID 295\n", + "INFO: Created token '' at ID 296\n", + "INFO: Created token '' at ID 297\n", + "INFO: Created token '' at ID 298\n", + "INFO: Created token '' at ID 299\n", + "INFO: Created token '' at ID 300\n", + "INFO: Created token '' at ID 301\n", + "INFO: Created token '' at ID 302\n", + "INFO: Created token '' at ID 303\n", + "INFO: Created token '' at ID 304\n", + "INFO: Created token '' at ID 305\n", + "INFO: Created token '' at ID 306\n", + "INFO: Created token '' at ID 307\n", + "INFO: Created token '' at ID 308\n", + "INFO: Created token '' at ID 309\n", + "INFO: Created token '' at ID 310\n", + "INFO: Created token '' at ID 311\n", + "INFO: Created token '' at ID 312\n", + "INFO: Created token '' at ID 313\n", + "INFO: Created token '' at ID 314\n", + "INFO: Created token '' at ID 315\n", + "INFO: Created token '' at ID 316\n", + "INFO: Created token '' at ID 317\n", + "INFO: Created token '' at ID 318\n", + "INFO: Created token '' at ID 319\n", + "INFO: Created token '' at ID 320\n", + "INFO: Created token '' at ID 321\n", + "INFO: Created token '' at ID 322\n", + "INFO: Created token '' at ID 323\n", + "INFO: Created token '' at ID 324\n", + "INFO: Created token '' at ID 325\n", + "INFO: Created token '' at ID 326\n", + "INFO: Created token '' at ID 327\n", + "INFO: Created token '' at ID 328\n", + "INFO: Created token '' at ID 329\n", + "INFO: Created token '' at ID 330\n", + "INFO: Created token '' at ID 331\n", + "INFO: Created token '' at ID 332\n", + "INFO: Created token '' at ID 333\n", + "INFO: Created token '' at ID 334\n", + "INFO: Created token '' at ID 335\n", + "INFO: Created token '' at ID 336\n", + "INFO: Created token '' at ID 337\n", + "INFO: Created token '' at ID 338\n", + "INFO: Created token '' at ID 339\n", + "INFO: Created token '' at ID 340\n", + "INFO: Created token '' at ID 341\n", + "INFO: Created token '' at ID 342\n", + "INFO: Created token '' at ID 343\n", + "INFO: Created token '' at ID 344\n", + "INFO: Created token '' at ID 345\n", + "INFO: Created token '' at ID 346\n", + "INFO: Created token '' at ID 347\n", + "INFO: Created token '' at ID 348\n", + "INFO: Created token '' at ID 349\n", + "INFO: Created token '' at ID 350\n", + "INFO: Created token '' at ID 351\n", + "INFO: Created token '' at ID 352\n", + "INFO: Created token '' at ID 353\n", + "INFO: Created token '' at ID 354\n", + "INFO: Created token '+' at ID 355\n", + "INFO: New tokenizer vocab size: 356\n", + "INFO: Created new tokenizer at: /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer_new.model\n" + ] + } + ], + "source": [ + "# ! protoc --python_out=/workspace/nemo/NeMo-opensource/scripts/tokenizers/ sentencepiece_model.proto\n", + "\n", + "!python /workspace/nemo/NeMo-opensource/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \\\n", + "--input_file /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer.model \\\n", + "--output_file /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer_new.model\\\n", + "--is_userdefined \\\n", + "--tokens $tokens_string" + ] + }, + { + "cell_type": "code", + "execution_count": 81, + "metadata": {}, + "outputs": [], + "source": [ + "import sentencepiece as spm\n", + "\n", + "sp = spm.SentencePieceProcessor()\n", + "sp.load('/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/tokenizer_new.model')\n", + "\n", + "vocab_list = [sp.id_to_piece(i) for i in range(sp.get_piece_size())]\n", + "# Save the vocabulary to a file\n", + "with open('new_tokenizer.vocab', 'w') as vocab_file:\n", + " for token in vocab_list:\n", + " vocab_file.write(token + '\\n')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [], + "source": [ + "# Updating vocab.txt\n", + "vocab_file = '/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/vocab.txt'\n", + "new_vocab_file = '/workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/new_vocab.txt'\n", + "# read the existing vocab file and add , , ... tokens\n", + "with open(vocab_file, 'r') as f:\n", + " lines = f.readlines()\n", + "\n", + "# Add the new tokens to the vocab file\n", + "with open(new_vocab_file, 'w') as f:\n", + " for line in lines:\n", + " f.write(line)\n", + " for token in tokens:\n", + " f.write(token + '\\n')" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "!cp new_tokenizer.vocab /workspace/nemo/NeMo-opensource/tutorials/asr/tokenizers/av_tokenizer/tokenizer_spe_bpe_v128/" + ] }, { "cell_type": "markdown", @@ -1531,9 +2520,19 @@ }, { "cell_type": "code", + "execution_count": 61, "metadata": { "id": "8sAz2_RyMu7J" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of tokens : 128\n" + ] + } + ], "source": [ "# Number of tokens in tokenizer -\n", "with open(os.path.join(TOKENIZER_DIR, 'tokenizer.vocab')) as f:\n", @@ -1541,15 +2540,23 @@ "\n", "num_tokens = len(tokens)\n", "print(\"Number of tokens : \", num_tokens)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 62, "metadata": { "id": "zktPYPCxNXNO" }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The text in this dataset is too small to construct a tokenizer with vocab size = 228. Current number of tokens = 128. Please reconstruct the tokenizer with fewer tokens\n" + ] + } + ], "source": [ "if num_tokens < VOCAB_SIZE:\n", " print(\n", @@ -1557,9 +2564,7 @@ " f\"with vocab size = {VOCAB_SIZE}. Current number of tokens = {num_tokens}. \"\n", " f\"Please reconstruct the tokenizer with fewer tokens\"\n", " )" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1574,14 +2579,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "mmSj18iQQTZx" }, + "outputs": [], "source": [ "model = nemo_asr.models.ASRModel.from_pretrained(\"stt_en_citrinet_512\", map_location='cpu')" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1601,15 +2606,15 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "FmFQKwGkoaIx" }, + "outputs": [], "source": [ "# Preserve the decoder parameters in case weight matching can be done later\n", "pretrained_decoder = model.decoder.state_dict()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1624,14 +2629,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "-8SKfYSVorgg" }, + "outputs": [], "source": [ "model.change_vocabulary(new_tokenizer_dir=TOKENIZER_DIR, new_tokenizer_type=\"bpe\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1648,9 +2653,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "367FBtRDorkT" }, + "outputs": [], "source": [ "# Insert preserved model weights if shapes match\n", "if model.decoder.decoder_layers[0].weight.shape == pretrained_decoder['decoder_layers.0.weight'].shape:\n", @@ -1658,9 +2665,7 @@ " logging.info(\"Decoder shapes matched - restored weights from pre-trained model\")\n", "else:\n", " logging.info(\"\\nDecoder shapes did not match - could not restore decoder weights from pre-trained model.\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1675,22 +2680,24 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "lfDW0gQVpm4d" }, + "outputs": [], "source": [ "#@title Freeze Encoder { display-mode: \"form\" }\n", "freeze_encoder = True #@param [\"False\", \"True\"] {type:\"raw\"}\n", "freeze_encoder = bool(freeze_encoder)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "oLkm96zkplrX" }, + "outputs": [], "source": [ "if freeze_encoder:\n", " model.encoder.freeze()\n", @@ -1699,9 +2706,7 @@ "else:\n", " model.encoder.unfreeze()\n", " logging.info(\"Model encoder has been un-frozen\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1718,14 +2723,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "pBYAd_2-R2r3" }, + "outputs": [], "source": [ "cfg = copy.deepcopy(model.cfg)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1740,9 +2745,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "NfbtgTC-RyzF" }, + "outputs": [], "source": [ "# Setup new tokenizer\n", "cfg.tokenizer.dir = TOKENIZER_DIR\n", @@ -1750,9 +2757,7 @@ "\n", "# Set tokenizer config\n", "model.cfg.tokenizer = cfg.tokenizer" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1769,21 +2774,23 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "wnw-ygClmg7t" }, + "outputs": [], "source": [ "# Setup train/val/test configs\n", "print(OmegaConf.to_yaml(cfg.train_ds))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "OlOowK7rRAvs" }, + "outputs": [], "source": [ "# Setup train, validation, test configs\n", "with open_dict(cfg):\n", @@ -1810,23 +2817,21 @@ " cfg.test_ds.pin_memory = True\n", " cfg.test_ds.use_start_end_token = True\n", " cfg.test_ds.trim_silence = True" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "y98ZAhBtRtoD" }, + "outputs": [], "source": [ "# setup model with new configs\n", "model.setup_training_data(cfg.train_ds)\n", "model.setup_multiple_validation_data(cfg.validation_ds)\n", "model.setup_multiple_test_data(cfg.test_ds)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1850,9 +2855,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "ozJDj6BktKw-" }, + "outputs": [], "source": [ "def analyse_ctc_failures_in_model(model):\n", " count_ctc_failures = 0\n", @@ -1889,52 +2896,52 @@ " model = model.train()\n", "\n", " return count_ctc_failures, am_seq_lengths, target_seq_lengths" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "hJGUcq2BtKzw" }, + "outputs": [], "source": [ "results = analyse_ctc_failures_in_model(model)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "crEWxvI2tK2S" }, + "outputs": [], "source": [ "num_ctc_failures, am_seq_lengths, target_seq_lengths = results" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "L8M0-mSI1Jp5" }, + "outputs": [], "source": [ "if num_ctc_failures > 0:\n", " logging.warning(f\"\\nCTC loss will fail for {num_ctc_failures} samples ({num_ctc_failures * 100./ float(len(am_seq_lengths))} % of samples)!\\n\"\n", " f\"Increase the vocabulary size of the tokenizer so that this number becomes close to zero !\")\n", "else:\n", " logging.info(\"No CTC failure cases !\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "00wKre0W1Jsx" }, + "outputs": [], "source": [ "# Compute average ratio of T / U\n", "avg_T = sum(am_seq_lengths) / float(len(am_seq_lengths))\n", @@ -1949,9 +2956,7 @@ "print(f\"Average Target sequence length = {avg_U}\")\n", "print()\n", "print(f\"Ratio of Average AM sequence length to target sequence length = {avg_length_ratio}\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1966,14 +2971,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "sS-xoplxSTJv" }, + "outputs": [], "source": [ "print(OmegaConf.to_yaml(cfg.optim))" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -1988,9 +2993,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "Io55nnbdXoeG" }, + "outputs": [], "source": [ "with open_dict(model.cfg.optim):\n", " model.cfg.optim.lr = 0.025\n", @@ -1998,9 +3005,7 @@ " model.cfg.optim.sched.warmup_steps = None # Remove default number of steps of warmup\n", " model.cfg.optim.sched.warmup_ratio = 0.10 # 10 % warmup\n", " model.cfg.optim.sched.min_lr = 1e-9" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -2015,9 +3020,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "6Vb35_oRh_sV" }, + "outputs": [], "source": [ "with open_dict(model.cfg.spec_augment):\n", " model.cfg.spec_augment.freq_masks = 2\n", @@ -2026,9 +3033,7 @@ " model.cfg.spec_augment.time_width = 0.05\n", "\n", "model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -2043,30 +3048,30 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", "id": "UfUlPXZS6vlV" }, + "outputs": [], "source": [ "#@title Metric\n", "use_cer = True #@param [\"False\", \"True\"] {type:\"raw\"}\n", "log_prediction = True #@param [\"False\", \"True\"] {type:\"raw\"}\n", "\n" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "6qpbMNZh68p9" }, + "outputs": [], "source": [ "model.wer.use_cer = use_cer\n", "model.wer.log_prediction = log_prediction" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -2083,9 +3088,11 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "bonpx5sRS07M" }, + "outputs": [], "source": [ "import torch\n", "import pytorch_lightning as ptl\n", @@ -2111,15 +3118,15 @@ "\n", "# finally, update the model's internal config\n", "model.cfg = model._cfg" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "SR4CiViFS8Ww" }, + "outputs": [], "source": [ "from nemo.utils import exp_manager\n", "\n", @@ -2141,15 +3148,15 @@ "config = OmegaConf.structured(config)\n", "\n", "logdir = exp_manager.exp_manager(trainer, config)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "OlvyYwYWTsl6" }, + "outputs": [], "source": [ "try:\n", " from google import colab\n", @@ -2163,21 +3170,19 @@ " %tensorboard --logdir /content/experiments/lang-$LANGUAGE/ASR-Model-Language-$LANGUAGE/\n", "else:\n", " print(\"To use tensorboard, please use this notebook in a Google Colab environment.\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "6X21Q2qfVLvG" }, + "outputs": [], "source": [ "%%time\n", "trainer.fit(model)" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -2192,16 +3197,16 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "id": "DoWNVNYGOaMX" }, + "outputs": [], "source": [ "save_path = f\"Model-{LANGUAGE}.nemo\"\n", "model.save_to(f\"{save_path}\")\n", "print(f\"Model saved at path : {os.getcwd() + os.path.sep + save_path}\")" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -2216,5 +3221,29 @@ "While the focus was on a small dataset for Japanese, nearly all of this information can be used for larger datasets and other scenarios where compute is limited, or the model's size prevents fine-tuning the entire model." ] } - ] -} \ No newline at end of file + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tutorials/asr/Streaming_ASR.ipynb b/tutorials/asr/Streaming_ASR.ipynb index a4701dc025d8..0dac0f23dc31 100644 --- a/tutorials/asr/Streaming_ASR.ipynb +++ b/tutorials/asr/Streaming_ASR.ipynb @@ -1,48 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lJz6FDU1lRzc" - }, - "outputs": [], - "source": [ - "\"\"\"\n", - "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", - "\n", - "Instructions for setting up Colab are as follows:\n", - "1. Open a new Python 3 notebook.\n", - "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", - "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", - "4. Run this cell to set up dependencies.\n", - "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n", - "\n\nNOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n", - "\"\"\"\n", - "# If you're using Google Colab and not running locally, run this cell.\n", - "\n", - "## Install dependencies\n", - "!pip install wget\n", - "!apt-get install sox libsndfile1 ffmpeg\n", - "!pip install text-unidecode\n", - "!pip install matplotlib>=3.3.2\n", - "\n", - "## Install NeMo\n", - "BRANCH = 'main'\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", - "\n", - "## Grab the config we'll use in this example\n", - "!mkdir configs\n", - "!wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/config.yaml\n", - "\n", - "\"\"\"\n", - "Remember to restart the runtime for the kernel to pick up any upgraded packages (e.g. matplotlib)!\n", - "Alternatively, you can uncomment the exit() below to crash and restart the kernel, in the case\n", - "that you want to use the \"Run All Cells\" (or similar) option.\n", - "\"\"\"\n", - "# exit()" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -83,33 +40,41 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# If something goes wrong during data processing, un-comment the following line to delete the cached dataset \n", "# !rm -rf datasets/mini-dev-clean\n", - "!mkdir -p datasets/mini-dev-clean" + "!mkdir -p /disk1/ksingla/datasets/mini-dev-clean" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100%|███████████████████████████████████████████| 38/38 [00:03<00:00, 12.53it/s]\n" + ] + } + ], "source": [ "!python ../../scripts/dataset_processing/get_librispeech_data.py \\\n", - " --data_root \"datasets/mini-dev-clean/\" \\\n", + " --data_root \"/disk1/ksingla/datasets/mini-dev-clean/\" \\\n", " --data_sets dev_clean_2" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "manifest = \"datasets/mini-dev-clean/dev_clean_2.json\"" + "manifest = \"/disk1/ksingla/datasets/mini-dev-clean/dev_clean_2.json\"" ] }, { @@ -121,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -144,13 +109,51 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers\n", + " built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)\n", + " configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-pocketsphinx --enable-librsvg --enable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\n", + " libavutil 56. 70.100 / 56. 70.100\n", + " libavcodec 58.134.100 / 58.134.100\n", + " libavformat 58. 76.100 / 58. 76.100\n", + " libavdevice 58. 13.100 / 58. 13.100\n", + " libavfilter 7.110.100 / 7.110.100\n", + " libswscale 5. 9.100 / 5. 9.100\n", + " libswresample 3. 9.100 / 3. 9.100\n", + " libpostproc 55. 9.100 / 55. 9.100\n", + "\u001b[0;33mGuessed Channel Layout for Input Stream #0.0 : mono\n", + "\u001b[0mInput #0, concat, from 'concat_file.txt':\n", + " Duration: N/A, start: 0.000000, bitrate: 256 kb/s\n", + " Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, mono, s16, 256 kb/s\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output #0, wav, to '/disk1/ksingla/datasets/mini-dev-clean/concatenated_audio.wav':\n", + " Metadata:\n", + " ISFT : Lavf58.76.100\n", + " Stream #0:0: Audio: pcm_s16le ([1][0][0][0] / 0x0001), 16000 Hz, mono, s16, 256 kb/s\n", + "Stream mapping:\n", + " Stream #0:0 -> #0:0 (copy)\n", + "Press [q] to stop, [?] for help\n", + "size= 28251kB time=00:15:03.96 bitrate= 256.0kbits/s speed= 898x \n", + "video:0kB audio:28251kB subtitle:0kB other streams:0kB global headers:0kB muxing overhead: 0.000270%\n", + "Finished concatenating audio file!\n" + ] + } + ], "source": [ "new_duration, ref_transcript = concat_audio(manifest, 15*60)\n", "\n", - "concat_audio_path = \"datasets/mini-dev-clean/concatenated_audio.wav\"\n", + "concat_audio_path = \"/disk1/ksingla/datasets/mini-dev-clean/concatenated_audio.wav\"\n", "\n", "!ffmpeg -t {new_duration} -safe 0 -f concat -i concat_file.txt -c copy -t {new_duration} {concat_audio_path} -y\n", "print(\"Finished concatenating audio file!\")" @@ -166,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -178,11 +181,22 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'cpu'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device = 'cpu'\n", "device" ] }, @@ -195,22 +209,95 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-05 23:24:06 cloud:58] Found existing object /root/.cache/torch/NeMo/NeMo_1.23.0/stt_en_conformer_ctc_small/5d2d8e5b2b5adb8f5091363c6ba19c55/stt_en_conformer_ctc_small.nemo.\n", + "[NeMo I 2024-07-05 23:24:06 cloud:64] Re-using file from: /root/.cache/torch/NeMo/NeMo_1.23.0/stt_en_conformer_ctc_small/5d2d8e5b2b5adb8f5091363c6ba19c55/stt_en_conformer_ctc_small.nemo\n", + "[NeMo I 2024-07-05 23:24:06 common:924] Instantiating model from pre-trained checkpoint\n", + "[NeMo I 2024-07-05 23:24:07 mixins:172] Tokenizer SentencePieceTokenizer initialized with 1024 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-07-05 23:24:07 modelPT:165] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: /data/NeMo_ASR_SET/English/v2.0/train/tarred_audio_manifest.json\n", + " sample_rate: 16000\n", + " batch_size: 64\n", + " shuffle: true\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " shuffle_n: 2048\n", + " is_tarred: true\n", + " tarred_audio_filepaths: /data/NeMo_ASR_SET/English/v2.0/train/audio__OP_0..4095_CL_.tar\n", + " \n", + "[NeMo W 2024-07-05 23:24:07 modelPT:172] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath:\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-dev-other.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-dev-clean.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-test-other.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 64\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " is_tarred: false\n", + " tarred_audio_filepaths: na\n", + " \n", + "[NeMo W 2024-07-05 23:24:07 modelPT:178] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath:\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-test-other.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-dev-clean.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-dev-other.json\n", + " - /data/ASR/LibriSpeech/librispeech_withsp2/manifests/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 64\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " is_tarred: false\n", + " tarred_audio_filepaths: na\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-05 23:24:07 features:289] PADDING: 0\n", + "[NeMo I 2024-07-05 23:24:07 save_restore_connector:249] Model EncDecCTCModelBPE was successfully restored from /root/.cache/torch/NeMo/NeMo_1.23.0/stt_en_conformer_ctc_small/5d2d8e5b2b5adb8f5091363c6ba19c55/stt_en_conformer_ctc_small.nemo.\n" + ] + } + ], "source": [ "# Clear up memory\n", - "torch.cuda.empty_cache()\n", - "gc.collect()\n", - "model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(\"stt_en_conformer_ctc_large\", map_location=device)\n", - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "#torch.cuda.empty_cache()\n", + "#gc.collect()\n", + "model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(\"stt_en_conformer_ctc_small\", map_location=device)\n", + "device = 'cpu'\n", "# device = 'cpu' # You can transcribe even longer samples on the CPU, though it will take much longer !\n", "model = model.to(device)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -233,12 +320,25 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "EncDecCTCModel.transcribe() got an unexpected keyword argument 'device'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[11], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast():\n\u001b[0;32m----> 2\u001b[0m transcript \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranscribe\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mconcat_audio_path\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mcpu\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n", + "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115\u001b[0m, in \u001b[0;36mcontext_decorator..decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 115\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mTypeError\u001b[0m: EncDecCTCModel.transcribe() got an unexpected keyword argument 'device'" + ] + } + ], "source": [ "with autocast():\n", - " transcript = model.transcribe([concat_audio_path], batch_size=1)[0]" + " transcript = model.transcribe([concat_audio_path], batch_size=1, device=\"cpu\")[0]" ] }, { diff --git a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb index c9c547a8383e..629a45756d52 100644 --- a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb @@ -1,29 +1,8 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "provenance": [], - "collapsed_sections": [ - "dm-qqTdZDUlZ", - "GGKgsW5gvAuf", - "0CqpJGR6ecYW" - ], - "toc_visible": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "accelerator": "GPU" - }, "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "id": "pEYsuj0J9pId" }, @@ -38,28 +17,85 @@ "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run this cell to set up dependencies.\n", "5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n", - "\n\nNOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n", + "\n", + "\n", + "NOTE: User is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use.\n", "\"\"\"\n", "# If you're using Google Colab and not running locally, run this cell.\n", - "import os\n", "\n", "# Install dependencies\n", - "!pip install wget\n", - "!apt-get install sox libsndfile1 ffmpeg\n", - "!pip install text-unidecode\n", - "!pip install matplotlib>=3.3.2\n", + "# !pip install wget\n", + "# !apt-get install sox libsndfile1 ffmpeg\n", + "# !pip install text-unidecode\n", + "# !pip install matplotlib>=3.3.2\n", + "\n", "\n", "## Install NeMo\n", "BRANCH = 'main'\n", - "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", + "# !python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n", "\n", "## Grab the config we'll use in this example\n", "# !mkdir configs\n", "# !wget -P configs/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/conf/contextnet_rnnt/contextnet_rnnt.yaml" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/nemo/NeMo-opensource/nemo/collections/asr/__init__.py\n", + "/workspace/nemo/NeMo-opensource/nemo/core/__init__.py\n", + "/workspace/nemo/NeMo-opensource/nemo/__init__.py\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "\n", + "# Insert local paths at the beginning of sys.path\n", + "sys.path.insert(0, os.path.abspath('/workspace/nemo/NeMo-opensource/'))\n", + "\n", + "import nemo.collections.asr as nemo_asr\n", + "print(nemo_asr.__file__)\n", + "import nemo.core as nemo_core\n", + "print(nemo_core.__file__)\n", + "from nemo.core import adapter_mixins\n", + "import nemo\n", + "print(nemo.__file__)\n", + "# Restore the site-packages paths\n", + "# sys.path.extend(site_packages_paths)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/lightning/__init__.py\n" + ] + } + ], + "source": [ + "import lightning\n", + "print(lightning.__file__)" + ] + }, { "cell_type": "markdown", + "metadata": { + "id": "cTV4WLrArmxS" + }, "source": [ "# ASR Domain Adaptation with Adapters\n", "\n", @@ -70,13 +106,13 @@ "-----\n", "\n", "In this tutorial, we will showcase **Adapters** : A powerful method to efficiently adapt a pre-trained model to a new dataset (with minimal amounts of data, even just 30 minutes !) with minimal compute resources (on a single GPU, in around 10 minutes of training time).\n" - ], - "metadata": { - "id": "cTV4WLrArmxS" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "f-LSGyL4xw9c" + }, "source": [ "## What are Adapters?\n", "\n", @@ -99,13 +135,13 @@ "-----\n", "\n", "Adapter modules such as this are usually initialized. The initial output of the adapter will always be zeros to prevent degradation of the original model's performance due to the addition of such modules." - ], - "metadata": { - "id": "f-LSGyL4xw9c" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "YGn1__-Jv2Bq" + }, "source": [ "## Advantages of Adapters\n", "\n", @@ -116,13 +152,13 @@ "- **Fast convergence**: Since the adapters only need to learn to modify the module's output slightly, and each adapter has a trivial parameter cost, they converge rapidly.\n", "- **Adapt only the encoder**: Adapters can be used anywhere, but they are most commonly used in just the encoder, keeping the decoder modules frozen. This allows the decoder to be unaffected by costly CTC/RNN-T training, which takes time to converge, and just the adapter modules in the encoder need to be updated.\n", "- **Dynamic and flexible adaptation**: Since adapter modules can be added any number of times, a single shared \"core\" model can have multiple adapters that are enabled/disabled dynamically to adapt to numerous scenarios. This potentially offers the case where a single \"core\" model is shared across multiple users, and each user has a small, personal adapter module used for personalization. " - ], - "metadata": { - "id": "YGn1__-Jv2Bq" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "8d7y1cygv4MP" + }, "source": [ "## Limitations of Adapters\n", "\n", @@ -135,13 +171,13 @@ " - **Note**: There is nothing fundamentally wrong with still changing the vocabulary of a model that supports adapters. The benefits of adapters will reduce significantly and require costly training (similar in time and memory to finetuning). The model can no longer recover its performance by disabling all of its adapters.\n", "- **Easy to overfit**: Since adapters enable domain adaptation on very small amounts of speech data, it is trivial to rapidly overfit these datasets and significantly degrade performance on the original domain. \n", " - **Note**: This can be overcome with some experimentation, further boosted by the fast experimentation cycle that adapters enable." - ], - "metadata": { - "id": "8d7y1cygv4MP" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "mtYWTi0irkS6" + }, "source": [ "# Dataset preparation\n", "\n", @@ -152,13 +188,15 @@ "First, we prepare some datasets that the original model was **not trained on**, making it a new domain to be adapted. \n", "\n", "In this tutorial, we will be utilizing the `AN4` dataset - also known as the Alphanumeric dataset, which was collected and published by Carnegie Mellon University. We chose this dataset primarily because it is **very small in size** (`<1 hours of training data`), **easy to overfit when training from scratch / fine-tuning by changing the decoder** (`previous tutorials can mostly get around 10-20% WER with fine-tuning without hyperparameter tuning`), and its **text is perfectly supported by the tokenization/decoding scheme of the model**." - ], - "metadata": { - "id": "mtYWTi0irkS6" - } + ] }, { "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "NpKgT6q5-gNk" + }, + "outputs": [], "source": [ "import os\n", "\n", @@ -167,15 +205,37 @@ "\n", "if not os.path.exists(\"scripts/process_an4_data.py\"):\n", " !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/dataset_processing/process_an4_data.py" - ], - "metadata": { - "id": "NpKgT6q5-gNk" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "0wZZuUDi_gEV" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "******\n", + "Tarfile already exists.\n", + "Finished conversion.\n", + "******\n", + "Preparing AN4 dataset ...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/workspace/nemo/NeMo-opensource/tutorials/asr/asr_adapters/scripts/process_an4_data.py:43: FutureWarning: get_duration() keyword argument 'filename' has been renamed to 'path' in version 0.10.0.\n", + "\tThis alias will be removed in version 1.0.\n", + " duration = librosa.core.get_duration(filename=audio_path)\n", + "\u001b[0mAN4 prepared !\n" + ] + } + ], "source": [ "import wget\n", "import tarfile \n", @@ -220,28 +280,26 @@ " --data_root=$an4_path\n", "\n", "print(\"AN4 prepared !\")" - ], - "metadata": { - "id": "0wZZuUDi_gEV" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "9fiqQeWDAXsH" + }, + "outputs": [], "source": [ "# Manifest filepaths\n", "TRAIN_MANIFEST = os.path.join(data_dir, \"an4\", \"train_manifest.json\")\n", "TEST_MANIFEST = os.path.join(data_dir, \"an4\", \"test_manifest.json\")" - ], - "metadata": { - "id": "9fiqQeWDAXsH" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "q2nxi5RzAfZ5" + }, "source": [ "# Prepare the \"base\" model\n", "\n", @@ -250,39 +308,40 @@ "-----\n", "\n", "Most importantly, we discuss a simple way to enable Adapter specific support to a pre-trained model checkpoint - by modifying the `encoder` config before loading the model." - ], - "metadata": { - "id": "q2nxi5RzAfZ5" - } + ] }, { "cell_type": "code", + "execution_count": 7, + "metadata": { + "id": "F-wt9y5iAali" + }, + "outputs": [], "source": [ "import torch\n", "from omegaconf import OmegaConf, open_dict\n", "from pytorch_lightning import Trainer\n", - "\n", - "import nemo.collections.asr as nemo_asr" - ], - "metadata": { - "id": "F-wt9y5iAali" - }, - "execution_count": null, - "outputs": [] + "from lightning.pytorch.loggers import WandbLogger\n", + "wandb_logger = WandbLogger(project=\"NEMO_TEST\")\n", + "# import nemo.collections.asr as nemo_asr" + ] }, { "cell_type": "code", - "source": [ - "model_name = \"stt_en_conformer_ctc_small\"" - ], + "execution_count": 8, "metadata": { "id": "uVOfU7gsCI5u" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "model_name = \"stt_en_conformer_ctc_large\"" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "TitUAeq67Hkl" + }, "source": [ "## Prepare an Adapter-compatible Encoder\n", "\n", @@ -293,35 +352,46 @@ "- Extract the model config from the \"base\" model.\n", "- Update the `encoder` section of the config to a subclass of that model (which does have Adapter support)\n", "- Initialize the model with this new config, therefore enabling adapter support." - ], - "metadata": { - "id": "TitUAeq67Hkl" - } + ] }, { "cell_type": "markdown", - "source": [ - "- Extract just the config of the model." - ], "metadata": { "id": "5V5UY-5c8FDv" - } + }, + "source": [ + "- Extract just the config of the model." + ] }, { "cell_type": "code", - "source": [ - "cfg = nemo_asr.models.ASRModel.from_pretrained(model_name, return_config=True)" - ], + "execution_count": 9, "metadata": { "id": "RzwLAHVqAqD9" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-22 20:42:32 cloud:58] Found existing object /root/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-07-22 20:42:32 cloud:64] Re-using file from: /root/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-07-22 20:42:32 common:815] Instantiating model from pre-trained checkpoint\n" + ] + } + ], + "source": [ + "cfg = nemo_asr.models.ASRModel.from_pretrained(model_name, return_config=True)" + ] }, { "cell_type": "code", + "execution_count": 10, + "metadata": { + "id": "O6xAz38-A_Bh" + }, + "outputs": [], "source": [ - "from nemo.core import adapter_mixins\n", "\n", "# Utility method to check and update the model config\n", "def update_model_config_to_support_adapter(model_cfg):\n", @@ -332,55 +402,156 @@ " \n", " print(\"Updated encoder _target_ model :\", model_cfg.encoder._target_)\n", " return model_cfg" - ], - "metadata": { - "id": "O6xAz38-A_Bh" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", - "source": [ - "- Update the model config's `encoder` section to support Adapters." - ], "metadata": { "id": "TDk2VMXI8OkG" - } + }, + "source": [ + "- Update the model config's `encoder` section to support Adapters." + ] }, { "cell_type": "code", - "source": [ - "cfg = update_model_config_to_support_adapter(cfg)" - ], + "execution_count": 11, "metadata": { "id": "iyp4xUOLBi0v" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Updated encoder _target_ model : nemo.collections.asr.modules.conformer_encoder.ConformerEncoderAdapter\n" + ] + } + ], + "source": [ + "cfg = update_model_config_to_support_adapter(cfg)" + ] }, { "cell_type": "markdown", - "source": [ - "- Finally load the model with the updated config." - ], "metadata": { "id": "26NTK00w8VIt" - } + }, + "source": [ + "- Finally load the model with the updated config." + ] }, { "cell_type": "code", - "source": [ - "model = nemo_asr.models.ASRModel.from_pretrained(model_name, override_config_path=cfg)" - ], + "execution_count": 12, "metadata": { "id": "7r36mkUGBvsy" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-22 20:42:32 cloud:58] Found existing object /root/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n", + "[NeMo I 2024-07-22 20:42:32 cloud:64] Re-using file from: /root/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo\n", + "[NeMo I 2024-07-22 20:42:32 common:815] Instantiating model from pre-trained checkpoint\n", + "[NeMo I 2024-07-22 20:42:33 mixins:172] Tokenizer SentencePieceTokenizer initialized with 128 tokens\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-07-22 20:42:33 modelPT:176] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/tarred_audio_manifest.json\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/tarred_audio_manifest.json\n", + " sample_rate: 16000\n", + " batch_size: 1\n", + " shuffle: true\n", + " num_workers: 4\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " trim_silence: false\n", + " max_duration: 20.0\n", + " min_duration: 0.1\n", + " is_tarred: true\n", + " tarred_audio_filepaths:\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket1/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket2/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket3/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket4/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket5/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket6/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket7/audio__OP_0..8191_CL_.tar\n", + " - - /data2/nemo_asr/nemo_asr_set_3.0/bucket8/audio__OP_0..8191_CL_.tar\n", + " shuffle_n: 2048\n", + " bucketing_strategy: synced_randomized\n", + " bucketing_batch_size:\n", + " - 34\n", + " - 30\n", + " - 26\n", + " - 22\n", + " - 18\n", + " - 16\n", + " - 12\n", + " - 8\n", + " \n", + "[NeMo W 2024-07-22 20:42:33 modelPT:183] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n", + "[NeMo W 2024-07-22 20:42:33 modelPT:189] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath:\n", + " - /manifests/librispeech/librivox-dev-other.json\n", + " - /manifests/librispeech/librivox-dev-clean.json\n", + " - /manifests/librispeech/librivox-test-other.json\n", + " - /manifests/librispeech/librivox-test-clean.json\n", + " sample_rate: 16000\n", + " batch_size: 32\n", + " shuffle: false\n", + " num_workers: 8\n", + " pin_memory: true\n", + " use_start_end_token: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-22 20:42:33 features:305] PADDING: 0\n", + "[NeMo I 2024-07-22 20:42:34 save_restore_connector:263] Model EncDecCTCModelBPE was successfully restored from /root/.cache/torch/NeMo/NeMo_2.0.0rc0/stt_en_conformer_ctc_large/afb212c5bcf904e326b5e5751e7c7465/stt_en_conformer_ctc_large.nemo.\n" + ] + } + ], + "source": [ + "model = nemo_asr.models.ASRModel.from_pretrained(model_name, override_config_path=cfg)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "x0C2r7388cRd" + }, "source": [ "-----\n", "\n", @@ -391,13 +562,29 @@ "**Recommendation**:\n", "\n", "You should normally start with 1-5 epochs of adaptation over your entire new domain, and then increase or decrease your number of training steps to trade off a balance in accuracy on general speech." - ], - "metadata": { - "id": "x0C2r7388cRd" - } + ] }, { "cell_type": "code", + "execution_count": 13, + "metadata": { + "id": "sWRUXzjQMWN5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: GPU available: True (cuda), used: True\n", + "WARNING: Logging before flag parsing goes to stderr.\n", + "I0722 20:42:51.693691 129733185169216 rank_zero.py:64] GPU available: True (cuda), used: True\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "I0722 20:42:51.715736 129733185169216 rank_zero.py:64] TPU available: False, using: 0 TPU cores\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "I0722 20:42:51.716401 129733185169216 rank_zero.py:64] HPU available: False, using: 0 HPUs\n" + ] + } + ], "source": [ "accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'\n", "max_steps = 300\n", @@ -407,28 +594,26 @@ " log_every_n_steps=5, check_val_every_n_epoch=3)\n", "\n", "model.set_trainer(trainer)" - ], - "metadata": { - "id": "sWRUXzjQMWN5" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "tJBriqr3tQV7" + }, + "outputs": [], "source": [ "# utility method\n", "import json\n", "from nemo.collections.asr.parts.utils.manifest_utils import read_manifest\n" - ], - "metadata": { - "id": "tJBriqr3tQV7" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "dm-qqTdZDUlZ" + }, "source": [ "## [Optional] Check if the new domain is compatible with the original decoder\n", "\n", @@ -437,180 +622,209 @@ "-----\n", "\n", "If this check fails, the training run might crash, or silently allow the model to learn to produce `⁇` tokens (when using SentencePiece tokenizers)." - ], - "metadata": { - "id": "dm-qqTdZDUlZ" - } + ] }, { "cell_type": "markdown", - "source": [ - "### Parse the base character set" - ], "metadata": { "id": "UKTiAPV_sdFI" - } + }, + "source": [ + "### Parse the base character set" + ] }, { "cell_type": "code", - "source": [ - "train_data = read_manifest(TRAIN_MANIFEST)\n", - "base_sets = [set(list(sample['text'])) for sample in train_data]\n", - "base_charset = set([])\n", - "for charset in base_sets:\n", - " base_charset.update(charset)\n", - "base_charset = list(sorted(list(base_charset)))\n", - "\n", - "print(\"Base charset :\", base_charset)" - ], + "execution_count": 15, "metadata": { "id": "WgogR3taD7NA" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "# train_data = read_manifest(TRAIN_MANIFEST)\n", + "# base_sets = [set(list(sample['text'])) for sample in train_data]\n", + "# base_charset = set([])\n", + "# for charset in base_sets:\n", + "# base_charset.update(charset)\n", + "# base_charset = list(sorted(list(base_charset)))\n", + "\n", + "# print(\"Base charset :\", base_charset)" + ] }, { "cell_type": "markdown", - "source": [ - "### Check if there are invalid characters" - ], "metadata": { "id": "x-0fzrfPshJj" - } + }, + "source": [ + "### Check if there are invalid characters" + ] }, { "cell_type": "code", - "source": [ - "def check_valid_charset_in_vocab(model, charset):\n", - " model_vocab = model.decoder.vocabulary\n", - " num_invalid = 0\n", - "\n", - " for char in charset:\n", - " if char != ' ' and char not in model_vocab:\n", - " print(f\"Character `{char}` does not exist in the base character set of the original model !\")\n", - " num_invalid += 1\n", - "\n", - " print(\"Number of invalid tokens :\", num_invalid)" - ], + "execution_count": 34, "metadata": { "id": "5laUkRf5Eb6l" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "# def check_valid_charset_in_vocab(model, charset):\n", + "# model_vocab = model.decoder.vocabulary\n", + "# num_invalid = 0\n", + "\n", + "# for char in charset:\n", + "# if char != ' ' and char not in model_vocab:\n", + "# print(f\"Character `{char}` does not exist in the base character set of the original model !\")\n", + "# num_invalid += 1\n", + "\n", + "# print(\"Number of invalid tokens :\", num_invalid)" + ] }, { "cell_type": "code", - "source": [ - "check_valid_charset_in_vocab(model, base_charset)" - ], + "execution_count": 16, "metadata": { "id": "5rEUqs7AFh5j" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of invalid tokens : 0\n" + ] + } + ], + "source": [ + "# check_valid_charset_in_vocab(model, base_charset)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Sf-2EHznGkI1" + }, "source": [ "# Evaluate original performance on AN4 dev set\n", "\n", "Now that we possess a model capable of supporting adapters, let us quickly test the performance of the pre-trained model on the AN4 test set without any training or fine-tuning." - ], - "metadata": { - "id": "Sf-2EHznGkI1" - } + ] }, { "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "Ak4v4aWjGoQH" + }, + "outputs": [], "source": [ "if not os.path.exists('scripts/transcribe_speech.py'):\n", " !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/transcribe_speech.py\n", "\n", "if not os.path.exists('scripts/speech_to_text_eval.py'):\n", " !wget -P scripts/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/asr/speech_to_text_eval.py" - ], - "metadata": { - "id": "Ak4v4aWjGoQH" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "# temporarily save current model\n", - "model.save_to(\"/content/unadapted_model.nemo\")" - ], + "execution_count": 18, "metadata": { "id": "OVlBKWCiIHw7" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "# # temporarily save current model\n", + "# model.save_to(\"/content/unadapted_model.nemo\")" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "r03iDw9k-dAm" + }, "source": [ "-----\n", "\n", "The following evaluation script will properly transcribe the AN4 test set, and score it against its ground truth." - ], - "metadata": { - "id": "r03iDw9k-dAm" - } + ] }, { "cell_type": "code", - "source": [ - "!python scripts/speech_to_text_eval.py \\\n", - " model_path=\"/content/unadapted_model.nemo\" \\\n", - " dataset_manifest=$TEST_MANIFEST \\\n", - " output_filename=\"/content/unadapted_predictions.json\" \\\n", - " batch_size=32 \\\n", - " use_cer=False" - ], + "execution_count": 19, "metadata": { "id": "C6YbPt70H0-N" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Traceback (most recent call last):\n", + " File \"/workspace/nemo/NeMo-opensource/tutorials/asr/asr_adapters/scripts/speech_to_text_eval.py\", line 71, in \n", + " import transcribe_speech\n", + " File \"/workspace/nemo/NeMo-opensource/tutorials/asr/asr_adapters/scripts/transcribe_speech.py\", line 29, in \n", + " from nemo.collections.asr.models.aed_multitask_models import parse_multitask_prompt\n", + "ImportError: cannot import name 'parse_multitask_prompt' from 'nemo.collections.asr.models.aed_multitask_models' (/usr/local/lib/python3.10/dist-packages/nemo/collections/asr/models/aed_multitask_models.py)\n" + ] + } + ], + "source": [ + "# !python scripts/speech_to_text_eval.py \\\n", + "# model_path=\"/content/unadapted_model.nemo\" \\\n", + "# dataset_manifest=$TEST_MANIFEST \\\n", + "# output_filename=\"/content/unadapted_predictions.json\" \\\n", + "# batch_size=32 \\\n", + "# use_cer=False" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "2VBQO3w3swu8" + }, "source": [ "------\n", "\n", "Check the predictions of the current model" - ], - "metadata": { - "id": "2VBQO3w3swu8" - } + ] }, { "cell_type": "code", - "source": [ - "!head -n 5 /content/unadapted_predictions.json" - ], + "execution_count": 20, "metadata": { "id": "SE8uoRLsJA9F" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "head: cannot open '/content/unadapted_predictions.json' for reading: No such file or directory\n" + ] + } + ], + "source": [ + "# !head -n 5 /content/unadapted_predictions.json" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "muRBgHHe-n7E" + }, "source": [ "-----\n", "\n", "Overall, the model does quite well, obtaining roughly 6% Word Error Rate without prior training on this dataset. \n", "\n", "**Note**: Pre-trained models in NeMo are trained on several thousands of hours of speech, so it is unsurprising why this model is this accurate without any training on this toy dataset. For more realistic cases, we usually observe the range of 10-30% WER for out-of-domain speech." - ], - "metadata": { - "id": "muRBgHHe-n7E" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "b-L3prIzs3CW" + }, "source": [ "# Setup training and evaluation of the model\n", "\n", @@ -621,22 +835,37 @@ "**Note**: Each model may have special parameters in their data loader. Please refer to the configs of the pre-trained models to determine what additional changes are necessary). Below recommendations are primarily for Conformer CTC and may differ from model to model.\n", "\n", "You can parse the model config via - `print(OmegaConf.to_yaml(model.cfg))`" - ], - "metadata": { - "id": "b-L3prIzs3CW" - } + ] }, { "cell_type": "markdown", - "source": [ - "## Setup dataloaders" - ], "metadata": { "id": "V2WirN5KJpsD" - } + }, + "source": [ + "## Setup dataloaders" + ] }, { "cell_type": "code", + "execution_count": 47, + "metadata": { + "id": "F0GIxhyCJmFv" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-19 09:32:48 collections:199] Dataset loaded with 948 files totalling 0.71 hours\n", + "[NeMo I 2024-07-19 09:32:48 collections:201] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-07-19 09:32:48 collections:199] Dataset loaded with 130 files totalling 0.10 hours\n", + "[NeMo I 2024-07-19 09:32:48 collections:201] 0 files were filtered totalling 0.00 hours\n", + "[NeMo I 2024-07-19 09:32:48 collections:199] Dataset loaded with 130 files totalling 0.10 hours\n", + "[NeMo I 2024-07-19 09:32:48 collections:201] 0 files were filtered totalling 0.00 hours\n" + ] + } + ], "source": [ "with open_dict(model.cfg):\n", " # Train Dataloader\n", @@ -651,28 +880,63 @@ "model.setup_training_data(model.cfg.train_ds)\n", "model.setup_multiple_validation_data(model.cfg.validation_ds)\n", "model.setup_multiple_test_data(model.cfg.validation_ds)" - ], - "metadata": { - "id": "F0GIxhyCJmFv" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "T3VuqcGTNuIJ" + }, "source": [ "## Setup Spectrogram Augmentation\n", "\n", "For this experiment we will continue to use the original spec augmentation config in the base model, however you may find better results by modifying the strength of this augmentation.\n", "\n", "**Note**: The script inside ASR examples **disables spec augment entirely**. This is done in order to provide a stable default to measure the best possible adaptation case, but may severely degrade the performance on general speech. Please be careful when copying the hyper parameters from the tutorial to the script for large scale experimentation." - ], - "metadata": { - "id": "T3VuqcGTNuIJ" - } + ] }, { "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sample_rate : 16000\n", + "log_prediction : True\n", + "ctc_reduction : mean_batch\n", + "skip_nan_grad : False\n", + "train_ds : {'manifest_filepath': 'datasets/an4/train_manifest.json', 'sample_rate': 16000, 'batch_size': 32, 'shuffle': True, 'num_workers': 4, 'pin_memory': True, 'use_start_end_token': False, 'trim_silence': False, 'max_duration': 20.0, 'min_duration': 0.1, 'is_tarred': False, 'tarred_audio_filepaths': None, 'shuffle_n': 2048, 'bucketing_strategy': 'synced_randomized', 'bucketing_batch_size': [34, 30, 26, 22, 18, 16, 12, 8]}\n", + "validation_ds : {'manifest_filepath': 'datasets/an4/test_manifest.json', 'sample_rate': 16000, 'batch_size': 32, 'shuffle': False, 'num_workers': 8, 'pin_memory': True, 'use_start_end_token': False}\n", + "test_ds : {'manifest_filepath': 'datasets/an4/test_manifest.json', 'sample_rate': 16000, 'batch_size': 32, 'shuffle': False, 'num_workers': 8, 'pin_memory': True, 'use_start_end_token': False}\n", + "tokenizer : {'dir': '/tokenizers/NeMo_ASR_SET/English/asr_set_3.0/tokenizer_spe_unigram_v128', 'type': 'bpe', 'model_path': 'nemo:e06949b0b85a485e9f280ea6d19e5492_tokenizer.model', 'vocab_path': 'nemo:53bbc634b62446de83525753e95a50ac_vocab.txt', 'spe_tokenizer_vocab': 'nemo:ff63e3c43c5f4b95bff702425366a4a6_tokenizer.vocab'}\n", + "preprocessor : {'_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'sample_rate': 16000, 'normalize': 'per_feature', 'window_size': 0.025, 'window_stride': 0.01, 'window': 'hann', 'features': 80, 'n_fft': 512, 'log': True, 'frame_splicing': 1, 'dither': 1e-05, 'pad_to': 0, 'pad_value': 0.0}\n", + "spec_augment : {'_target_': 'nemo.collections.asr.modules.SpectrogramAugmentation', 'freq_masks': 2, 'time_masks': 10, 'freq_width': 27, 'time_width': 0.05}\n", + "encoder : {'_target_': 'nemo.collections.asr.modules.conformer_encoder.ConformerEncoderAdapter', 'feat_in': 80, 'feat_out': -1, 'n_layers': 18, 'd_model': 512, 'subsampling': 'striding', 'subsampling_factor': 4, 'subsampling_conv_channels': 512, 'ff_expansion_factor': 4, 'self_attention_model': 'rel_pos', 'n_heads': 8, 'att_context_size': [-1, -1], 'xscaling': True, 'untie_biases': True, 'pos_emb_max_len': 5000, 'conv_kernel_size': 31, 'conv_norm_type': 'batch_norm', 'dropout': 0.1, 'dropout_emb': 0.0, 'dropout_att': 0.1}\n", + "decoder : {'_target_': 'nemo.collections.asr.modules.ConvASRDecoder', 'feat_in': 512, 'num_classes': 128, 'vocabulary': ['', '▁', 's', 't', 'e', 'd', 'o', '▁the', 'a', 'i', '▁a', 'u', 'y', 'm', 'l', 'n', 'p', 're', 'c', 'h', 'r', '▁s', 'g', '▁to', 'er', 'ing', 'f', '▁and', 'an', '▁i', 'k', '▁that', \"'\", '▁of', '▁in', 'w', '▁p', 'ed', 'or', 'al', 'ar', '▁f', 'en', 'in', 'b', '▁you', '▁w', '▁b', 'le', 'll', 'es', '▁it', 've', 'ur', '▁we', '▁re', '▁be', 'ly', '▁is', '▁he', '▁o', '▁c', 'it', '▁n', '▁on', 'un', '▁t', 'on', 'se', 'th', 'ce', '▁do', 'ic', '▁for', '▁th', 'ion', 'ch', '▁was', 'ri', 'ent', '▁g', 'ver', '▁co', 'li', '▁ha', '▁ma', 'la', 'ro', 'v', 'us', '▁ca', '▁di', '▁this', 'ra', '▁st', '▁e', '▁not', '▁so', '▁de', '▁have', 'ter', 'ir', '▁go', 'ation', '▁with', 'ate', '▁me', '▁mo', 'ment', '▁con', '▁but', 'vi', '▁pro', '▁ho', 'j', '▁com', 'ight', '▁know', '▁what', 'ect', '▁ex', '▁some', '▁would', '▁like', 'x', '▁his', 'q', 'z']}\n", + "optim : {'name': 'adamw', 'lr': 2.0, 'betas': [0.9, 0.98], 'weight_decay': 0.001, 'sched': {'name': 'NoamAnnealing', 'd_model': 512, 'warmup_steps': 10000, 'warmup_ratio': None, 'min_lr': 1e-06}}\n", + "compute_eval_loss : False\n", + "variational_noise : {'start_step': 0, 'std': 0.0}\n", + "target : nemo.collections.asr.models.ctc_bpe_models.EncDecCTCModelBPE\n", + "nemo_version : 1.9.0rc0\n", + "decoding : {'strategy': 'greedy', 'preserve_alignments': None, 'compute_timestamps': None, 'word_seperator': ' ', 'ctc_timestamp_type': 'all', 'batch_dim_index': 0, 'greedy': {'preserve_alignments': False, 'compute_timestamps': False, 'preserve_frame_confidence': False, 'confidence_method_cfg': {'name': 'entropy', 'entropy_type': 'tsallis', 'alpha': 0.33, 'entropy_norm': 'exp', 'temperature': 'DEPRECATED'}}, 'beam': {'beam_size': 4, 'search_type': 'default', 'preserve_alignments': False, 'compute_timestamps': False, 'return_best_hypothesis': True, 'beam_alpha': 1.0, 'beam_beta': 0.0, 'kenlm_path': None, 'flashlight_cfg': {'lexicon_path': None, 'boost_path': None, 'beam_size_token': 16, 'beam_threshold': 20.0, 'unk_weight': -inf, 'sil_weight': 0.0}, 'pyctcdecode_cfg': {'beam_prune_logp': -10.0, 'token_min_logp': -5.0, 'prune_history': False, 'hotwords': None, 'hotword_weight': 10.0}}, 'confidence_cfg': {'preserve_frame_confidence': False, 'preserve_token_confidence': False, 'preserve_word_confidence': False, 'exclude_blank': True, 'aggregation': 'min', 'tdt_include_duration': False, 'method_cfg': {'name': 'entropy', 'entropy_type': 'tsallis', 'alpha': 0.33, 'entropy_norm': 'exp', 'temperature': 'DEPRECATED'}}, 'temperature': 1.0}\n" + ] + } + ], + "source": [ + "for key in model.cfg.keys():\n", + " print(f\"{key} : {model.cfg[key]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "id": "T-XFuaA3OlOB" + }, + "outputs": [], "source": [ "with open_dict(model.cfg):\n", " # Spec Augment\n", @@ -682,15 +946,13 @@ " model.cfg.spec_augment.time_width = model.cfg.spec_augment.time_width # Can be changed\n", "\n", "model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)" - ], - "metadata": { - "id": "T-XFuaA3OlOB" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "xGpdUWl_tGuA" + }, "source": [ "## Setup optimizer and scheduler\n", "\n", @@ -701,25 +963,76 @@ "Feel free to modify these values to see the effect on adapters' convergence.\n", "\n", "**Note**: The hyper parameters below correspond to the base model and may not match those applied in the ASR examples! Please note that the script the examples defaults to an **AdamW** optimizer with a **CosineAnnealing** scheduler, where as the config of Conformers is geneally a **AdamW** optimizer with a **NoamAnnealing** scheduler. The *learning rate*, *weight decay* and other hyper parameters may not be exactly the same between the tutorial and the example scripts, so please be careful when transferring the hyper parameters for large scale experiments." - ], - "metadata": { - "id": "xGpdUWl_tGuA" - } + ] }, { "cell_type": "code", - "source": [ - "if 'optim' in model.cfg:\n", - " print(OmegaConf.to_yaml(model.cfg.optim))" - ], + "execution_count": 50, "metadata": { "id": "UDEIfMTcP6j6" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name: adamw\n", + "lr: 2.0\n", + "betas:\n", + "- 0.9\n", + "- 0.98\n", + "weight_decay: 0.001\n", + "sched:\n", + " name: NoamAnnealing\n", + " d_model: 512\n", + " warmup_steps: 10000\n", + " warmup_ratio: null\n", + " min_lr: 1.0e-06\n", + "\n" + ] + } + ], + "source": [ + "if 'optim' in model.cfg:\n", + " print(OmegaConf.to_yaml(model.cfg.optim))" + ] }, { "cell_type": "code", + "execution_count": 51, + "metadata": { + "id": "tp_8FGPcKjMd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-19 09:32:52 modelPT:767] Optimizer config = AdamW (\n", + " Parameter Group 0\n", + " amsgrad: False\n", + " betas: [0.9, 0.98]\n", + " capturable: False\n", + " differentiable: False\n", + " eps: 1e-08\n", + " foreach: None\n", + " fused: None\n", + " lr: 0.1\n", + " maximize: False\n", + " weight_decay: 0.0\n", + " )\n", + "[NeMo I 2024-07-19 09:32:52 lr_scheduler:923] Scheduler \"\" \n", + " will be used during training (effective maximum steps = 300) - \n", + " Parameters : \n", + " (d_model: 512\n", + " warmup_steps: 100\n", + " warmup_ratio: null\n", + " min_lr: 1.0e-06\n", + " max_steps: 300\n", + " )\n" + ] + } + ], "source": [ "with open_dict(model.cfg):\n", " model.cfg.optim.lr = 0.1\n", @@ -727,63 +1040,88 @@ " model.cfg.optim.sched.warmup_steps = 100\n", "\n", "model.setup_optimization(model.cfg.optim);" - ], - "metadata": { - "id": "tp_8FGPcKjMd" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "AGrThAt9Qh0D" + }, "source": [ "# Adapters: Supported Components\n", "\n", "A NeMo model may have multiple types of adapters that are supported in each of their components. Let us see at a glance what are some of the adapter types supported by the Conformer ASR model.\n", "\n", "**Note**: Every domain may support their own types of adapters, and use them in different ways. Please refer to the documentation of each domain for information on the adapter support." - ], - "metadata": { - "id": "AGrThAt9Qh0D" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "Wq1JLbNvROcL" + }, "source": [ "-----\n", "Let's start with the modules in which the model will support adapters. We can select these adapters with a special syntax to construct \"Module adapters\".\n", "\n", "**Note**: `''` refers to the \"default\" adapter - usually the `encoder` but it is model dependent. It may also be that no specific modules are provided, in which case only `default` adapters will be available." - ], - "metadata": { - "id": "Wq1JLbNvROcL" - } + ] }, { "cell_type": "code", - "source": [ - "if hasattr(model, 'adapter_module_names'):\n", - " print(model.adapter_module_names)" - ], + "execution_count": 53, "metadata": { "id": "fRIDhU8RVBwi" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['', 'encoder', 'decoder', 'joint']\n" + ] + } + ], + "source": [ + "if hasattr(model, 'adapter_module_names'):\n", + " print(model.adapter_module_names)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "u5BOWWBjfQwN" + }, "source": [ "-----\n", "Next, we can try to obtain the accepted types of each of the child modules in the Model." - ], - "metadata": { - "id": "u5BOWWBjfQwN" - } + ] }, { "cell_type": "code", - "source": [ + "execution_count": 54, + "metadata": { + "id": "iNnSp_azQ2u8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Module : ConformerEncoderAdapter\n", + "\n", + "\n", + "\n", + "\n", + "Module : ConvASRDecoder\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(model.children())\n", "for module in model.children():\n", " if hasattr(module, 'get_accepted_adapter_types'):\n", " types = module.get_accepted_adapter_types()\n", @@ -792,26 +1130,24 @@ " for tp in types:\n", " print(tp)\n", " print()" - ], - "metadata": { - "id": "iNnSp_azQ2u8" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "YXTC4LiSnB2O" + }, "source": [ "-----\n", "\n", "As you can see, a single component of the model may support one or more adapter types (or none at all)! Below, we will experiment with the simple Linear Adapters, but as an exercise, you might try to use other adapter types present here." - ], - "metadata": { - "id": "YXTC4LiSnB2O" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "WFCUrYxnGPt3" + }, "source": [ "# Adapters: Creation and Preparation\n", "\n", @@ -822,40 +1158,52 @@ "We first import a config for a basic `LinearAdapter` most often used in literature. \n", "\n", "`LinearAdapter` is a simple network comprising LayerNorm, a bottleneck Linear layer, an activation, and an upcast Linear layer (so that input and output channel dim match). We provide some configuration parameters (such as the input dim and the bottleneck dim)." - ], - "metadata": { - "id": "WFCUrYxnGPt3" - } + ] }, { "cell_type": "code", - "source": [ - "from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig" - ], + "execution_count": 55, "metadata": { "id": "oZZr6vSntuyX" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig\n", + "from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import MultiHeadAttentionAdapterConfig\n", + "from nemo.collections.asr.parts.submodules.adapters.multi_head_attention_adapter_module import RelPositionMultiHeadAttentionAdapterConfig" + ] }, { "cell_type": "code", + "execution_count": 58, + "metadata": { + "id": "dlj0Yud4MxOi" + }, + "outputs": [], "source": [ "#%% [code]\n", "#@title Adapter Setup { display-mode: \"form\" }\n", "adapter_name = \"AN4\" #@param {type:\"string\"}\n", - "adapter_dim = 32 #@param {type:\"integer\"}\n", + "adapter_dim = 64 #@param {type:\"integer\"}\n", "adapter_activation = \"swish\" #@param {type:\"string\"}\n", "adapter_norm_position = \"pre\" #@param [\"pre\", \"post\"]" - ], - "metadata": { - "id": "dlj0Yud4MxOi" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 59, + "metadata": { + "id": "Uv8WRQkXU3mu" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "LinearAdapterConfig(in_features=512, dim=64, activation='swish', norm_position='pre', dropout=0.0, adapter_strategy=ResidualAddAdapterStrategyConfig(stochastic_depth=0.0, l2_lambda=0.0, _target_='nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy'), _target_='nemo.collections.common.parts.adapter_modules.LinearAdapter')\n" + ] + } + ], "source": [ "adapter_cfg = LinearAdapterConfig(\n", " in_features=model.cfg.encoder.d_model, # conformer specific model dim. Every layer emits this dim at its output.\n", @@ -864,83 +1212,127 @@ " norm_position=adapter_norm_position, # whether to use LayerNorm at the beginning or the end of the adapter\n", ")\n", "print(adapter_cfg)" - ], - "metadata": { - "id": "Uv8WRQkXU3mu" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "pIECyKxit58r" + }, "source": [ "## Add a new adapter module\n", "\n", "Now that our adapter config is ready. Next, we perform a check to see what is the size of the original model and what its size will be after adding the adapter module." - ], - "metadata": { - "id": "pIECyKxit58r" - } + ] }, { "cell_type": "code", - "source": [ - "model.summarize()" - ], + "execution_count": 60, "metadata": { "id": "-MbSTbYiYtnB" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "data": { + "text/plain": [ + " | Name | Type | Params | Mode \n", + "--------------------------------------------------------------------------------\n", + "0 | preprocessor | AudioToMelSpectrogramPreprocessor | 0 | train\n", + "1 | encoder | ConformerEncoderAdapter | 121 M | train\n", + "2 | decoder | ConvASRDecoder | 66.2 K | train\n", + "3 | loss | CTCLoss | 0 | train\n", + "4 | spec_augmentation | SpectrogramAugmentation | 0 | train\n", + "5 | wer | WER | 0 | train\n", + "--------------------------------------------------------------------------------\n", + "121 M Trainable params\n", + "0 Non-trainable params\n", + "121 M Total params\n", + "486.005 Total estimated model params size (MB)" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.summarize()" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "vjYmPbwCC0LZ" + }, "source": [ "-----\n", "\n", "Next, we use `add_adapter` to add adapter blocks to the `encoder`.\n", "\n", "A single line can be used to add adapter modules to every layer of the `encoder` module. We pass it a unique name to identify this adapter and the adapter config (which can be helpful to enable or disable adapters later)." - ], - "metadata": { - "id": "vjYmPbwCC0LZ" - } + ] }, { "cell_type": "code", - "source": [ - "model.add_adapter(name=adapter_name, cfg=adapter_cfg)" - ], + "execution_count": 61, "metadata": { "id": "El6ewd1GX9V7" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "model.add_adapter(name=adapter_name, cfg=adapter_cfg)" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "jMsmj1W-DTSd" + }, "source": [ "-----\n", "\n", "As expected, the number of parameters increased by a marginal amount (roughly 200,000 parameters)." - ], - "metadata": { - "id": "jMsmj1W-DTSd" - } + ] }, { "cell_type": "code", - "source": [ - "model.summarize()" - ], + "execution_count": 62, "metadata": { "id": "rIvw0_8iYpHW" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "data": { + "text/plain": [ + " | Name | Type | Params | Mode \n", + "--------------------------------------------------------------------------------\n", + "0 | preprocessor | AudioToMelSpectrogramPreprocessor | 0 | train\n", + "1 | encoder | ConformerEncoderAdapter | 122 M | train\n", + "2 | decoder | ConvASRDecoder | 66.2 K | train\n", + "3 | loss | CTCLoss | 0 | train\n", + "4 | spec_augmentation | SpectrogramAugmentation | 0 | train\n", + "5 | wer | WER | 0 | train\n", + "--------------------------------------------------------------------------------\n", + "122 M Trainable params\n", + "0 Non-trainable params\n", + "122 M Total params\n", + "490.798 Total estimated model params size (MB)" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.summarize()" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "RH6cXPW2ZHdZ" + }, "source": [ "## Enable / Disable Adapters\n", "\n", @@ -949,25 +1341,34 @@ "For this purpose, we utilize the `model.set_enabled_adapters` method - it takes an optional `name` and a boolean value for `enabled`. If a name is not passed, it will set enable/disable all available adapters.\n", "\n", "**Note**: We recommend training one adapter at a time, disjoint from all other adapters. As such, it simplifies the selection of adapters for each particular domain. To do so - **disable all adapters first, then enable only the newly added adapter**." - ], - "metadata": { - "id": "RH6cXPW2ZHdZ" - } + ] }, { "cell_type": "code", - "source": [ - "model.set_enabled_adapters(enabled=False) # disable all adapters\n", - "model.set_enabled_adapters(name=adapter_name, enabled=True) # enable only the current adapter we want to train" - ], + "execution_count": 63, "metadata": { "id": "ogUfDkjdZKHu" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-19 09:39:28 adapter_mixins:719] Setting adapter 'AN4' status : Enabled = False\n", + "[NeMo I 2024-07-19 09:39:28 adapter_mixins:734] Setting adapter 'AN4' status : Enabled = True\n" + ] + } + ], + "source": [ + "model.set_enabled_adapters(enabled=False) # disable all adapters\n", + "model.set_enabled_adapters(name=adapter_name, enabled=True) # enable only the current adapter we want to train" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "V87SBzdDY1x1" + }, "source": [ "## Training only the adapter(s)\n", "\n", @@ -976,25 +1377,51 @@ "We provide the general utility methods for this purpose - `model.freeze()` and `model.unfreeze_enabled_adapters()`. \n", "\n", "The second method will look up all the enabled adapters selected in the previous step and enable their gradient calculation so that they can be trained." - ], - "metadata": { - "id": "V87SBzdDY1x1" - } + ] }, { "cell_type": "code", - "source": [ - "model.freeze()\n", - "model.unfreeze_enabled_adapters()" - ], + "execution_count": 64, "metadata": { "id": "RN2YayAoYzaI" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.0.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.1.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.2.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.3.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.4.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.5.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.6.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.7.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.8.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.9.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.10.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.11.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.12.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.13.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.14.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.15.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.16.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:405] Froze module encoder.layers.17.conv.batch_norm: BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + "[NeMo I 2024-07-19 09:39:34 adapter_mixins:435] Unfrozen adapter : AN4\n" + ] + } + ], + "source": [ + "model.freeze()\n", + "model.unfreeze_enabled_adapters()" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "5PriDOuwEbmp" + }, "source": [ "### Why are BatchNormalization layers being frozen?\n", "\n", @@ -1006,39 +1433,82 @@ "\n", "For this reason, `unfreeze_enabled_adapters()` has an argument `freeze_batchnorm=True` as the default. It will find all the batch normalization layers and disable this flag so that it will the encoder layers remain exactly frozen even during adapter finetuning. This allows the original model performance to be recovered.\n", "\n" - ], - "metadata": { - "id": "5PriDOuwEbmp" - } + ] }, { "cell_type": "code", - "source": [ - "model.summarize()" - ], + "execution_count": 65, "metadata": { "id": "Lf3pdwQ2Zch5" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "data": { + "text/plain": [ + " | Name | Type | Params | Mode\n", + "-------------------------------------------------------------------------------\n", + "0 | preprocessor | AudioToMelSpectrogramPreprocessor | 0 | eval\n", + "1 | encoder | ConformerEncoderAdapter | 122 M | eval\n", + "2 | decoder | ConvASRDecoder | 66.2 K | eval\n", + "3 | loss | CTCLoss | 0 | eval\n", + "4 | spec_augmentation | SpectrogramAugmentation | 0 | eval\n", + "5 | wer | WER | 0 | eval\n", + "-------------------------------------------------------------------------------\n", + "1.2 M Trainable params\n", + "121 M Non-trainable params\n", + "122 M Total params\n", + "490.798 Total estimated model params size (MB)" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.summarize()" + ] }, { "cell_type": "markdown", + "metadata": { + "id": "JI6C_TYGGgyZ" + }, "source": [ "-----\n", "\n", "Here we see that after the above steps, we will be training just ~ 200,000 parameters out of a 10+ M parameter model." - ], - "metadata": { - "id": "JI6C_TYGGgyZ" - } + ] }, { "cell_type": "code", + "execution_count": 31, + "metadata": { + "id": "w9ciIw-2bSHq" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-08 10:06:22 exp_manager:396] Experiments will be logged at experiments/ASR-Adapters/2024-07-08_10-06-22\n", + "[NeMo I 2024-07-08 10:06:22 exp_manager:856] TensorboardLogger has been set up\n", + "[NeMo I 2024-07-08 10:06:22 exp_manager:871] WandBLogger has been set up\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2024-07-08 10:06:22 exp_manager:966] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to 300. Please ensure that max_steps will run for at least 3 epochs to ensure that checkpointing will not error out.\n" + ] + } + ], "source": [ "# Prepare NeMo's Experiment manager to handle checkpoint saving and logging for us\n", "from nemo.utils import exp_manager\n", "\n", + "\n", "# Environment variable generally used for multi-node multi-gpu training.\n", "# In notebook environments, this flag is unnecessary and can cause logs of multiple training runs to overwrite each other.\n", "os.environ.pop('NEMO_EXPM_VERSION', None)\n", @@ -1052,85 +1522,929 @@ " always_save_nemo=True,\n", " save_best_model=True,\n", " ),\n", + " create_wandb_logger=True,\n", + " wandb_logger_kwargs=OmegaConf.create({\"project\": \"NEMO_TEST\", \"name\": \"ASR-Adapters\", \"log_model\":\"all\"}),\n", ")\n", "\n", "exp_config = OmegaConf.structured(exp_config)\n", "\n", "logdir = exp_manager.exp_manager(trainer, exp_config)" - ], - "metadata": { - "id": "w9ciIw-2bSHq" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", - "source": [ - "# Finally, train the adapters\n", - "trainer.fit(model)" - ], + "execution_count": 33, "metadata": { "id": "cY2TJod3ZfyE" }, - "execution_count": null, - "outputs": [] + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2024-07-07 13:27:08 modelPT:767] Optimizer config = AdamW (\n", + " Parameter Group 0\n", + " amsgrad: False\n", + " betas: [0.9, 0.98]\n", + " capturable: False\n", + " differentiable: False\n", + " eps: 1e-08\n", + " foreach: None\n", + " fused: None\n", + " lr: 0.1\n", + " maximize: False\n", + " weight_decay: 0.0\n", + " )\n", + "[NeMo I 2024-07-07 13:27:08 lr_scheduler:923] Scheduler \"\" \n", + " will be used during training (effective maximum steps = 300) - \n", + " Parameters : \n", + " (d_model: 176\n", + " warmup_steps: 100\n", + " warmup_ratio: null\n", + " min_lr: 1.0e-06\n", + " max_steps: 300\n", + " )\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + " | Name | Type | Params | Mode\n", + "-------------------------------------------------------------------------------\n", + "0 | preprocessor | AudioToMelSpectrogramPreprocessor | 0 | eval\n", + "1 | encoder | ConformerEncoderAdapter | 13.2 M | eval\n", + "2 | decoder | ConvASRDecoder | 181 K | eval\n", + "3 | loss | CTCLoss | 0 | eval\n", + "4 | spec_augmentation | SpectrogramAugmentation | 0 | eval\n", + "5 | wer | WER | 0 | eval\n", + "-------------------------------------------------------------------------------\n", + "185 K Trainable params\n", + "13.2 M Non-trainable params\n", + "13.3 M Total params\n", + "53.360 Total estimated model params size (MB)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b9415ef40e0b4ef581d18cb559be9768", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00: store_path=/content/unadapted_predictions.json, local_path=/content/unadapted_predictions.json", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m/workspace/nemo/NeMo-opensource/nemo/collections/asr/parts/utils/manifest_utils.py:477\u001b[0m, in \u001b[0;36mread_manifest\u001b[0;34m(manifest)\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 477\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmanifest\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/content/unadapted_predictions.json'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[53], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m original_transcripts \u001b[38;5;241m=\u001b[39m \u001b[43mread_manifest\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m/content/unadapted_predictions.json\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m adapter_disabled_transcripts \u001b[38;5;241m=\u001b[39m read_manifest(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m/content/adapter_disabled_predictions.json\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m orig, new \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(original_transcripts, adapter_disabled_transcripts):\n", + "File \u001b[0;32m/workspace/nemo/NeMo-opensource/nemo/collections/asr/parts/utils/manifest_utils.py:479\u001b[0m, in \u001b[0;36mread_manifest\u001b[0;34m(manifest)\u001b[0m\n\u001b[1;32m 477\u001b[0m f \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mopen\u001b[39m(manifest\u001b[38;5;241m.\u001b[39mget(), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m, encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m:\n\u001b[0;32m--> 479\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mManifest file could not be opened: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmanifest\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 481\u001b[0m errors \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 482\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m f\u001b[38;5;241m.\u001b[39mreadlines():\n", + "\u001b[0;31mException\u001b[0m: Manifest file could not be opened: : store_path=/content/unadapted_predictions.json, local_path=/content/unadapted_predictions.json" + ] + } + ], "source": [ "original_transcripts = read_manifest('/content/unadapted_predictions.json')\n", "adapter_disabled_transcripts = read_manifest('/content/adapter_disabled_predictions.json')\n", @@ -1244,39 +2901,39 @@ " print(\"Original = \", orig['pred_text'])\n", " print(\"Adapters disabled = \", new['pred_text']) \n", " print()" - ], - "metadata": { - "id": "YFKN7QYuvBzP" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "0CqpJGR6ecYW" + }, "source": [ "# [EXTRA] Add as many adapters as needed\n", "\n", "Now that we have showcased how to utilize adapters for domain adaptation, we can take this further and adapt even more datasets - as many as needed!\n", "\n", "There is no implicit restriction on how many adapters can be added, as shown below. Still, we do recommend freezing all adapters and training only one at a time to prevent cross-interaction between adapters." - ], - "metadata": { - "id": "0CqpJGR6ecYW" - } + ] }, { "cell_type": "code", - "source": [ - "model.add_adapter(name=\"AN4-v2\", cfg=adapter_cfg)" - ], + "execution_count": null, "metadata": { "id": "13vZHFFEeK_g" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "model.add_adapter(name=\"AN4-v2\", cfg=adapter_cfg)" + ] }, { "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "iOrJ72SUelp6" + }, + "outputs": [], "source": [ "model.set_enabled_adapters(enabled=False)\n", "model.set_enabled_adapters(name='AN4-v2', enabled=True)\n", @@ -1285,15 +2942,13 @@ "model.unfreeze_enabled_adapters()\n", "\n", "model.summarize()" - ], - "metadata": { - "id": "iOrJ72SUelp6" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "EIli6c_OvKDH" + }, "source": [ "# Further reading\n", "\n", @@ -1302,10 +2957,38 @@ "Please follow the following articles that discuss the use of adapters in ASR - \n", "- [Exploiting Adapters for Cross-lingual Low-resource Speech Recognition](https://arxiv.org/abs/2105.11905)\n", "- [Efficient Adapter Transfer of Self-Supervised Speech Models for Automatic Speech Recognition](https://arxiv.org/abs/2202.03218)\n" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [ + "dm-qqTdZDUlZ", + "GGKgsW5gvAuf", + "0CqpJGR6ecYW" ], - "metadata": { - "id": "EIli6c_OvKDH" - } + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" } - ] + }, + "nbformat": 4, + "nbformat_minor": 0 }