Source code for malaya_speech.streaming.torchaudio

"""
https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html
"""
import collections
from datetime import datetime
from malaya_speech.utils.validator import check_pipeline
from malaya_speech.utils.torch_featurization import StreamReader, torchaudio_available
from malaya_speech.torch_model.torchaudio import Conformer
from malaya_speech.streaming import stream as base_stream
from functools import partial
import torch
import logging

logger = logging.getLogger(__name__)

if StreamReader is None:
    logger.warning(
        f'`torchaudio.io.StreamReader` is not available, `{__name__}` is not able to use.')


[docs]class ContextCacher: """Cache the end of input data and prepend the next input data with it. Args: segment_length (int): The size of main segment. If the incoming segment is shorter, then the segment is padded. context_length (int): The size of the context, cached and appended. """ def __init__(self, segment_length: int, context_length: int): self.segment_length = segment_length self.context_length = context_length self.context = torch.zeros([context_length]) def __call__(self, chunk: torch.Tensor): if chunk.size(0) < self.segment_length: chunk = torch.nn.functional.pad(chunk, (0, self.segment_length - chunk.size(0))) chunk_with_context = torch.cat((self.context, chunk)) self.context = chunk[-self.context_length:] return chunk_with_context
def _base_stream( src, format=None, option=None, buffer_size: int = 4096, sample_rate: int = 16000, segment_length: int = 2560, ): if StreamReader is None: raise ValueError( '`torchaudio.io.StreamReader is not available, please make sure your ffmpeg installed properly.') streamer = StreamReader(src=src, format=format, option=option, buffer_size=buffer_size) streamer.add_basic_audio_stream(frames_per_chunk=segment_length, sample_rate=sample_rate) logger.info(streamer.get_src_stream_info(0)) stream_iterator = streamer.stream() return streamer.stream() class Audio: def __init__( self, src, vad_model=None, format=None, option=None, buffer_size: int = 4096, sample_rate: int = 16000, segment_length: int = 2560, mode_utterence: bool = True, hard_utterence: bool = True, **kwargs, ): self.vad_model = vad_model self.stream_iterator = _base_stream( src=src, format=format, option=option, buffer_size=buffer_size, sample_rate=sample_rate, segment_length=segment_length, ) self.segment_length = segment_length self.mode_utterence = mode_utterence self.hard_utterence = hard_utterence def destroy(self): pass def vad_collector(self, num_padding_frames=20, ratio=0.75): """ Generator that yields series of consecutive audio frames comprising each utterence, separated by yielding a single None. Determines voice activity by ratio of frames in padding_ms. Uses a buffer to include padding_ms prior to being triggered. Example: (frame, ..., frame, None, frame, ..., frame, None, ...) |---utterence---| |---utterence---| """ ring_buffer = collections.deque(maxlen=num_padding_frames) triggered = False for i, (chunk,) in enumerate(self.stream_iterator): frame = chunk[:, 0].numpy() if len(frame) != self.segment_length: continue if self.vad_model: try: is_speech = self.vad_model(frame) if isinstance(is_speech, dict): is_speech = is_speech['vad'] except Exception as e: logger.debug(e) is_speech = False else: is_speech = True logger.debug(is_speech) frame = (frame, i * self.segment_length) if self.mode_utterence: if not self.hard_utterence: yield frame if not triggered: ring_buffer.append((frame, is_speech)) num_voiced = len([f for f, speech in ring_buffer if speech]) if num_voiced > ratio * ring_buffer.maxlen: triggered = True if self.hard_utterence: for f, s in ring_buffer: yield f ring_buffer.clear() else: if self.hard_utterence: yield frame ring_buffer.append((frame, is_speech)) num_unvoiced = len( [f for f, speech in ring_buffer if not speech] ) if num_unvoiced > ratio * ring_buffer.maxlen: triggered = False yield None ring_buffer.clear() else: yield frame ring_buffer.append((frame, is_speech)) num_unvoiced = len( [f for f, speech in ring_buffer if not speech] ) if num_unvoiced > ratio * ring_buffer.maxlen: yield None ring_buffer.clear()
[docs]def stream( src, vad_model=None, asr_model=None, classification_model=None, format=None, option=None, buffer_size: int = 4096, sample_rate: int = 16000, segment_length: int = 2560, num_padding_frames: int = 20, ratio: float = 0.75, min_length: float = 0.1, max_length: float = 10.0, realtime_print: bool = True, **kwargs, ): """ Stream an audio using torchaudio library. Parameters ---------- vad_model: object, optional (default=None) vad model / pipeline. asr_model: object, optional (default=None) ASR model / pipeline, will transcribe each subsamples realtime. classification_model: object, optional (default=None) classification pipeline, will classify each subsamples realtime. format: str, optional (default=None) Supported `format` for `torchaudio.io.StreamReader`, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader option: dict, optional (default=None) Supported `option` for `torchaudio.io.StreamReader`, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader buffer_size: int, optional (default=4096) Supported `buffer_size` for `torchaudio.io.StreamReader`, buffer size in byte. Used only when src is file-like object, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader sample_rate: int, optional (default = 16000) output sample rate. segment_length: int, optional (default=2560) usually derived from asr_model.segment_length * asr_model.hop_length, size of audio chunks, actual size in term of second is `segment_length` / `sample_rate`. num_padding_frames: int, optional (default=20) size of acceptable padding frames for queue. ratio: float, optional (default = 0.75) if 75% of the queue is positive, assumed it is a voice activity. min_length: float, optional (default=0.1) minimum length (second) to accept a subsample. max_length: float, optional (default=10.0) maximum length (second) to accept a subsample. realtime_print: bool, optional (default=True) Will print results for ASR. **kwargs: vector argument vector argument pass to malaya_speech.streaming.pyaudio.Audio interface. Returns ------- result : List[dict] """ return base_stream( audio_class=partial(Audio, src=src, format=format, option=option, buffer_size=buffer_size), vad_model=vad_model, asr_model=asr_model, classification_model=classification_model, sample_rate=sample_rate, segment_length=segment_length, num_padding_frames=num_padding_frames, ratio=ratio, min_length=min_length, max_length=max_length, realtime_print=realtime_print, **kwargs, )
[docs]def stream_rnnt( src, asr_model=None, classification_model=None, format=None, option=None, beam_width: int = 10, buffer_size: int = 4096, sample_rate: int = 16000, segment_length: int = 2560, context_length: int = 640, realtime_print: bool = True, **kwargs, ): """ Parameters ----------- src: str Supported `src` for `torchaudio.io.StreamReader` Read more at https://pytorch.org/audio/stable/tutorials/streamreader_basic_tutorial.html#sphx-glr-tutorials-streamreader-basic-tutorial-py or https://pytorch.org/audio/stable/tutorials/streamreader_advanced_tutorial.html#sphx-glr-tutorials-streamreader-advanced-tutorial-py asr_model: object, optional (default=None) ASR model / pipeline, will transcribe each subsamples realtime. must be an object of `malaya_speech.torch_model.torchaudio.Conformer`. classification_model: object, optional (default=None) classification pipeline, will classify each subsamples realtime. format: str, optional (default=None) Supported `format` for `torchaudio.io.StreamReader`, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader option: dict, optional (default=None) Supported `option` for `torchaudio.io.StreamReader`, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader buffer_size: int, optional (default=4096) Supported `buffer_size` for `torchaudio.io.StreamReader`, buffer size in byte. Used only when src is file-like object, https://pytorch.org/audio/stable/generated/torchaudio.io.StreamReader.html#torchaudio.io.StreamReader sample_rate: int, optional (default=16000) sample rate from input device, this will auto resampling. segment_length: int, optional (default=2560) usually derived from asr_model.segment_length * asr_model.hop_length, size of audio chunks, actual size in term of second is `segment_length` / `sample_rate`. context_length: int, optional (default=640) usually derived from asr_model.right_context_length * asr_model.hop_length, size of append context chunks, only useful for streaming RNNT. beam_width: int, optional (default=10) width for beam decoding. realtime_print: bool, optional (default=True) Will print results for ASR. """ if not isinstance(asr_model, Conformer): raise ValueError('`asr_model` only support Enformer RNNT.') if not getattr(asr_model, 'rnnt_streaming', False): raise ValueError('`asr_model` only support Enformer RNNT.') if classification_model: check_pipeline( classification_model, 'classification', 'classification_model' ) if asr_model.feature_extractor.pad: asr_model.feature_extractor.pad = False stream_iterator = _base_stream( src=src, format=format, option=option, buffer_size=buffer_size, sample_rate=sample_rate, ) cacher = ContextCacher(segment_length, context_length) @torch.inference_mode() def run_inference(state=None, hypothesis=None): results = [] try: for i, (chunk,) in enumerate(stream_iterator, start=1): audio = chunk[:, 0] wav_data = { 'wav_data': audio.numpy(), 'timestamp': datetime.now(), } segment = cacher(audio) features, length = asr_model.feature_extractor(segment) hypos, state = asr_model.decoder.infer( features, length, beam_width, state=state, hypothesis=hypothesis) hypothesis = hypos[0] transcript = asr_model.tokenizer(hypothesis[0], lstrip=False) wav_data['asr_model'] = transcript if len(transcript.strip()) and classification_model: t_ = classification_model(wav_data['wav_data']) if isinstance(t_, dict): t_ = t_['classification'] wav_data['classification_model'] = t_ if realtime_print: print(transcript, end='', flush=True) results.append(wav_data) except KeyboardInterrupt: pass except Exception as e: raise e return results return run_inference()