Source code for malaya_speech.model.unet

from malaya_speech.utils import featurization
from malaya_speech.model.frame import Frame
from malaya_speech.utils.padding import (
    sequence_nd as padding_sequence_nd,
)
from malaya_speech.model.abstract import Abstract


[docs]class UNET(Abstract): def __init__(self, input_nodes, output_nodes, sess, model, name): self._input_nodes = input_nodes self._output_nodes = output_nodes self._sess = sess self.__model__ = model self.__name__ = name
[docs] def predict(self, inputs): """ Enhance inputs, will return melspectrogram. Parameters ---------- inputs: List[np.array] Returns ------- result: List """ inputs = [ input.array if isinstance(input, Frame) else input for input in inputs ] mels = [featurization.scale_mel(s).T for s in inputs] x, lens = padding_sequence_nd( mels, maxlen=256, dim=0, return_len=True ) r = self._execute( inputs=[x], input_labels=['Placeholder'], output_labels=['logits'], ) l = r['logits'] results = [] for index in range(len(x)): results.append( featurization.unscale_mel( x[index, : lens[index]].T + l[index, : lens[index], :, 0].T ) ) return results
def __call__(self, inputs): return self.predict(inputs)
[docs]class UNETSTFT(Abstract): def __init__( self, input_nodes, output_nodes, instruments, sess, model, name ): self._input_nodes = input_nodes self._output_nodes = output_nodes self._instruments = instruments self._sess = sess self.__model__ = model self.__name__ = name
[docs] def predict(self, input): """ Enhance inputs, will return waveform. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. Returns ------- result: Dict """ if isinstance(input, Frame): input = input.array r = self._execute( inputs=[input], input_labels=['Placeholder'], output_labels=list(self._output_nodes.keys()), ) results = {} for no, instrument in enumerate(self._instruments): results[instrument] = r[f'logits_{no}'] return results
def __call__(self, input): """ Enhance inputs, will return waveform. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. Returns ------- result: Dict """ return self.predict(input)
[docs]class UNET1D(Abstract): def __init__(self, input_nodes, output_nodes, sess, model, name): self._input_nodes = input_nodes self._output_nodes = output_nodes self._sess = sess self.__model__ = model self.__name__ = name
[docs] def predict(self, input): """ Enhance inputs, will return waveform. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. Returns ------- result: np.array """ if isinstance(input, Frame): input = input.array r = self._execute( inputs=[input], input_labels=['Placeholder'], output_labels=['logits'], ) return r['logits']
def __call__(self, input): """ Enhance inputs, will return waveform. Parameters ---------- input: np.array np.array or malaya_speech.model.frame.Frame. Returns ------- result: np.array """ return self.predict(input)