Skip to content

Commit

Permalink
Add fairseq onnx support and strict configuration, fixes some onnx er…
Browse files Browse the repository at this point in the history
…rors
  • Loading branch information
SystemPanic committed Aug 1, 2023
1 parent dc04baa commit 09c0964
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -1725,7 +1725,7 @@ def load_checkpoint(
assert not self.training

def load_fairseq_checkpoint(
self, config, checkpoint_dir, eval=False
self, config, checkpoint_dir, eval=False, strict=True
): # pylint: disable=unused-argument, redefined-builtin
"""Load VITS checkpoints released by fairseq here: https://github.com/facebookresearch/fairseq/tree/main/examples/mms
Performs some changes for compatibility.
Expand Down Expand Up @@ -1763,7 +1763,7 @@ def load_fairseq_checkpoint(
)
# load fairseq checkpoint
new_chk = rehash_fairseq_vits_checkpoint(checkpoint_file)
self.load_state_dict(new_chk)
self.load_state_dict(new_chk, strict=strict)
if eval:
self.eval()
assert not self.training
Expand Down Expand Up @@ -1844,33 +1844,38 @@ def onnx_inference(text, text_lengths, scales, sid=None, langid=None):

# set dummy inputs
dummy_input_length = 100
sequences = torch.randint(low=0, high=self.args.num_chars, size=(1, dummy_input_length), dtype=torch.long)
sequences = torch.randint(low=0, high=2, size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
speaker_id = None
language_id = None
if self.num_speakers > 1:
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales)
input_names = ["input", "input_lengths", "scales"]

if self.num_speakers > 0:
speaker_id = torch.LongTensor([0])
if self.num_languages > 0 and self.embedded_language_dim > 0:
dummy_input += (speaker_id, )
input_names.append("sid")

if hasattr(self, 'num_languages') and self.num_languages > 0 and self.embedded_language_dim > 0:
language_id = torch.LongTensor([0])
scales = torch.FloatTensor([self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp])
dummy_input = (sequences, sequence_lengths, scales, speaker_id, language_id)

dummy_input += (language_id, )
input_names.append("langid")
# export to ONNX
torch.onnx.export(
model=self,
args=dummy_input,
opset_version=15,
f=output_path,
verbose=verbose,
input_names=["input", "input_lengths", "scales", "sid", "langid"],
input_names=input_names,
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "phonemes"},
"input_lengths": {0: "batch_size"},
"output": {0: "batch_size", 1: "time1", 2: "time2"},
},
)

# rollback
self.forward = _forward
if training:
Expand All @@ -1880,7 +1885,7 @@ def onnx_inference(text, text_lengths, scales, sid=None, langid=None):

def load_onnx(self, model_path: str, cuda=False):
import onnxruntime as ort

providers = [
"CPUExecutionProvider"
if cuda is False
Expand Down Expand Up @@ -1908,16 +1913,19 @@ def inference_onnx(self, x, x_lengths=None, speaker_id=None, language_id=None):
[self.inference_noise_scale, self.length_scale, self.inference_noise_scale_dp],
dtype=np.float32,
)

input_params = {
"input": x,
"input_lengths": x_lengths,
"scales": scales
}
if not speaker_id is None:
input_params["sid"] = torch.tensor([speaker_id]).cpu().numpy()
if not language_id is None:
input_params["langid"] = torch.tensor([language_id]).cpu().numpy()

audio = self.onnx_sess.run(
["output"],
{
"input": x,
"input_lengths": x_lengths,
"scales": scales,
"sid": None if speaker_id is None else torch.tensor([speaker_id]).cpu().numpy(),
"langid": None if language_id is None else torch.tensor([language_id]).cpu().numpy(),
},
input_params,
)
return audio[0][0]

Expand Down

0 comments on commit 09c0964

Please sign in to comment.