Skip to content

Commit

Permalink
update vllm_inference to support vision LM
Browse files Browse the repository at this point in the history
  • Loading branch information
khai-meetkai committed Aug 22, 2024
2 parents 02f1de8 + 9478eaf commit af597e3
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 60 deletions.
46 changes: 0 additions & 46 deletions functionary/prompt_template/internlm2_prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,6 @@ def get_assistant_prefixes(self) -> List[str]:
def get_stop_tokens_for_generation(self) -> List[str]:
return [self.eos_token]

def get_force_function_call_prefix(self, function_name: str):
return f"{function_name}\n"

def get_start_of_function_call_token(self) -> str:
return ""

def pre_process_messages_before_inference(self, messages: List[Dict]) -> List[Dict]:
"""Order the tool results by the order of tool call ids
Args:
messages (List[Dict]): List of messages
Returns:
List[Dict]: List of messages
"""
return prompt_utils.reorder_tool_messages_by_tool_call_ids(messages)

def convert_message_to_prompt(self, message: Dict) -> str:
role = message["role"]
content = message.get("content", None)
Expand Down Expand Up @@ -106,32 +89,3 @@ def convert_message_to_prompt(self, message: Dict) -> str:
tool_call_prompts
)
return prompt_template.format(text=total_content)

def get_force_text_generation_prefix(self):
return f"all\n"

def get_chat_template_jinja(self) -> str:
chat_template = """{% for message in messages %}
{% if message['role'] == 'user' or message['role'] == 'system' %}
{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}<br>
{% elif message['role'] == 'tool' %}
{{ '<|im_start|>' + message['role'] + '\nname=' + message['name'] + '\n' + message['content'] + '<|im_end|>' }}<br>
{% else %}
{{ '<|im_start|>' + message['role'] + '\n'}}<br>
{% if message['content'] is not none %}
{{ '>>>all\n' + message['content'] }}<br>
{% endif %}
{% if 'tool_calls' in message and message['tool_calls'] is not none %}
{% for tool_call in message['tool_calls'] %}
{{ '>>>' + tool_call['function']['name'] + '\n' + tool_call['function']['arguments'] }}<br>
{% endfor %}
{% endif %}
{{ '<|im_end|>' }}<br>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}{{ '<|im_start|>{role}\n' }}{% endif %}
"""
chat_template = chat_template.replace(" ", "")
chat_template = chat_template.replace("<br>\n", "")
chat_template = chat_template.strip()
return chat_template
6 changes: 3 additions & 3 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_error_response(


async def check_all_errors(request, served_model) -> Optional[JSONResponse]:
if request.model != served_model:
if request.model not in served_model:
return create_error_response(
status_code=HTTPStatus.NOT_FOUND,
message=f"The model `{request.model}` does not exist.",
Expand Down Expand Up @@ -146,7 +146,7 @@ async def process_chat_completion(
request: ChatCompletionRequest,
raw_request: Optional[Request],
tokenizer: Any,
served_model: str,
served_model: List[str],
engine_model_config: Any,
enable_grammar_sampling: bool,
engine: Any,
Expand Down Expand Up @@ -333,7 +333,7 @@ async def completion_stream_generator(

chunk = StreamChoice(**response)
result = ChatCompletionChunk(
id=request_id, choices=[chunk], model=served_model
id=request_id, choices=[chunk], model=model_name
)
chunk_dic = result.dict(exclude_unset=True)
chunk_data = json.dumps(chunk_dic, ensure_ascii=False)
Expand Down
29 changes: 20 additions & 9 deletions server_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import argparse
import asyncio
import json
import logging
import re
from typing import Any, AsyncGenerator, Dict, List, Literal, Optional, Tuple, Union
import logging

import fastapi
import uvicorn
Expand All @@ -42,16 +42,25 @@
logger.addHandler(logging.StreamHandler())


served_model = None
served_model = []
app = fastapi.FastAPI()


@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=served_model, root=served_model, permission=[ModelPermission()])
]
"""Show available models."""
model_cards = []
if isinstance(served_model, list):
for model in served_model:
model_cards.append(
ModelCard(id=model, root=model, permission=[ModelPermission()])
)
else:
model_cards.append(
ModelCard(
id=served_model, root=served_model, permission=[ModelPermission()]
)
)
return ModelList(data=model_cards)


Expand Down Expand Up @@ -130,9 +139,11 @@ async def create_chat_completion(raw_request: Request):
logger.info(f"args: {args}")

if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
logger.info(
"args.served_model_name is not used in this service and will be ignored. Served model will consist of args.model only."
)

served_model = [args.model]

engine_args = AsyncEngineArgs.from_cli_args(args)
# A separate tokenizer to map token IDs to strings.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_prompt_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, *args, **kwargs):
"meetkai/functionary-small-v2.4",
"meetkai/functionary-small-v2.5",
"meetkai/functionary-medium-v3.0",
"meta-llama/Meta-Llama-3.1-8B-Instruct",
"meetkai/functionary-small-v3.1",
"OpenGVLab/InternVL2-8B",
]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(self, *args, **kwargs):
PromptTemplateV2: "meetkai/functionary-small-v2.4",
Llama3Template: "meetkai/functionary-small-v2.5",
Llama3TemplateV3: "meetkai/functionary-medium-v3.0",
Llama31Template: "meta-llama/Meta-Llama-3.1-8B-Instruct",
Llama31Template: "meetkai/functionary-small-v3.1",
LlavaLlama: "lmms-lab/llama3-llava-next-8b",
InternLMChat: "OpenGVLab/InternVL2-8B",
}
Expand Down

0 comments on commit af597e3

Please sign in to comment.