Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def load_embedder_config():
embedder_config = load_json_config("embedder.json")

# Process client classes
for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock"]:
for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "embedder_azure"]:
if key in embedder_config and "client_class" in embedder_config[key]:
class_name = embedder_config[key]["client_class"]
if class_name in CLIENT_CLASSES:
Expand All @@ -174,6 +174,8 @@ def get_embedder_config():
return configs.get("embedder_google", {})
elif embedder_type == 'ollama' and 'embedder_ollama' in configs:
return configs.get("embedder_ollama", {})
elif embedder_type == 'azure' and 'embedder_azure' in configs:
return configs.get("embedder_azure", {})
else:
return configs.get("embedder", {})

Expand Down Expand Up @@ -235,21 +237,41 @@ def is_bedrock_embedder():
client_class = embedder_config.get("client_class", "")
return client_class == "BedrockClient"

def is_azure_embedder():
"""
Check if the current embedder configuration uses AzureAIClient.

Returns:
bool: True if using AzureAIClient, False otherwise
"""
embedder_config = get_embedder_config()
if not embedder_config:
return False

model_client = embedder_config.get("model_client")
if model_client:
return model_client.__name__ == "AzureAIClient"

client_class = embedder_config.get("client_class", "")
return client_class == "AzureAIClient"
Comment on lines +240 to +256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new function is_azure_embedder duplicates the logic found in is_ollama_embedder, is_google_embedder, and is_bedrock_embedder. To improve maintainability and reduce code duplication, consider creating a private helper function that takes the client name as an argument. This would centralize the checking logic.

For example, you could have a helper:

def _is_embedder_of_type(client_name: str) -> bool:
    embedder_config = get_embedder_config()
    if not embedder_config:
        return False

    model_client = embedder_config.get("model_client")
    if model_client:
        return model_client.__name__ == client_name

    client_class = embedder_config.get("client_class", "")
    return client_class == client_name

Then this function could be simplified to return _is_embedder_of_type("AzureAIClient").


def get_embedder_type():
"""
Get the current embedder type based on configuration.

Returns:
str: 'bedrock', 'ollama', 'google', or 'openai' (default)
str: 'bedrock', 'ollama', 'google', 'azure', or 'openai' (default)
"""
if is_bedrock_embedder():
return 'bedrock'
elif is_ollama_embedder():
return 'ollama'
elif is_google_embedder():
return 'google'
else:
return 'openai'
embedder_checks = {
'bedrock': is_bedrock_embedder,
'ollama': is_ollama_embedder,
'google': is_google_embedder,
'azure': is_azure_embedder,
}
for embedder_type, check_func in embedder_checks.items():
if check_func():
return embedder_type
return 'openai'

# Load repository and file filters configuration
def load_repo_config():
Expand Down Expand Up @@ -341,7 +363,7 @@ def load_lang_config():

# Update embedder configuration
if embedder_config:
for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "retriever", "text_splitter"]:
for key in ["embedder", "embedder_ollama", "embedder_google", "embedder_bedrock", "embedder_azure", "retriever", "text_splitter"]:
if key in embedder_config:
configs[key] = embedder_config[key]

Expand Down
8 changes: 8 additions & 0 deletions api/config/embedder.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@
"dimensions": 256
}
},
"embedder_azure": {
"client_class": "AzureAIClient",
"batch_size": 100,
"model_kwargs": {
"model": "text-embedding-3-small",
"dimensions": 256
}
},
"retriever": {
"top_k": 20
},
Expand Down
22 changes: 9 additions & 13 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,17 @@ def patched_watch(*args, **kwargs):

import uvicorn

# Check for required environment variables
required_env_vars = ['GOOGLE_API_KEY', 'OPENAI_API_KEY']
missing_vars = [var for var in required_env_vars if not os.environ.get(var)]
if missing_vars:
logger.warning(f"Missing environment variables: {', '.join(missing_vars)}")
logger.warning("Some functionality may not work correctly without these variables.")

# Configure Google Generative AI
# Configure providers based on settings
from api.config import configs
import google.generativeai as genai
from api.config import GOOGLE_API_KEY

if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
else:
logger.warning("GOOGLE_API_KEY not configured")
# Only configure Google if it's being used as a provider
if configs.get("default_provider") == "google":
from api.config import GOOGLE_API_KEY
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
else:
logger.warning("GOOGLE_API_KEY not configured but Google is the default provider")

if __name__ == "__main__":
# Get port from environment variable or use default
Expand Down
4 changes: 4 additions & 0 deletions api/tools/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def get_embedder(is_local_ollama: bool = False, use_google_embedder: bool = Fals
embedder_config = configs["embedder_google"]
elif embedder_type == 'bedrock':
embedder_config = configs["embedder_bedrock"]
elif embedder_type == 'azure':
embedder_config = configs["embedder_azure"]
else: # default to openai
embedder_config = configs["embedder"]
elif is_local_ollama:
Expand All @@ -37,6 +39,8 @@ def get_embedder(is_local_ollama: bool = False, use_google_embedder: bool = Fals
embedder_config = configs["embedder_ollama"]
elif current_type == 'google':
embedder_config = configs["embedder_google"]
elif current_type == 'azure':
embedder_config = configs["embedder_azure"]
else:
embedder_config = configs["embedder"]

Expand Down
4 changes: 2 additions & 2 deletions api/websocket_wiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ async def handle_websocket_chat(websocket: WebSocket):
if hasattr(last_message, 'content') and last_message.content:
tokens = count_tokens(last_message.content, request.provider == "ollama")
logger.info(f"Request size: {tokens} tokens")
if tokens > 8000:
logger.warning(f"Request exceeds recommended token limit ({tokens} > 7500)")
if tokens > 9000:
logger.warning(f"Request exceeds recommended token limit ({tokens} > 9000)")
Comment on lines +82 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The token limit 9000 is used as a magic number in both the condition and the log message. To improve maintainability and avoid potential inconsistencies (like the one fixed in this change), consider defining this value as a constant at the module level (e.g., REQUEST_TOKEN_LIMIT = 9000) and referencing it in both places.

input_too_large = True

# Create a new RAG instance for this request
Expand Down