Source code for malaya_speech.model.huggingface

import tensorflow as tf
import numpy as np
from itertools import groupby
from malaya_speech.model.frame import Frame
from malaya_speech.utils.astype import int_to_float
from malaya_speech.utils.padding import sequence_1d
from malaya_speech.utils.char import HF_CTC_VOCAB, HF_CTC_VOCAB_IDX
from malaya_speech.utils.char import decode as char_decode
from malaya_speech.utils.read import resample
from malaya_speech.utils.activation import softmax
from malaya_speech.utils.aligner import (
    get_trellis,
    backtrack,
    merge_repeats,
    merge_words,
)
from malaya_speech.model.abstract import Abstract
from scipy.special import log_softmax
from typing import Callable


def batching(audios):
    batch, lens = sequence_1d(audios, return_len=True)
    attentions = [[1] * l for l in lens]
    attentions = sequence_1d(attentions)
    normed_input_values = []

    for vector, length in zip(batch, attentions.sum(-1)):
        normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
        if length < normed_slice.shape[0]:
            normed_slice[length:] = 0.0

        normed_input_values.append(normed_slice)

    normed_input_values = np.array(normed_input_values)
    return normed_input_values.astype(np.float32), attentions


[docs]class HuggingFace_CTC(Abstract): def __init__(self, hf_model, model, name): self.hf_model = hf_model self.__model__ = model self.__name__ = name
[docs] def greedy_decoder(self, inputs): """ Transcribe inputs using greedy decoder. Parameters ---------- input: List[np.array] List[np.array] or List[malaya_speech.model.frame.Frame]. Returns ------- result: List[str] """ logits = self.predict_logits(inputs=inputs) argmax = np.argmax(logits, axis=-1) results = [] for i in range(len(argmax)): tokens = char_decode(argmax[i], lookup=HF_CTC_VOCAB + ['_']) grouped_tokens = [token_group[0] for token_group in groupby(tokens)] filtered_tokens = list(filter(lambda token: token != '_', grouped_tokens)) r = ''.join(filtered_tokens).strip() results.append(r) return results
[docs] def predict(self, inputs): """ Predict logits from inputs using greedy decoder. Parameters ---------- input: List[np.array] List[np.array] or List[malaya_speech.model.frame.Frame]. Returns ------- result: List[str] """ return self.greedy_decoder(inputs=inputs)
[docs] def predict_logits(self, inputs, norm_func=softmax): """ Predict logits from inputs. Parameters ---------- input: List[np.array] List[np.array] or List[malaya_speech.model.frame.Frame]. norm_func: Callable, optional (default=malaya.utils.activation.softmax) Returns ------- result: List[np.array] """ inputs = [ input.array if isinstance(input, Frame) else input for input in inputs ] normed_input_values, attentions = batching(inputs) out = self.hf_model(normed_input_values, attention_mask=attentions) return norm_func(out[0].numpy(), axis=-1)
[docs] def gradio(self, record_mode: bool = True, lm_func: Callable = None, **kwargs): """ Transcribe an input using beam decoder on Gradio interface. Parameters ---------- record_mode: bool, optional (default=True) if True, Gradio will use record mode, else, file upload mode. lm_func: Callable, optional (default=None) if not None, will pass a logits with shape [T, D]. **kwargs: keyword arguments for `iface.launch`. """ try: import gradio as gr except BaseException: raise ModuleNotFoundError( 'gradio not installed. Please install it by `pip install gradio` and try again.' ) def pred(audio): sample_rate, data = audio if len(data.shape) == 2: data = np.mean(data, axis=1) data = int_to_float(data) data = resample(data, sample_rate, 16000) if lm_func is not None: logits = self.predict_logits(inputs=[data])[0] return lm_func(logits) else: return self.greedy_decoder(inputs=[data])[0] title = 'HuggingFace-Wav2Vec2-STT' if lm_func is not None: title = f'{title} with LM' description = 'It will take sometime for the first time, after that, should be really fast.' if record_mode: input = 'microphone' else: input = 'audio' iface = gr.Interface(pred, input, 'text', title=title, description=description) return iface.launch(**kwargs)
def __call__(self, input): """ Transcribe input using greedy decoder. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. Returns ------- result: str """ return self.predict([input])[0]
[docs]class HuggingFace_Aligner(Abstract): def __init__(self, hf_model, model, name): self.hf_model = hf_model self.__model__ = model self.__name__ = name
[docs] def predict(self, input, transcription: str, sample_rate: int = 16000): """ Transcribe input, will return a string. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. transcription: str transcription of input audio. sample_rate: int, optional (default=16000) sample rate for `input`. Returns ------- result: Dict[chars_alignment, words_alignment, alignment] """ input = input.array if isinstance(input, Frame) else input normed_input_values, attentions = batching([input]) out = self.hf_model(normed_input_values, attention_mask=attentions) logits = out[0].numpy() o = log_softmax(logits, axis=-1)[0] tokens = [HF_CTC_VOCAB_IDX[c] for c in transcription] trellis = get_trellis(o, tokens, blank_id=len(HF_CTC_VOCAB)) path = backtrack(trellis, o, tokens, blank_id=len(HF_CTC_VOCAB)) segments = merge_repeats(path, transcription) word_segments = merge_words(segments) t = (len(input) / sample_rate) / o.shape[0] chars_alignment = [] for s in segments: chars_alignment.append({'text': s.label, 'start': s.start * t, 'end': s.end * t, 'start_t': s.start, 'end_t': s.end, 'score': s.score}) words_alignment = [] for s in word_segments: words_alignment.append({'text': s.label, 'start': s.start * t, 'end': s.end * t, 'start_t': s.start, 'end_t': s.end, 'score': s.score}) return { 'chars_alignment': chars_alignment, 'words_alignment': words_alignment, 'alignment': trellis, }
def __call__(self, input, transcription: str): """ Transcribe input, will return a string. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. transcription: str transcription of input audio Returns ------- result: Dict[chars_alignment, words_alignment, alignment] """ return self.predict(input, transcription)