Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,6 +1556,7 @@ class ChatProviderTemplate(TypedDict):
"enable": False,
"id": "whisper_selfhost",
"model": "tiny",
"whisper_device": "cpu",
},
"SenseVoice(Local)": {
"type": "sensevoice_stt_selfhost",
Expand Down Expand Up @@ -2555,6 +2556,12 @@ class ChatProviderTemplate(TypedDict):
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
},
"whisper_device": {
"description": "推理设备",
"type": "string",
"hint": "Whisper 推理设备。Apple Silicon 可选 mps;其他环境建议使用 cpu。若指定 mps 但当前环境不可用,将自动回退到 cpu。",
"options": ["cpu", "mps"],
},
"id": {
"description": "ID",
"type": "string",
Expand Down
9 changes: 6 additions & 3 deletions astrbot/core/knowledge_base/kb_db_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING

from sqlalchemy import delete, func, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlmodel import col, desc

from astrbot.core import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
Expand All @@ -15,6 +15,9 @@
)
from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path

if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB


class KBSQLiteDatabase:
def __init__(self, db_path: str | None = None) -> None:
Expand Down Expand Up @@ -296,7 +299,7 @@ async def get_documents_with_metadata_batch(

return metadata_map

async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
async def delete_document_by_id(self, doc_id: str, vec_db: "FaissVecDB") -> None:
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
Expand Down Expand Up @@ -324,7 +327,7 @@ async def get_media_by_id(self, media_id: str) -> KBMedia | None:
result = await session.execute(stmt)
return result.scalar_one_or_none()

async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
async def update_kb_stats(self, kb_id: str, vec_db: "FaissVecDB") -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()

Expand Down
9 changes: 7 additions & 2 deletions astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import time
import uuid
from pathlib import Path
from typing import TYPE_CHECKING

import aiofiles

from astrbot.core import logger
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.provider.provider import (
EmbeddingProvider,
Expand All @@ -27,6 +27,9 @@
from .parsers.util import select_parser
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT

if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB


class RateLimiter:
"""一个简单的速率限制器"""
Expand Down Expand Up @@ -160,7 +163,7 @@ async def get_rp(self) -> RerankProvider | None:
return None
return rp

async def _ensure_vec_db(self) -> FaissVecDB:
async def _ensure_vec_db(self) -> "FaissVecDB":
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")

Expand All @@ -173,6 +176,8 @@ async def _ensure_vec_db(self) -> FaissVecDB:
f"知识库 {self.kb.kb_name}({self.kb.kb_id}) 初始化重排序能力失败,将跳过重排序: {e}",
)

from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB

vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
Expand Down
17 changes: 11 additions & 6 deletions astrbot/core/knowledge_base/retrieval/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@

import time
from dataclasses import dataclass
from typing import TYPE_CHECKING

from astrbot import logger
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider

from ..kb_helper import KBHelper

if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB


@dataclass
class RetrievalResult:
Expand Down Expand Up @@ -170,18 +173,20 @@ async def retrieve(
first_rerank = None
for kb_id in kb_ids:
vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
rerank_provider = (
getattr(vec_db, "rerank_provider", None) if vec_db else None
)
if rerank_provider is None:
continue

rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
and rerank_pi == rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
first_rerank = rerank_provider
break
if first_rerank and retrieval_results:
try:
Expand Down
7 changes: 5 additions & 2 deletions astrbot/core/knowledge_base/retrieval/sparse_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import json
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

import jieba
from rank_bm25 import BM25Okapi

from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase

if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB


@dataclass
class SparseResult:
Expand Down Expand Up @@ -73,7 +76,7 @@ async def retrieve(
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): FaissVecDB is only imported under TYPE_CHECKING, so this runtime type annotation will fail with NameError.

Because FaissVecDB only exists under TYPE_CHECKING, this annotation must not be evaluated at runtime. Please either quote the type (e.g. "FaissVecDB" | None), enable from __future__ import annotations in this module, or change the import so FaissVecDB is available at runtime.

if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
Expand Down
24 changes: 21 additions & 3 deletions astrbot/core/provider/sources/whisper_selfhosted_source.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
import uuid
from functools import partial
from typing import cast

import whisper
Expand Down Expand Up @@ -28,17 +29,34 @@ def __init__(
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config["model"])
self.device = str(provider_config.get("whisper_device", "cpu")).strip().lower()
self.model = None
Comment on lines 29 to 33
Copy link

Copilot AI Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR description is scoped to deferring faiss imports during startup, but this change set also introduces a new Whisper config field (whisper_device) and related behavior/logging (plus UI metadata updates). Please update the PR description to include this additional feature, or split it into a separate PR to keep the change focused and easier to review/revert independently.

Copilot uses AI. Check for mistakes.

def _resolve_device(self) -> str:
if self.device == "mps":
import torch # torch is a dependency of openai-whisper

mps_backend = getattr(torch.backends, "mps", None)
if mps_backend and mps_backend.is_available():
return "mps"
logger.warning("Whisper 已配置为使用 MPS,但当前环境不可用,将回退到 CPU。")
return "cpu"
if self.device != "cpu":
logger.warning(
"Whisper 配置了未知 device=%s,将回退到 CPU。",
self.device,
)
return "cpu"

async def initialize(self) -> None:
loop = asyncio.get_running_loop()
device = self._resolve_device()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(
None,
whisper.load_model,
self.model_name,
partial(whisper.load_model, self.model_name, device=device),
)
logger.info("Whisper 模型加载完成。")
logger.info("Whisper 模型加载完成。device=%s", device)

async def _is_silk_file(self, file_path) -> bool:
silk_header = b"SILK"
Expand Down
5 changes: 4 additions & 1 deletion astrbot/dashboard/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import base64
import traceback
from io import BytesIO
from typing import TYPE_CHECKING

from astrbot.api import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager

if TYPE_CHECKING:
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB


async def generate_tsne_visualization(
query: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,10 @@
"description": "Notes for local Whisper deployment",
"hint": "Before enabling, install the openai-whisper library (NVIDIA users download ~2GB mainly for torch and cuda; CPU users download ~1GB), and install ffmpeg. Otherwise STT will not work."
},
"whisper_device": {
"description": "Inference device",
"hint": "Whisper inference device. Apple Silicon can use mps; other environments should use cpu. If mps is selected but unavailable, AstrBot will fall back to cpu."
},
"id": {
"description": "ID"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,10 @@
"description": "Заметки по локальному развертыванию Whisper",
"hint": "Перед включением установите openai-whisper и ffmpeg."
},
"whisper_device": {
"description": "Устройство инференса",
"hint": "Устройство для инференса Whisper. На Apple Silicon можно выбрать mps; в остальных средах рекомендуется cpu. Если выбран mps, но он недоступен, AstrBot автоматически переключится на cpu."
},
"id": {
"description": "ID провайдера"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,10 @@
"description": "本地部署 Whisper 模型须知",
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。"
},
"whisper_device": {
"description": "推理设备",
"hint": "Whisper 推理设备。Apple Silicon 可选 mps;其他环境建议使用 cpu。若指定 mps 但当前环境不可用,将自动回退到 cpu。"
},
"id": {
"description": "ID"
},
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_kb_manager_resilience.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ async def test_ensure_vec_db_clears_stale_init_error(
mock_vec_db.close = AsyncMock()

with patch(
"astrbot.core.knowledge_base.kb_helper.FaissVecDB",
"astrbot.core.db.vec_db.faiss_impl.vec_db.FaissVecDB",
return_value=mock_vec_db,
):
# Execute _ensure_vec_db
Expand Down