|
1 | 1 | try: |
2 | | - from ..llm_factory import OpenAILLMs |
| 2 | + from ..llm_factory import OpenAILLMs, GoogleAILLMs |
3 | 3 | from .base_prompts import \ |
4 | 4 | role_prompt, conv_pref_prompt, update_conv_pref_prompt, summary_prompt, update_summary_prompt, summary_system_prompt |
5 | 5 | from ..utils.types import InvokeAgentResponseType |
6 | 6 | except ImportError: |
7 | | - from src.agents.llm_factory import OpenAILLMs |
| 7 | + from src.agents.llm_factory import OpenAILLMs, GoogleAILLMs |
8 | 8 | from src.agents.base_agent.base_prompts import \ |
9 | 9 | role_prompt, conv_pref_prompt, update_conv_pref_prompt, summary_prompt, update_summary_prompt, summary_system_prompt |
10 | 10 | from src.agents.utils.types import InvokeAgentResponseType |
@@ -35,9 +35,9 @@ class State(TypedDict): |
35 | 35 |
|
36 | 36 | class BaseAgent: |
37 | 37 | def __init__(self): |
38 | | - llm = OpenAILLMs() |
| 38 | + llm = OpenAILLMs() # OpenAILLMs() or GoogleAILLMs() |
39 | 39 | self.llm = llm.get_llm() |
40 | | - summarisation_llm = OpenAILLMs() |
| 40 | + summarisation_llm = OpenAILLMs() # OpenAILLMs() or GoogleAILLMs() |
41 | 41 | self.summarisation_llm = summarisation_llm.get_llm() |
42 | 42 | self.summary = "" |
43 | 43 | self.conversationalStyle = "" |
@@ -120,12 +120,12 @@ def summarize_conversation(self, state: State, config: RunnableConfig) -> dict: |
120 | 120 | conversationalStyle_message = self.conversation_preference_prompt |
121 | 121 |
|
122 | 122 | # STEP 1: Summarize the conversation |
123 | | - messages = state["messages"][:-1] + [SystemMessage(content=summary_message)] |
| 123 | + messages = state["messages"][:-1] + [HumanMessage(content=summary_message)] |
124 | 124 | valid_messages = self.check_for_valid_messages(messages) |
125 | 125 | summary_response = self.summarisation_llm.invoke(valid_messages) |
126 | 126 |
|
127 | 127 | # STEP 2: Analyze the conversational style |
128 | | - messages = state["messages"][:-1] + [SystemMessage(content=conversationalStyle_message)] |
| 128 | + messages = state["messages"][:-1] + [HumanMessage(content=conversationalStyle_message)] |
129 | 129 | valid_messages = self.check_for_valid_messages(messages) |
130 | 130 | conversationalStyle_response = self.summarisation_llm.invoke(valid_messages) |
131 | 131 |
|
|
0 commit comments