Skip to content

Commit

Permalink
llm: support streamed responses
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Aug 1, 2024
1 parent 0619926 commit 922f9d2
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 187 deletions.
226 changes: 109 additions & 117 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,141 +1,133 @@
import asyncio
import logging
import os
from typing import Dict, Any, Optional
from typing import Dict, Any, List, Optional

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, hf_hub_download
from threading import Thread
from typing import AsyncGenerator, Union, Dict, Any, Optional, List

logger = logging.getLogger(__name__)


# class LLMGeneratePipeline(Pipeline):
# def __init__(self, model_id: str):
# self.model_id = model_id
# kwargs = {
# "cache_dir": get_model_dir()
# }
# self.device = get_torch_device()
# folder_name = file_download.repo_folder_name(
# repo_id=model_id, repo_type="model"
# )
# folder_path = os.path.join(get_model_dir(), folder_name)

# # Check for fp16 variant
# has_fp16_variant = any(
# ".fp16.safetensors" in fname
# for _, _, files in os.walk(folder_path)
# for fname in files
# )
# if self.device != "cpu" and has_fp16_variant:
# logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id)
# kwargs["torch_dtype"] = torch.float16
# kwargs["variant"] = "fp16"

# # Load tokenizer
# self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

# # Load model
# self.model = AutoModelForCausalLM.from_pretrained(
# model_id, **kwargs).to(self.device)

# # Set up generation config
# self.generation_config = self.model.generation_config

# # Optional: Add optimizations
# sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
# if sfast_enabled:
# logger.info(
# "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s",
# model_id,
# )
# from app.pipelines.optim.sfast import compile_model
# self.model = compile_model(self.model)

# def __call__(self, prompt: str, system_msg: Optional[str] = None,
# temperature: Optional[float] = None,
# max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]:
# if system_msg:
# input_text = f"{system_msg}\n\n{prompt}"
# else:
# input_text = prompt

# input_ids = self.tokenizer.encode(
# input_text, return_tensors="pt").to(self.device)

# # Update generation config
# gen_kwargs = {}
# if temperature is not None:
# gen_kwargs['temperature'] = temperature
# if max_tokens is not None:
# gen_kwargs['max_new_tokens'] = max_tokens

# # Merge generation config with provided kwargs
# gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs}

# # Generate response
# with torch.no_grad():
# output = self.model.generate(
# input_ids,
# **gen_kwargs
# )

# # Decode the response
# response = self.tokenizer.decode(output[0], skip_special_tokens=True)

# # Calculate tokens used
# tokens_used = len(output[0])

# return {
# "response": response.strip(),
# "tokens_used": tokens_used
# }

# def __str__(self) -> str:
# return f"LLMPipeline model_id={self.model_id}"


class LLMGeneratePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
self.device = get_torch_device()

kwargs = {
"cache_dir": get_model_dir(),
"device_map": "auto",
"torch_dtype": torch.bfloat16 if self.device != "cpu" else torch.float32,
"cache_dir": get_model_dir()
}

logger.info(f"Loading model {model_id}")
self.pipeline = pipeline(
"text-generation",
model=model_id,
tokenizer=model_id,
**kwargs
self.device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)

def __call__(self, prompt: str, system_msg: str = None, **kwargs):
messages = []
if system_msg:
messages.append({"role": "system", "content": system_msg})
messages.append({"role": "user", "content": prompt})

outputs = self.pipeline(
messages,
max_new_tokens=kwargs.get("max_tokens", 256),
temperature=kwargs.get("temperature", 0.7),
# Check for fp16 variant
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if self.device != "cpu" and has_fp16_variant:
logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif self.device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

response = outputs[0]["generated_text"]
# Assuming the response is the last message in the conversation
response = response.split("assistant:")[-1].strip()
# Add device mapping
kwargs["device_map"] = "auto"

return {
"response": response,
"tokens_used": len(self.pipeline.tokenizer.encode(response))
}
logger.info(f"Loading model {model_id}")
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)
self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)

# Set up generation config
self.generation_config = self.model.generation_config

self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Optional: Add optimizations
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)

async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
if system_msg:
conversation.append({"role": "system", "content": system_msg})
if history:
for user, assistant in history:
conversation.extend([{"role": "user", "content": user}, {
"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": prompt})

input_ids = self.tokenizer.apply_chat_template(
conversation, return_tensors="pt").to(self.model.device)
attention_mask = torch.ones_like(input_ids)

max_new_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

# Start with the generation config
generate_kwargs = self.generation_config.to_dict()
# Update with our specific parameters
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})

# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
if temperature == 0:
generate_kwargs['do_sample'] = False

# Start generation in a separate thread
thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs)
thread.start()

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
await asyncio.sleep(0) # Allow other tasks to run
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise

input_length = input_ids.size(1)
yield {"tokens_used": input_length + total_tokens}

def __str__(self):
return f"LLMGeneratePipeline(model_id={self.model_id})"

def model_generate_wrapper(self, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
self.model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise
68 changes: 57 additions & 11 deletions runner/app/routes/llm_generate.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
import logging
import os
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, Form, status
from fastapi.responses import JSONResponse
from typing import Annotated, Optional, List
from fastapi import APIRouter, Depends, Form, status, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_200_OK: {"model": LlmResponse},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post("/llm-generate", response_model=LlmResponse, responses=RESPONSES)
@router.post("/llm-generate",
response_model=LlmResponse, responses=RESPONSES)
@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False)
async def llm_generate(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = None,
temperature: Annotated[float, Form()] = None,
max_tokens: Annotated[int, Form()] = None,
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand All @@ -49,16 +54,57 @@ async def llm_generate(
)

try:
result = pipeline(
history_list = json.loads(history)
if not isinstance(history_list, list):
raise ValueError("History must be a JSON array")

generator = pipeline(
prompt=prompt,
system_msg=system_msg,
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens
)
return JSONResponse(content=result)

if stream:
return StreamingResponse(stream_generator(generator), media_type="text/event-stream")
else:
full_response = ""
async for chunk in generator:
if isinstance(chunk, dict):
tokens_used = chunk["tokens_used"]
break
full_response += chunk

return LlmResponse(response=full_response, tokens_used=tokens_used)

except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format for history"}
)
except ValueError as ve:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(ve)}
)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("Internal server error during LLM processing."),
content={"detail": "Internal server error during LLM processing."}
)


async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict): # This is the final result
yield f"data: {json.dumps(chunk)}\n\n"
break
else:
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
18 changes: 18 additions & 0 deletions runner/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import subprocess

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")

# Check system CUDA version
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release ")[-1].split(",")[0]
print(f"System CUDA version: {cuda_version}")
except:
print("Unable to check system CUDA version")

# Print the current device
print(f"Current device: {torch.cuda.get_device_name(0)}")
Loading

0 comments on commit 922f9d2

Please sign in to comment.