Skip to content

Commit 5dc7795

Browse files
GFJHogueheliamoh
andauthored
Safety assessment, Language detection (#101), new HybridRetriever class (#102)
* feat: expand preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection * feat: Add new runnables for checking question safety and scope, query expansion and conversation history management * feat:improved hybrid retrieval - Replace SelfQueryRetriever with efficient hybrid search (BM25 + vector) - Add RRF (Reciprocal Rank Fusion) support for query expansion - Implement parallel processing for improved performance * feat: Add new runnables for checking question safety and scope, query expansion and conversation history management * code quality check fixes * fix: Resolve mypy linter errors - Add type annotation for rrf_scores in retrieval_utils.py - Fix metadata dictionary comprehension in csv_chroma.py - Update retriever type annotations to use Any - Add isinstance check for BM25Retriever - Remove default values from TypedDict in base.py - Fix TypedDict expansion in postprocess method * remove: Remove reactome_kg directory from repository * feat: expand preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection * feat: expand preprocessing to a multi-step workflow. - Implement parallel execution of safety and scope check, query expansion, and language detection * feat:improved hybrid retrieval - Replace SelfQueryRetriever with efficient hybrid search (BM25 + vector) - Add RRF (Reciprocal Rank Fusion) support for query expansion - Implement parallel processing for improved performance * feat:improved answer generation, in-line citation handling and hallucination mitigation * remove irrelevant docs * [WIP] clean up changes * [WIP] clean up changes (2) * revert retrieval changes * macos-intel actions runner * fix SafetyCheck type usage * stream unsafe response to user * black spacing * cross-db use detected_language from base state * pre-release docker push * new HybridRetriever class (#102) * rewrite HybridRetriever class * fix types * fix HybridRetriever class inheritance issues * fix lint * multithread, as in Helia's code * fix typing for #102 --------- Co-authored-by: Helia Mohammadi <helia.mohammadi01@gmail.com>
1 parent e398a37 commit 5dc7795

9 files changed

Lines changed: 376 additions & 103 deletions

File tree

.github/workflows/ci.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99
push:
1010
branches:
1111
- main
12+
- pre-release
1213

1314
permissions:
1415
id-token: write
@@ -35,7 +36,7 @@ jobs:
3536
runs-on: ${{ matrix.os }}
3637
strategy:
3738
matrix:
38-
os: [ubuntu-latest, macos-13]
39+
os: [ubuntu-latest, macos-15-intel]
3940

4041
steps:
4142
- uses: actions/checkout@v4
@@ -81,7 +82,7 @@ jobs:
8182
path: /tmp/image.tar
8283

8384
docker-push:
84-
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
85+
if: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'ref/heads/pre-release') }}
8586
needs: docker-build
8687
runs-on: ubuntu-latest
8788

src/agent/profiles/base.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Annotated, TypedDict
1+
from typing import Annotated, Literal, TypedDict
22

33
from langchain_core.embeddings import Embeddings
44
from langchain_core.language_models.chat_models import BaseChatModel
55
from langchain_core.messages import BaseMessage
66
from langchain_core.runnables import Runnable, RunnableConfig
77
from langgraph.graph.message import add_messages
88

9+
from agent.tasks.detect_language import create_language_detector
910
from agent.tasks.rephrase import create_rephrase_chain
11+
from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
1012
from tools.external_search.state import SearchState, WebSearchResult
1113
from tools.external_search.workflow import create_search_workflow
1214

@@ -28,6 +30,11 @@ class BaseState(InputState, OutputState, total=False):
2830
rephrased_input: str # LLM-generated query from user input
2931
chat_history: Annotated[list[BaseMessage], add_messages]
3032

33+
# Preprocessing results
34+
safety: str # "true" or "false" from safety check
35+
reason_unsafe: str # Reason if unsafe
36+
detected_language: str # Detected language
37+
3138

3239
class BaseGraphBuilder:
3340
# NOTE: Anything that is common to all graph builders goes here
@@ -38,21 +45,40 @@ def __init__(
3845
embedding: Embeddings,
3946
) -> None:
4047
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
48+
self.safety_checker: Runnable = create_safety_checker(llm)
49+
self.language_detector: Runnable = create_language_detector(llm)
4150
self.search_workflow: Runnable = create_search_workflow(llm)
4251

4352
async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
4453
rephrased_input: str = await self.rephrase_chain.ainvoke(
4554
{
4655
"user_input": state["user_input"],
47-
"chat_history": state["chat_history"],
56+
"chat_history": state.get("chat_history", []),
4857
},
4958
config,
5059
)
51-
return BaseState(rephrased_input=rephrased_input)
60+
safety_check: SafetyCheck = await self.safety_checker.ainvoke(
61+
{"rephrased_input": rephrased_input}, config
62+
)
63+
detected_language: str = await self.language_detector.ainvoke(
64+
{"user_input": state["user_input"]}, config
65+
)
66+
return BaseState(
67+
rephrased_input=rephrased_input,
68+
safety=safety_check.safety,
69+
reason_unsafe=safety_check.reason_unsafe,
70+
detected_language=detected_language,
71+
)
72+
73+
def proceed_with_research(self, state: BaseState) -> Literal["Continue", "Finish"]:
74+
return "Continue" if state["safety"] == "true" else "Finish"
5275

5376
async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
5477
search_results: list[WebSearchResult] = []
55-
if config["configurable"]["enable_postprocess"]:
78+
if (
79+
config["configurable"].get("enable_postprocess")
80+
and state["safety"] == "true"
81+
):
5682
result: SearchState = await self.search_workflow.ainvoke(
5783
SearchState(
5884
input=state["rephrased_input"],

src/agent/profiles/cross_database.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,11 @@
1515
create_uniprot_rewriter_w_reactome
1616
from agent.tasks.cross_database.summarize_reactome_uniprot import \
1717
create_reactome_uniprot_summarizer
18-
from agent.tasks.detect_language import create_language_detector
19-
from agent.tasks.safety_checker import SafetyCheck, create_safety_checker
2018
from retrievers.reactome.rag import create_reactome_rag
2119
from retrievers.uniprot.rag import create_uniprot_rag
2220

2321

2422
class CrossDatabaseState(BaseState):
25-
safety: str # LLM-assessed safety level of the user input
26-
query_language: str # language of the user input
27-
2823
reactome_query: str # LLM-generated query for Reactome
2924
reactome_answer: str # LLM-generated answer from Reactome
3025
reactome_completeness: str # LLM-assessed completeness of the Reactome answer
@@ -46,21 +41,18 @@ def __init__(
4641
self.reactome_rag: Runnable = create_reactome_rag(llm, embedding)
4742
self.uniprot_rag: Runnable = create_uniprot_rag(llm, embedding)
4843

49-
self.safety_checker = create_safety_checker(llm)
5044
self.completeness_checker = create_completeness_grader(llm)
51-
self.detect_language = create_language_detector(llm)
5245
self.write_reactome_query = create_reactome_rewriter_w_uniprot(llm)
5346
self.write_uniprot_query = create_uniprot_rewriter_w_reactome(llm)
5447
self.summarize_final_answer = create_reactome_uniprot_summarizer(
55-
llm.model_copy(update={"streaming": True})
48+
llm, streaming=True
5649
)
5750

5851
# Create graph
5952
state_graph = StateGraph(CrossDatabaseState)
6053
# Set up nodes
6154
state_graph.add_node("check_question_safety", self.check_question_safety)
6255
state_graph.add_node("preprocess_question", self.preprocess)
63-
state_graph.add_node("identify_query_language", self.identify_query_language)
6456
state_graph.add_node("conduct_research", self.conduct_research)
6557
state_graph.add_node("generate_reactome_answer", self.generate_reactome_answer)
6658
state_graph.add_node("rewrite_reactome_query", self.rewrite_reactome_query)
@@ -74,7 +66,6 @@ def __init__(
7466
state_graph.add_node("postprocess", self.postprocess)
7567
# Set up edges
7668
state_graph.set_entry_point("preprocess_question")
77-
state_graph.add_edge("preprocess_question", "identify_query_language")
7869
state_graph.add_edge("preprocess_question", "check_question_safety")
7970
state_graph.add_conditional_edges(
8071
"check_question_safety",
@@ -104,39 +95,18 @@ def __init__(
10495

10596
self.uncompiled_graph: StateGraph = state_graph
10697

107-
async def check_question_safety(
98+
def check_question_safety(
10899
self, state: CrossDatabaseState, config: RunnableConfig
109100
) -> CrossDatabaseState:
110-
result: SafetyCheck = await self.safety_checker.ainvoke(
111-
{"input": state["rephrased_input"]},
112-
config,
113-
)
114-
if result.binary_score == "No":
101+
if state["safety"] != "true":
115102
inappropriate_input = f"This is the user's question and it is NOT appropriate for you to answer: {state["user_input"]}. \n\n explain that you are unable to answer the question but you can answer questions about topics related to the Reactome Pathway Knowledgebase or UniProt Knowledgebas."
116103
return CrossDatabaseState(
117-
safety=result.binary_score,
118104
user_input=inappropriate_input,
119105
reactome_answer="",
120106
uniprot_answer="",
121107
)
122108
else:
123-
return CrossDatabaseState(safety=result.binary_score)
124-
125-
async def proceed_with_research(
126-
self, state: CrossDatabaseState
127-
) -> Literal["Continue", "Finish"]:
128-
if state["safety"] == "Yes":
129-
return "Continue"
130-
else:
131-
return "Finish"
132-
133-
async def identify_query_language(
134-
self, state: CrossDatabaseState, config: RunnableConfig
135-
) -> CrossDatabaseState:
136-
query_language: str = await self.detect_language.ainvoke(
137-
{"user_input": state["user_input"]}, config
138-
)
139-
return CrossDatabaseState(query_language=query_language)
109+
return CrossDatabaseState()
140110

141111
async def conduct_research(
142112
self, state: CrossDatabaseState, config: RunnableConfig
@@ -256,7 +226,7 @@ async def generate_final_response(
256226
final_response: str = await self.summarize_final_answer.ainvoke(
257227
{
258228
"input": state["rephrased_input"],
259-
"query_language": state["query_language"],
229+
"detected_language": state["detected_language"],
260230
"reactome_answer": state["reactome_answer"],
261231
"uniprot_answer": state["uniprot_answer"],
262232
},

src/agent/profiles/react_to_me.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from langgraph.graph.state import StateGraph
88

99
from agent.profiles.base import BaseGraphBuilder, BaseState
10+
from agent.tasks.unsafe_question import create_unsafe_answer_generator
1011
from retrievers.reactome.rag import create_reactome_rag
1112

1213

@@ -23,6 +24,9 @@ def __init__(
2324
super().__init__(llm, embedding)
2425

2526
# Create runnables (tasks & tools)
27+
self.unsafe_answer_generator: Runnable = create_unsafe_answer_generator(
28+
llm, streaming=True
29+
)
2630
self.reactome_rag: Runnable = create_reactome_rag(
2731
llm, embedding, streaming=True
2832
)
@@ -32,15 +36,40 @@ def __init__(
3236
# Set up nodes
3337
state_graph.add_node("preprocess", self.preprocess)
3438
state_graph.add_node("model", self.call_model)
39+
state_graph.add_node("generate_unsafe_response", self.generate_unsafe_response)
3540
state_graph.add_node("postprocess", self.postprocess)
3641
# Set up edges
3742
state_graph.set_entry_point("preprocess")
38-
state_graph.add_edge("preprocess", "model")
43+
state_graph.add_conditional_edges(
44+
"preprocess",
45+
self.proceed_with_research,
46+
{"Continue": "model", "Finish": "generate_unsafe_response"},
47+
)
3948
state_graph.add_edge("model", "postprocess")
49+
state_graph.add_edge("generate_unsafe_response", "postprocess")
4050
state_graph.set_finish_point("postprocess")
4151

4252
self.uncompiled_graph: StateGraph = state_graph
4353

54+
async def generate_unsafe_response(
55+
self, state: ReactToMeState, config: RunnableConfig
56+
) -> ReactToMeState:
57+
answer: str = await self.unsafe_answer_generator.ainvoke(
58+
{
59+
"language": state["detected_language"],
60+
"user_input": state["rephrased_input"],
61+
"reason_unsafe": state["reason_unsafe"],
62+
},
63+
config,
64+
)
65+
return ReactToMeState(
66+
chat_history=[
67+
HumanMessage(state["user_input"]),
68+
AIMessage(answer),
69+
],
70+
answer=answer,
71+
)
72+
4473
async def call_model(
4574
self, state: ReactToMeState, config: RunnableConfig
4675
) -> ReactToMeState:

src/agent/tasks/cross_database/summarize_reactome_uniprot.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@
3636
("system", summarization_message),
3737
(
3838
"human",
39-
"User question: {input} \n\n Language: {query_language} \n\n Reactome-drived information: \n {reactome_answer} \n\n UniProt-drived infromation: \n {uniprot_answer}.",
39+
"User question: {input} \n\n Language: {detected_language} \n\n Reactome-drived information: \n {reactome_answer} \n\n UniProt-drived infromation: \n {uniprot_answer}.",
4040
),
4141
]
4242
)
4343

4444

45-
def create_reactome_uniprot_summarizer(llm: BaseChatModel) -> Runnable:
45+
def create_reactome_uniprot_summarizer(
46+
llm: BaseChatModel, streaming: bool = False
47+
) -> Runnable:
48+
if streaming:
49+
llm = llm.model_copy(update={"streaming": True})
4650
return (summarizer_prompt | llm | StrOutputParser()).with_config(
4751
run_name="summarize_answer"
4852
)

src/agent/tasks/safety_checker.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,62 @@
44
from pydantic import BaseModel, Field
55

66
safety_check_message = """
7-
You are an expert grader with a strong understanding of ethical guidelines and responsible AI usage.
7+
You are an expert scientific assistant. You have advanced training in scientific ethics, dual-use assessment, and responsible AI.
88
9-
Your task is to determine whether a user-submitted question is appropriate and not harmful.
10-
Provide a binary output as either:
11-
Yes: The question is appropriate and not harmful.
12-
No: The question is harmful, offensive, inappropriate, or unethical.
9+
Tasks:
10+
11+
1. Safety & Appropriateness Check
12+
- Assess the question and determine if it is safe and appropriate for responsible AI use.
13+
- The question and any plausible answer to it must NOT be harmful, offensive, unethical, or violate scientific or ethical guidelines.
14+
- The question or its likely answers must NOT request, describe, or explore how to carry out harmful, illegal, dangerous, or dual-use activities (such as modifying pathogens, increasing infectivity, or other forms of misuse)—even if the scenario is hypothetical, fictional, or indirect.
15+
- The question must NOT request personal, medical, legal, or other types of advice outside the scope of scientific knowledge.
16+
- If a question is ambiguous or could plausibly be used for harm, treat it as unsafe.
17+
- Treat hypothetical, fictional, or made-up scenarios with the same level of scrutiny as real-world questions.
18+
19+
2. Reactome Relevance Check
20+
- Determine if the question is relevant to biology, life sciences, molecular biology, or related topics.
21+
- Mark questions as not relevant if they are about unrelated topics (such as programming, math, history, trivia, etc.).
22+
23+
IMPORTANT:
24+
- If the standalone question is unsafe or not relevant return "safety": "false".
25+
- If the standalone question is both safe **and** relevant, return "safety": "true".
26+
27+
Return only a JSON object in the following format:
28+
"safety": "true" or "false", // Use string, not boolean. Mark as "false" if unsafe OR not relevant.
29+
"reason_unsafe": "..." // If 'safety' is false, briefly state the reason
30+
31+
Examples:
32+
1. Q: Describe methods for bypassing cellular immune responses to viral infection
33+
"safety": "false",
34+
"reason_unsafe": "Could be used for harmful or unethical manipulation of immune evasion."
35+
36+
2. Q: What are common side effects of statins in the treatment of high cholesterol?
37+
"safety": "false",
38+
"reason_unsafe": "Requests medical advice, which is out of scientific scope."
39+
40+
3. Q: How can gene drives be used to eliminate invasive species in an ecosystem?
41+
"safety": "false",
42+
"reason_unsafe": "Enables the use of gene drives, which have significant ecological and dual-use risks."
43+
44+
4. Q: What is the role of the immune system in the treatment of cancer?
45+
"safety": "true",
46+
"reason_unsafe": ""
1347
"""
1448

1549
safety_check_prompt = ChatPromptTemplate.from_messages(
1650
[
1751
("system", safety_check_message),
18-
("human", "User question: \n\n {input}"),
52+
("human", "User question: \n\n {rephrased_input}"),
1953
]
2054
)
2155

2256

2357
class SafetyCheck(BaseModel):
24-
binary_score: str = Field(
25-
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'Yes' or 'No'."
58+
safety: str = Field(
59+
description="Indicates whether the question is appropriate and related to molecular biology. Expected values: 'true' or 'false'."
60+
)
61+
reason_unsafe: str = Field(
62+
description="If 'safety' is false, briefly state the reason; if 'safety' is true, leave this field empty."
2663
)
2764

2865

src/agent/tasks/unsafe_question.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from langchain_core.language_models.chat_models import BaseChatModel
2+
from langchain_core.output_parsers import StrOutputParser
3+
from langchain_core.prompts import ChatPromptTemplate
4+
from langchain_core.runnables import Runnable
5+
6+
safety_check_message = """
7+
You are an expert scientific assistant operating under the React-to-Me platform. React-to-Me helps both experts and non-experts explore molecular biology using trusted data from the Reactome database.
8+
9+
You have advanced training in scientific ethics, dual-use research concerns, and responsible AI use.
10+
11+
You will receive three inputs:
12+
1. The user's question.
13+
2. A system-generated variable called `reason_unsafe`, which explains why the question cannot be answered.
14+
3. The user's preferred language (as a language code or name).
15+
16+
Your task is to clearly, respectfully, and firmly explain to the user *why* their question cannot be answered, based solely on the `reason_unsafe` input. Do **not** attempt to answer, rephrase, or guide the user toward answering the original question.
17+
18+
You must:
19+
- Respond in the user’s preferred language.
20+
- Politely explain the refusal, grounded in the `reason_unsafe`.
21+
- Emphasize React-to-Me’s mission: to support responsible exploration of molecular biology through trusted databases.
22+
- Suggest examples of appropriate topics (e.g., protein function, pathways, gene interactions using Reactome/UniProt).
23+
24+
You must not provide any workaround, implicit answer, or redirection toward unsafe content.
25+
"""
26+
27+
safety_check_prompt = ChatPromptTemplate.from_messages(
28+
[
29+
("system", safety_check_message),
30+
(
31+
"user",
32+
"Language:{language}\n\nQuestion:{user_input}\n\n Reason for unsafe or out of scope: {reason_unsafe}",
33+
),
34+
]
35+
)
36+
37+
38+
def create_unsafe_answer_generator(
39+
llm: BaseChatModel, streaming: bool = False
40+
) -> Runnable:
41+
if streaming:
42+
llm = llm.model_copy(update={"streaming": True})
43+
return (safety_check_prompt | llm | StrOutputParser()).with_config(
44+
run_name="unsafe_answer_generator"
45+
)

0 commit comments

Comments
 (0)