-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathquery.py
More file actions
61 lines (50 loc) · 2.2 KB
/
query.py
File metadata and controls
61 lines (50 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from langchain_community.chat_models import ChatOllama
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.retrievers.multi_query import MultiQueryRetriever
from get_vector_db import get_vector_db
LLM_MODEL = os.getenv('LLM_MODEL', 'mistral')
# Function to get the prompt templates for generating alternative questions and answering based on context
def get_prompt():
QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is to generate five
different versions of the given user question to retrieve relevant documents from
a vector database. By generating multiple perspectives on the user question, your
goal is to help the user overcome some of the limitations of the distance-based
similarity search. Provide these alternative questions separated by newlines.
Original question: {question}""",
)
template = """Answer the question based ONLY on the following context:
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
return QUERY_PROMPT, prompt
# Main function to handle the query process
def query(input):
if input:
# Initialize the language model with the specified model name
llm = ChatOllama(model=LLM_MODEL)
# Get the vector database instance
db = get_vector_db()
# Get the prompt templates
QUERY_PROMPT, prompt = get_prompt()
# Set up the retriever to generate multiple queries using the language model and the query prompt
retriever = MultiQueryRetriever.from_llm(
db.as_retriever(),
llm,
prompt=QUERY_PROMPT
)
# Define the processing chain to retrieve context, generate the answer, and parse the output
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
response = chain.invoke(input)
return response
return None