Skip to content

Commit

Permalink
Implement chunking gpt_cond
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Nov 13, 2023
1 parent 6f1cba2 commit a16360a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 33 deletions.
10 changes: 8 additions & 2 deletions TTS/tts/configs/xtts_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ class XttsConfig(BaseTTSConfig):
Defaults to `16`.
gpt_cond_len (int):
Secs audio to be used as conditioning for the autoregressive model. Defaults to `3`.
Secs audio to be used as conditioning for the autoregressive model. Defaults to `12`.
gpt_cond_chunk_len (int):
Audio chunk size in secs. Audio is split into chunks and latents are extracted for each chunk. Then the
latents are averaged. Chunking improves the stability. It must be <= gpt_cond_len.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to `4`.
max_ref_len (int):
Maximum number of seconds of audio to be used as conditioning for the decoder. Defaults to `10`.
Expand Down Expand Up @@ -95,6 +100,7 @@ class XttsConfig(BaseTTSConfig):
num_gpt_outputs: int = 1

# cloning
gpt_cond_len: int = 3
gpt_cond_len: int = 12
gpt_cond_chunk_len: int = 4
max_ref_len: int = 10
sound_norm_refs: bool = False
101 changes: 70 additions & 31 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,39 +255,57 @@ def device(self):
return next(self.parameters()).device

@torch.inference_mode()
def get_gpt_cond_latents(self, audio, sr, length: int = 3):
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
"""Compute the conditioning latents for the GPT model from the given audio.
Args:
audio (tensor): audio tensor.
sr (int): Sample rate of the audio.
length (int): Length of the audio in seconds. Defaults to 3.
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
is being used without chunking. It must be < `length`. Defaults to 6.
"""
if sr != 22050:
audio = torchaudio.functional.resample(audio, sr, 22050)
audio = audio[:, : 22050 * length]
if length > 0:
audio = audio[:, : 22050 * length]
if self.args.gpt_use_perceiver_resampler:
n_fft = 2048
hop_length = 256
win_length = 1024
style_embs = []
for i in range(0, audio.shape[1], 22050 * chunk_length):
audio_chunk = audio[:, i : i + 22050 * chunk_length]
mel_chunk = wav_to_mel_cloning(
audio_chunk,
mel_norms=self.mel_stats.cpu(),
n_fft=2048,
hop_length=256,
win_length=1024,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
style_embs.append(style_emb)

# mean style embedding
cond_latent = torch.stack(style_embs).mean(dim=0)
else:
n_fft = 4096
hop_length = 1024
win_length = 4096
mel = wav_to_mel_cloning(
audio,
mel_norms=self.mel_stats.cpu(),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
mel = wav_to_mel_cloning(
audio,
mel_norms=self.mel_stats.cpu(),
n_fft=4096,
hop_length=1024,
win_length=4096,
power=2,
normalized=False,
sample_rate=22050,
f_min=0,
f_max=8000,
n_mels=80,
)
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
return cond_latent.transpose(1, 2)

@torch.inference_mode()
Expand Down Expand Up @@ -323,12 +341,24 @@ def get_speaker_embedding(self, audio, sr):
def get_conditioning_latents(
self,
audio_path,
max_ref_length=30,
gpt_cond_len=6,
max_ref_length=10,
gpt_cond_chunk_len=6,
librosa_trim_db=None,
sound_norm_refs=False,
load_sr=24000,
load_sr=22050,
):
"""Get the conditioning latents for the GPT model from the given audio.
Args:
audio_path (str or List[str]): Path to reference audio file(s).
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
"""
# deal with multiples references
if not isinstance(audio_path, list):
audio_paths = [audio_path]
Expand All @@ -349,14 +379,17 @@ def get_conditioning_latents(
if librosa_trim_db is not None:
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]

# compute latents for the decoder
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
speaker_embeddings.append(speaker_embedding)

audios.append(audio)

# use a merge of all references for gpt cond latents
# merge all the audios and compute the latents for the gpt
full_audio = torch.cat(audios, dim=-1)
gpt_cond_latents = self.get_gpt_cond_latents(full_audio, load_sr, length=gpt_cond_len) # [1, 1024, T]
gpt_cond_latents = self.get_gpt_cond_latents(
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
) # [1, 1024, T]

if speaker_embeddings:
speaker_embedding = torch.stack(speaker_embeddings)
Expand Down Expand Up @@ -397,6 +430,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs
"top_k": config.top_k,
"top_p": config.top_p,
"gpt_cond_len": config.gpt_cond_len,
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
"max_ref_len": config.max_ref_len,
"sound_norm_refs": config.sound_norm_refs,
}
Expand All @@ -417,7 +451,8 @@ def full_inference(
top_p=0.85,
do_sample=True,
# Cloning
gpt_cond_len=6,
gpt_cond_len=30,
gpt_cond_chunk_len=6,
max_ref_len=10,
sound_norm_refs=False,
**hf_generate_kwargs,
Expand Down Expand Up @@ -448,7 +483,10 @@ def full_inference(
(aka boring) outputs. Defaults to 0.8.
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
else the first `gpt_cond_len` secs is used. Defaults to 6 seconds.
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
Expand All @@ -461,6 +499,7 @@ def full_inference(
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
audio_path=ref_audio_path,
gpt_cond_len=gpt_cond_len,
gpt_cond_chunk_len=gpt_cond_chunk_len,
max_ref_length=max_ref_len,
sound_norm_refs=sound_norm_refs,
)
Expand Down Expand Up @@ -566,7 +605,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
if overlap_len > len(wav_chunk):
# wav_chunk is smaller than overlap_len, pass on last wav_gen
if wav_gen_prev is not None:
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):]
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
else:
# not expecting will hit here as problem happens on last chunk
wav_chunk = wav_gen[-overlap_len:]
Expand All @@ -576,7 +615,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
wav_chunk[:overlap_len] += crossfade_wav

wav_overlap = wav_gen[-overlap_len:]
wav_gen_prev = wav_gen
return wav_chunk, wav_gen_prev, wav_overlap
Expand Down

0 comments on commit a16360a

Please sign in to comment.