Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import copy
import os
import traceback
from typing import Protocol, runtime_checkable

Expand Down Expand Up @@ -406,10 +407,40 @@ def get_merged_provider_config(self, provider_config: dict) -> dict:
pc = merged_config
return pc

def _resolve_env_key_list(self, provider_config: dict) -> dict:
keys = provider_config.get("key", [])
if not isinstance(keys, list):
return provider_config
resolved_keys = []
for idx, key in enumerate(keys):
if isinstance(key, str) and key.startswith("$"):
env_key = key[1:]
if env_key.startswith("{") and env_key.endswith("}"):
env_key = env_key[1:-1]
if env_key:
env_val = os.getenv(env_key)
if env_val is None:
provider_id = provider_config.get("id")
logger.warning(
f"Provider {provider_id} 配置项 key[{idx}] 使用环境变量 {env_key} 但未设置。",
)
resolved_keys.append("")
else:
resolved_keys.append(env_val)
else:
resolved_keys.append(key)
else:
resolved_keys.append(key)
provider_config["key"] = resolved_keys
return provider_config

async def load_provider(self, provider_config: dict):
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
provider_config = self.get_merged_provider_config(provider_config)

if provider_config.get("provider_type", "") == "chat_completion":
provider_config = self._resolve_env_key_list(provider_config)

if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return
Expand Down