Source code for malaya_speech.torch_model.torchaudio

import torch
from malaya_speech.model.frame import Frame
from malaya_speech.utils.bpe import load_sentencepiece, SentencePieceTokenProcessor
from malaya_speech.utils.torch_featurization import (
    FeatureExtractor,
    RNNTBeamSearch,
    post_process_hypos,
    conformer_rnnt_base,
    conformer_rnnt_tiny,
    conformer_rnnt_medium,
    emformer_rnnt_base,
)
from malaya_boilerplate.torch_utils import to_tensor_cuda, to_numpy

model_mapping = {
    'mesolitica/conformer-base': conformer_rnnt_base,
    'mesolitica/conformer-tiny': conformer_rnnt_tiny,
    'mesolitica/conformer-medium': conformer_rnnt_medium,
    'mesolitica/emformer-base': emformer_rnnt_base,
}


[docs]class Conformer(torch.nn.Module): sample_rate = 16000 segment_length = 16 hop_length = 160 right_context_length = 4 def __init__(self, pth, sp_model, stats_file, model, name): super().__init__() conformer = model_mapping[model]() conformer.load_state_dict(torch.load(pth, map_location='cpu')) self.model = conformer self.tokenizer = SentencePieceTokenProcessor(sp_model) self.feature_extractor = FeatureExtractor(stats_file, pad='emformer' in model) self.blank_idx = self.tokenizer.sp_model.get_piece_size() self.decoder = RNNTBeamSearch(self.model, self.blank_idx) self.__model__ = model self.__name__ = name self.rnnt_streaming = 'emformer' in model
[docs] def forward(self, inputs, beam_width: int = 20): """ Transcribe inputs using beam decoder. Parameters ---------- inputs: List[np.array] List[np.array] or List[malaya_speech.model.frame.Frame]. beam_width: int, optional (default=20) beam size for beam decoder. Returns ------- result: List[Tuple] """ cuda = next(self.parameters()).is_cuda inputs = [ input.array if isinstance(input, Frame) else input for input in inputs ] results = [] for input in inputs: mel, mel_len = self.feature_extractor(input) mel = to_tensor_cuda(mel, cuda) mel_len = to_tensor_cuda(mel_len, cuda) hypotheses = self.decoder(mel, mel_len, beam_width) results.append(post_process_hypos(hypotheses, self.tokenizer.sp_model)) return results
[docs] def beam_decoder(self, inputs, beam_width: int = 20): """ Transcribe inputs using beam decoder. Parameters ---------- inputs: List[np.array] List[np.array] or List[malaya_speech.model.frame.Frame]. beam_width: int, optional (default=20) beam size for beam decoder. Returns ------- result: List[str] """ r = self.forward(inputs=inputs, beam_width=beam_width) return [r_[0][0] for r_ in r]