-
Notifications
You must be signed in to change notification settings - Fork 5
Adds RAG ADM Component and creates tagging pipeline to test it #270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,3 +183,144 @@ def run(self, | |
| break | ||
|
|
||
| return top_choice, top_choice_justification, positive_dialog | ||
|
|
||
|
|
||
| class PromptBasedRAGAlignedADMComponent(PromptBasedAlignedADMComponent): | ||
| """ | ||
| Variant of PromptBasedAlignedADMComponent that accepts rag_context from the | ||
| pipeline's working_output and passes it to system_prompt_template and | ||
| scenario_description_template calls. Templates that don't declare rag_context | ||
| will silently ignore it via call_with_coerced_args. | ||
| """ | ||
|
|
||
| def run(self, | ||
| scenario_state, | ||
| choices, | ||
| alignment_target, | ||
| positive_icl_dialog_elements=[], | ||
| negative_icl_dialog_elements=[], | ||
| rag_context=None): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same story here I think (making |
||
| kdma_values = alignment_target.kdma_values | ||
| if len(kdma_values) != 1: | ||
| raise RuntimeError("This ADM assumes a single KDMA target, aborting!") | ||
| kdma_value = kdma_values[0] | ||
| if isinstance(kdma_value, KDMAValue): | ||
| kdma_value = kdma_value.to_dict() | ||
|
|
||
| kdma = kdma_value['kdma'] | ||
| value = kdma_value['value'] | ||
| negative_value = 1 - value | ||
|
|
||
| scenario_description = call_with_coerced_args( | ||
| self.scenario_description_template, | ||
| {'scenario_state': scenario_state, | ||
| 'rag_context': rag_context}) | ||
|
|
||
| prompt = call_with_coerced_args( | ||
| self.prompt_template, | ||
| {'scenario_state': scenario_state, | ||
| 'scenario_description': scenario_description, | ||
| 'choices': choices, | ||
| 'rag_context': rag_context}) | ||
|
|
||
| positive_dialog = [] | ||
| if self.system_prompt_template is not None: | ||
| positive_system_prompt = call_with_coerced_args( | ||
| self.system_prompt_template, | ||
| {'target_kdma': kdma, | ||
| 'target_value': value, | ||
| 'rag_context': rag_context}) | ||
|
|
||
| positive_dialog.insert( | ||
| 0, DialogElement(role='system', | ||
| content=positive_system_prompt)) | ||
|
|
||
| if len(positive_icl_dialog_elements) > 0: | ||
| positive_dialog.extend(positive_icl_dialog_elements) | ||
|
|
||
| positive_dialog.append( | ||
| DialogElement(role='user', content=prompt)) | ||
|
|
||
| positive_dialog_prompt = self.structured_inference_engine.dialog_to_prompt( | ||
| positive_dialog) | ||
|
|
||
| log.info("[bold]*POSITIVE RAG DIALOG PROMPT*[/bold]", | ||
| extra={"markup": True}) | ||
| log.info(positive_dialog_prompt) | ||
|
|
||
| if self.num_negative_samples > 0: | ||
| negative_dialog = [] | ||
| if self.system_prompt_template is not None: | ||
| negative_system_prompt = call_with_coerced_args( | ||
| self.system_prompt_template, | ||
| {'target_kdma': kdma, | ||
| 'target_value': negative_value, | ||
| 'rag_context': rag_context}) | ||
|
|
||
| negative_dialog.insert( | ||
| 0, DialogElement(role='system', | ||
| content=negative_system_prompt)) | ||
|
|
||
| if len(negative_icl_dialog_elements) > 0: | ||
| negative_dialog.extend(negative_icl_dialog_elements) | ||
|
|
||
| negative_dialog.append( | ||
| DialogElement(role='user', content=prompt)) | ||
|
|
||
| negative_dialog_prompt = self.structured_inference_engine.dialog_to_prompt( | ||
| negative_dialog) | ||
|
|
||
| log.info("[bold]*NEGATIVE RAG DIALOG PROMPT*[/bold]", | ||
| extra={"markup": True}) | ||
| log.info(negative_dialog_prompt) | ||
|
|
||
| output_schema = call_with_coerced_args( | ||
| self.output_schema_template, | ||
| {'choices': choices}) | ||
|
|
||
| positive_responses = self.structured_inference_engine.run_inference( | ||
| [positive_dialog_prompt] * self.num_positive_samples, output_schema) | ||
| positive_choices = [r['action_choice'] for r in positive_responses] | ||
| for i, positive_response in enumerate(positive_responses): | ||
| log.info("[bold]*POSITIVE RAG RESPONSE ({}, sample #{})*[/bold]".format( | ||
| kdma, i), extra={"markup": True}) | ||
| log.info(positive_response, extra={"highlighter": JSON_HIGHLIGHTER}) | ||
|
|
||
| if self.num_negative_samples > 0: | ||
| negative_responses = self.structured_inference_engine.run_inference( | ||
| [negative_dialog_prompt] * self.num_negative_samples, output_schema) | ||
| negative_choices = [r['action_choice'] for r in negative_responses] | ||
| for i, negative_response in enumerate(negative_responses): | ||
| log.info("[bold]*NEGATIVE RAG RESPONSE ({}, sample #{})*[/bold]".format( | ||
| kdma, i), extra={"markup": True}) | ||
| log.info(negative_response, extra={"highlighter": JSON_HIGHLIGHTER}) | ||
| else: | ||
| negative_choices = None | ||
|
|
||
| votes = self.vote_calculator_fn( | ||
| choices, positive_choices, negative_choices) | ||
|
|
||
| log.explain("[bold]*VOTES*[/bold]", | ||
| extra={"markup": True}) | ||
| log.explain(votes, extra={"highlighter": JSON_HIGHLIGHTER}) | ||
|
|
||
| if self.filter_votes_to_positives: | ||
| filtered_votes = filter_votes_to_responses(votes, positive_choices) | ||
| if filtered_votes != votes: | ||
| log.explain("Filtering votes down to choices where we " | ||
| "have a positive response") | ||
| log.explain(filtered_votes, | ||
| extra={"highlighter": JSON_HIGHLIGHTER}) | ||
| final_votes = filtered_votes | ||
| else: | ||
| final_votes = votes | ||
|
|
||
| top_choice, top_choice_score = max(final_votes.items(), key=lambda x: x[1]) | ||
|
|
||
| top_choice_justification = "" | ||
| for response in positive_responses: | ||
| if response['action_choice'] == top_choice: | ||
| top_choice_justification = response['detailed_reasoning'] | ||
| break | ||
|
|
||
| return top_choice, top_choice_justification, positive_dialog | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| ''' | ||
| Uses the LangChain framework to enable prompt-injected RAG with hydra-configurable | ||
| inference LLM model selection. The knowledge base vector store is dynamically | ||
| generated from the provided document files. | ||
|
|
||
| Prompt-Injection RAG ADM Components | ||
| ''' | ||
| from typing import Iterable, Union, List, Dict | ||
| from os import PathLike | ||
|
|
||
| from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
| from langchain_community.vectorstores import FAISS | ||
| from langchain_huggingface import HuggingFaceEmbeddings | ||
| from langchain_core.documents import Document | ||
|
|
||
| from align_system.algorithms.abstracts import ADMComponent | ||
|
|
||
| DocumentFileType = Union[str, bytes, PathLike] | ||
| DocumentFileListType = Iterable[DocumentFileType] | ||
|
|
||
|
|
||
| class LangChainRAGIndexerADMComponent(ADMComponent): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think having RAG implemented as an |
||
| def __init__(self, | ||
| docs_files: DocumentFileListType, | ||
| embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | ||
| chunk_size: int = 1000, | ||
| chunk_overlap: int = 200, | ||
| add_start_index: bool = True, | ||
| k: int = 5): | ||
| self.docs_files = docs_files | ||
| self.embedding_model_name = embedding_model_name | ||
| self.k = k | ||
|
|
||
| docs = LangChainRAGIndexerADMComponent._load_docs(docs_files) | ||
|
|
||
| text_splitter = RecursiveCharacterTextSplitter( | ||
| chunk_size=chunk_size, | ||
| chunk_overlap=chunk_overlap, | ||
| add_start_index=add_start_index, | ||
| ) | ||
| all_splits = text_splitter.split_documents(docs) | ||
|
|
||
| embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) | ||
| vector_store = FAISS.from_documents(all_splits, embeddings) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing this |
||
|
|
||
| self.retriever = vector_store.as_retriever(search_kwargs={"k": k}) | ||
|
|
||
| def run_returns(self): | ||
| return "rag_context" | ||
|
|
||
| def run(self, scenario_state) -> str: | ||
| query = scenario_state.unstructured | ||
| docs = self.retriever.invoke(query) | ||
| passages = [f"[Passage {i}]\n{doc.page_content.strip()}" | ||
| for i, doc in enumerate(docs, start=1)] | ||
| return "\n\n".join(passages) | ||
|
|
||
| def retrieve(self, query: str) -> Dict: | ||
| docs = self.retriever.invoke(query) | ||
| return { | ||
| "question": query, | ||
| "context": [ | ||
| {"content": d.page_content, "metadata": d.metadata} | ||
| for d in docs | ||
| ] | ||
| } | ||
|
|
||
| @staticmethod | ||
| def _load_docs(docs_files: DocumentFileListType) -> List[Document]: | ||
| docs = [] | ||
| for d in docs_files: | ||
| with open(d, 'r') as f: | ||
| text = f.read() | ||
| docs.append(Document(page_content=text, metadata={"source": str(d)})) | ||
| return docs | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| name: tagging_fewshot_aligned_rag | ||
|
|
||
| defaults: | ||
| # Import defaults into this namspace (adm) as @name, for further | ||
| # customization | ||
|
|
||
| # Shared variables / components | ||
| - /attribute@start: start | ||
| - /attribute@salt: salt | ||
| - /attribute@bcd: bcd | ||
| - /inference_engine@structured_inference_engine: outlines_structured_multinomial | ||
| - /template/scenario_description@scenario_description_template: tagging | ||
| - /template/prompt@prompt_template: tagging | ||
| # ADM components to be used in "steps" | ||
| - /adm_component/misc@step_definitions.rag_indexer: rag_indexer | ||
| - /adm_component/icl@step_definitions.icl: tagging | ||
| - /adm_component/misc@step_definitions.format_choices: itm_format_choices | ||
| - /adm_component/direct@step_definitions.tagging_rag_aligned: tagging_rag_aligned | ||
| - /adm_component/misc@step_definitions.ensure_chosen_action: ensure_chosen_action | ||
| - /adm_component/misc@step_definitions.populate_choice_info: populate_choice_info | ||
| # Use definitions in this file to override defaults defined above | ||
| - _self_ | ||
|
|
||
| attribute_definitions: | ||
| START: ${adm.start} | ||
| SALT: ${adm.salt} | ||
| BCD_SIEVE: ${adm.bcd} | ||
|
|
||
| step_definitions: | ||
| rag_indexer: | ||
| docs_files: | ||
| - /data/users/yonatan.gefen/align-system/align_system/documents/start.md | ||
| - /data/users/yonatan.gefen/align-system/align_system/documents/start_triage_flowchart.md | ||
| - /data/users/yonatan.gefen/align-system/align_system/documents/Salt.md | ||
|
Comment on lines
+32
to
+34
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to check with Emily and Aaron here, but I don't know that we want all protocol documents available all the time (it may depend on the target, e.g. for a Also we have a convention for storing / using files in the repo for this kind of thing it seems (or at least use a |
||
| icl: | ||
| scenario_description_template: ${ref:adm.scenario_description_template} | ||
| attributes: ${adm.attribute_definitions} | ||
| prompt_template: ${ref:adm.prompt_template} | ||
| icl_generator_partial: | ||
| scenario_description_template: ${ref:adm.scenario_description_template} | ||
|
|
||
| instance: | ||
| _target_: align_system.algorithms.pipeline_adm.PipelineADM | ||
|
|
||
| steps: | ||
| # Reference the step instances we want to use in order | ||
| - ${ref:adm.step_definitions.format_choices} | ||
| - ${ref:adm.step_definitions.rag_indexer} | ||
| - ${ref:adm.step_definitions.icl} | ||
| - ${ref:adm.step_definitions.tagging_rag_aligned} | ||
| - ${ref:adm.step_definitions.ensure_chosen_action} | ||
| - ${ref:adm.step_definitions.populate_choice_info} | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, so is this really the only difference with the base class? If so wondering if it just makes more sense to add it to the original class 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm. Thinking it's better to be a distinct subclass but let's make
rag_contexta required argument to therunfunction (if you're not using RAG just use the base class?)