Skip to content

Commit

Permalink
Add EmbeddingManager and BaseIDManager (#1374)
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Mar 31, 2022
1 parent 1b22f03 commit 060e0f9
Show file tree
Hide file tree
Showing 27 changed files with 412 additions and 404 deletions.
6 changes: 3 additions & 3 deletions TTS/bin/compute_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
use_cuda=args.use_cuda,
)

class_name_key = encoder_manager.speaker_encoder_config.class_name_key
class_name_key = encoder_manager.encoder_config.class_name_key

# compute speaker embeddings
speaker_mapping = {}
Expand All @@ -63,10 +63,10 @@
wav_file_name = os.path.basename(wav_file)
if args.old_file is not None and wav_file_name in encoder_manager.clip_ids:
# get the embedding from the old file
embedd = encoder_manager.get_d_vector_by_clip(wav_file_name)
embedd = encoder_manager.get_embedding_by_clip(wav_file_name)
else:
# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
embedd = encoder_manager.compute_embedding_from_clip(wav_file)

# create speaker_mapping if target dataset is defined
speaker_mapping[wav_file_name] = {}
Expand Down
10 changes: 5 additions & 5 deletions TTS/bin/eval_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

def compute_encoder_accuracy(dataset_items, encoder_manager):

class_name_key = encoder_manager.speaker_encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.speaker_encoder_config, "map_classid_to_classname", None)
class_name_key = encoder_manager.encoder_config.class_name_key
map_classid_to_classname = getattr(encoder_manager.encoder_config, "map_classid_to_classname", None)

class_acc_dict = {}

Expand All @@ -22,13 +22,13 @@ def compute_encoder_accuracy(dataset_items, encoder_manager):
wav_file = item["audio_file"]

# extract the embedding
embedd = encoder_manager.compute_d_vector_from_clip(wav_file)
if encoder_manager.speaker_encoder_criterion is not None and map_classid_to_classname is not None:
embedd = encoder_manager.compute_embedding_from_clip(wav_file)
if encoder_manager.encoder_criterion is not None and map_classid_to_classname is not None:
embedding = torch.FloatTensor(embedd).unsqueeze(0)
if encoder_manager.use_cuda:
embedding = embedding.cuda()

class_id = encoder_manager.speaker_encoder_criterion.softmax.inference(embedding).item()
class_id = encoder_manager.encoder_criterion.softmax.inference(embedding).item()
predicted_label = map_classid_to_classname[str(class_id)]
else:
predicted_label = None
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def setup_loader(ap, r, verbose=False):
precompute_num_workers=0,
use_noise_augment=False,
verbose=verbose,
speaker_id_mapping=speaker_manager.speaker_ids if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.d_vectors if c.use_d_vector_file else None,
speaker_id_mapping=speaker_manager.ids if c.use_speaker_embedding else None,
d_vector_mapping=speaker_manager.embeddings if c.use_d_vector_file else None,
)

if c.use_phonemes and c.compute_input_seq_cache:
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,15 @@ def main():
print(
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
print(synthesizer.tts_model.speaker_manager.speaker_ids)
print(synthesizer.tts_model.speaker_manager.ids)
return

# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
print(
" > Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
print(synthesizer.tts_model.language_manager.language_id_mapping)
print(synthesizer.tts_model.language_manager.ids)
return

# check the arguments against a multi-speaker model.
Expand Down
4 changes: 2 additions & 2 deletions TTS/bin/train_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from trainer.trainer_utils import get_optimizer

from TTS.encoder.dataset import EncoderDataset
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_speaker_encoder_model
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
from TTS.encoder.utils.samplers import PerfectBatchSampler
from TTS.encoder.utils.training import init_training
from TTS.encoder.utils.visual import plot_embeddings
Expand Down Expand Up @@ -258,7 +258,7 @@ def main(args): # pylint: disable=redefined-outer-name
global train_classes

ap = AudioProcessor(**c.audio)
model = setup_speaker_encoder_model(c)
model = setup_encoder_model(c)

optimizer = get_optimizer(c.optimizer, c.optimizer_params, c.lr, model)

Expand Down
2 changes: 1 addition & 1 deletion TTS/encoder/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def to_camel(text):
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)


def setup_speaker_encoder_model(config: "Coqpit"):
def setup_encoder_model(config: "Coqpit"):
if config.model_params["model_name"].lower() == "lstm":
model = LSTMSpeakerEncoder(
config.model_params["input_dim"],
Expand Down
2 changes: 1 addition & 1 deletion TTS/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def index():
"index.html",
show_details=args.show_details,
use_multi_speaker=use_multi_speaker,
speaker_ids=speaker_manager.speaker_ids if speaker_manager is not None else None,
speaker_ids=speaker_manager.ids if speaker_manager is not None else None,
use_gst=use_gst,
)

Expand Down
32 changes: 14 additions & 18 deletions TTS/tts/models/base_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ def get_aux_input_from_test_setences(self, sentence_info):
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
d_vector = self.speaker_manager.get_random_embeddings()
else:
d_vector = self.speaker_manager.get_d_vector_by_speaker(speaker_name)
d_vector = self.speaker_manager.get_d_vector_by_name(speaker_name)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
speaker_id = self.speaker_manager.ids[speaker_name]

# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]

return {
"text": text,
Expand Down Expand Up @@ -279,23 +279,19 @@ def get_data_loader(
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
speaker_id_mapping = (
self.speaker_manager.speaker_ids if config.model_args.use_speaker_embedding else None
)
d_vector_mapping = self.speaker_manager.d_vectors if config.model_args.use_d_vector_file else None
speaker_id_mapping = self.speaker_manager.ids if config.model_args.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.model_args.use_d_vector_file else None
config.use_d_vector_file = config.model_args.use_d_vector_file
else:
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.d_vectors if config.use_d_vector_file else None
speaker_id_mapping = self.speaker_manager.ids if config.use_speaker_embedding else None
d_vector_mapping = self.speaker_manager.embeddings if config.use_d_vector_file else None
else:
speaker_id_mapping = None
d_vector_mapping = None

# setup multi-lingual attributes
if hasattr(self, "language_manager") and self.language_manager is not None:
language_id_mapping = (
self.language_manager.language_id_mapping if self.args.use_language_embedding else None
)
language_id_mapping = self.language_manager.ids if self.args.use_language_embedding else None
else:
language_id_mapping = None

Expand Down Expand Up @@ -352,13 +348,13 @@ def _get_test_aux_input(

d_vector = None
if self.config.use_d_vector_file:
d_vector = [self.speaker_manager.d_vectors[name]["embedding"] for name in self.speaker_manager.d_vectors]
d_vector = [self.speaker_manager.embeddings[name]["embedding"] for name in self.speaker_manager.embeddings]
d_vector = (random.sample(sorted(d_vector), 1),)

aux_inputs = {
"speaker_id": None
if not self.config.use_speaker_embedding
else random.sample(sorted(self.speaker_manager.speaker_ids.values()), 1),
else random.sample(sorted(self.speaker_manager.ids.values()), 1),
"d_vector": d_vector,
"style_wav": None, # TODO: handle GST style input
}
Expand Down Expand Up @@ -405,7 +401,7 @@ def on_init_start(self, trainer):
"""Save the speaker.json and language_ids.json at the beginning of the training. Also update both paths."""
if self.speaker_manager is not None:
output_path = os.path.join(trainer.output_path, "speakers.json")
self.speaker_manager.save_speaker_ids_to_file(output_path)
self.speaker_manager.save_ids_to_file(output_path)
trainer.config.speakers_file = output_path
# some models don't have `model_args` set
if hasattr(trainer.config, "model_args"):
Expand All @@ -416,7 +412,7 @@ def on_init_start(self, trainer):

if hasattr(self, "language_manager") and self.language_manager is not None:
output_path = os.path.join(trainer.output_path, "language_ids.json")
self.language_manager.save_language_ids_to_file(output_path)
self.language_manager.save_ids_to_file(output_path)
trainer.config.language_ids_file = output_path
if hasattr(trainer.config, "model_args"):
trainer.config.model_args.language_ids_file = output_path
Expand Down
2 changes: 1 addition & 1 deletion TTS/tts/models/glow_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def init_multispeaker(self, config: Coqpit):
)
if self.speaker_manager is not None:
assert (
config.d_vector_dim == self.speaker_manager.d_vector_dim
config.d_vector_dim == self.speaker_manager.embedding_dim
), " [!] d-vector dimension mismatch b/w config and speaker manager."
# init speaker embedding layer
if config.use_speaker_embedding and not config.use_d_vector_file:
Expand Down
44 changes: 20 additions & 24 deletions TTS/tts/models/vits.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,28 +652,28 @@ def init_multispeaker(self, config: Coqpit):

# TODO: make this a function
if self.args.use_speaker_encoder_as_loss:
if self.speaker_manager.speaker_encoder is None and (
if self.speaker_manager.encoder is None and (
not self.args.speaker_encoder_model_path or not self.args.speaker_encoder_config_path
):
raise RuntimeError(
" [!] To use the speaker consistency loss (SCL) you need to specify speaker_encoder_model_path and speaker_encoder_config_path !!"
)

self.speaker_manager.speaker_encoder.eval()
self.speaker_manager.encoder.eval()
print(" > External Speaker Encoder Loaded !!")

if (
hasattr(self.speaker_manager.speaker_encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"]
hasattr(self.speaker_manager.encoder, "audio_config")
and self.config.audio["sample_rate"] != self.speaker_manager.encoder.audio_config["sample_rate"]
):
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.audio_config["sample_rate"],
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)
# pylint: disable=W0101,W0105
self.audio_transform = torchaudio.transforms.Resample(
orig_freq=self.config.audio.sample_rate,
new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"],
new_freq=self.speaker_manager.encoder.audio_config["sample_rate"],
)

def _init_speaker_embedding(self):
Expand Down Expand Up @@ -887,7 +887,7 @@ def forward( # pylint: disable=dangerous-default-value
pad_short=True,
)

if self.args.use_speaker_encoder_as_loss and self.speaker_manager.speaker_encoder is not None:
if self.args.use_speaker_encoder_as_loss and self.speaker_manager.encoder is not None:
# concate generated and GT waveforms
wavs_batch = torch.cat((wav_seg, o), dim=0)

Expand All @@ -896,7 +896,7 @@ def forward( # pylint: disable=dangerous-default-value
if self.audio_transform is not None:
wavs_batch = self.audio_transform(wavs_batch)

pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)
pred_embs = self.speaker_manager.encoder.forward(wavs_batch, l2_norm=True)

# split generated and GT speaker embeddings
gt_spk_emb, syn_spk_emb = torch.chunk(pred_embs, 2, dim=0)
Expand Down Expand Up @@ -1223,18 +1223,18 @@ def get_aux_input_from_test_sentences(self, sentence_info):
if hasattr(self, "speaker_manager"):
if config.use_d_vector_file:
if speaker_name is None:
d_vector = self.speaker_manager.get_random_d_vector()
d_vector = self.speaker_manager.get_random_embeddings()
else:
d_vector = self.speaker_manager.get_mean_d_vector(speaker_name, num_samples=None, randomize=False)
d_vector = self.speaker_manager.get_mean_embedding(speaker_name, num_samples=None, randomize=False)
elif config.use_speaker_embedding:
if speaker_name is None:
speaker_id = self.speaker_manager.get_random_speaker_id()
speaker_id = self.speaker_manager.get_random_id()
else:
speaker_id = self.speaker_manager.speaker_ids[speaker_name]
speaker_id = self.speaker_manager.ids[speaker_name]

# get language id
if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None:
language_id = self.language_manager.language_id_mapping[language_name]
language_id = self.language_manager.ids[language_name]

return {
"text": text,
Expand Down Expand Up @@ -1289,26 +1289,22 @@ def format_batch(self, batch: Dict) -> Dict:
d_vectors = None

# get numerical speaker ids from speaker names
if self.speaker_manager is not None and self.speaker_manager.speaker_ids and self.args.use_speaker_embedding:
speaker_ids = [self.speaker_manager.speaker_ids[sn] for sn in batch["speaker_names"]]
if self.speaker_manager is not None and self.speaker_manager.ids and self.args.use_speaker_embedding:
speaker_ids = [self.speaker_manager.ids[sn] for sn in batch["speaker_names"]]

if speaker_ids is not None:
speaker_ids = torch.LongTensor(speaker_ids)
batch["speaker_ids"] = speaker_ids

# get d_vectors from audio file names
if self.speaker_manager is not None and self.speaker_manager.d_vectors and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.d_vectors
if self.speaker_manager is not None and self.speaker_manager.embeddings and self.args.use_d_vector_file:
d_vector_mapping = self.speaker_manager.embeddings
d_vectors = [d_vector_mapping[w]["embedding"] for w in batch["audio_files"]]
d_vectors = torch.FloatTensor(d_vectors)

# get language ids from language names
if (
self.language_manager is not None
and self.language_manager.language_id_mapping
and self.args.use_language_embedding
):
language_ids = [self.language_manager.language_id_mapping[ln] for ln in batch["language_names"]]
if self.language_manager is not None and self.language_manager.ids and self.args.use_language_embedding:
language_ids = [self.language_manager.ids[ln] for ln in batch["language_names"]]

if language_ids is not None:
language_ids = torch.LongTensor(language_ids)
Expand Down Expand Up @@ -1490,7 +1486,7 @@ def init_from_config(config: "VitsConfig", samples: Union[List[List], List[Dict]
language_manager = LanguageManager.init_from_config(config)

if config.model_args.speaker_encoder_model_path:
speaker_manager.init_speaker_encoder(
speaker_manager.init_encoder(
config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path
)
return Vits(new_config, ap, tokenizer, speaker_manager, language_manager)
Expand Down
Loading

0 comments on commit 060e0f9

Please sign in to comment.