Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions align_system/algorithms/outlines_baseline_adm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,100 @@ def _generic_object_repr(obj):
num_samples={self.num_samples},
vote_calculator_fn={_generic_object_repr(self.vote_calculator_fn)},
)""", flags=re.MULTILINE).strip()


class OutlinesRAGBaselineADMComponent(OutlinesBaselineADMComponent):
"""
Variant of OutlinesBaselineADMComponent that accepts rag_context from the
pipeline's working_output and passes it to system_prompt_template and
prompt_template calls. Templates that don't declare a rag_context parameter
will silently ignore it via call_with_coerced_args.
"""

def run(self,
scenario_state,
choices,
rag_context=None):
if self.enable_caching:
scenario_state_copy = copy.deepcopy(scenario_state)
if hasattr(scenario_state, 'elapsed_time'):
scenario_state_copy.elapsed_time = 0

depends = '\n'.join((
self.cache_repr(),
repr(scenario_state_copy),
repr(choices),
repr(rag_context)))

cacher = ub.Cacher('outlines_rag_baseline_adm_component', depends, verbose=0)
log.debug(f'cacher.fpath={cacher.fpath}')

cached_output = cacher.tryload()
if cached_output is not None:
log.info("Cache hit for `outlines_rag_baseline_adm_component`"
" returning cached output")
return cached_output
else:
log.info("Cache miss for `outlines_rag_baseline_adm_component` ..")

scenario_description = call_with_coerced_args(
self.scenario_description_template,
{'scenario_state': scenario_state,
'rag_context': rag_context})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, so is this really the only difference with the base class? If so wondering if it just makes more sense to add it to the original class 🤔

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm. Thinking it's better to be a distinct subclass but let's make rag_context a required argument to the run function (if you're not using RAG just use the base class?)


dialog = []
if self.system_prompt is not None:
system_prompt = self.system_prompt
dialog.insert(0, DialogElement(role='system',
content=system_prompt,
tags=['regression']))
elif self.system_prompt_template is not None:
system_prompt = call_with_coerced_args(
self.system_prompt_template,
{'rag_context': rag_context})
dialog.insert(0, DialogElement(role='system',
content=system_prompt))

prompt = call_with_coerced_args(
self.prompt_template,
{'scenario_state': scenario_state,
'scenario_description': scenario_description,
'choices': choices,
'rag_context': rag_context})

dialog.append(DialogElement(role='user', content=prompt))

output_schema = call_with_coerced_args(
self.output_schema_template,
{'choices': choices})

dialog_prompt = self.structured_inference_engine.dialog_to_prompt(dialog)

log.info("[bold]*RAG TAGGING DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(dialog_prompt)

responses = self.structured_inference_engine.run_inference(
[dialog_prompt] * self.num_samples, output_schema)

votes = self.vote_calculator_fn(
choices, [r['action_choice'] for r in responses])

log.explain("[bold]*VOTES*[/bold]",
extra={"markup": True})
log.explain(votes, extra={"highlighter": JSON_HIGHLIGHTER})

top_choice, top_choice_score = max(votes.items(), key=lambda x: x[1])

top_choice_justification = ""
for response in responses:
if response['action_choice'] == top_choice:
top_choice_justification = response['detailed_reasoning']
break

outputs = (top_choice, top_choice_justification, dialog)

if self.enable_caching:
cacher.save(outputs)

return outputs
2 changes: 2 additions & 0 deletions align_system/algorithms/pipeline_adm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def choose_action(self,
per_step_timing_stats = []

for i, step in enumerate(self.steps):
if step is None:
continue
step_returns = step.run_returns()

start_time = timer()
Expand Down
141 changes: 141 additions & 0 deletions align_system/algorithms/prompt_based_aligned_adm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,144 @@ def run(self,
break

return top_choice, top_choice_justification, positive_dialog


class PromptBasedRAGAlignedADMComponent(PromptBasedAlignedADMComponent):
"""
Variant of PromptBasedAlignedADMComponent that accepts rag_context from the
pipeline's working_output and passes it to system_prompt_template and
scenario_description_template calls. Templates that don't declare rag_context
will silently ignore it via call_with_coerced_args.
"""

def run(self,
scenario_state,
choices,
alignment_target,
positive_icl_dialog_elements=[],
negative_icl_dialog_elements=[],
rag_context=None):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same story here I think (making rag_context non-optional)

kdma_values = alignment_target.kdma_values
if len(kdma_values) != 1:
raise RuntimeError("This ADM assumes a single KDMA target, aborting!")
kdma_value = kdma_values[0]
if isinstance(kdma_value, KDMAValue):
kdma_value = kdma_value.to_dict()

kdma = kdma_value['kdma']
value = kdma_value['value']
negative_value = 1 - value

scenario_description = call_with_coerced_args(
self.scenario_description_template,
{'scenario_state': scenario_state,
'rag_context': rag_context})

prompt = call_with_coerced_args(
self.prompt_template,
{'scenario_state': scenario_state,
'scenario_description': scenario_description,
'choices': choices,
'rag_context': rag_context})

positive_dialog = []
if self.system_prompt_template is not None:
positive_system_prompt = call_with_coerced_args(
self.system_prompt_template,
{'target_kdma': kdma,
'target_value': value,
'rag_context': rag_context})

positive_dialog.insert(
0, DialogElement(role='system',
content=positive_system_prompt))

if len(positive_icl_dialog_elements) > 0:
positive_dialog.extend(positive_icl_dialog_elements)

positive_dialog.append(
DialogElement(role='user', content=prompt))

positive_dialog_prompt = self.structured_inference_engine.dialog_to_prompt(
positive_dialog)

log.info("[bold]*POSITIVE RAG DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(positive_dialog_prompt)

if self.num_negative_samples > 0:
negative_dialog = []
if self.system_prompt_template is not None:
negative_system_prompt = call_with_coerced_args(
self.system_prompt_template,
{'target_kdma': kdma,
'target_value': negative_value,
'rag_context': rag_context})

negative_dialog.insert(
0, DialogElement(role='system',
content=negative_system_prompt))

if len(negative_icl_dialog_elements) > 0:
negative_dialog.extend(negative_icl_dialog_elements)

negative_dialog.append(
DialogElement(role='user', content=prompt))

negative_dialog_prompt = self.structured_inference_engine.dialog_to_prompt(
negative_dialog)

log.info("[bold]*NEGATIVE RAG DIALOG PROMPT*[/bold]",
extra={"markup": True})
log.info(negative_dialog_prompt)

output_schema = call_with_coerced_args(
self.output_schema_template,
{'choices': choices})

positive_responses = self.structured_inference_engine.run_inference(
[positive_dialog_prompt] * self.num_positive_samples, output_schema)
positive_choices = [r['action_choice'] for r in positive_responses]
for i, positive_response in enumerate(positive_responses):
log.info("[bold]*POSITIVE RAG RESPONSE ({}, sample #{})*[/bold]".format(
kdma, i), extra={"markup": True})
log.info(positive_response, extra={"highlighter": JSON_HIGHLIGHTER})

if self.num_negative_samples > 0:
negative_responses = self.structured_inference_engine.run_inference(
[negative_dialog_prompt] * self.num_negative_samples, output_schema)
negative_choices = [r['action_choice'] for r in negative_responses]
for i, negative_response in enumerate(negative_responses):
log.info("[bold]*NEGATIVE RAG RESPONSE ({}, sample #{})*[/bold]".format(
kdma, i), extra={"markup": True})
log.info(negative_response, extra={"highlighter": JSON_HIGHLIGHTER})
else:
negative_choices = None

votes = self.vote_calculator_fn(
choices, positive_choices, negative_choices)

log.explain("[bold]*VOTES*[/bold]",
extra={"markup": True})
log.explain(votes, extra={"highlighter": JSON_HIGHLIGHTER})

if self.filter_votes_to_positives:
filtered_votes = filter_votes_to_responses(votes, positive_choices)
if filtered_votes != votes:
log.explain("Filtering votes down to choices where we "
"have a positive response")
log.explain(filtered_votes,
extra={"highlighter": JSON_HIGHLIGHTER})
final_votes = filtered_votes
else:
final_votes = votes

top_choice, top_choice_score = max(final_votes.items(), key=lambda x: x[1])

top_choice_justification = ""
for response in positive_responses:
if response['action_choice'] == top_choice:
top_choice_justification = response['detailed_reasoning']
break

return top_choice, top_choice_justification, positive_dialog
75 changes: 75 additions & 0 deletions align_system/algorithms/rag_retreiver_adm_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
'''
Uses the LangChain framework to enable prompt-injected RAG with hydra-configurable
inference LLM model selection. The knowledge base vector store is dynamically
generated from the provided document files.

Prompt-Injection RAG ADM Components
'''
from typing import Iterable, Union, List, Dict
from os import PathLike

from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document

from align_system.algorithms.abstracts import ADMComponent

DocumentFileType = Union[str, bytes, PathLike]
DocumentFileListType = Iterable[DocumentFileType]


class LangChainRAGIndexerADMComponent(ADMComponent):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having RAG implemented as an ADMComponent is probably OK for now, but I was thinking of it more like how we have the StructuredInferenceEngine pieces (as in the same instance could potentially be re-used by different ADM components). This might come into play for our multi-attribute ADMs where we need to do some kind of ICL retrieval for relevance computation, and then again for provided in-context examples.

def __init__(self,
docs_files: DocumentFileListType,
embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
chunk_size: int = 1000,
chunk_overlap: int = 200,
add_start_index: bool = True,
k: int = 5):
self.docs_files = docs_files
self.embedding_model_name = embedding_model_name
self.k = k

docs = LangChainRAGIndexerADMComponent._load_docs(docs_files)

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=add_start_index,
)
all_splits = text_splitter.split_documents(docs)

embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
vector_store = FAISS.from_documents(all_splits, embeddings)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing this vector_store is just in memory here right? Or is it backed by a file / cache on disk? Not requesting any changes here just curious (I know the documents we're using now are small, but what happens if it's a massive collection)


self.retriever = vector_store.as_retriever(search_kwargs={"k": k})

def run_returns(self):
return "rag_context"

def run(self, scenario_state) -> str:
query = scenario_state.unstructured
docs = self.retriever.invoke(query)
passages = [f"[Passage {i}]\n{doc.page_content.strip()}"
for i, doc in enumerate(docs, start=1)]
return "\n\n".join(passages)

def retrieve(self, query: str) -> Dict:
docs = self.retriever.invoke(query)
return {
"question": query,
"context": [
{"content": d.page_content, "metadata": d.metadata}
for d in docs
]
}

@staticmethod
def _load_docs(docs_files: DocumentFileListType) -> List[Document]:
docs = []
for d in docs_files:
with open(d, 'r') as f:
text = f.read()
docs.append(Document(page_content=text, metadata={"source": str(d)}))
return docs
1 change: 1 addition & 0 deletions align_system/configs/action_based.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ defaults:
- interface: input_output_file
- adm: outlines_transformers_structured_baseline
- driver: itm_phase1
# - hydra/launcher: basic
- override hydra/job_logging: custom

loglevel: "EXPLAIN"
Expand Down
52 changes: 52 additions & 0 deletions align_system/configs/adm/tagging_fewshot_aligned_rag.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: tagging_fewshot_aligned_rag

defaults:
# Import defaults into this namspace (adm) as @name, for further
# customization

# Shared variables / components
- /attribute@start: start
- /attribute@salt: salt
- /attribute@bcd: bcd
- /inference_engine@structured_inference_engine: outlines_structured_multinomial
- /template/scenario_description@scenario_description_template: tagging
- /template/prompt@prompt_template: tagging
# ADM components to be used in "steps"
- /adm_component/misc@step_definitions.rag_indexer: rag_indexer
- /adm_component/icl@step_definitions.icl: tagging
- /adm_component/misc@step_definitions.format_choices: itm_format_choices
- /adm_component/direct@step_definitions.tagging_rag_aligned: tagging_rag_aligned
- /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action
- /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info
# Use definitions in this file to override defaults defined above
- _self_

attribute_definitions:
START: ${adm.start}
SALT: ${adm.salt}
BCD_SIEVE: ${adm.bcd}

step_definitions:
rag_indexer:
docs_files:
- /data/users/yonatan.gefen/align-system/align_system/documents/start.md
- /data/users/yonatan.gefen/align-system/align_system/documents/start_triage_flowchart.md
- /data/users/yonatan.gefen/align-system/align_system/documents/Salt.md
Comment on lines +32 to +34
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to check with Emily and Aaron here, but I don't know that we want all protocol documents available all the time (it may depend on the target, e.g. for a start target, we would use the two start documents, but not the Salt.md).

Also we have a convention for storing / using files in the repo for this kind of thing it seems (or at least use a /data/shared path rather than your own directory): https://github.com/ITM-Kitware/align-system/blob/main/align_system/configs/adm_component/icl/tagging.yaml#L17-L19

icl:
scenario_description_template: ${ref:adm.scenario_description_template}
attributes: ${adm.attribute_definitions}
prompt_template: ${ref:adm.prompt_template}
icl_generator_partial:
scenario_description_template: ${ref:adm.scenario_description_template}

instance:
_target_: align_system.algorithms.pipeline_adm.PipelineADM

steps:
# Reference the step instances we want to use in order
- ${ref:adm.step_definitions.format_choices}
- ${ref:adm.step_definitions.rag_indexer}
- ${ref:adm.step_definitions.icl}
- ${ref:adm.step_definitions.tagging_rag_aligned}
- ${ref:adm.step_definitions.ensure_chosen_action}
- ${ref:adm.step_definitions.populate_choice_info}
Loading