Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sidroopdaska/faster decoding #164

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
*wandb
*.parquet
*.wav
*.pt
*.bin
*.png
*.DS_Store
Expand All @@ -22,7 +21,6 @@
*.tar
*.db
*.dat
*.json

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
1 change: 1 addition & 0 deletions fam/llm/.gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pt filter=lfs diff=lfs merge=lfs -text
3 changes: 3 additions & 0 deletions fam/llm/decoder.pt
Git LFS file not shown
107 changes: 107 additions & 0 deletions fam/llm/decoder_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
{
"data_path": "",
"val_data_path": "",
"wandb_run_name": "",
"_data_path": "",
"cat_encodec_first_two_hierarchies": false,
"use_second_stage": false,
"use_extra_preprocessing": false,
"input_upsampling_factor": 160,
"add_noise": false,
"_val_data_path": "",
"_wandb_run_name": "",
"val_num_dl_workers": 1,
"val_batch_size": 1,
"resblock": "1",
"num_gpus": 2,
"batch_size": 32,
"learning_rate": 0.00005,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"use_speaker_embedding": false,
"seed": 1234,
"upsample_rates": [
10,
2,
2,
2,
2
],
"upsample_kernel_sizes": [
20,
4,
4,
4,
4
],
"upsample_initial_channel": 1536,
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"activation": "snakebeta",
"snake_logscale": true,
"resolutions": [
[
1024,
120,
600
],
[
2048,
240,
1200
],
[
512,
50,
240
]
],
"mpd_reshapes": [
2,
3,
5,
7,
11
],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"inital_channels": 2048,
"segment_size": 10240,
"_comment": "below specifies size of conv_pre, and is used inside commented out data loaders!",
"num_mels": 80,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 320,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_dl_workers": 32,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 2
}
}
125 changes: 67 additions & 58 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@
import json
import math
import os
import shutil
import tempfile
import time
import uuid
from pathlib import Path
from typing import Literal, Optional

import librosa
import scipy.io.wavfile # type: ignore
import torch
import tyro
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download # type: ignore

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.fast_inference_utils import build_model, main
from fam.llm.inference import (
EncodecDecoder,
InferenceConfig,
Model,
TiltedEncodec,
TrainedBPETokeniser,
get_cached_embedding,
get_cached_file,
get_enhancer,
)
from fam.llm.inference import get_cached_embedding, get_cached_file
from fam.llm.model_decoder import EmbeddingDecoder
from fam.llm.utils import (
check_audio_file,
get_default_dtype,
Expand All @@ -35,12 +27,20 @@
posthog = PosthogClient() # see fam/telemetry/README.md for more information


class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self


class TTS:
END_OF_AUDIO_TOKEN = 1024

def __init__(
self,
model_name: str = "metavoiceio/metavoice-1B-v0.1",
decoder_config_path: str = f"{os.path.dirname(os.path.abspath(__file__))}/decoder_config.json",
decoder_checkpoint_file: str = f"{os.path.dirname(os.path.abspath(__file__))}/decoder.pt",
*,
seed: int = 1337,
output_dir: str = "outputs",
Expand Down Expand Up @@ -69,29 +69,25 @@ def __init__(
self._dtype = get_default_dtype()
self._device = get_device()
self._model_dir = snapshot_download(repo_id=model_name)
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
if first_stage_path:
print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
self._first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=seed,
device=self._device,
dtype=self._dtype,
compile=False,
init_from="resume",
output_dir=self.output_dir,
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer("df")
if not os.path.exists(decoder_config_path):
raise ValueError(f"EmbeddingDecoder config file not found at {decoder_config_path}")

if not os.path.exists(decoder_checkpoint_file):
raise ValueError(f"EmbeddingDecoder checkpoint file not found at {decoder_checkpoint_file}")

with open(decoder_config_path) as f:
self.decoder_config = AttrDict(json.loads(f.read()))
self.decoder = EmbeddingDecoder(self.decoder_config).to(self._device)
state_dict_g = torch.load(decoder_checkpoint_file, map_location=self._device)
self.decoder.load_state_dict(state_dict_g["generator"])
self.decoder.eval()
self.decoder.remove_weight_norm()

self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
Expand All @@ -108,7 +104,7 @@ def __init__(
self._model_name = model_name
self._telemetry_origin = telemetry_origin

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=2.0, temperature=1.0) -> str:
"""
text: Text to speak
spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
Expand All @@ -128,7 +124,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.

start = time.time()
# first stage LLM
tokens = main(
_, output_embs = main(
model=self.model,
tokenizer=self.tokenizer,
model_size=self.model_size,
Expand All @@ -138,33 +134,46 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
)
_, extracted_audio_ids = self.first_stage_adapter.decode([tokens])

b_speaker_embs = spk_emb.unsqueeze(0)

# second stage LLM + multi-band diffusion model
wav_files = self.llm_second_stage(
texts=[text],
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
speaker_embs=b_speaker_embs,
batch_size=1,
guidance_scale=None,
top_p=None,
top_k=200,
temperature=1.0,
max_new_tokens=None,
)
# TODO: run EmbeddingDecoder, and save and print output wav_file path?
output_embs = output_embs.to(dtype=torch.float32).transpose(1, 2) # (b, c, t)

model_upsample_factor = math.prod(self.decoder_config.upsample_rates) # type: ignore
if self.decoder_config.input_upsampling_factor != model_upsample_factor: # type: ignore
output_embs = torch.nn.functional.interpolate(
output_embs,
scale_factor=[
self.decoder_config.input_upsampling_factor / model_upsample_factor # type: ignore
], # [320/256] or [160 / 128],
mode="linear",
)

if self.decoder_config.add_noise: # type: ignore
output_embs = torch.cat(
[
output_embs,
torch.randn(
# add model_upsample_factor worth of noise to each input!
(output_embs.shape[0], model_upsample_factor, output_embs.shape[-1]),
device=output_embs.device,
dtype=output_embs.dtype,
),
],
dim=1,
)

with torch.no_grad():
y_g_hat = self.decoder(output_embs)
audio = y_g_hat.squeeze()
audio = audio * 32768.0
audio = audio.cpu().numpy().astype("int16")

# enhance using deepfilternet
wav_file = wav_files[0]
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
print(f"\nSaved audio to {wav_file}.wav")
wav_file_name = str(Path(self.output_dir) / f"synth_{uuid.uuid4()}.wav")
scipy.io.wavfile.write(wav_file_name, 24000, audio)
print(f"\nSaved audio to {wav_file_name}.wav")

# calculating real-time factor (RTF)
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
audio, sr = librosa.load(str(wav_file_name))
duration_s = librosa.get_duration(y=audio, sr=sr)
real_time_factor = time_to_synth_s / duration_s
print(f"\nTotal time to synth (s): {time_to_synth_s}")
Expand Down Expand Up @@ -192,7 +201,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
)
)

return str(wav_file) + ".wav"
return str(wav_file_name)


if __name__ == "__main__":
Expand Down
Loading
Loading