Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add speech+llm triton trt-llm inference solution #648

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .gitmodules

This file was deleted.

16 changes: 15 additions & 1 deletion triton/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ In this tutorial, we'll go through how to run non-streaming (offline) and strea
- [Prepare Environment](#prepare-environment)
- [Deploy on Triton Inference Server](#deploy-on-triton-inference-server)
- [Quick Start](#quick-start)
- [Inference Client](client/README.md)
- [Inference Client](#benchmark-using-dataset)
- [Using TensorRT acceleration](#using-tensorrt-acceleration)
- [TRT Quick start](#trt-quick-start)
- [Benchmark for Conformer TRT encoder vs ONNX](#benchmark-for-conformer-trt-encoder-vs-onnx)
Expand Down Expand Up @@ -67,6 +67,20 @@ export CUDA_VISIBLE_DEVICES="your_gpu_id"
bash scripts/build_wenetspeech_zipformer_offline_trt.sh
```

## Benchmark using Dataset
```sh
git clone https://github.com/yuekaizhang/Triton-ASR-Client.git
cd Triton-ASR-Client
pip3 install -r requirements.txt
num_task=16
python3 client.py \
--server-addr localhost \
--model-name whisper \
--num-tasks $num_task \
--whisper-prompt "<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>" \
--manifest-dir ./datasets/aishell1_test
```

## Using TensorRT acceleration

### TRT Quick start
Expand Down
1 change: 0 additions & 1 deletion triton/client
Submodule client deleted from 597b70
9 changes: 9 additions & 0 deletions triton/speech_llm/Dockerfile.server
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM nvcr.io/nvidia/tritonserver:24.08-trtllm-python-py3
WORKDIR /workspace

COPY build.sh .
COPY launch_server.sh .
COPY fill_template.py .
COPY model_repo_whisper_qwen_trtllm model_repo_whisper_qwen_trtllm

RUN pip install kaldialign soundfile tritonclient[grpc]
66 changes: 66 additions & 0 deletions triton/speech_llm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
## Triton Inference Serving Best Practice for Speech LLM

### Model Training
See https://github.com/k2-fsa/icefall/tree/master/egs/speech_llm/ASR_LLM.

### Quick Start
Directly launch the service using docker compose.
```sh
docker compose up --build
```

### Build Image
Build the docker image from scratch.
```sh
# build from scratch, cd to the parent dir of Dockerfile.server
docker build . -f Dockerfile.server -t soar97/triton-whisper-qwen:24.08
```

### Create Docker Container
```sh
your_mount_dir=/mnt:/mnt
docker run -it --name "whisper-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-whisper-qwen:24.08
```

### Export Models to TensorRT-LLM
Inside docker container, we would follow the official guide of TensorRT-LLM to build qwen and whisper TensorRT-LLM engines. See [here](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/whisper).

```sh
bash build.sh
```

### Launch Server
```sh
bash launch_server.sh
```

<!-- ### Launch Gradio WebUI Client
The gradio client supports text and speech as the inputs.

```sh
git-lfs install
git clone https://huggingface.co/spaces/yuekai/triton-asr-client.git
cd triton-asr-client
pip3 install -r requirements.txt
python3 app.py
``` -->

### Benchmark using Dataset
```sh
git clone https://github.com/yuekaizhang/Triton-ASR-Client.git
cd Triton-ASR-Client
num_task=16
python3 client.py \
--server-addr localhost \
--model-name infer_bls \
--num-tasks $num_task \
--manifest-dir ./datasets/aishell1_test \
--compute-cer
```

### Benchmark Results
Decoding on a single A10 GPU, audios are padded to 30s, using aishell1 test set files

| Model | Backend | Concurrency | RTF |
|-------|-----------|-----------------------|---------|
| Whisper Large-v2 Encoder + Qwen 1.5B | python backend speech encoder + trt-llm backend llm | 16 | 0.016 |
19 changes: 19 additions & 0 deletions triton/speech_llm/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
huggingface_checkpoint_dir=./whisper_qwen_1.5B
huggingface-cli download --local-dir $huggingface_checkpoint_dir yuekai/whisper_qwen_multi_hans_zh_triton_checkpoint
cd $huggingface_checkpoint_dir && bash build_qwen.sh && bash build_whisper_encoder.sh && cd -

model_repo=./model_repo_whisper_qwen_trtllm_exp
rm -rf $model_repo
cp -r ./model_repo_whisper_qwen_trtllm $model_repo || exit 1

engine_path=$huggingface_checkpoint_dir/qwen2_1.5B_instruct_fp16_merged
encoder_engine_dir=$huggingface_checkpoint_dir/whisper_multi_zh
adapter_dir=$huggingface_checkpoint_dir/icefall_asr_multi-hans_whisper_qwen2_1.5B/epoch-2-avg-6.pt
max_batch=16
decoupled_mode=false
max_queue_delay_microseconds=0
n_mels=80
n_instances=8
python3 fill_template.py -i $model_repo/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:$max_batch,decoupled_mode:${decoupled_mode},max_beam_width:1,engine_dir:${engine_path},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2000,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${max_queue_delay_microseconds}
python3 fill_template.py -i $model_repo/speech_encoder/config.pbtxt triton_max_batch_size:$max_batch,adapter_dir:$adapter_dir,encoder_engine_dir:$encoder_engine_dir,max_queue_delay_microseconds:${max_queue_delay_microseconds}
python3 fill_template.py -i $model_repo/infer_bls/config.pbtxt triton_max_batch_size:$max_batch,n_mels:$n_mels,n_instances:$n_instances,decoupled_mode:${decoupled_mode},max_queue_delay_microseconds:${max_queue_delay_microseconds}
20 changes: 20 additions & 0 deletions triton/speech_llm/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
services:
asr:
build:
context: .
dockerfile: Dockerfile.server
ports:
- "8000:8000"
- "8001:8001"
- "8002:8002"
environment:
- PYTHONIOENCODING=utf-8
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
command: >
/bin/bash -c "bash build.sh && bash launch_server.sh"
38 changes: 38 additions & 0 deletions triton/speech_llm/fill_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#! /usr/bin/env python3
from argparse import ArgumentParser
from string import Template


def main(file_path, substitutions, in_place):
with open(file_path) as f:
pbtxt = Template(f.read())

sub_dict = {"max_queue_size": 0}
for sub in substitutions.split(","):
key, value = sub.split(":")
sub_dict[key] = value

pbtxt = pbtxt.safe_substitute(sub_dict)

if in_place:
with open(file_path, "w") as f:
f.write(pbtxt)
else:
print(pbtxt)


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("file_path", help="path of the .pbtxt to modify")
parser.add_argument(
"substitutions",
help=
"substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..."
)
parser.add_argument("--in_place",
"-i",
action="store_true",
help="do the operation in-place")
args = parser.parse_args()

main(**vars(args))
5 changes: 5 additions & 0 deletions triton/speech_llm/launch_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export CUDA_VISIBLE_DEVICES="0"

model_repo=./model_repo_whisper_qwen_trtllm_exp

tritonserver --model-repository $model_repo
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Reference: https://github.com/openai/whisper/blob/main/whisper/audio.py
import numpy as np
import torch
import torch.nn.functional as F
from typing import Union
import os

def mel_filters(device, n_mels: int =128) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:

np.savez_compressed(
"mel_filters.npz",
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels == 80 or n_mels == 128 , f"Unsupported n_mels: {n_mels}"
with np.load(
os.path.join(os.path.dirname(__file__), "mel_filters.npz")
) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)


def log_mel_spectrogram(
audio: Union[torch.Tensor],
filters: torch.Tensor,
n_mels: int = 128,
n_fft: int = 400,
hop_length: int = 160,
):
"""
Compute the log-Mel spectrogram of

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 or 128 is supported

filters: torch.Tensor

Returns
-------
torch.Tensor, shape = (128, n_frames)
A Tensor that contains the Mel spectrogram
"""
window = torch.hann_window(n_fft).to(audio.device)
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
# cast to float 16
log_spec = log_spec.half()
return log_spec

class FeatureExtractor(torch.nn.Module):
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""

def __init__(self, n_mels: int = 128):
self.device = torch.device("cuda")
self.n_mels = n_mels
self.filters = mel_filters(self.device, n_mels=self.n_mels)

def compute_feature(self, wav, target: int = 3000):
mel = log_mel_spectrogram(wav, self.filters)
assert mel.shape[1] <= target, f"{mel.shape[1]} > {target}, audio is too long"
if mel.shape[1] < target:
mel = F.pad(mel, (0, target - mel.shape[1]), mode='constant')
mel = mel.unsqueeze(0)
return mel
Binary file not shown.
Loading