Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ LITELLM_LOG=FAKE_LOG_LEVEL
HASH_SALT="FAKE_HASH_SALT"
HASH_ALGO="FAKE_HASH_ALGO"
AUTH_TOKEN_EXPIRATION=9999
DATA_COLLECTION_HOST_PREFIX="fake_prefix"
1 change: 1 addition & 0 deletions k8s/welearn-api/values.dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ config:
nonSensitive:
PG_HOST: dev-lab-projects-backend.postgres.database.azure.com
TIKA_URL_BASE: https://tika.k8s.lp-i.dev/
DATA_COLLECTION_HOST_PREFIX: welearn
allowedHostsRegexes:
mainUrl: |-
https:\/\/welearn\.k8s\.lp-i\.dev
Expand Down
1 change: 1 addition & 0 deletions k8s/welearn-api/values.prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ config:
nonSensitive:
PG_HOST: prod-prod-projects-backend.postgres.database.azure.com
TIKA_URL_BASE: https://tika.k8s.lp-i.org/
DATA_COLLECTION_HOST_PREFIX: workshop
allowedHostsRegexes:
alphaUrls: |-
https://[a-zA-Z0-9-]*\.alpha-welearn\.lp-i\.org
Expand Down
52 changes: 14 additions & 38 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ langgraph-checkpoint-postgres = "^2.0.23"
azure-ai-inference = "^1.0.0b9"
azure-identity = "^1.25.0"
psycopg = {extras = ["binary"], version = "^3.2.10"}
welearn-database = "1.2.0"
welearn-database = "1.3.0"
bs4 = "^0.0.2"
urllib3 = "^2.6.3"
refinedoc = "^1.0.1"
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ env =
RUN_ENV=development
TIKA_URL_BASE=https://tika.example.com
USE_CACHED_SETTINGS=True
DATA_COLLECTION_HOST_PREFIX=workshop

filterwarnings =
ignore:.*U.*mode is deprecated:DeprecationWarning
30 changes: 25 additions & 5 deletions src/app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, Optional, cast
from uuid import UUID

import backoff
import psycopg
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from langchain_core.messages import ToolMessage
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
Expand All @@ -14,6 +15,7 @@
from src.app.models import chat as models
from src.app.services.abst_chat import get_chat_service
from src.app.services.constants import subjects as subjectsDict
from src.app.services.data_collection import get_data_collection_service
from src.app.services.exceptions import (
EmptyQueryError,
InvalidQuestionError,
Expand Down Expand Up @@ -225,7 +227,7 @@ async def q_and_a_rephrase_stream(
"/chat/answer",
summary="Chat Answer",
description="This endpoint is used to get the answer to the user's query based on the provided context and history",
response_model=str,
response_model=dict[str, str | UUID | None],
)
@backoff.on_exception(
wait_gen=backoff.expo,
Expand All @@ -237,8 +239,11 @@ async def q_and_a_rephrase_stream(
factor=2,
)
async def q_and_a_ans(
body: models.ContextOut = Depends(get_params), chatfactory=Depends(get_chat_service)
) -> Optional[str]:
request: Request,
body: models.ContextOut = Depends(get_params),
chatfactory=Depends(get_chat_service),
data_collection=Depends(get_data_collection_service),
) -> dict[str, str | UUID | None] | None:
"""_summary_

Args:
Expand All @@ -250,14 +255,29 @@ async def q_and_a_ans(
str: openai chat completion content
"""

session_id = request.headers.get("X-Session-ID")

try:
content = await chatfactory.chat_message(
query=body.query,
history=body.history,
docs=body.sources,
subject=subjectsDict.get(body.subject, None),
)
return cast(str, content)

conversation_id, message_id = await data_collection.register_chat_data(
session_id=session_id,
user_query=body.query,
conversation_id=body.conversation_id,
answer_content=content,
sources=body.sources,
)

return {
"message_id": message_id,
"answer": content,
"conversation_id": conversation_id,
}
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)

Expand Down
15 changes: 13 additions & 2 deletions src/app/api/api_v1/endpoints/metric.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from fastapi import APIRouter, Response, status
from fastapi import APIRouter, Depends, Request, Response, status
from pydantic import ValidationError
from starlette.concurrency import run_in_threadpool

from src.app.api.dependencies import get_settings
from src.app.models.metric import RowCorpusQtyDocInfo
from src.app.models.metric import DocumentClickUpdateResponse, RowCorpusQtyDocInfo
from src.app.services.data_collection import get_data_collection_service
from src.app.services.sql_db.queries import get_document_qty_table_info_sync
from src.app.utils.logger import logger as utils_logger

Expand Down Expand Up @@ -38,3 +39,13 @@ async def get_nb_docs_info_per_corpus(
if len(ret) == 0:
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return ret


@router.post("/clicked_document")
async def update_clicked_doc_from_chat_message(
body: DocumentClickUpdateResponse,
request: Request,
data_collection=Depends(get_data_collection_service),
) -> str:
await data_collection.register_document_click(body.doc_id, body.message_id)
return "updated"
2 changes: 1 addition & 1 deletion src/app/api/api_v1/endpoints/micro_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
convert_embedding_bytes,
)
from src.app.services.search import SearchService, get_search_service
from src.app.services.sql_service import (
from src.app.services.sql_db.queries import (
get_context_documents,
get_subject,
get_subjects,
Expand Down
9 changes: 4 additions & 5 deletions src/app/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Response
from fastapi.concurrency import run_in_threadpool
from qdrant_client.models import ScoredPoint

from src.app.models.documents import Document
from src.app.models.search import (
Expand Down Expand Up @@ -84,7 +83,7 @@ async def get_nb_docs() -> dict[str, int]:
"/collections/{collection}",
summary="search documents in a specific collection",
description="Search documents in a specific collection",
response_model=list[ScoredPoint] | str | None,
response_model=list[Document] | str | None,
)
async def search_doc_by_collection(
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -127,7 +126,7 @@ async def search_doc_by_collection(
"/by_slices",
summary="search all slices",
description="Search slices in all collections or in collections specified",
response_model=list[ScoredPoint] | None | str,
response_model=list[Document] | None | str,
)
async def search_all_slices_by_lang(
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -158,7 +157,7 @@ async def search_all_slices_by_lang(
"/multiple_by_slices",
summary="search all slices",
description="Search slices in all collections or in collections specified",
response_model=list[ScoredPoint] | None,
response_model=list[Document] | None,
)
async def multi_search_all_slices_by_lang(
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -186,7 +185,7 @@ async def multi_search_all_slices_by_lang(
"/by_document",
summary="search all documents",
description="Search by documents, returns only one result by document id",
response_model=list[ScoredPoint] | None | str,
response_model=list[Document] | None | str,
)
async def search_all(
background_tasks: BackgroundTasks,
Expand Down
1 change: 1 addition & 0 deletions src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Settings(BaseSettings):
"""

BACKEND_CORS_ORIGINS_REGEX: str = CLIENT_ORIGINS_REGEX
DATA_COLLECTION_HOST_PREFIX: str

def get_api_version(self, cls):
return {
Expand Down
2 changes: 1 addition & 1 deletion src/app/middleware/monitor_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi.concurrency import run_in_threadpool
from starlette.middleware.base import BaseHTTPMiddleware

from src.app.services.sql_service import register_endpoint
from src.app.services.sql_db.queries import register_endpoint
from src.app.utils.logger import logger as logger_utils

logger = logger_utils(__name__)
Expand Down
6 changes: 6 additions & 0 deletions src/app/models/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ContextOut(BaseModel):
history: list[dict] = []
query: str
subject: str | None = Field(None)
conversation_id: uuid.UUID | None = Field(None)


class Role(Enum):
Expand Down Expand Up @@ -71,6 +72,11 @@ class AgentResponse(BaseModel):
docs: list[ScoredPoint] | None = None


class UserQueryMetadata(BaseModel):
conversation_id: uuid.UUID
message_id: uuid.UUID


PROMPTS = Literal["STANDALONE", "NEW_QUESTIONS", "REPHRASE"]

RESPONSE_TYPE = Literal["json_object", "text"]
7 changes: 7 additions & 0 deletions src/app/models/metric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from pydantic import BaseModel


Expand All @@ -6,3 +8,8 @@ class RowCorpusQtyDocInfo(BaseModel):
url: str
qty_total: int
qty_in_qdrant: int


class DocumentClickUpdateResponse(BaseModel):
message_id: UUID
doc_id: UUID
120 changes: 120 additions & 0 deletions src/app/services/data_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import uuid
from datetime import datetime, timedelta
from typing import Any

from fastapi import HTTPException, Request, status
from fastapi.concurrency import run_in_threadpool

from src.app.api.dependencies import get_settings
from src.app.models.documents import Document
from src.app.services.sql_db.queries import (
get_current_data_collection_campaign,
update_returned_document_click,
write_chat_answer,
write_user_query,
)
from src.app.services.sql_db.queries_user import get_user_from_session_id
from src.app.utils.logger import logger as utils_logger

logger = utils_logger(__name__)

_cache: dict[str, Any] = {"is_campaign_active": None, "expires": None}

# get from setting the starts with string
settings = get_settings()


class DataCollection:
def __init__(self, host: str):
is_campaign_active = self.get_campaign_state()
host_settings = settings.DATA_COLLECTION_HOST_PREFIX
self.should_collect = host.startswith(host_settings) and is_campaign_active
logger.info(
"data_collection: host_settings=%s, is_campaign=%s, should_collect=%s",
host_settings,
is_campaign_active,
self.should_collect,
)

def get_campaign_state(
self,
):
"""Returns True if a campaign is active, False otherwise."""

now = datetime.now()
if _cache["expires"] and now < _cache["expires"]:
return _cache["is_campaign_active"] is not None

campaign = get_current_data_collection_campaign()

_cache["is_campaign_active"] = campaign and campaign.is_active
_cache["expires"] = now + timedelta(hours=6)

return _cache["is_campaign_active"]

async def register_chat_data(
self,
session_id: str | None,
user_query: str,
conversation_id: uuid.UUID | None,
answer_content: str,
sources: list[Document],
) -> tuple[uuid.UUID | None, uuid.UUID | None]:

if not self.should_collect:
logger.info("data_collection is not enabled.")
return None, None

logger.info("data_collection is enabled. Registering chat data.")

if not session_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"message": "Session ID not found",
"code": "SESSION_ID_NOT_FOUND",
},
)

user_id = await run_in_threadpool(
get_user_from_session_id, uuid.UUID(session_id)
)

if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
"message": "User not found",
"code": "USER_NOT_FOUND",
},
)

conversation_id = await run_in_threadpool(
write_user_query, user_id, user_query, conversation_id
)

message_id = await run_in_threadpool(
write_chat_answer, user_id, answer_content, sources, conversation_id
)

return conversation_id, message_id

async def register_document_click(
self,
doc_id: uuid.UUID,
message_id: uuid.UUID,
) -> None:
if not self.should_collect:
logger.info("data_collection is not enabled.")
return

logger.info("data_collection is enabled. Registering document click.")

await run_in_threadpool(update_returned_document_click, doc_id, message_id)


def get_data_collection_service(request: Request) -> DataCollection:
host = request.url.hostname
if host is None:
return DataCollection(host="")
return DataCollection(host=host)
2 changes: 1 addition & 1 deletion src/app/services/data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.exc import IntegrityError
from welearn_database.data.enumeration import Step

from src.app.services.sql_service import (
from src.app.services.sql_db.queries import (
write_new_data_quality_error,
write_process_state,
)
Expand Down
2 changes: 1 addition & 1 deletion src/app/services/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.app.models.collections import Collection
from src.app.models.documents import JourneySectionType
from src.app.services.exceptions import LanguageNotSupportedError
from src.app.services.sql_service import get_embeddings_model_id_according_name
from src.app.services.sql_db.queries import get_embeddings_model_id_according_name
from src.app.utils.decorators import log_time_and_error_sync
from src.app.utils.logger import logger as utils_logger

Expand Down
Empty file removed src/app/services/monitoring.py
Empty file.
2 changes: 1 addition & 1 deletion src/app/services/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from src.app.services.data_quality import DataQualityChecker
from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError
from src.app.services.helpers import convert_embedding_bytes
from src.app.services.sql_service import get_subject
from src.app.services.sql_db.queries import get_subject
from src.app.utils.decorators import log_time_and_error, log_time_and_error_sync
from src.app.utils.logger import logger as logger_utils

Expand Down
2 changes: 1 addition & 1 deletion src/app/services/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.sql import select
from welearn_database.data.models import APIKeyManagement

from src.app.services.sql_service import session_maker
from src.app.services.sql_db.queries import session_maker
from src.app.utils.logger import logger as logger_utils

api_key_header = APIKeyHeader(name="X-API-Key")
Expand Down
Loading