Skip to content

Commit

Permalink
feat: anonymised telemetry to track usage patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
Siddharth Sharma committed Apr 16, 2024
1 parent 0345ea8 commit 9ee1cbf
Show file tree
Hide file tree
Showing 10 changed files with 1,084 additions and 644 deletions.
Binary file modified assets/bria.mp3
Binary file not shown.
33 changes: 30 additions & 3 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
get_device,
normalize_text,
)
from fam.telemetry import TelemetryEvent
from fam.telemetry.posthog import PosthogClient

posthog = PosthogClient() # see fam/telemetry/README.md for more information


class TTS:
Expand Down Expand Up @@ -68,7 +72,7 @@ def __init__(
os.makedirs(self.output_dir, exist_ok=True)
if first_stage_path:
print(f"Overriding first stage checkpoint via provided model: {first_stage_path}")
first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"
self._first_stage_ckpt = first_stage_path or f"{self._model_dir}/first_stage.pt"

second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
Expand All @@ -90,13 +94,16 @@ def __init__(
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(first_stage_ckpt),
checkpoint_path=Path(self._first_stage_ckpt),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
compile_prefill=True,
quantisation_mode=quantisation_mode,
)
self._seed = seed
self._quantisation_mode = quantisation_mode
self._model_name = model_name

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
Expand Down Expand Up @@ -156,8 +163,28 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
duration_s = librosa.get_duration(y=audio, sr=sr)
real_time_factor = time_to_synth_s / duration_s
print(f"\nTotal time to synth (s): {time_to_synth_s}")
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
print(f"Real-time factor: {real_time_factor:.2f}")

posthog.capture(
TelemetryEvent(
name="user_ran_tts",
properties={
"text": text,
"temperature": temperature,
"guidance_scale": guidance_scale,
"top_p": top_p,
"spk_ref_path": spk_ref_path,
"speech_duration_s": duration_s,
"time_to_synth_s": time_to_synth_s,
"real_time_factor": round(real_time_factor, 2),
"quantisation_mode": self._quantisation_mode,
"seed": self._seed,
"first_stage_ckpt": self._first_stage_ckpt,
},
)
)

return str(wav_file) + ".wav"

Expand Down
69 changes: 45 additions & 24 deletions fam/llm/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import itertools
import math
from pathlib import Path
import time
from pathlib import Path
from typing import Any, Dict, Optional

import click
Expand All @@ -19,7 +19,12 @@
from fam.llm.model import GPT, GPTConfig
from fam.llm.preprocessing.audio_token_mode import get_params_for_mode
from fam.llm.preprocessing.data_pipeline import get_training_tuple
from fam.llm.utils import hash_dictionary
from fam.telemetry import TelemetryEvent
from fam.telemetry.posthog import PosthogClient

# see fam/telemetry/README.md for more information
posthog = PosthogClient()

dtype: Literal["bfloat16", "float16", "tfloat32", "float32"] = (
"bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"
Expand Down Expand Up @@ -50,11 +55,13 @@
ckpts_save_dir = ckpts_base_dir / out_dir
os.makedirs(ckpts_save_dir, exist_ok=True)


def get_globals_state():
""" Return entirety of configuration global state which can be used for logging. """
"""Return entirety of configuration global state which can be used for logging."""
config_keys = [k for k, v in globals().items() if not k.startswith("_") and isinstance(v, (int, float, bool, str))]
return {k: globals()[k] for k in config_keys} # will be useful for logging


model_args: dict = dict(
n_layer=n_layer,
n_head=n_head,
Expand All @@ -72,6 +79,7 @@ def get_globals_state():
swiglu_multiple_of=swiglu_multiple_of,
) # start with model_args from command line


def strip_prefix(state_dict: Dict[str, Any], unwanted_prefix: str):
# TODO: this also appears in fast_inference_utils._load_model, it should be moved to a common place.
for k, v in list(state_dict.items()):
Expand Down Expand Up @@ -146,19 +154,13 @@ def main(train: Path, val: Path, model_id: str, ckpt: Optional[Path], spk_emb_ck
allow_ops_in_compiled_graph()
model = torch.compile(model) # type: ignore

def estimate_loss(dataset, iters: int=eval_iters):
""" Estimate loss on a dataset by running on `iters` batches. """
def estimate_loss(dataset, iters: int = eval_iters):
"""Estimate loss on a dataset by running on `iters` batches."""
if dataset is None:
return torch.nan
losses = []
for _, batch in zip(tqdm(range(iters)), dataset):
X, Y, SE = get_training_tuple(
batch,
causal,
num_codebooks,
speaker_cond,
device
)
X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)
with ctx:
_, loss = model(X, Y, speaker_embs=SE, speaker_emb_mask=None)
losses.append(loss.item())
Expand Down Expand Up @@ -206,9 +208,7 @@ def get_lr(it):
mode_params["ctx_window"],
device,
)
train_dataloader = itertools.cycle(
DataLoader(train_dataset, batch_size, shuffle=True)
)
train_dataloader = itertools.cycle(DataLoader(train_dataset, batch_size, shuffle=True))
train_data = iter(train_dataloader)
# we do not perform any explicit checks for dataset overlap & leave it to the user
# to handle this
Expand All @@ -219,13 +219,7 @@ def get_lr(it):
eval_train_data = DataLoader(train_dataset, batch_size, shuffle=True)

batch = next(train_data)
X, Y, SE = get_training_tuple(
batch,
causal,
num_codebooks,
speaker_cond,
device
)
X, Y, SE = get_training_tuple(batch, causal, num_codebooks, speaker_cond, device)

t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process
Expand All @@ -244,11 +238,29 @@ def get_lr(it):
for param in model.parameters():
param.requires_grad = False
for param in itertools.chain(
model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune*-1:].parameters()
model.transformer.ln_f.parameters(), model.transformer.h[last_n_blocks_to_finetune * -1 :].parameters()
):
param.requires_grad = True
print(f"After freezing excl. last {last_n_blocks_to_finetune} transformer blocks: {trainable_count(model)=}...")

# log start of finetuning event
properties = {
**config,
**model_args,
"train": str(train),
"val": str(val),
"model_id": model_id,
"ckpt": ckpt,
"spk_emb_ckpt": spk_emb_ckpt,
}
finetune_jobid = hash_dictionary(properties)
posthog.capture(
TelemetryEvent(
name="user_started_finetuning",
properties={"finetune_jobid": finetune_jobid, **properties},
)
)

while True:
lr = get_lr(iter_num) if decay_lr else learning_rate
for param_group in optimizer.param_groups:
Expand Down Expand Up @@ -278,7 +290,9 @@ def get_lr(it):
if losses["val"] < best_val_loss:
best_val_loss = losses["val"]
if iter_num > 0:
ckpt_save_name = ckpt_save_name.replace(".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt")
ckpt_save_name = ckpt_save_name.replace(
".pt", f"_bestval_{best_val_loss}".replace(".", "_") + ".pt"
)
save_checkpoint = True

save_checkpoint = save_checkpoint or iter_num % save_interval == 0
Expand Down Expand Up @@ -352,7 +366,14 @@ def get_lr(it):

# termination conditions
if iter_num > max_iters:
break
# log end of finetuning event
posthog.capture(
TelemetryEvent(
name="user_completed_finetuning",
properties={"finetune_jobid": finetune_jobid},
)
)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions fam/llm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,8 +402,8 @@ def get_cached_file(file_or_uri: str):
# hash the file path to get the cache name
_cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext

os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
os.makedirs(os.path.expanduser("~/.cache/metavoice/"), exist_ok=True)
cache_path = os.path.expanduser(f"~/.cache/metavoice/{_cache_name}")

if not os.path.exists(cache_path):
command = f"curl -o {cache_path} {file_or_uri}"
Expand Down
14 changes: 14 additions & 0 deletions fam/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import json
import os
import re
import subprocess
Expand Down Expand Up @@ -87,3 +89,15 @@ def get_default_dtype() -> str:

def get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"


def hash_dictionary(d: dict):
# Serialize the dictionary into JSON with sorted keys to ensure consistency
serialized = json.dumps(d, sort_keys=True)
# Encode the serialized string to bytes
encoded = serialized.encode()
# Create a hash object (you can also use sha1, sha512, etc.)
hash_object = hashlib.sha256(encoded)
# Get the hexadecimal digest of the hash
hash_digest = hash_object.hexdigest()
return hash_digest
5 changes: 5 additions & 0 deletions fam/telemetry/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Telemetry

This directory holds all the telemetry for MetaVoice. We, MetaVoice, capture anonymized telemetry to understand usage patterns.

If you prefer to opt out of telemetry, set `ANONYMIZED_TELEMETRY=False` in an .env file at the root level of this repo.
43 changes: 43 additions & 0 deletions fam/telemetry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import abc
from abc import abstractmethod
from dataclasses import dataclass
import os
import uuid
from pathlib import Path


@dataclass(frozen=True)
class TelemetryEvent:
name: str
properties: dict


class TelemetryClient(abc.ABC):
USER_ID_PATH = str(Path.home() / ".cache" / "metavoice" / "telemetry_user_id")
UNKNOWN_USER_ID = "UNKNOWN"
_curr_user_id = None

@abstractmethod
def capture(self, event: TelemetryEvent) -> None:
pass

@property
def user_id(self) -> str:
if self._curr_user_id:
return self._curr_user_id

# File access may fail due to permissions or other reasons. We don't want to
# crash so we catch all exceptions.
try:
if not os.path.exists(self.USER_ID_PATH):
os.makedirs(os.path.dirname(self.USER_ID_PATH), exist_ok=True)
with open(self.USER_ID_PATH, "w") as f:
new_user_id = str(uuid.uuid4())
f.write(new_user_id)
self._curr_user_id = new_user_id
else:
with open(self.USER_ID_PATH, "r") as f:
self._curr_user_id = f.read()
except Exception:
self._curr_user_id = self.UNKNOWN_USER_ID
return self._curr_user_id
40 changes: 40 additions & 0 deletions fam/telemetry/posthog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging
import os
import sys

from dotenv import load_dotenv
from posthog import Posthog

from fam.telemetry import TelemetryClient, TelemetryEvent

load_dotenv()
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout), logging.StreamHandler(sys.stderr)])


class PosthogClient(TelemetryClient):
def __init__(self):
self._posthog = Posthog(
project_api_key="phc_tk7IUlV7Q7lEa9LNbXxyC1sMWlCqiW6DkHyhJrbWMCS", host="https://eu.posthog.com"
)

if not os.getenv("ANONYMIZED_TELEMETRY", True) or "pytest" in sys.modules:
self._posthog.disabled = True
logger.info("Anonymized telemetry disabled. See fam/telemetry/README.md for more information.")
else:
logger.info("Anonymized telemetry enabled. See fam/telemetry/README.md for more information.")

posthog_logger = logging.getLogger("posthog")
posthog_logger.disabled = True # Silence posthog's logging

super().__init__()

def capture(self, event: TelemetryEvent) -> None:
try:
self._posthog.capture(
self.user_id,
event.name,
{**event.properties},
)
except Exception as e:
logger.error(f"Failed to send telemetry event {event.name}: {e}")
Loading

0 comments on commit 9ee1cbf

Please sign in to comment.