Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LLM Pipeline #137

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case "llm-generate":
from app.pipelines.llm_generate import LLMGeneratePipeline
return LLMGeneratePipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -82,6 +85,10 @@ def load_route(pipeline: str) -> any:
from app.routes import upscale

return upscale.router
case "llm-generate":
from app.routes import llm_generate

return llm_generate.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
186 changes: 186 additions & 0 deletions runner/app/pipelines/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import asyncio
import logging
import os
import psutil
from typing import Dict, Any, List, Optional, AsyncGenerator, Union

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download, snapshot_download
from threading import Thread

logger = logging.getLogger(__name__)

def get_max_memory():
num_gpus = torch.cuda.device_count()
gpu_memory = {i: f"{torch.cuda.get_device_properties(i).total_memory // 1024**3}GiB" for i in range(num_gpus)}
cpu_memory = f"{psutil.virtual_memory().available // 1024**3}GiB"
max_memory = {**gpu_memory, "cpu": cpu_memory}

logger.info(f"Max memory configuration: {max_memory}")
return max_memory

def load_model_8bit(model_id: str, **kwargs):
max_memory = get_max_memory()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map="auto",
max_memory=max_memory,
offload_folder="offload",
low_cpu_mem_usage=True,
**kwargs
)

return tokenizer, model

def load_model_fp16(model_id: str, **kwargs):
device = get_torch_device()
max_memory = get_max_memory()

# Check for fp16 variant
local_model_path = os.path.join(get_model_dir(), file_download.repo_folder_name(repo_id=model_id, repo_type="model"))
has_fp16_variant = any(".fp16.safetensors" in fname for _, _, files in os.walk(local_model_path) for fname in files)

if device != "cpu" and has_fp16_variant:
logger.info("Loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
elif device != "cpu":
kwargs["torch_dtype"] = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs)

config = AutoModelForCausalLM.from_pretrained(model_id, **kwargs).config

with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)

checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True)

model = load_checkpoint_and_dispatch(
model,
checkpoint_dir,
device_map="auto",
max_memory=max_memory,
no_split_module_classes=["LlamaDecoderLayer"], # Adjust based on your model architecture
dtype=kwargs.get("torch_dtype", torch.float32),
offload_folder="offload",
offload_state_dict=True,
)

return tokenizer, model

class LLMGeneratePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {
"cache_dir": get_model_dir(),
"local_files_only": True,
}
self.device = get_torch_device()

# Generate the correct folder name
folder_path = file_download.repo_folder_name(repo_id=model_id, repo_type="model")
self.local_model_path = os.path.join(get_model_dir(), folder_path)
self.checkpoint_dir = snapshot_download(model_id, cache_dir=get_model_dir(), local_files_only=True)

logger.info(f"Local model path: {self.local_model_path}")
logger.info(f"Directory contents: {os.listdir(self.local_model_path)}")

use_8bit = os.getenv("USE_8BIT", "").strip().lower() == "true"

if use_8bit:
logger.info("Using 8-bit quantization")
self.tokenizer, self.model = load_model_8bit(model_id, **kwargs)
else:
logger.info("Using fp16/bf16 precision")
self.tokenizer, self.model = load_model_fp16(model_id, **kwargs)

logger.info(f"Model loaded and distributed. Device map: {self.model.hf_device_map}")

# Set up generation config
self.generation_config = self.model.generation_config

self.terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

# Optional: Add optimizations
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
if sfast_enabled:
logger.info(
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.model = compile_model(self.model)
async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
conversation = []
if system_msg:
conversation.append({"role": "system", "content": system_msg})
if history:
conversation.extend(history)
conversation.append({"role": "user", "content": prompt})

input_ids = self.tokenizer.apply_chat_template(conversation, return_tensors="pt").to(self.model.device)
attention_mask = torch.ones_like(input_ids)

max_new_tokens = kwargs.get("max_tokens", 256)
temperature = kwargs.get("temperature", 0.7)

streamer = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

generate_kwargs = self.generation_config.to_dict()
generate_kwargs.update({
"input_ids": input_ids,
"attention_mask": attention_mask,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": temperature > 0,
"temperature": temperature,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.eos_token_id,
})

thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs)
thread.start()

total_tokens = 0
try:
for text in streamer:
total_tokens += 1
yield text
await asyncio.sleep(0) # Allow other tasks to run
except Exception as e:
logger.error(f"Error during streaming: {str(e)}")
raise

input_length = input_ids.size(1)
yield {"tokens_used": input_length + total_tokens}

def model_generate_wrapper(self, **kwargs):
try:
logger.debug("Entering model.generate")
with torch.cuda.amp.autocast(): # Use automatic mixed precision
self.model.generate(**kwargs)
logger.debug("Exiting model.generate")
except Exception as e:
logger.error(f"Error in model.generate: {str(e)}", exc_info=True)
raise

def __str__(self):
return f"LLMGeneratePipeline(model_id={self.model_id})"
110 changes: 110 additions & 0 deletions runner/app/routes/llm_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
from typing import Annotated, Optional, List
from fastapi import APIRouter, Depends, Form, status, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, LlmResponse, TextResponse, http_error
import json

router = APIRouter()

logger = logging.getLogger(__name__)

RESPONSES = {
status.HTTP_200_OK: {"model": LlmResponse},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post("/llm-generate",
response_model=LlmResponse, responses=RESPONSES)
@router.post("/llm-generate/", response_model=LlmResponse, responses=RESPONSES, include_in_schema=False)
async def llm_generate(
prompt: Annotated[str, Form()],
model_id: Annotated[str, Form()] = "",
system_msg: Annotated[str, Form()] = "",
temperature: Annotated[float, Form()] = 0.7,
max_tokens: Annotated[int, Form()] = 256,
history: Annotated[str, Form()] = "[]", # We'll parse this as JSON
stream: Annotated[bool, Form()] = False,
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

try:
history_list = json.loads(history)
if not isinstance(history_list, list):
raise ValueError("History must be a JSON array")

generator = pipeline(
prompt=prompt,
history=history_list,
system_msg=system_msg if system_msg else None,
temperature=temperature,
max_tokens=max_tokens
)

if stream:
return StreamingResponse(stream_generator(generator), media_type="text/event-stream")
else:
full_response = ""
async for chunk in generator:
if isinstance(chunk, dict):
tokens_used = chunk["tokens_used"]
break
full_response += chunk

return LlmResponse(response=full_response, tokens_used=tokens_used)

except json.JSONDecodeError:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": "Invalid JSON format for history"}
)
except ValueError as ve:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(ve)}
)
except Exception as e:
logger.error(f"LLM processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Internal server error during LLM processing."}
)


async def stream_generator(generator):
try:
async for chunk in generator:
if isinstance(chunk, dict): # This is the final result
yield f"data: {json.dumps(chunk)}\n\n"
break
else:
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {str(e)}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
5 changes: 5 additions & 0 deletions runner/app/routes/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class TextResponse(BaseModel):
chunks: List[chunk]


class LlmResponse(BaseModel):
response: str
tokens_used: int


class APIError(BaseModel):
msg: str

Expand Down
18 changes: 18 additions & 0 deletions runner/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import subprocess

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")

# Check system CUDA version
try:
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
cuda_version = nvcc_output.split("release ")[-1].split(",")[0]
print(f"System CUDA version: {cuda_version}")
except:
print("Unable to check system CUDA version")

# Print the current device
print(f"Current device: {torch.cuda.get_device_name(0)}")
4 changes: 4 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ function download_all_models() {

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download LLM models (Warning: large model size)
huggingface-cli download meta-llama/Meta-Llama-3.1-8B-Instruct --include "*.json" "*.bin" "*.safetensors" "*.txt" --cache-dir models

}

# Enable HF transfer acceleration.
Expand Down
2 changes: 2 additions & 0 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
image_to_video,
text_to_image,
upscale,
llm_generate
)
from fastapi.openapi.utils import get_openapi

Expand Down Expand Up @@ -85,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"):
app.include_router(image_to_video.router)
app.include_router(upscale.router)
app.include_router(audio_to_text.router)
app.include_router(llm_generate.router)

use_route_names_as_operation_ids(app)

Expand Down
Loading