from malaya_speech.utils import featurization
from malaya_speech.model.frame import Frame
from malaya_speech.utils.padding import (
sequence_nd as padding_sequence_nd,
)
from malaya_speech.model.abstract import Abstract
[docs]class UNET(Abstract):
def __init__(self, input_nodes, output_nodes, sess, model, name):
self._input_nodes = input_nodes
self._output_nodes = output_nodes
self._sess = sess
self.__model__ = model
self.__name__ = name
[docs] def predict(self, inputs):
"""
Enhance inputs, will return melspectrogram.
Parameters
----------
inputs: List[np.array]
Returns
-------
result: List
"""
inputs = [
input.array if isinstance(input, Frame) else input
for input in inputs
]
mels = [featurization.scale_mel(s).T for s in inputs]
x, lens = padding_sequence_nd(
mels, maxlen=256, dim=0, return_len=True
)
r = self._execute(
inputs=[x],
input_labels=['Placeholder'],
output_labels=['logits'],
)
l = r['logits']
results = []
for index in range(len(x)):
results.append(
featurization.unscale_mel(
x[index, : lens[index]].T + l[index, : lens[index], :, 0].T
)
)
return results
def __call__(self, inputs):
return self.predict(inputs)
[docs]class UNETSTFT(Abstract):
def __init__(
self, input_nodes, output_nodes, instruments, sess, model, name
):
self._input_nodes = input_nodes
self._output_nodes = output_nodes
self._instruments = instruments
self._sess = sess
self.__model__ = model
self.__name__ = name
[docs] def predict(self, input):
"""
Enhance inputs, will return waveform.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: Dict
"""
if isinstance(input, Frame):
input = input.array
r = self._execute(
inputs=[input],
input_labels=['Placeholder'],
output_labels=list(self._output_nodes.keys()),
)
results = {}
for no, instrument in enumerate(self._instruments):
results[instrument] = r[f'logits_{no}']
return results
def __call__(self, input):
"""
Enhance inputs, will return waveform.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: Dict
"""
return self.predict(input)
[docs]class UNET1D(Abstract):
def __init__(self, input_nodes, output_nodes, sess, model, name):
self._input_nodes = input_nodes
self._output_nodes = output_nodes
self._sess = sess
self.__model__ = model
self.__name__ = name
[docs] def predict(self, input):
"""
Enhance inputs, will return waveform.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: np.array
"""
if isinstance(input, Frame):
input = input.array
r = self._execute(
inputs=[input],
input_labels=['Placeholder'],
output_labels=['logits'],
)
return r['logits']
def __call__(self, input):
"""
Enhance inputs, will return waveform.
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame.
Returns
-------
result: np.array
"""
return self.predict(input)