import torch.nn as nn
import torch
from malaya_speech.model.frame import Frame
from malaya_speech.utils.torch_utils import to_tensor_cuda, to_numpy, from_log
from malaya_speech.train.model.voicefixer.base import VoiceFixer as BaseVoiceFixer
from malaya_speech.train.model.voicefixer.nvsr import NVSR as BaseNVSR
from malaya_speech.train.model.nuwave2_torch.inference import NuWave2 as BaseNuWave2
from scipy.signal import resample_poly
[docs]class VoiceFixer(BaseVoiceFixer):
def __init__(self, pth, vocoder_pth, model, name):
super(VoiceFixer, self).__init__(pth, vocoder_pth)
self.eval()
self.__model__ = model
self.__name__ = name
[docs] def predict(self, input, remove_higher_frequency: bool = True):
"""
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame,
must an audio with 44100 sampling rate.
remove_higher_frequency: bool, optional (default = True)
Remove high frequency before neural upsampling.
Returns
-------
result: np.array with 44100 sampling rate
"""
input = input.array if isinstance(input, Frame) else input
wav_10k = input
cuda = next(self.parameters()).is_cuda
res = []
seg_length = 44100 * 30
break_point = seg_length
while break_point < wav_10k.shape[0] + seg_length:
segment = wav_10k[break_point - seg_length: break_point]
if remove_higher_frequency:
segment = self.remove_higher_frequency(segment)
sp, mel_noisy = self._pre(self._model, segment, cuda)
out_model = self._model(sp, mel_noisy)
denoised_mel = from_log(out_model['mel'])
out = self._model.vocoder(denoised_mel, cuda)
if torch.max(torch.abs(out)) > 1.0:
out = out / torch.max(torch.abs(out))
out, _ = self._trim_center(out, segment)
res.append(out)
break_point += seg_length
out = torch.cat(res, -1)
return to_numpy(out[0][0])
[docs] def forward(self, input, remove_higher_frequency: bool = True):
return self.predict(input=input, remove_higher_frequency=remove_higher_frequency)
[docs]class NVSR(BaseNVSR):
def __init__(self, pth, vocoder_pth, model, name):
super(NVSR, self).__init__(pth, vocoder_pth)
self.eval()
self.__model__ = model
self.__name__ = name
[docs] def predict(self, input):
"""
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame,
must an audio with 44100 sampling rate.
Returns
-------
result: np.array with 44100 sampling rate
"""
input = input.array if isinstance(input, Frame) else input
return self.forward(input)
class NuWave2(BaseNuWave2):
def __init__(self, pth, model, name):
super(NuWave2, self).__init__()
self.eval()
ckpt = torch.load(pth, map_location='cpu')
self.load_state_dict(ckpt)
self.__model__ = model
self.__name__ = name
def predict(self, input, sr: int, steps: int = 8):
"""
Parameters
----------
input: np.array
np.array or malaya_speech.model.frame.Frame,
prefer 8000, 12000, 16000 or 22050 or 44000 sampling rate.
sr: int
sampling rate, prefer 8000, 12000, 16000 or 22050 or 44000 sampling rate.
steps: int, optional (default=8)
diffusion steps.
Returns
-------
result: np.array with 48k sampling rate
"""
input = input.array if isinstance(input, Frame) else input
wav = input
cuda = next(self.parameters()).is_cuda
noise_schedule = None
highcut = sr // 2
nyq = 0.5 * self.hparams.audio.sampling_rate
hi = highcut / nyq
fft_size = self.hparams.audio.filter_length // 2 + 1
band = torch.zeros(fft_size, dtype=torch.int64)
band[:int(hi * fft_size)] = 1
wav_l = resample_poly(wav, self.hparams.audio.sampling_rate, sr)
wav = torch.from_numpy(wav).unsqueeze(0)
wav_l = torch.from_numpy(wav_l.copy()).float().unsqueeze(0)
band = band.unsqueeze(0)
wav_l = to_tensor_cuda(wav_l, cuda)
band = to_tensor_cuda(band, cuda)
wav_recon, wav_list = self.inference(wav_l, band, steps, noise_schedule)
return to_numpy(wav_recon[0])
def forward(self, input, sr: int, steps: int = 8):
return self.predict(input=input, sr=sr, steps=steps)