Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA Fine Tuning #82

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,12 @@ cython_debug/
#.idea/
**/.tmp
!fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt


# Local data files for testing
dataset


# Dummy dataset for fine tuning demonstration.
# Includes 25 samples of the same speaker (VCTK Dataset-->"p311")
!dummy_dataset/**
4 changes: 3 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from fam.llm.utils import check_audio_file

#### setup model
TTS_MODEL = TTS()
lora_ckpt_path = 'saved_models/finetune_001/lora_iter_num_5.pt'
# lora_ckpt_path = None
TTS_MODEL = TTS(lora_ckpt_path=lora_ckpt_path)

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
Expand Down
179 changes: 179 additions & 0 deletions dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import pathlib
import typing as tp

import julius
import torch
import torchaudio
from audiocraft.data.audio import audio_read
from encodec import EncodecModel
from torch.utils.data import Dataset

from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.fast_inference_utils import encode_tokens
from fam.llm.inference import SpeakerEncoder, TrainedBPETokeniser, get_cached_embedding
from fam.llm.utils import normalize_text

MBD_SAMPLE_RATE = 24000
END_OF_AUDIO_TOKEN = 1024

class MetavoiceData(Dataset):
def __init__(self, dataset_dir: str, block_size: int, validation_split: float, encodec_model: EncodecModel, tokenizer: TrainedBPETokeniser, spkemb_model: SpeakerEncoder, device: str, precision: torch.dtype):

self.dataset_dir = dataset_dir
self.block_size = block_size
self.validation_split = validation_split
self.encodec_model = encodec_model
self.tokenizer = tokenizer
self.spkemb_model = spkemb_model
self.device = device
self.precision = precision

self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=END_OF_AUDIO_TOKEN)

# Loop through dataset_dir and create a list of tuples (wav_path, text)
# File system will look like:
# dataset_dir/<utt_id>.wav and dataset_dir/<utt_id>.txt
data_list = []
for audio_file in pathlib.Path(dataset_dir).glob('*.wav'):
utt_id = audio_file.stem
wav_path = f"{dataset_dir}/{utt_id}.wav"
txt_path = f"{dataset_dir}/{utt_id}.txt"
with open(txt_path, 'r') as f:
text = f.read()

wav, sr = torchaudio.load(wav_path)
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)
torchaudio.save(wav_path, wav, MBD_SAMPLE_RATE)

data_list.append((wav_path, text))

self._prepare_dataset(data_list)

def _prepare_dataset(self, data_list: tp.List[tp.Tuple[str, str]]):
# We take data_list, extract all prompts and encodec tokens, and append them with EOT for all of them
# This is done to prepare the dataset for the first stage of training

full_sequence = torch.tensor([], dtype=torch.long, device=self.device)
spk_embds = []
current_wavs = torch.tensor([], dtype=torch.float, device=self.device)
current_wav_duration = 0
for wav_path, text in data_list:
# Extract text tokenization
prompt = self._extract_text_tokens(text)

# Extract encodec tokens
encodec_tokens = self._extract_encodec_tokens(wav_path)

# Concatenate prompt and encodec tokens, and EOT token at the end
eot = torch.tensor([END_OF_AUDIO_TOKEN], dtype=torch.long, device=self.device)
sequence = torch.cat((prompt, encodec_tokens, eot))

# Append to dataset
# print("Encodec Tokens Length: ", encodec_tokens.size(0))
# print("Prompt Length: ", prompt.size(0))
# print("Tokenized Data Point length:", sequence.size(0))
# print("Prompt: ", prompt)
full_sequence = torch.cat((full_sequence, sequence), dim=-1)

# Get wav data
wav, sr = torchaudio.load(wav_path) # Load the audio file
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)
if wav.ndim == 2:
wav = wav.mean(dim=0) # Average channels if stereo
wav = wav.to(self.device)
current_wavs = torch.cat((current_wavs, wav.unsqueeze(0)), dim=1) # Concatenate along time axis
current_wav_duration += wav.size(0) / MBD_SAMPLE_RATE
if current_wav_duration >= 45: # 45 seconds
current_wav_path = os.path.join(self.dataset_dir, "tmp_concatenated_wavs.wav")
torchaudio.save(current_wav_path, current_wavs.cpu(), MBD_SAMPLE_RATE)

# Extract speaker embeddings of the concatenated wav
spk_emb = self._extract_speaker_embeddings(current_wav_path)
spk_embds.append(spk_emb)

# Reset
current_wav_duration = 0
current_wavs = torch.tensor([], dtype=torch.float32, device=self.device)
os.remove(current_wav_path)

# Split full_sequence into training and validation
split = int(len(full_sequence) * (1 - self.validation_split))
self.train_dataset = full_sequence[:split]
self.val_dataset = full_sequence[split:]

self.spk_embds = torch.stack(spk_embds) # (N, 1, 256)

def get_batch(self, split: tp.Literal['train', 'val'], batch_size: int):
if split == 'train':
data = self.train_dataset
elif split == 'val':
data = self.val_dataset

ix = torch.randint(0, data.size(0) - self.block_size, (batch_size,))
x = torch.stack([data[i:i+self.block_size] for i in ix])
y = torch.stack([data[i+1:i+self.block_size+1] for i in ix])

# Random batch_size number of speaker embeddings
spk_emb = self.spk_embds[torch.randint(0, self.spk_embds.size(0), (batch_size,))]

return x, y, spk_emb

def _extract_text_tokens(self, text: str):
# For text tokens, one can use the tokenizer per:
# https://github.com/metavoiceio/metavoice-src/blob/main/fam/llm/inference.py#L177
text = normalize_text(text)
encoded = encode_tokens(self.tokenizer, text, device=self.device)

return encoded

def _extract_encodec_tokens(self, wav_path: str):
# read audio
wav, sr = audio_read(wav_path)

# Resample to MBD's expected sample rate
if sr != MBD_SAMPLE_RATE:
wav = julius.resample_frac(wav, sr, MBD_SAMPLE_RATE)

# Convert to mono and fix dimensionality
if wav.ndim == 2:
wav = wav.mean(axis=0, keepdims=True)
wav = wav.unsqueeze(0) # Add batch dimension

# Extract tokens
wav = wav.to(self.device)
tokens = self.encodec_model.encode(wav) # list[EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]]

tokens = tokens[0][0][0] # (8, T)

# Only return tokens in first 2 hierarchies for training stage 1
# Not sure if this is the correct approach.
tokens = tokens[:2] # (2, T)

# Interleave and flatten the first two hierarchies
# Then add 1024 to 1st hierarchy tokens to match stage 1 output
tokens = tokens.flatten().to(dtype=torch.int32) # (2*T)
tokens[0::2] += END_OF_AUDIO_TOKEN

return tokens

# # Convert tokens to list before decoding to audio indices
# tokens = tokens.tolist() # list[int]

# # convert into audio ids
# _, extracted_audio_ids = self.first_stage_adapter.decode([tokens])

# # list[list[int], list[int]] -> (2, T), dtype long
# encodec_tokens = torch.tensor(extracted_audio_ids, dtype=torch.long, device=self.device).unsqueeze(0)

# # Interleave tokens and flatten (2, T) -> (2T,)
# encodec_tokens = encodec_tokens.flatten() # (2T,)

# return encodec_tokens # (2T,)

def _extract_speaker_embeddings(self, wav_path: str):
# For speaker embedding, you can also follow the code at:
# https://github.com/metavoiceio/metavoice-src/blob/main/fam/llm/inference.py#L435
return get_cached_embedding(wav_path, self.spkemb_model).to(self.device, dtype=self.precision)
1 change: 1 addition & 0 deletions dummy_dataset/p311_001.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Please call Stella.
Binary file added dummy_dataset/p311_001.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_002.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Ask her to bring these things with her from the store.
Binary file added dummy_dataset/p311_002.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_003.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Six spoons of fresh snow peas, five thick slabs of blue cheese, and maybe a snack for her brother Bob.
Binary file added dummy_dataset/p311_003.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_004.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
We also need a small plastic snake and a big toy frog for the kids.
Binary file added dummy_dataset/p311_004.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_005.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
She can scoop these things into three red bags, and we will go meet her Wednesday at the train station.
Binary file added dummy_dataset/p311_005.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_006.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When the sunlight strikes raindrops in the air, they act as a prism and form a rainbow.
Binary file added dummy_dataset/p311_006.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_007.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The rainbow is a division of white light into many beautiful colors.
Binary file added dummy_dataset/p311_007.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_008.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
These take the shape of a long round arch, with its path high above, and its two ends apparently beyond the horizon.
Binary file added dummy_dataset/p311_008.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_009.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
There is , according to legend, a boiling pot of gold at one end.
Binary file added dummy_dataset/p311_009.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_010.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
People look, but no one ever finds it.
Binary file added dummy_dataset/p311_010.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_011.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
When a man looks for something beyond his reach, his friends say he is looking for the pot of gold at the end of a rainbow.
Binary file added dummy_dataset/p311_011.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_012.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Throughout the centuries people have explained the rainbow in various ways.
Binary file added dummy_dataset/p311_012.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_013.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Some have accepted it as a miracle without physical explanation.
Binary file added dummy_dataset/p311_013.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_014.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
To the Hebrews it was a token that there would be no more universal floods.
Binary file added dummy_dataset/p311_014.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_015.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The Greeks used to imagine that it was a sign from the gods to foretell war or heavy rain.
Binary file added dummy_dataset/p311_015.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_016.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The Norsemen considered the rainbow as a bridge over which the gods passed from earth to their home in the sky.
Binary file added dummy_dataset/p311_016.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_017.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Others have tried to explain the phenomenon physically.
Binary file added dummy_dataset/p311_017.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_018.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Aristotle thought that the rainbow was caused by reflection of the sun's rays by the rain.
Binary file added dummy_dataset/p311_018.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_019.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Since then physicists have found that it is not reflection, but refraction by the raindrops which causes the rainbows.
Binary file added dummy_dataset/p311_019.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_020.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Many complicated ideas about the rainbow have been formed.
Binary file added dummy_dataset/p311_020.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_021.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The difference in the rainbow depends considerably upon the size of the drops, and the width of the colored band increases as the size of the drops increases.
Binary file added dummy_dataset/p311_021.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_022.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
The actual primary rainbow observed is said to be the effect of super-imposition of a number of bows.
Binary file added dummy_dataset/p311_022.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_023.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
If the red of the second bow falls upon the green of the first, the result is to give a bow with an abnormally wide yellow band, since red and green light when mixed form yellow.
Binary file added dummy_dataset/p311_023.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_024.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a very common type of bow, one showing mainly red and yellow, with little or no green or blue.
Binary file added dummy_dataset/p311_024.wav
Binary file not shown.
1 change: 1 addition & 0 deletions dummy_dataset/p311_025.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
No sooner had we got it, than we wanted it different.
Binary file added dummy_dataset/p311_025.wav
Binary file not shown.
6 changes: 3 additions & 3 deletions fam/llm/adapters/flattened_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def decode(self, tokens: list[list[int]]) -> tuple[list[int], list[list[int]]]:
if len(set([len(x) for x in extracted_audio_ids])) != 1:
min_len = min([len(x) for x in extracted_audio_ids])
max_len = max([len(x) for x in extracted_audio_ids])
print("WARNING: Number of tokens at each hierarchy must be of the same length!")
print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
print([len(x) for x in extracted_audio_ids])
# print("WARNING: Number of tokens at each hierarchy must be of the same length!")
# print(f"Truncating to min length of {min_len} tokens from {max_len} max.")
# print([len(x) for x in extracted_audio_ids])
extracted_audio_ids = [x[:min_len] for x in extracted_audio_ids]

return text_ids[:-1], extracted_audio_ids
Expand Down
4 changes: 3 additions & 1 deletion fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class TTS:
END_OF_AUDIO_TOKEN = 1024

def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs",
lora_ckpt_path: str | None = None
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
device=self._device,
compile=True,
compile_prefill=True,
lora_ckpt_path=lora_ckpt_path,
)

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
Expand Down
26 changes: 23 additions & 3 deletions fam/llm/fast_inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import torch._inductor.config
import tqdm

from lora import TransformerWithLoRA


def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -125,7 +127,7 @@ def prefill(
**sampling_kwargs,
) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, spk_emb, input_pos)
logits, _ = model(x, spk_emb, input_pos)
return sample(logits, **sampling_kwargs)[0]


Expand All @@ -138,7 +140,7 @@ def decode_one_token(
) -> Tuple[torch.Tensor, torch.Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, spk_emb, input_pos)
logits, _ = model(x, spk_emb, input_pos)
return sample(logits, **sampling_kwargs)


Expand Down Expand Up @@ -208,6 +210,10 @@ def generate(
next_token = prefill(model, prompt.view(1, -1).repeat(2, 1), spk_emb, input_pos, **sampling_kwargs)
seq = torch.cat([seq, next_token.view(1)])

print("max_new_tokens: ", max_new_tokens)
print("next token: ", next_token)
print("seq: ", seq)

input_pos = torch.tensor([T], device=device, dtype=torch.int)

generated_tokens, _ = decode_n_tokens(
Expand All @@ -220,6 +226,7 @@ def generate(
end_of_audio_token=end_of_audio_token,
**sampling_kwargs,
)
print("generated tokens: ", generated_tokens)
seq = torch.cat([seq, torch.cat(generated_tokens)])

return seq
Expand Down Expand Up @@ -251,6 +258,7 @@ def _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision):
# from quantize import WeightOnlyInt4QuantHandler
# simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
# model = simple_quantizer.convert_for_runtime()


checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=False)
state_dict = checkpoint["model"]
Expand Down Expand Up @@ -319,14 +327,26 @@ def build_model(
compile_prefill: bool = False,
compile: bool = True,
device: str = "cuda",
lora_ckpt_path: str | None = None,
):
assert checkpoint_path.is_file(), checkpoint_path

print(f"Using device={device}")

print("Loading model ...")
t0 = time.time()
model, tokenizer, smodel = _load_model(checkpoint_path, spk_emb_ckpt_path, device, precision)
model, tokenizer, smodel = _load_model(
checkpoint_path,
spk_emb_ckpt_path,
device,
precision,
)

if lora_ckpt_path:
print(f"Loading LoRA from {lora_ckpt_path}")
model = TransformerWithLoRA(model, training_mode=False)
model.load_lora(lora_ckpt_path)
model = model.to(device)

device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
Loading