Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 0 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ Documentation and more examples: [functionary.meetkai.com](https://functionary.m
+ [2024/05/17] We release [meetkai/functionary-small-v2.5](https://huggingface.co/meetkai/functionary-small-v2.5) with better capability for function calling and code interpreter compared with [functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4)
+ [2024/05/06] Streaming support for functionary v2 to v2.4 models is released in [llama-cpp-python](https://github.com/abetlen/llama-cpp-python)!
+ [2024/05/03] Added support for serverless vLLM deployment on [Modal.com](https://modal.com/)
+ [2024/04/27] New and improved grammar sampling! Ensures 100% accuracy in generating function names, prompt template and parameters.
+ [2024/04/02] We release [meetkai/functionary-small-v2.4](https://huggingface.co/meetkai/functionary-small-v2.4) and [meetkai/functionary-medium-v2.4](https://huggingface.co/meetkai/functionary-medium-v2.4)! The first functionary models with code-interpreter ability (by passing in `{type: "code_interpreter"}` in tools)!

</details>
Expand Down Expand Up @@ -114,17 +113,6 @@ curl -X POST http://localhost:8000/v1/unload_lora_adapter \
```


### Grammar Sampling (Only in vLLM)

We also offer our own function-calling grammar sampling feature which constrains the LLM's generation to always follow the prompt template, and ensures 100% accuracy for function name. The parameters are generated using the efficient [lm-format-enforcer](https://github.com/noamgat/lm-format-enforcer), which ensures that the parameters follow the schema of the tool called. To enable grammar sampling, run the vLLM server with the command-line argument <code>--enable-grammar-sampling</code>:

```shell
python3 server_vllm.py --model "meetkai/functionary-medium-v3.1" --max-model-len 8192 --tensor-parallel-size 2 --enable-grammar-sampling
```

**Note:** Grammar Sampling support is applicable only for the V2, V3.0, V3.2 models. There is no such support for V1 and V3.1 models.


### Text-Generation-Inference (TGI)

We also provide a service that performs inference on Functionary models using [Text-Generation-Inference](https://huggingface.co/docs/text-generation-inference/en/index) (TGI). Follow these steps to get started:
Expand Down Expand Up @@ -711,11 +699,6 @@ Evaluation function call prediction in SGD dataset. The accuracy metric measures

See training [README](functionary/train/README.md)

## Safety & Security

While its not strictly enforced, to ensure more *secure* function execution, one can enable grammar sampling to enforce type checking.
Main safety checks needs to be done in the functions/actions themselves. Such as validation of the given input, or the ouput that will be given to the model.

## Roadmap

- [ ] OpenAPI specification based plugin support.
Expand All @@ -724,7 +707,6 @@ Main safety checks needs to be done in the functions/actions themselves. Such as
- [X] [text-generation-inference](https://github.com/huggingface/text-generation-inference)
- [X] Streaming Support
- [X] function_call parameter to server
- [X] Grammar Sampling to ensure 100% accuracy for function and parameter names
- [X] Parallel function calling support
- [X] Python function calling support (Automatic detection of type annotations and calling them automatically)
- [X] Real world usage examples, such as creating agents.
Expand Down
33 changes: 1 addition & 32 deletions functionary/inference.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
from typing import Dict, List, Optional, Union

import torch
from lmformatenforcer import CharacterLevelParser, JsonSchemaParser
from lmformatenforcer.integrations.vllm import build_vllm_logits_processor
from transformers import (
LlamaForCausalLM,
LlamaTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
_cached_build_vllm_token_enforcer_tokenizer_data,
_normalize_json_schema_object,
)
from vllm.sampling_params import LogitsProcessor

from functionary.inference_utils import StopWordsCriteria
from functionary.openai_types import ChatMessage, Function, FunctionCall, Tool
from functionary.prompt_template import get_prompt_template_from_tokenizer
from functionary.prompt_template.prompt_utils import prepare_messages_for_inference
from functionary.inference_utils import StopWordsCriteria


def tokenize(message: ChatMessage, tokenizer: LlamaTokenizer, device="cuda:0"):
Expand Down Expand Up @@ -100,30 +93,6 @@ def generate_message(
return ChatMessage(**result)


async def get_lm_format_enforcer_vllm_logits_processor_from_tool_name(
tool_name, tools_or_functions, tokenizer
) -> LogitsProcessor:
"""
Given a tool_name and list of tool definitions, find the json schema
of the tool with tool_name name and get the necessary vLLM logits processor
for the given tool schema."""

tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(tokenizer)
character_level_parser: CharacterLevelParser

# Get the tool schema
for tool_or_function in tools_or_functions:
if tool_or_function["name"] == tool_name:
raw_tool_schema = tool_or_function["parameters"]
break
schema = _normalize_json_schema_object(raw_tool_schema)
character_level_parser = JsonSchemaParser(schema)
logits_processor = build_vllm_logits_processor(
tokenizer_data, character_level_parser
)
return logits_processor


if __name__ == "__main__":
# First lets create an example messages list with all different types of roles and content.
functions = [
Expand Down
34 changes: 7 additions & 27 deletions functionary/vllm_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ async def process_chat_completion(
served_model: List[str],
served_loras: List[LoRARequest],
engine_model_config: Any,
enable_grammar_sampling: bool,
engine: Any,
):
error_check_ret = await check_all_errors(request, served_model, served_loras)
Expand Down Expand Up @@ -216,14 +215,6 @@ async def process_chat_completion(
tok_ids = tokenizer.encode(stop_tok, add_special_tokens=False)
stop_token_ids.append(tok_ids[-1])

# In vLLM==0.4.1, SamplingParams.logprobs has a proportional effect on latency
# We need to limit the size of SamplingParams.logprobs as a temporary fix first
# while investigating this problem in vLLM
if enable_grammar_sampling is False:
logprobs = None
else:
logprobs = 200

try:
sampling_params = SamplingParams(
n=request.n,
Expand All @@ -238,28 +229,17 @@ async def process_chat_completion(
top_k=request.top_k,
ignore_eos=request.ignore_eos,
skip_special_tokens=False,
logprobs=logprobs,
logprobs=None,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

if enable_grammar_sampling:
result_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
lora_request=lora_request,
sampling_params=sampling_params,
request_id=request_id,
tools_or_functions=tools_or_functions,
prompt_template_cls=prompt_template,
tool_choice=tool_func_choice,
)
else:
result_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
lora_request=lora_request,
sampling_params=sampling_params,
request_id=request_id,
)
result_generator = engine.generate(
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
lora_request=lora_request,
sampling_params=sampling_params,
request_id=request_id,
)

async def abort_request() -> None:
await engine.abort(request_id)
Expand Down
Loading