Source code for malaya_speech.utils.subword

from .text_encoder.subword_encoder import SubwordTextEncoder, _trim_underscore_and_tell
from .text_encoder import pad_decr
import re
import six
from typing import List
import logging

logger = logging.getLogger(__name__)

BLANK = 0

sentencepiece_available = False
try:
    import sentencepiece as spm
    sentencepiece_available = True
except Exception as e:
    logger.warning(
        '`sentencepiece` is not available, any models that use sentencepiece will not able to use.')


def get_index_multilanguage(r, tokenizers, len_vocab):
    for i in range(len(tokenizers)):
        sum_v = sum(len_vocab[:i + 1])
        if r < sum(len_vocab[:i + 1]):
            return i, r - sum(len_vocab[:i])


[docs]def generate_tokenizer( strings: List[str], target_vocab_size: int = 1024, max_subword_length: int = 4, max_corpus_chars=None, reserved_tokens=None, ): """ Build a subword dictionary. """ return SubwordTextEncoder.build_from_corpus( strings, target_vocab_size=target_vocab_size, max_subword_length=max_subword_length, max_corpus_chars=max_corpus_chars, reserved_tokens=reserved_tokens, )
[docs]def save(tokenizer, path: str): """ Save subword dictionary to a text file. """ tokenizer.save_to_file(path)
[docs]def load(path: str): """ Load text file into subword dictionary. """ return SubwordTextEncoder.load_from_file(path)
[docs]def encode(tokenizer, string: str, add_blank: bool = False): """ Encode string to integer representation based on ascii table or lookup variable. Parameters ----------- tokenizer: object tokenizer object string: str add_blank: bool, optional (default=False) add BLANK token at the starting of encoded, this is for transducer / transformer based. lookup: List[str], optional (default=None) list of unique strings. Returns -------- result: List[int] """ r = tokenizer.encode(string) if add_blank: r = [BLANK] + r return r
[docs]def decode(tokenizer, ids): """ Decode integer representation to string based on tokenizer vocab. Parameters ----------- tokenizer: object tokenizer object ids: List[int] Returns -------- result: str """ return tokenizer.decode([i for i in ids if i > 0])
[docs]def decode_multilanguage(tokenizers, ids): """ Decode integer representation to string using list of tokenizer objects. Parameters ----------- tokenizers: List[object] List of tokenizer objects. ids: List[int] Returns -------- result: str """ if not len(ids): return '' len_vocab = [l.vocab_size for l in tokenizers] last_index, v = get_index_multilanguage(ids[0], tokenizers, len_vocab) d, q = [], [v] for r in ids[1:]: index, v = get_index_multilanguage(r, tokenizers, len_vocab) if index != last_index: d.append(decode(tokenizers[last_index], q)) q = [v] last_index = index else: q.append(v) if len(q): d.append(decode(tokenizers[last_index], q)) d = re.sub(r'[ ]+', ' ', ' '.join(d)).strip() return d
def align_multilanguage(tokenizers, ids, get_index=False): ids = pad_decr(ids) subword_ids = ids subwords_ = [] prev_bytes = [] prev_ids = [] ids = [] len_vocab = [l.vocab_size for l in tokenizers] def consume_prev_bytes(): if prev_bytes: subwords_.extend(prev_bytes) ids.extend(prev_ids) return [], [] for no, subword_id in enumerate(subword_ids): last_index, v = get_index_multilanguage(subword_id, tokenizers, len_vocab) subword = tokenizers[last_index]._id_to_subword(v) if isinstance(subword, six.binary_type): # Byte-encoded prev_bytes.append(subword.decode('utf-8', 'replace')) if subword == b' ': prev_ids.append(None) else: prev_ids.append(no) else: # If there were bytes previously, convert to unicode. prev_bytes, prev_ids = consume_prev_bytes() trimmed, add_space = _trim_underscore_and_tell(subword) ids.append(no) subwords_.append(trimmed) if add_space: subwords_.append(' ') ids.append(None) prev_bytes = consume_prev_bytes() if get_index: return subwords_, ids else: return tf.compat.as_text(''.join(subwords_))
[docs]def load_sentencepiece(model_file): """ Parameters ---------- model_file: str sentencepiece model file. Returns -------- result: sentencepiece.SentencePieceProcessor """ if not sentencepiece_available: raise ModuleNotFoundError( 'sentencepiece not installed. Please install it by `pip install sentencepiece` and try again.' ) return spm.SentencePieceProcessor(model_file=model_file)
class SentencePieceTokenProcessor: def __init__(self, sp_model_path): self.sp_model = load_sentencepiece(model_file=sp_model_path) self.post_process_remove_list = { self.sp_model.unk_id(), self.sp_model.eos_id(), self.sp_model.pad_id(), } def __call__(self, tokens: List[int], lstrip: bool = True) -> str: filtered_hypo_tokens = [token_index for token_index in tokens[1:] if token_index not in self.post_process_remove_list] output_string = ''.join( self.sp_model.id_to_piece(filtered_hypo_tokens)).replace( '\u2581', ' ') if lstrip: return output_string.lstrip() else: return output_string def merge_sentencepiece_tokens( paired_tokens, **kwargs, ): new_paired_tokens = [] n_tokens = len(paired_tokens) i = 0 while i < n_tokens: current_token, current_weight = paired_tokens[i] if isinstance(current_token, bytes): current_token = current_token.decode() if not current_token.startswith('▁'): previous_token, previous_weight = new_paired_tokens.pop() merged_token = previous_token merged_weight = [previous_weight] while ( not current_token.startswith('▁') ): merged_token = merged_token + current_token.replace('▁', '') merged_weight.append(current_weight) i = i + 1 if i < n_tokens: current_token, current_weight = paired_tokens[i] else: break new_paired_tokens.append((merged_token, merged_weight)) else: new_paired_tokens.append((current_token, current_weight)) i = i + 1 return new_paired_tokens def merge_bpe_tokens( paired_tokens, rejected=['<s>', '</s>', '<unk>', '<pad>', '<mask>'], prefix_char='Ġ', **kwargs, ): new_paired_tokens = [] paired_tokens = [t for t in paired_tokens if t[0] not in rejected] n_tokens = len(paired_tokens) i = 0 while i < n_tokens: current_token, current_weight = paired_tokens[i] if isinstance(current_token, bytes): current_token = current_token.decode() if i > 0 and not current_token.startswith(prefix_char) and current_token not in rejected: previous_token, previous_weight = new_paired_tokens.pop() merged_token = previous_token merged_weight = [previous_weight] while ( not current_token.startswith(prefix_char) and current_token not in rejected ): merged_token = merged_token + current_token.replace(prefix_char, '') merged_weight.append(current_weight) i = i + 1 if i < n_tokens: current_token, current_weight = paired_tokens[i] else: break new_paired_tokens.append((merged_token, merged_weight)) else: new_paired_tokens.append((current_token, current_weight)) i = i + 1 words = [ i[0].replace(prefix_char, '') for i in new_paired_tokens if i[0] not in rejected ] weights = [i[1] for i in new_paired_tokens if i[0] not in rejected] return list(zip(words, weights))