Speech-to-Text CTC#

Encoder model + CTC loss

This tutorial is available as an IPython notebook at malaya-speech/example/stt-ctc-model.

This module is not language independent, so it not save to use on different languages. Pretrained models trained on hyperlocal languages.

This is an application of malaya-speech Pipeline, read more about malaya-speech Pipeline at malaya-speech/example/pipeline.

[1]:
import malaya_speech
import numpy as np
from malaya_speech import Pipeline

List available CTC model#

[2]:
malaya_speech.stt.available_ctc()
[2]:
Size (MB) Quantized Size (MB) WER CER WER-LM CER-LM Language
hubert-conformer-tiny 36.6 10.3 0.335968 0.0882573 0.199227 0.0635223 [malay]
hubert-conformer 115 31.1 0.238714 0.0608998 0.141479 0.0450751 [malay]
hubert-conformer-large 392 100 0.220314 0.054927 0.128006 0.0385329 [malay]
hubert-conformer-large-3mixed 392 100 0.241126 0.0787939 0.132761 0.057482 [malay, singlish, mandarin]
best-rq-conformer-tiny 36.6 10.3 0.319291 0.078988 0.179582 0.055521 [malay]
best-rq-conformer 115 31.1 0.253678 0.0658045 0.154206 0.0482278 [malay]
best-rq-conformer-large 392 100 0.234651 0.0601605 0.130082 0.044521 [malay]

Google Speech-to-Text accuracy#

We tested on the same malay dataset to compare malaya-speech models and Google Speech-to-Text, check the notebook at benchmark-google-speech-malay-dataset.ipynb.

[3]:
malaya_speech.stt.google_accuracy
[3]:
{'malay': {'WER': 0.164775, 'CER': 0.059732},
 'singlish': {'WER': 0.4941349, 'CER': 0.3026296}}

Again, even some models beat google speech-to-text accuracy for CER, we really need to be skeptical with the score, the test set and postprocessing might favoured for malaya-speech.

Load CTC model#

def deep_ctc(
    model: str = 'hubert-conformer', quantized: bool = False, **kwargs
):
    """
    Load Encoder-CTC ASR model.

    Parameters
    ----------
    model : str, optional (default='hubert-conformer')
        Model architecture supported. Allowed values:

        * ``'hubert-conformer-tiny'`` - Finetuned HuBERT Conformer TINY.
        * ``'hubert-conformer'`` - Finetuned HuBERT Conformer.
        * ``'hubert-conformer-large'`` - Finetuned HuBERT Conformer LARGE.
        * ``'hubert-conformer-large-3mixed'`` - Finetuned HuBERT Conformer LARGE for (Malay + Singlish + Mandarin) languages.
        * ``'best-rq-conformer-tiny'`` - Finetuned BEST-RQ Conformer TINY.
        * ``'best-rq-conformer'`` - Finetuned BEST-RQ Conformer.
        * ``'best-rq-conformer-large'`` - Finetuned BEST-RQ Conformer LARGE.


    quantized : bool, optional (default=False)
        if True, will load 8-bit quantized model.
        Quantized model not necessary faster, totally depends on the machine.

    Returns
    -------
    result : malaya_speech.model.tf.Wav2Vec2_CTC class
    """
[4]:
model = malaya_speech.stt.deep_ctc(model = 'hubert-conformer-large')

Load Quantized deep model#

To load 8-bit quantized model, simply pass quantized = True, default is False.

We can expect slightly accuracy drop from quantized model, and not necessary faster than normal 32-bit float model, totally depends on machine.

[5]:
quantized_model = malaya_speech.stt.deep_ctc(model = 'hubert-conformer-large', quantized = True)
WARNING:root:Load quantized model will cause accuracy drop.

Load sample#

[6]:
ceramah, sr = malaya_speech.load('speech/khutbah/wadi-annuar.wav')
record1, sr = malaya_speech.load('speech/record/savewav_2020-11-26_22-36-06_294832.wav')
record2, sr = malaya_speech.load('speech/record/savewav_2020-11-26_22-40-56_929661.wav')
[8]:
import IPython.display as ipd

ipd.Audio(ceramah, rate = sr)
[8]:

As we can hear, the speaker speaks in kedahan dialects plus some arabic words, let see how good our model is.

[9]:
ipd.Audio(record1, rate = sr)
[9]:
[10]:
ipd.Audio(record2, rate = sr)
[10]:

Predict using greedy decoder#

def greedy_decoder(self, inputs):
    """
    Transcribe inputs using greedy decoder.

    Parameters
    ----------
    input: List[np.array]
        List[np.array] or List[malaya_speech.model.frame.Frame].

    Returns
    -------
    result: List[str]
    """
[9]:
%%time

model.greedy_decoder([ceramah, record1, record2])
CPU times: user 16.5 s, sys: 5.54 s, total: 22 s
Wall time: 4.2 s
[9]:
['jadi dalam perjalanan ini dunia yang susah ini ketika nabi mengajar muaz bin jabal tadi ni alah maini',
 'helo nama saya esin saya tak suka mandi ketak saya masak',
 'helo nama saya musin saya suka mandi saya mandi titiap hari']
[10]:
%%time

quantized_model.greedy_decoder([ceramah, record1, record2])
CPU times: user 16.7 s, sys: 5.5 s, total: 22.2 s
Wall time: 4.15 s
[10]:
['jadi dalam perjalanan ini dunia yang susah ini ketika nabi mengajar muaz bin jabal tadi ni alah maini',
 'helo nama saya esin saya tak suka mandi ketak saya masak',
 'helo nama saya musin saya suka mandi saya mandi titiap hari']

Predict using beam decoder#

def beam_decoder(self, inputs, beam_width: int = 100):
    """
    Transcribe inputs using beam decoder.

    Parameters
    ----------
    input: List[np.array]
        List[np.array] or List[malaya_speech.model.frame.Frame].
    beam_width: int, optional (default=100)
        beam size for beam decoder.

    Returns
    -------
    result: List[str]
    """
[11]:
%%time

model.beam_decoder([ceramah, record1, record2])
CPU times: user 26.9 s, sys: 11.8 s, total: 38.7 s
Wall time: 21.9 s
[11]:
['jadi dalam perjalanan ini dunia yang susah ini ketika nabi mengajar muaz bin jabal tadi ni alah ma ini',
 'helo nama saya esin saya tak suka mandi ketak saya masak',
 'helo nama saya musin saya suka mandi saya mandi titiap hari']
[12]:
%%time

quantized_model.beam_decoder([ceramah, record1, record2])
CPU times: user 26.5 s, sys: 11 s, total: 37.5 s
Wall time: 19.3 s
[12]:
['jadi dalam perjalanan ini dunia yang susah ini ketika nabi mengajar muaz bin jabal tadi ni alah ma ini',
 'helo nama saya esin saya tak suka mandi ketak saya masak',
 'helo nama saya musin saya suka mandi saya mandi titiap hari']

Predict logits#

def predict_logits(self, inputs, norm_func=softmax):
    """
    Predict logits from inputs.

    Parameters
    ----------
    input: List[np.array]
        List[np.array] or List[malaya_speech.model.frame.Frame].
    norm_func: Callable, optional (default=malaya.utils.activation.softmax)


    Returns
    -------
    result: List[np.array]
    """
[13]:
%%time

logits = model.predict_logits([ceramah, record1, record2])
CPU times: user 28.1 s, sys: 11.9 s, total: 40 s
Wall time: 22.6 s
[16]:
logits, logits[0].shape, logits[1].shape
[16]:
([array([[1.8236330e-14, 1.0061867e-09, 9.1841962e-10, ..., 1.0257798e-12,
          1.4030079e-09, 2.5916536e-04],
         [8.9639464e-15, 1.7033805e-09, 4.6274323e-10, ..., 6.6502495e-13,
          8.0938090e-10, 4.3023058e-04],
         [3.1887525e-15, 2.8573975e-08, 1.4154197e-10, ..., 3.1815506e-13,
          2.4540896e-09, 1.0617726e-03],
         ...,
         [8.8733091e-16, 1.2685842e-08, 6.6173206e-11, ..., 3.7230733e-13,
          4.0759787e-12, 1.5587657e-04],
         [6.6120927e-16, 6.5794703e-09, 5.7139283e-11, ..., 4.9776345e-13,
          4.2690807e-12, 8.7887602e-05],
         [5.4730716e-16, 2.1361308e-09, 9.3025414e-11, ..., 5.9886417e-13,
          1.5507430e-11, 1.0300719e-04]], dtype=float32),
  array([[2.5148340e-15, 3.8227799e-09, 1.0598683e-09, ..., 4.1977169e-13,
          7.1061929e-10, 3.9750684e-04],
         [2.9357024e-14, 4.2996367e-11, 8.0736609e-09, ..., 3.3508845e-12,
          3.6292098e-11, 3.5416318e-07],
         [1.9910680e-15, 7.9035152e-09, 3.4416633e-10, ..., 1.3513049e-12,
          1.3454201e-09, 1.0257948e-04],
         ...,
         [3.6507369e-16, 3.2854697e-09, 3.0207747e-10, ..., 8.0089668e-13,
          1.2993007e-10, 1.5945344e-03],
         [4.2754438e-16, 6.5015402e-09, 2.0224432e-10, ..., 8.0650563e-13,
          1.6817953e-10, 1.4560440e-03],
         [3.1852618e-16, 6.3009313e-09, 1.1704311e-10, ..., 5.8428116e-13,
          2.6164873e-10, 1.3219280e-03]], dtype=float32),
  array([[7.4554222e-15, 8.0248808e-10, 4.3944155e-09, ..., 7.1492242e-13,
          1.8890089e-09, 8.0676808e-05],
         [2.0308339e-15, 2.7493043e-09, 9.9622988e-10, ..., 3.6468658e-13,
          4.4663002e-09, 4.2535696e-04],
         [9.8700091e-16, 5.8224430e-09, 9.5636510e-10, ..., 2.5640658e-13,
          6.0229244e-10, 5.0613942e-04],
         ...,
         [6.6229983e-16, 2.5211668e-09, 1.0523804e-10, ..., 8.1805699e-13,
          2.0467164e-10, 1.5298247e-03],
         [7.8388564e-16, 3.5194940e-09, 8.4871700e-11, ..., 6.7016152e-13,
          1.1333524e-10, 1.0733902e-03],
         [6.9106448e-16, 3.5260113e-09, 6.5353133e-11, ..., 4.5797551e-13,
          1.2999080e-10, 9.0871076e-04]], dtype=float32)],
 (499, 39),
 (299, 39))

You can use output from predict_logits to feed into ctc-decoders or pyctcdecode with language model to get better results.

[ ]: