import torch
import numpy as np
from itertools import groupby
from malaya_speech.model.frame import Frame
from malaya_speech.utils.astype import int_to_float
from malaya_speech.utils.padding import sequence_1d
from malaya_speech.utils.char import HF_CTC_VOCAB, HF_CTC_VOCAB_IDX
from malaya_speech.utils.char import decode as char_decode
from malaya_speech.utils.read import resample
from malaya_speech.utils.activation import softmax
from malaya_speech.utils.aligner import (
get_trellis,
backtrack,
merge_repeats,
merge_words,
)
from malaya_speech.utils.subword import merge_bpe_tokens
from malaya_speech.model.abstract import Abstract
from malaya_boilerplate.torch_utils import to_tensor_cuda, to_numpy
from scipy.special import log_softmax
from typing import Callable
import logging
logger = logging.getLogger(__name__)
whisper_available = False
try:
import whisper
whisper_available = True
except Exception as e:
logger.warning(
'`openai-whisper` is not available, native whisper processor is not available, will use huggingface processor instead.')
def batching(audios):
batch, lens = sequence_1d(audios, return_len=True)
attentions = [[1] * l for l in lens]
attentions = sequence_1d(attentions)
normed_input_values = []
for vector, length in zip(batch, attentions.sum(-1)):
normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
if length < normed_slice.shape[0]:
normed_slice[length:] = 0.0
normed_input_values.append(normed_slice)
normed_input_values = np.array(normed_input_values)
return normed_input_values.astype(np.float32), attentions
[docs]class CTC(torch.nn.Module):
def __init__(self, hf_model, model, name):
super().__init__()
self.hf_model = hf_model
self.__model__ = model
self.__name__ = name
[docs] def greedy_decoder(self, inputs):
"""
Transcribe inputs using greedy decoder.
Parameters
----------
input: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
Returns
-------
result: List[str]
"""
logits = self.predict_logits(inputs=inputs)
argmax = np.argmax(logits, axis=-1)
results = []
for i in range(len(argmax)):
tokens = char_decode(argmax[i], lookup=HF_CTC_VOCAB + ['_'])
grouped_tokens = [token_group[0] for token_group in groupby(tokens)]
filtered_tokens = list(filter(lambda token: token != '_', grouped_tokens))
r = ''.join(filtered_tokens).strip()
results.append(r)
return results
[docs] def predict(self, inputs):
"""
Predict logits from inputs using greedy decoder.
Parameters
----------
input: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
Returns
-------
result: List[str]
"""
return self.greedy_decoder(inputs=inputs)
[docs] def predict_logits(self, inputs, norm_func=softmax):
"""
Predict logits from inputs.
Parameters
----------
input: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
norm_func: Callable, optional (default=malaya.utils.activation.softmax)
Returns
-------
result: List[np.array]
"""
inputs = [
input.array if isinstance(input, Frame) else input
for input in inputs
]
cuda = next(self.hf_model.parameters()).is_cuda
normed_input_values, attentions = batching(inputs)
normed_input_values = to_tensor_cuda(torch.tensor(normed_input_values), cuda)
attentions = to_tensor_cuda(torch.tensor(attentions), cuda)
out = self.hf_model(normed_input_values, attention_mask=attentions)
return norm_func(to_numpy(out[0]), axis=-1)
[docs] def gradio(self, record_mode: bool = True,
lm_func: Callable = None,
**kwargs):
"""
Transcribe an input using beam decoder on Gradio interface.
Parameters
----------
record_mode: bool, optional (default=True)
if True, Gradio will use record mode, else, file upload mode.
lm_func: Callable, optional (default=None)
if not None, will pass a logits with shape [T, D].
**kwargs: keyword arguments for `iface.launch`.
"""
try:
import gradio as gr
except BaseException:
raise ModuleNotFoundError(
'gradio not installed. Please install it by `pip install gradio` and try again.'
)
def pred(audio):
sample_rate, data = audio
if len(data.shape) == 2:
data = np.mean(data, axis=1)
data = int_to_float(data)
data = resample(data, sample_rate, 16000)
if lm_func is not None:
logits = self.predict_logits(inputs=[data])[0]
return lm_func(logits)
else:
return self.greedy_decoder(inputs=[data])[0]
title = 'HuggingFace-Wav2Vec2-STT'
if lm_func is not None:
title = f'{title} with LM'
description = 'It will take sometime for the first time, after that, should be really fast.'
if record_mode:
input = 'microphone'
else:
input = 'audio'
iface = gr.Interface(pred, input, 'text', title=title, description=description)
return iface.launch(**kwargs)
def __call__(self, input):
"""
Transcribe input using greedy decoder.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: str
"""
return self.predict([input])[0]
[docs]class Aligner(torch.nn.Module):
def __init__(self, hf_model, model, name):
super().__init__()
self.hf_model = hf_model
self.__model__ = model
self.__name__ = name
[docs] def predict(self, input, transcription: str, sample_rate: int = 16000):
"""
Transcribe input, will return a string.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
transcription: str
transcription of input audio.
sample_rate: int, optional (default=16000)
sample rate for `input`.
Returns
-------
result: Dict[chars_alignment, words_alignment, alignment]
"""
input = input.array if isinstance(input, Frame) else input
cuda = next(self.hf_model.parameters()).is_cuda
normed_input_values, attentions = batching([input])
normed_input_values = to_tensor_cuda(torch.tensor(normed_input_values), cuda)
attentions = to_tensor_cuda(torch.tensor(attentions), cuda)
out = self.hf_model(normed_input_values, attention_mask=attentions)
logits = to_numpy(out[0])
o = log_softmax(logits, axis=-1)[0]
tokens = [HF_CTC_VOCAB_IDX[c] for c in transcription]
trellis = get_trellis(o, tokens, blank_id=len(HF_CTC_VOCAB))
path = backtrack(trellis, o, tokens, blank_id=len(HF_CTC_VOCAB))
segments = merge_repeats(path, transcription)
word_segments = merge_words(segments)
t = (len(input) / sample_rate) / o.shape[0]
chars_alignment = []
for s in segments:
chars_alignment.append({'text': s.label,
'start': s.start * t,
'end': s.end * t,
'start_t': s.start,
'end_t': s.end,
'score': s.score})
words_alignment = []
for s in word_segments:
words_alignment.append({'text': s.label,
'start': s.start * t,
'end': s.end * t,
'start_t': s.start,
'end_t': s.end,
'score': s.score})
return {
'chars_alignment': chars_alignment,
'words_alignment': words_alignment,
'alignment': trellis,
}
def __call__(self, input, transcription: str):
"""
Transcribe input, will return a string.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
transcription: str
transcription of input audio
Returns
-------
result: Dict[chars_alignment, words_alignment, alignment]
"""
return self.predict(input, transcription)
[docs]class Seq2Seq(torch.nn.Module):
def __init__(self, hf_model, processor, model, name, use_whisper_processor=False, **kwargs):
super().__init__()
self.hf_model = hf_model
self.processor = processor
self.__model__ = model
self.__name__ = name
if use_whisper_processor:
if 'whisper' not in model.lower():
logger.warning(
'`use_whisper_processor` only available for whisper model, will fallback to huggingface processor')
use_whisper_processor = False
if not whisper_available:
logger.warning(
'openai-whisper not installed. Please install it by `pip install openai-whisper` and try again. Will fallback to huggingface processor')
use_whisper_processor = False
self.use_whisper_processor = use_whisper_processor
[docs] def generate(self, inputs, skip_special_tokens: bool = True, **kwargs):
"""
Transcribe inputs.
Returns
-------
result: List[str]
Parameters
----------
input: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
skip_special_tokens: bool, optional (default=True)
skip special tokens during decoding.
**kwargs: vector arguments pass to huggingface `generate` method.
Read more at https://huggingface.co/docs/transformers/main_classes/text_generation
Returns
-------
result: List[str]
"""
inputs = [
input.array if isinstance(input, Frame) else input
for input in inputs
]
cuda = next(self.hf_model.parameters()).is_cuda
if self.use_whisper_processor:
mels = []
for k in range(len(inputs)):
audio = whisper.pad_or_trim(inputs[k].astype(np.float32))
mel = whisper.log_mel_spectrogram(audio)
mels.append({'input_features': mel})
batch = self.processor.feature_extractor.pad(mels, return_tensors="pt")
input_features = batch.input_features
else:
input_features = self.processor(
inputs, return_tensors='pt', sampling_rate=16000).input_features
input_features = to_tensor_cuda(input_features, cuda)
outputs = self.hf_model.generate(input_features, **kwargs)
return self.processor.tokenizer.batch_decode(
outputs, skip_special_tokens=skip_special_tokens)
[docs] def predict_logits(self, inputs, norm_func=softmax, **kwargs):
"""
Predict logits from inputs.
Parameters
----------
input: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
norm_func: Callable, optional (default=malaya.utils.activation.softmax)
Returns
-------
result: List[np.array]
"""
if kwargs.get('num_beams', 0) > 0:
raise ValueError('beam decoding is not supported.')
outputs = self.generate(
inputs=inputs,
output_attentions=True,
output_hidden_states=True,
output_scores=True,
return_dict_in_generate=True,
**kwargs,
)
stacked = torch.stack(outputs.scores)
return to_numpy(stacked)
def __call__(self, input, **kwargs):
"""
Transcribe input.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: str
"""
return self.generate([input], **kwargs)[0]
[docs]class Seq2SeqAligner(torch.nn.Module):
def __init__(self, hf_model, processor, model, name, **kwargs):
super().__init__()
self.hf_model = hf_model
self.processor = processor
self.tokenizer = self.processor.tokenizer
self.__model__ = model
self.__name__ = name
self.AUDIO_SAMPLES_PER_TOKEN = processor.feature_extractor.hop_length * 2
self.AUDIO_TIME_PER_TOKEN = self.AUDIO_SAMPLES_PER_TOKEN / processor.feature_extractor.sampling_rate
[docs] def predict(
self,
input,
transcription: str,
lang: str = 'ms',
median_filter_size: int = 7,
):
"""
Transcribe input, will return a string.
Based on https://github.com/openai/whisper/blob/main/notebooks/Multilingual_ASR.ipynb
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
transcription: str
transcription of input audio.
lang: str, optional (default='ms')
if you feed singlish speech, it is better to give `en` language.
median_filter_size: int, optional (default=7)
sliding median size.
Returns
-------
result: Dict[chars_alignment, words_alignment, alignment]
"""
try:
from dtw import dtw
from scipy.signal import medfilt
except Exception as e:
raise ModuleNotFoundError(
'dtw-python not installed. Please install it by `pip install dtw-python` and try again.'
)
input = input.array if isinstance(input, Frame) else input
cuda = next(self.hf_model.parameters()).is_cuda
input_features = self.processor([input], return_tensors='pt').input_features
input_features = to_tensor_cuda(input_features, cuda)
label = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(
f'<|startoftranscript|><|{lang}|><|transcribe|><|notimestamps|>{transcription}<|endoftext|>'))
labels = self.tokenizer.pad([{'input_ids': label}], return_tensors='pt')
with torch.no_grad():
o = self.hf_model(
input_features=input_features,
labels=labels['input_ids'],
output_attentions=True,
return_dict=True,
)
duration = len(input)
weights = torch.cat(o['cross_attentions'])
weights = weights[:, :, :, : duration // self.AUDIO_SAMPLES_PER_TOKEN].cpu()
weights = medfilt(weights, (1, 1, 1, median_filter_size))
weights = torch.tensor(weights).softmax(dim=-1)
w = weights / weights.norm(dim=-2, keepdim=True)
matrix = w.mean(axis=(0, 1))
alignment = dtw(-matrix.double().numpy())
xticks = np.arange(0, matrix.shape[1], 1 / self.AUDIO_TIME_PER_TOKEN)
xticklabels = (xticks * self.AUDIO_TIME_PER_TOKEN).round().astype(np.int32)
yticklabels = self.tokenizer.convert_ids_to_tokens(labels['input_ids'][0])
yticks = np.arange(len(yticklabels))
jumps = np.pad(np.diff(alignment.index1s), (1, 0), constant_values=1).astype(bool)
jump_times = alignment.index2s[jumps] * self.AUDIO_TIME_PER_TOKEN
subwords_alignment = []
for i in range(len(yticklabels)):
d = {
'text': yticklabels[i],
'start': 0.0 if i == 0 else jump_times[i - 1],
'end': jump_times[i]
}
subwords_alignment.append(d)
merged_bpes = merge_bpe_tokens(
zip(yticklabels, subwords_alignment), rejected=self.tokenizer.all_special_tokens)
words_alignment = []
for m in merged_bpes:
if isinstance(m[1], list):
start = m[1][0]['start']
end = m[1][-1]['end']
else:
start = m[1]['start']
end = m[1]['end']
words_alignment.append({
'text': m[0],
'start': start,
'end': end,
})
alignment_x = alignment.index2s
alignment_y = alignment.index1s
return {
'subwords_alignment': subwords_alignment,
'words_alignment': words_alignment,
'alignment': to_numpy(matrix),
'alignment_x': alignment_x,
'alignment_y': alignment_y,
'xticks': xticks,
'xticklabels': xticklabels,
'yticks': yticks,
'yticklabels': yticklabels,
}
[docs]class XVector(torch.nn.Module):
def __init__(self, hf_model, processor, model, name):
super().__init__()
self.hf_model = hf_model
self.processor = processor
self.__model__ = model
self.__name__ = name
[docs] def vectorize(self, inputs):
"""
Vectorize inputs.
Parameters
----------
inputs: List[np.array]
List[np.array] or List[malaya_speech.model.frame.Frame].
Returns
-------
result: np.array
returned [B, D].
"""
inputs = [
input.array if isinstance(input, Frame) else input
for input in inputs
]
cuda = next(self.hf_model.parameters()).is_cuda
inputs = self.processor(inputs, return_tensors='pt', sampling_rate=16000, padding=True)
for k in inputs.keys():
inputs[k] = to_tensor_cuda(inputs[k], cuda)
embeddings = self.hf_model(**inputs).embeddings
embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
return to_numpy(embeddings)
[docs] def forward(self, inputs):
return self.vectorize(inputs)