Skip to content

Commit

Permalink
refactor(audio.processor): remove duplicate quantization methods
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard committed Nov 15, 2023
1 parent ddbaecd commit da0f5b4
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 45 deletions.
3 changes: 2 additions & 1 deletion TTS/bin/extract_tts_spectrograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
import os
from TTS.utils.audio.numpy_transforms import quantize

import numpy as np
import torch
Expand Down Expand Up @@ -197,7 +198,7 @@ def extract_spectrograms(

# quantize and save wav
if quantize_bits > 0:
wavq = ap.quantize(wav, quantize_bits)
wavq = quantize(wav, quantize_bits)
np.save(wavq_path, wavq)

# save TTS mel
Expand Down
40 changes: 0 additions & 40 deletions TTS/utils/audio/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,43 +631,3 @@ def get_duration(self, filename: str) -> float:
filename (str): Path to the wav file.
"""
return librosa.get_duration(filename=filename)

@staticmethod
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
mu = 2**qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(
signal,
)

@staticmethod
def mulaw_decode(wav, qc):
"""Recovers waveform from quantized values."""
mu = 2**qc - 1
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x

@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)

@staticmethod
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
"""Quantize a waveform to a given number of bits.
Args:
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
bits (int): Number of quantization bits.
Returns:
np.ndarray: Quantized waveform.
"""
return (x + 1.0) * (2**bits - 1) / 2

@staticmethod
def dequantize(x, bits):
"""Dequantize a waveform from the given number of bits."""
return 2 * x / (2**bits - 1) - 1
6 changes: 5 additions & 1 deletion TTS/vocoder/datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode)
quant = (
mulaw_encode(wav=y, mulaw_qc=config.mode)
if config.model_args.mulaw
else quantize(x=y, quantize_bits=config.mode)
)
np.save(quant_path, quant)


Expand Down
5 changes: 4 additions & 1 deletion TTS/vocoder/datasets/wavernn_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize
import numpy as np
import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -66,7 +67,9 @@ def load_item(self, index):
x_input = audio
elif isinstance(self.mode, int):
x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
mulaw_encode(wav=audio, mulaw_qc=self.mode)
if self.mulaw
else quantize(x=audio, quantize_bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
Expand Down
3 changes: 2 additions & 1 deletion TTS/vocoder/models/wavernn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from dataclasses import dataclass, field
from typing import Dict, List, Tuple
from TTS.utils.audio.numpy_transforms import mulaw_decode

import numpy as np
import torch
Expand Down Expand Up @@ -399,7 +400,7 @@ def inference(self, mels, batched=None, target=None, overlap=None):
output = output[0]

if self.args.mulaw and isinstance(self.args.mode, int):
output = AudioProcessor.mulaw_decode(output, self.args.mode)
output = mulaw_decode(wav=output, mulaw_qc=self.args.mode)

# Fade-out at the end to avoid signal cutting out suddenly
fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length)
Expand Down
3 changes: 2 additions & 1 deletion notebooks/ExtractTTSpectrogram.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n",
"%matplotlib inline\n",
"\n",
Expand Down Expand Up @@ -190,7 +191,7 @@
"\n",
" # quantize and save wav\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = ap.quantize(wav, QUANTIZE_BITS)\n",
" wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
Expand Down

0 comments on commit da0f5b4

Please sign in to comment.