-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrieve.py
More file actions
398 lines (307 loc) · 19.8 KB
/
retrieve.py
File metadata and controls
398 lines (307 loc) · 19.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
import os
import jieba # 用於中文文本分詞
import pdfplumber # 用於從PDF文件中提取文字的工具
from tqdm import tqdm
from rank_bm25 import BM25Okapi # 使用BM25演算法進行文件檢索
import io
import re
import fitz
import pytesseract
import numpy as np
from PIL import Image
from langchain_text_splitters import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer, util, CrossEncoder
def load_data(source_path: str) -> dict:
"載入參考資料,返回一個字典,key為檔案名稱,value為PDF檔內容的文本"
masked_file_ls = os.listdir(source_path) # 獲取資料夾中的檔案列表
corpus_dict = {int(file.replace('.pdf', '')): read_pdf(os.path.join(source_path, file)) for file in tqdm(masked_file_ls)} # 讀取每個PDF文件的文本,並以檔案名作為鍵,文本內容作為值存入字典
return corpus_dict
def read_pdf(pdf_loc: str, page_infos: list = None) -> str:
"讀取單個PDF文件並返回其文本內容"
# pdf = pdfplumber.open(pdf_loc) # 打開指定的PDF文件
# # TODO: 可自行用其他方法讀入資料,或是對PDF中多模態資料(表格、圖片等)進行處理
# # 如果指定了頁面範圍,則只提取該範圍的頁面,否則提取所有頁面
# pages = pdf.pages[page_infos[0]:page_infos[1]] if page_infos else pdf.pages
# pdf_text = ''
# for _, page in enumerate(pages): # 迴圈遍歷每一頁
# text = page.extract_text() # 提取頁面的文本內容
# if text: pdf_text += text
# tables = page.extract_table() # 提取頁面的表格內容
# if tables:
# for table in tables:
# df = pd.DataFrame(table[1:], columns=table[0]) # 第一行是表頭
# table_text = df_to_text(df)
# pdf_text += '\n' + table_text
# pdf.close()
pdf = fitz.open(pdf_loc)
pages = range(page_infos[0], page_infos[1]) if page_infos else range(len(pdf)) # Determine the range of pages to extract from
pdf_text = ''
for page_num in pages:
page = pdf[page_num]
text = page.get_text("text")
pdf_text += text
for block in page.get_text("dict")["blocks"]:
if block["type"] == 2: # This block is an image
for image in page.get_images(full=True):
xref = image[0]
base_image = pdf.extract_image(xref)
image_bytes = base_image["image"]
# Convert image bytes to a PIL image and apply OCR
img = Image.open(io.BytesIO(image_bytes))
ocr_text = pytesseract.image_to_string(img)
pdf_text += '\n' + ocr_text # Append OCR text to pdf_text
if block["type"] == 0: # This block is text
for line in block["lines"]:
for span in line["spans"]:
pdf_text += span["text"] + '\n'
pdf.close()
return pdf_text
"""
# 將DataFrame轉換為文本
def df_to_text(df: pd.DataFrame) -> str:
table_text = "Table Data:\n"
for _, row in df.iterrows():
row_text = ', '.join([f"{col}: {val}" for col, val in row.items()])
table_text += row_text + '\n'
return table_text
"""
def preprocess_text(text: str) -> str:
"Remove punctuation, extra whitespace, and other unwanted characters from the text"
text = re.sub(r"\n{3,}", r"\n", text) # Replaces any occurrence of three or more consecutive newlines (\n{3,}) with a single newline (\n)
text = re.sub('\s', " ", text) # Converts all whitespace characters (\s), such as tabs and newlines, to a single space
text = re.sub("\n\n", "", text) # Removes double newlines
text = re.sub(r'(\.{6})([^"’”」』])', r"\1\n\2", text) # Inserts a newline after an English ellipsis (......) when it’s not followed by a quotation mark (^"’”」』)
text = re.sub(r'(\…{2})([^"’”」』])', r"\1\n\2", text) # Inserts a newline after an Chinese ellipsis (……) when it’s not followed by a quotation mark (^"’”」』)
text = re.sub(r'<[^>]+>', '', text) # Removes any text within angle brackets (<...>), commonly used for HTML tags
text = re.sub(r'([;;.!?。!?\?])([^”’])', r"\1\n\2", text) # Adds a newline after single-character sentence-ending punctuation (;;.!?。!?\?) when it’s not followed by quotation marks (^”’)
text = re.sub(r'([;;!?。!?\?]["’”」』]{0,2})([^;;!?,。!?\?])', r'\1\n\2', text)
return text.rstrip() # Removes any trailing whitespace at the end of the text.
def segment_text(text: str, max_length: int=200) -> list[str]:
"""Splits text into smaller segments of max_length characters."""
segments = []
current_segment = ""
sentences = text.split('\n') # Split by newline for paragraph-based segments
for sentence in sentences:
if len(current_segment) + len(sentence) <= max_length:
current_segment += sentence + " "
else:
segments.append(current_segment.strip())
current_segment = sentence + " "
if current_segment:
segments.append(current_segment.strip()) # Add the last segment if it's not empty
return segments
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n", " ", ".", "。", "!", "?", "!", "?", ";", ";", "……", "(", ")", "(", ")", ":", ":", "、", ",", ",", "「", "」", "『", "』", "“", "”", "\"", "《", "》", "【", "】", "[", "]", ">", "<", "+", "-", "*", "="],
chunk_size=200,
chunk_overlap=100,
length_function=len
)
def retrieve_BM25(qs: str, source: list[int], corpus_dict: dict[int, str]) -> int:
"根據查詢語句和指定的來源,以 BM25演算法 檢索答案"
filtered_corpus = [preprocess_text(corpus_dict[int(file_id)]) for file_id in source]
# [TODO] 可自行替換其他檢索方式,以提升效能
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in filtered_corpus] # 將每篇文檔進行分詞
bm25 = BM25Okapi(tokenized_corpus) # 使用BM25演算法建立檢索模型
tokenized_query = list(jieba.cut_for_search(preprocess_text(qs))) # 將查詢語句進行分詞
ans = bm25.get_top_n(tokenized_query, list(filtered_corpus), n=1) # 根據查詢語句檢索,返回最相關的文檔,其中n為可調整項
best_ans = ans[0]
# 找回與最佳匹配文本相對應的檔案名
res = [key for key, value in corpus_dict.items() if preprocess_text(value).strip() == best_ans.strip()]
return res[0] # 回傳檔案名
def retrieve_BM25_segment(qs: str, source: list[int], corpus_dict: dict[int, str]) -> int:
"根據查詢語句和指定的來源,以 segmentation + BM25演算法 檢索答案"
segmented_corpus = [] # Segment each document in source and build a segmented corpus
segment_to_doc = [] # Keeps track of which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments)) # Track document ID for each segment
# Tokenize each segment for BM25
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in segmented_corpus] # 將每篇文檔進行分詞
bm25 = BM25Okapi(tokenized_corpus) # 使用BM25演算法建立檢索模型
tokenized_query = list(jieba.cut_for_search(preprocess_text(qs))) # 將查詢語句進行分詞
# Retrieve top segment
ans = bm25.get_top_n(tokenized_query, segmented_corpus, n=1)
if not ans: return None
best_segment = ans[0]
# Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None
# Load a pre-trained sentence embedding model
model_name = [
"distiluse-base-multilingual-cased-v1", # BERT Multilingual model [https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#multilingual-models]
"paraphrase-multilingual-MiniLM-L12-v2", # BERT Multilingual model [https://www.sbert.net/docs/sentence_transformer/pretrained_models.html#multilingual-models]
"maidalun1020/bce-embedding-base_v1", # BCEmbedding model [https://pypi.org/project/BCEmbedding/]
]
embedding_model = SentenceTransformer("maidalun1020/bce-embedding-base_v1")
def retrieve_BM25_Embedding(qs: str, source: list[int], corpus_dict: dict[int, str], top_n_bm25: int=10) -> int:
"根據查詢語句和指定的來源,以 segmentation + BM25演算法 + Embedding模型 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Tokenize segments and perform initial retrieval with BM25
processed_qs = preprocess_text(qs)
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in segmented_corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = list(jieba.cut_for_search(processed_qs))
initial_candidates = bm25.get_top_n(tokenized_query, segmented_corpus, n=top_n_bm25)
# Step 3: Generate embeddings for the query and each initial candidate segment
query_embedding = embedding_model.encode(processed_qs, convert_to_tensor=True, normalize_embeddings=True)
candidate_embeddings = [embedding_model.encode(doc, convert_to_tensor=True, normalize_embeddings=True) for doc in initial_candidates]
# Step 4: Calculate cosine similarity between query embedding and each candidate
similarities = [util.pytorch_cos_sim(query_embedding, candidate_embedding).item() for candidate_embedding in candidate_embeddings]
# Step 5: Select the best candidate based on similarity score
best_candidate_index = similarities.index(max(similarities))
best_segment = initial_candidates[best_candidate_index]
# Step 6: Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None # Return None if no match is found
def retrieve_onlyEmbedding(qs: str, source: list[int], corpus_dict: dict[int, str]) -> int:
"根據查詢語句和指定的來源,以 segmentation + Embedding模型 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Generate embeddings for the query and each candidate segment
processed_qs = preprocess_text(qs)
query_embedding = embedding_model.encode(processed_qs, convert_to_tensor=True, normalize_embeddings=True)
candidate_embeddings = [embedding_model.encode(doc, convert_to_tensor=True, normalize_embeddings=True) for doc in segmented_corpus]
# Step 3: Calculate cosine similarity between query embedding and each candidate
similarities = [util.pytorch_cos_sim(query_embedding, candidate_embedding).item() for candidate_embedding in candidate_embeddings]
# Step 4: Select the best candidate based on similarity score
best_candidate_index = similarities.index(max(similarities))
best_segment = segmented_corpus[best_candidate_index]
# Step 5: Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None # Return None if no match is found
reranker_model = CrossEncoder('maidalun1020/bce-reranker-base_v1', max_length=512)
def retrieve_BM25_Reranker(qs: str, source: list[int], corpus_dict: dict[int, str], top_n_bm25: int=10) -> int:
"根據查詢語句和指定的來源,以 segmentation + BM25演算法 + Reranker模型 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Tokenize segments and perform initial retrieval with BM25
processed_qs = preprocess_text(qs)
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in segmented_corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = list(jieba.cut_for_search(processed_qs))
initial_candidates = bm25.get_top_n(tokenized_query, segmented_corpus, n=top_n_bm25)
# Step 3: Prepare query-candidate pairs for re-ranking
query_candidate_pairs = [(processed_qs, candidate) for candidate in initial_candidates]
# Step 4: Use re-ranker model to score each query-candidate pair
scores = reranker_model.predict(query_candidate_pairs)
# Step 5: Select the candidate with the highest score
best_candidate_index = scores.argmax()
best_segment = initial_candidates[best_candidate_index]
# Step 6: Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None # Return None if no match is found
def retrieve_onlyReranker(qs: str, source: list[int], corpus_dict: dict[int, str]) -> int:
"根據查詢語句和指定的來源,以 segmentation + Reranker模型 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Prepare query-candidate pairs for re-ranking
processed_qs = preprocess_text(qs)
query_candidate_pairs = [(processed_qs, candidate) for candidate in segmented_corpus]
# Step 3: Use re-ranker model to score each query-candidate pair
scores = reranker_model.predict(query_candidate_pairs)
# Step 4: Select the candidate with the highest score
best_candidate_index = scores.argmax()
best_segment = segmented_corpus[best_candidate_index]
# Step 5: Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None # Return None if no match is found
def retrieve_Embedding_Reranker(qs: str, source: list[int], corpus_dict: dict[int, str], top_n: int=10) -> int:
"根據查詢語句和指定的來源,以 segmentation + Embedding模型 + Reranker模型 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text)
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Generate embeddings for the query and each candidate segment
processed_qs = preprocess_text(qs)
query_embedding = embedding_model.encode(processed_qs, convert_to_tensor=True, normalize_embeddings=True)
candidate_embeddings = [embedding_model.encode(doc, convert_to_tensor=True, normalize_embeddings=True) for doc in segmented_corpus]
# Step 3: Calculate cosine similarity between query embedding and each candidate
similarities = [util.pytorch_cos_sim(query_embedding, candidate_embedding).item() for candidate_embedding in candidate_embeddings]
# Step 4: Select the top N candidates based on cosine similarity
top_indices = np.argsort(similarities)[-top_n:] # Get indices of top N by similarity
top_candidates = [segmented_corpus[i] for i in top_indices]
top_candidate_docs = [segment_to_doc[i] for i in top_indices] # Track the original document ID of each top segment
# Step 5: Prepare query-candidate pairs for re-ranking with CrossEncoder
query_candidate_pairs = [(processed_qs, candidate) for candidate in top_candidates]
# Step 6: Use re-ranker model to score each query-candidate pair
scores = reranker_model.predict(query_candidate_pairs)
# Step 7: Select the candidate with the highest re-ranker score
best_candidate_index = np.argmax(scores)
best_document_id = top_candidate_docs[best_candidate_index]
return best_document_id # Return the document ID of the best matching segment
def retrieve_BM25_Embedding_Reranker(qs: str, source: list[int], corpus_dict: dict[int, str], top_n_bm25: int=10, top_n_embed: int=5) -> int:
"根據查詢語句和指定的來源,以 Segmentation + BM25 + Embedding 模型 + Reranker 檢索答案"
# Step 1: Segment each document in source and build a segmented corpus
segmented_corpus = []
segment_to_doc = [] # Track which document each segment belongs to
for file_id in source:
document_text = preprocess_text(corpus_dict[file_id])
segments = text_splitter.split_text(document_text) # Segment using text_splitter
segmented_corpus.extend(segments)
segment_to_doc.extend([file_id] * len(segments))
# Step 2: Perform initial retrieval with BM25
processed_qs = preprocess_text(qs)
tokenized_corpus = [list(jieba.cut_for_search(doc)) for doc in segmented_corpus]
bm25 = BM25Okapi(tokenized_corpus)
tokenized_query = list(jieba.cut_for_search(processed_qs))
initial_candidates = bm25.get_top_n(tokenized_query, segmented_corpus, n=top_n_bm25)
# Step 3: Embed query and BM25 candidates for semantic similarity filtering
query_embedding = embedding_model.encode(processed_qs, convert_to_tensor=True, normalize_embeddings=True)
candidate_embeddings = [embedding_model.encode(doc, convert_to_tensor=True, normalize_embeddings=True) for doc in initial_candidates]
# Calculate cosine similarities and select top_n_embed candidates
similarities = [util.pytorch_cos_sim(query_embedding, candidate_embedding).item() for candidate_embedding in candidate_embeddings]
top_embed_indices = np.argsort(similarities)[-top_n_embed:] # Get indices of top N by embedding similarity
top_embed_candidates = [initial_candidates[i] for i in top_embed_indices]
# Step 4: Prepare query-candidate pairs for final re-ranking with CrossEncoder
query_candidate_pairs = [(processed_qs, candidate) for candidate in top_embed_candidates]
# Step 5: Use re-ranker model to score each query-candidate pair
scores = reranker_model.predict(query_candidate_pairs)
# Step 6: Select the candidate with the highest re-ranker score
best_candidate_index = scores.argmax()
best_segment = top_embed_candidates[best_candidate_index]
# Step 7: Find corresponding document ID for the best matching segment
for i, segment in enumerate(segmented_corpus):
if segment.strip() == best_segment.strip():
return segment_to_doc[i]
return None # Return None if no match is found