-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmaplm_wrong_list_search.py
More file actions
136 lines (111 loc) · 4.77 KB
/
maplm_wrong_list_search.py
File metadata and controls
136 lines (111 loc) · 4.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
from tqdm.auto import tqdm
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
import os
import json
def find_wrong_answers_list(json_path: str,
dataset_type: str) -> list[dict]:
with open(os.path.join(json_path, dataset_type), "r") as j_file:
json_data = json.load(j_file)
wrong_answers_lst = []
for id, data in json_data.items():
for _, options in data["QA"].items():
for _, option in options.items():
if "answer" in option and isinstance(option["answer"], list):
if len(option["answer"]) > 1:
wrong_answer = {
id: {
"image_path": os.path.join(id, "photo_forward.jpg"),
"question": option["question"],
"option": option["option"],
"answer": option["answer"],
"tag": option["tag"]
}
}
wrong_answers_lst.append(wrong_answer)
print(f"Number of wrong answers: {len(wrong_answers_lst)}")
print(f"Dataset length: {len(json_data)}")
return wrong_answers_lst
def load_model(model_name: str):
processor = AutoProcessor.from_pretrained(model_name, use_fast=True)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="cuda",
).eval()
return model, processor
def generate_answer(model, processor, message) -> str:
text = processor.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
image_inputs, _ = process_vision_info(message)
inputs = processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
).to("cuda")
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
parsed = json.loads(output_text[0])
return parsed["correct_answer"]
def generate_correct_answers(json_path: str,
dataset_type: str,
model_name: str,
system_prompt: str,
img_path: str,
correct_output_name: str) -> None:
wrong_entries = find_wrong_answers_list(json_path, dataset_type)
model, processor = load_model(model_name)
for entry in tqdm(wrong_entries):
for id, data in entry.items():
message = [
{
"role": "system",
"content": [
{
"type": "text",
"text": system_prompt,
}
]
},
{
"role": "user",
"content": [
{"type": "image", "image": os.path.join(img_path, data["image_path"])},
{"type": "text", "text": f"Question: {data["question"]}\nPossible Answers: {data["option"]}"}
]
}
]
correct_answer = generate_answer(model, processor, message)
data["correct_answer"] = correct_answer
out_path = os.path.join(json_path, correct_output_name)
with open(out_path, "w") as f:
json.dump(wrong_entries, f, indent=4)
if __name__ == "__main__":
base_path = "./datasets/maplm_v2"
src_file = "train_v2.json"
correct_out = "correct_answers.json"
model_name = "Qwen/Qwen2.5-VL-32B-Instruct-AWQ"
system_prompt = """
You are a visual reasoning assistant. You will be provided with an image, a question about the image, and a list of possible answers. Analyze the image carefully and select the most appropriate answer based solely on what you see. Focus on visual cues such as movement, direction, objects, environment, and any relevant details in the scene.
You must return your answer in the following JSON format:
{
"correct_answer": "your_selected_answer_from_the_list"
}
"""
img_path = "/home/cis-g1/Documents/datasets/maplm_v2/data/images"
generate_correct_answers(
json_path=base_path,
dataset_type=src_file,
correct_output_name=correct_out,
model_name=model_name,
system_prompt=system_prompt,
img_path=img_path
)