Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastchat/serve/monitor/classify/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ To test your new classifier for a new category, you would have to make sure you
python label.py --config config.yaml --testing
```

If you are labeling a vision category, add the `--vision` flag to the command. This will add a new column to the input data called `image_path` that contains the path to the image corresponding to each conversation. Ensure that you update your config with the correct `image_dir` where the images are stored.

Then, add your new category bench to `tag_names` in `display_score.py`. After making sure that you also have a correctly formatted ground truth json file, you can report the performance of your classifier by running
```console
python display_score.py --bench <your_bench>
Expand Down
411 changes: 407 additions & 4 deletions fastchat/serve/monitor/classify/category.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions fastchat/serve/monitor/classify/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ task_name:

model_name: null
name: llama-3-70b-instruct
api_type: openai
endpoints:
- api_base: null
api_key: null
parallel: 50
temperature: 0.0
max_token: 512

image_dir: null # directory where vision arena images are stored

max_retry: 2
retry_sleep: 10
error_output: $ERROR$
152 changes: 141 additions & 11 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,95 @@ def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=No
return output


def chat_completion_anthropic(model, messages, temperature, max_tokens, api_dict=None):
import anthropic

if api_dict:
api_key = api_dict["api_key"]
else:
api_key = os.environ["ANTHROPIC_API_KEY"]

sys_msg = ""
if messages[0]["role"] == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]

output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
c = anthropic.Anthropic(api_key=api_key)
response = c.messages.create(
model=model,
messages=messages,
stop_sequences=[anthropic.HUMAN_PROMPT],
max_tokens=max_tokens,
temperature=temperature,
system=sys_msg,
)
output = response.content[0].text
break
except anthropic.APIError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output


def chat_completion_gemini(
model, messages, temperature, max_tokens, api_dict=None, image_path=None
):
import google
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from PIL import Image

if api_dict:
api_key = api_dict["api_key"]
genai.configure(api_key=api_key)
else:
genai.configure(api_key=os.environ["GENAI_API_KEY"])

sys_msg = ""
if messages[0]["role"] == "system":
sys_msg = messages[0]["content"]
messages = messages[1:]

prompt = messages[0]["content"]
if type(prompt) == list:
prompt = [prompt[0]["text"], Image.open(image_path).convert("RGB")]

safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
gemini = genai.GenerativeModel(model, system_instruction=sys_msg)
gemini.max_output_tokens = max_tokens
gemini.temperature = temperature
response = gemini.generate_content(prompt, safety_settings=safety_settings)
if response.candidates[0].finish_reason != 1:
print(
f"Gemini did not finish generating content: {response.candidates[0].finish_reason}"
)
output = "Gemini did not finish generating content"
else:
output = response.text
break
except google.api_core.exceptions.ResourceExhausted as e:
# THIS IS A TEMPORARY FIX
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except Exception as e:
# THIS IS A TEMPORARY FIX
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
return output


def get_answer(
question: dict,
model_name: str,
Expand All @@ -98,6 +187,7 @@ def get_answer(
api_dict: dict,
categories: list,
testing: bool,
api_type: str,
):
if "category_tag" in question:
category_tag = question["category_tag"]
Expand All @@ -107,14 +197,34 @@ def get_answer(
output_log = {}

for category in categories:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
conv = category.pre_process(question)
if api_type == "openai":
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
elif api_type == "anthropic":
output = chat_completion_anthropic(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
elif api_type == "gemini":
output = chat_completion_gemini(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
image_path=question.get("image_path"),
)
else:
raise ValueError(f"api_type {api_type} not supported")
# Dump answers
category_tag[category.name_tag] = category.post_process(output)

Expand Down Expand Up @@ -169,6 +279,7 @@ def find_required_tasks(row):
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--testing", action="store_true")
parser.add_argument("--vision", action="store_true")
args = parser.parse_args()

enter = input(
Expand Down Expand Up @@ -199,6 +310,15 @@ def find_required_tasks(row):
assert len(input_data) == len(input_data.uid.unique())
print(f"{len(input_data)}# of input data just loaded")

if args.vision:
old_len = len(input_data)
input_data["image_hash"] = input_data.conversation_a.map(
lambda convo: convo[0]["content"][1][0]
)
input_data["image_path"] = input_data.image_hash.map(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do we get the image_hash here?

Copy link
Collaborator Author

@lisadunlap lisadunlap Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the conversation, is the format still [{content: [text, [images]]}?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the way i have it now doesn't support multi-image

lambda x: f"{config['image_dir']}/{x}.png"
)

if config["cache_file"]:
print("loading cache data")
with open(config["cache_file"], "rb") as f:
Expand Down Expand Up @@ -246,9 +366,18 @@ def find_required_tasks(row):
f"{name}: {len(not_labeled[not_labeled.required_tasks.map(lambda tasks: name in tasks)])}"
)

not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join([convo[i]["content"] for i in range(0, len(convo), 2)])
)
if args.vision:
not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join(
[convo[i]["content"][0] for i in range(0, len(convo), 2)]
)
)
else:
not_labeled["prompt"] = not_labeled.conversation_a.map(
lambda convo: "\n".join(
[convo[i]["content"] for i in range(0, len(convo), 2)]
)
)
not_labeled["prompt"] = not_labeled.prompt.map(lambda x: x[:12500])

with concurrent.futures.ThreadPoolExecutor(
Expand All @@ -270,6 +399,7 @@ def find_required_tasks(row):
if category.name_tag in row["required_tasks"]
],
args.testing,
config["api_type"],
)
futures.append(future)
for future in tqdm.tqdm(
Expand Down
34 changes: 34 additions & 0 deletions fastchat/serve/monitor/classify/vision_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Yaml config file for category classification

input_file: null # json
cache_file: null # json
output_file: null # json line

convert_to_json: True

task_name:
- captioning_v0.1
- homework_v0.1
- ocr_v0.1
- humor_v0.1
- entity_recognition_v0.1
- creative_writing_vision_v0.1
- diagram_v0.1


model_name: null
name: gemini-1.5-flash
api_type: gemini
endpoints:
- api_base: null
api_key: null

parallel: 50
temperature: 0.0
max_token: 512

image_dir: null # directory where vision arena images are stored

max_retry: 2
retry_sleep: 10
error_output: $ERROR$
Loading