Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dist
*.swp

# Other
.gradio/certificate.pem
.DS_Store
wandb
output
Expand Down
296 changes: 64 additions & 232 deletions README.md

Large diffs are not rendered by default.

380 changes: 380 additions & 0 deletions README_original.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def load_compress_model(self, model_path, device, torch_dtype, revision="main"):
)

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")
return get_conv_template("chatgpt")


# A global registry for all model adapters
Expand Down
13 changes: 9 additions & 4 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_api_provider_stream_iter(
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
azure_api_version=model_api_dict.get("azure_api_version"),
)
elif model_api_dict["api_type"] == "openai_no_stream":
prompt = conv.to_openai_api_messages()
Expand All @@ -50,6 +51,7 @@ def get_api_provider_stream_iter(
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
stream=False,
azure_api_version=model_api_dict.get("azure_api_version"),
)
elif model_api_dict["api_type"] == "openai_o1":
prompt = conv.to_openai_api_messages()
Expand All @@ -61,6 +63,7 @@ def get_api_provider_stream_iter(
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
azure_api_version=model_api_dict.get("azure_api_version"),
is_o1=True,
)
elif model_api_dict["api_type"] == "openai_assistant":
Expand Down Expand Up @@ -275,18 +278,20 @@ def openai_api_stream_iter(
api_key=None,
stream=True,
is_o1=False,
azure_api_version=None,
):
import openai

api_key = api_key or os.environ["OPENAI_API_KEY"]

if "azure" in model_name:
if azure_api_version:
logger.info(f"Using Azure API version {azure_api_version}")
client = openai.AzureOpenAI(
api_version="2023-07-01-preview",
azure_endpoint=api_base or "https://api.openai.com/v1",
api_version=azure_api_version,
azure_endpoint=api_base,
api_key=api_key,
)
else:
logger.info(f"Using OpenAI API")
client = openai.OpenAI(
base_url=api_base or "https://api.openai.com/v1",
api_key=api_key,
Expand Down
5 changes: 4 additions & 1 deletion fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,10 @@ def share_click(state0, state1, model_selector0, model_selector1, request: gr.Re
)


SAMPLING_WEIGHTS = {}
SAMPLING_WEIGHTS = {
"Checklist-GPT-4-0125-Preview": 1,
"Checklist-GPT-o1": 1,
}

# target model sampling weights will be boosted.
BATTLE_TARGETS = {}
Expand Down
23 changes: 22 additions & 1 deletion fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
class State:
def __init__(self, model_name, is_vision=False):
self.conv = get_conversation_template(model_name)
logger.info(f"model_name: {model_name}")
logger.info(f"conv: {self.conv}")
self.conv_id = uuid.uuid4().hex
self.skip_next = False
self.model_name = model_name
Expand Down Expand Up @@ -271,6 +273,19 @@ def load_demo_single(context: Context, query_params):
return [state, dropdown_update]


# def load_demo(url_params, request: gr.Request):
# global models

# ip = get_ip(request)
# logger.info(f"load_demo. ip: {ip}. params: {url_params}")

# if args.model_list_mode == "reload":
# models, all_models = get_model_list(
# controller_url, args.register_api_endpoint_file, vision_arena=False
# )

# return load_demo_single(models, url_params)

def load_demo(url_params, request: gr.Request):
global models

Expand All @@ -281,8 +296,13 @@ def load_demo(url_params, request: gr.Request):
models, all_models = get_model_list(
controller_url, args.register_api_endpoint_file, vision_arena=False
)

# Create a Context object with the models
context = Context()
context.text_models = models
context.models = all_models if 'all_models' in locals() else models

return load_demo_single(models, url_params)
return load_demo_single(context, url_params)


def vote_last_response(state, vote_type, model_selector, request: gr.Request):
Expand Down Expand Up @@ -538,6 +558,7 @@ def bot_response(
)
extra_body = recommended_config.get("extra_body", None)


stream_iter = get_api_provider_stream_iter(
conv,
model_name,
Expand Down
Loading