Source code for malaya_speech.utils.aligner

import numpy as np
from dataclasses import dataclass
from malaya_speech.utils.char import CTC_VOCAB as labels


def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0):
    from scipy.stats import betabinom

    x = np.arange(0, phoneme_count)
    mel_text_probs = []
    for i in range(1, mel_count + 1):
        a, b = scaling_factor * i, scaling_factor * (mel_count + 1 - i)
        mel_i_prob = betabinom(phoneme_count, a, b).pmf(x)
        mel_text_probs.append(mel_i_prob)
    return np.array(mel_text_probs)


def mas(attn_map, width=1):
    # assumes mel x text
    opt = np.zeros_like(attn_map)
    attn_map = np.log(attn_map)
    attn_map[0, 1:] = -np.inf
    log_p = np.zeros_like(attn_map)
    log_p[0, :] = attn_map[0, :]
    prev_ind = np.zeros_like(attn_map, dtype=np.int64)
    for i in range(1, attn_map.shape[0]):
        for j in range(attn_map.shape[1]):
            prev_j = np.arange(max(0, j - width), j + 1)
            prev_log = np.array([log_p[i - 1, prev_idx] for prev_idx in prev_j])

            ind = np.argmax(prev_log)
            log_p[i, j] = attn_map[i, j] + prev_log[ind]
            prev_ind[i, j] = prev_j[ind]

    # now backtrack
    curr_text_idx = attn_map.shape[1] - 1
    for i in range(attn_map.shape[0] - 1, -1, -1):
        opt[i, curr_text_idx] = 1
        curr_text_idx = prev_ind[i, curr_text_idx]
    opt[0, curr_text_idx] = 1

    assert opt.sum(0).all()
    assert opt.sum(1).all()

    return opt


def binarize_attention(attn, in_len, out_len):
    b_size = attn.shape[0]
    attn_cpu = attn
    attn_out = np.zeros_like(attn)
    for ind in range(b_size):
        hard_attn = mas(attn_cpu[ind, 0, : out_len[ind], : in_len[ind]])
        attn_out[ind, 0, : out_len[ind], : in_len[ind]] = hard_attn
    return attn_out


[docs]@dataclass class Point: token_index: int time_index: int score: float
[docs]@dataclass class Segment: label: str start: int end: int score: float def __repr__(self): return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" @property def length(self): return self.end - self.start
def get_trellis(emission, tokens, blank_id=len(labels)): num_frame = emission.shape[0] num_tokens = len(tokens) trellis = np.full((num_frame + 1, num_tokens + 1), -float('inf')) trellis[:, 0] = 0 for t in range(num_frame): trellis[t + 1, 1:] = np.maximum( trellis[t, 1:] + emission[t, blank_id], trellis[t, :-1] + emission[t, tokens], ) return trellis def backtrack(trellis, emission, tokens, blank_id=len(labels)): # Note: # j and t are indices for trellis, which has extra dimensions # for time and tokens at the beginning. # When referring to time frame index `T` in trellis, # the corresponding index in emission is `T-1`. # Similarly, when referring to token index `J` in trellis, # the corresponding index in transcript is `J-1`. j = trellis.shape[1] - 1 t_start = np.argmax(trellis[:, j]).item() path = [] for t in range(t_start, 0, -1): # 1. Figure out if the current position was stay or change # Note (again): # `emission[J-1]` is the emission at time frame `J` of trellis dimension. # Score for token staying the same from time frame J-1 to T. stayed = trellis[t - 1, j] + emission[t - 1, blank_id] # Score for token changing from C-1 at T-1 to J at T. changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] # 2. Store the path with frame-wise probability. prob = np.exp(emission[t - 1, tokens[j - 1] if changed > stayed else 0]).item() # Return token index and time index in non-trellis coordinate. path.append(Point(j - 1, t - 1, prob)) # 3. Update the token if changed > stayed: j -= 1 if j == 0: break else: raise ValueError('Failed to align') return path[::-1] def merge_repeats(path, transcript): i1, i2 = 0, 0 segments = [] while i1 < len(path): while i2 < len(path) and path[i1].token_index == path[i2].token_index: i2 += 1 score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) segments.append( Segment( transcript[path[i1].token_index], path[i1].time_index, path[i2 - 1].time_index + 1, score, ) ) i1 = i2 return segments def merge_words(segments, separator=' '): words = [] i1, i2 = 0, 0 while i1 < len(segments): if i2 >= len(segments) or segments[i2].label == separator: if i1 != i2: segs = segments[i1:i2] word = "".join([seg.label for seg in segs]) score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs) words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score)) i1 = i2 + 1 i2 = i1 else: i2 += 1 return words
[docs]def put_comma(alignment, min_threshold: float = 0.5): """ Put comma in alignment from force alignment model. Parameters ----------- alignment: List[Dict[text, start, end]] min_threshold: float, optional (default=0.5) minimum threshold in term of seconds to assume a comma. Returns -------- result: List[str] """ r = [] for no, row in enumerate(alignment): if no > 0: if alignment[no]['start'] - alignment[no-1]['end'] >= min_threshold: r.append(',') r.append(row['text']) r.append('.') return r
[docs]def plot_alignments( alignment, subs_alignment, words_alignment, waveform, separator: str = ' ', sample_rate: int = 16000, figsize: tuple = (16, 9), plot_score_char: bool = False, plot_score_word: bool = True, ): """ plot alignment. Parameters ---------- alignment: np.array usually `alignment` output. subs_alignment: list usually `chars_alignment` or `subwords_alignment` output. words_alignment: list usually `words_alignment` output. waveform: np.array input audio. separator: str, optional (default=' ') separator between words, only useful if `subs_alignment` is character based. sample_rate: int, optional (default=16000) figsize: tuple, optional (default=(16, 9)) figure size for matplotlib `figsize`. plot_score_char: bool, optional (default=False) plot score on top of character plots. plot_score_word: bool, optional (default=True) plot score on top of word plots. """ try: import seaborn as sns import matplotlib.pyplot as plt except BaseException: raise ValueError( 'seaborn and matplotlib not installed. Please install it by `pip install matplotlib seaborn` and try again.' ) trellis_with_path = alignment.copy() if trellis_with_path.shape[1] == len(subs_alignment): trellis_with_path = np.concatenate([trellis_with_path, np.zeros((trellis_with_path.shape[0], 2))], axis=1) for i, seg in enumerate(subs_alignment): if seg['text'] != separator: trellis_with_path[seg['start_t'] + 1: seg['end_t'] + 1, i + 1] = float('nan') fig, [ax1, ax2] = plt.subplots(2, 1, figsize=figsize) ax1.imshow(trellis_with_path[1:, 1:].T, origin='lower') ax1.set_xticks([]) ax1.set_yticks([]) for word in words_alignment: ax1.axvline(word['start_t'] - 0.5) ax1.axvline(word['end_t'] - 0.5) for i, seg in enumerate(subs_alignment): if seg['text'] != separator: ax1.annotate(seg['text'], (seg['start_t'], i + 0.3)) if plot_score_char: ax1.annotate(f"{seg['score']:.2f}", (seg['start_t'], i + 4), fontsize=8) ratio = waveform.shape[0] / (alignment.shape[0] - 1) ax2.plot(waveform) for word in words_alignment: x0 = ratio * word['start_t'] x1 = ratio * word['end_t'] ax2.axvspan(x0, x1, alpha=0.1, color='red') if plot_score_word: ax2.annotate(f"{word['score']:.2f}", (x0, 0.8)) for seg in subs_alignment: if seg['text'] != separator: ax2.annotate(seg['text'], (seg['start_t'] * ratio, 0.9)) xticks = ax2.get_xticks() plt.xticks(xticks, xticks / sample_rate) ax2.set_xlabel('time (second)') ax2.set_yticks([]) ax2.set_ylim(-1.0, 1.0) ax2.set_xlim(0, waveform.shape[0]) plt.show()