Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/app/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
ModelNotFoundError,
bad_request,
)
from src.app.services.search import SearchService, get_search_service
from src.app.services.search import (
SearchService,
get_search_service,
MIX_NOT_ALLOWED_CORPUS,
)
from src.app.services.search_helpers import search_multi_inputs
from src.app.services.sql_db.queries import (
get_collections_sync,
Expand Down Expand Up @@ -66,6 +70,7 @@ async def get_corpus():
"lang": lang,
"model": model,
"corpus": f"{name}_{lang}_{model}",
"is_allowed": name not in MIX_NOT_ALLOWED_CORPUS,
}
for name, lang, model in collections
]
Expand Down
21 changes: 21 additions & 0 deletions src/app/services/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from qdrant_client.http import models as http_models
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from qdrant_client.models import FieldCondition, MatchAny

from src.app.models.collections import Collection
from src.app.models.search import (
Expand All @@ -35,6 +36,9 @@ async def get_qdrant(request: Request) -> AsyncQdrantClient:
return request.app.state.qdrant


MIX_NOT_ALLOWED_CORPUS = ["conversation", "ipcc"]


class SearchService:
import threading

Expand Down Expand Up @@ -66,6 +70,12 @@ def __init__(self, client):
]
self.col_prefix = "collection_welearn_"

@staticmethod
def remove_unwanted_corpora(corpora: tuple[str, ...]) -> tuple[str, ...]:
return tuple(
corpus for corpus in corpora if corpus not in MIX_NOT_ALLOWED_CORPUS
)

@staticmethod
def flavored_with_subject(
sdg_emb: ndarray, subject_emb: ndarray, discipline_factor: int | float = 2
Expand Down Expand Up @@ -238,6 +248,9 @@ async def search_handler(
subject_influence_factor=qp.influence_factor,
)

if method == SearchMethods.BY_SLICES and qp.corpora and len(qp.corpora):
qp.corpora = self.remove_unwanted_corpora(qp.corpora)

filter_content = [
FilterDefinition(key="document_corpus", value=qp.corpora),
FilterDefinition(key="document_details.readability", value=qp.readability),
Expand All @@ -253,6 +266,14 @@ async def search_handler(

data = []
if method == "by_slices":
if not filters:
filters = qdrant_models.Filter()
filters.must_not = [
FieldCondition(
key="document_corpus",
match=MatchAny(any=MIX_NOT_ALLOWED_CORPUS),
)
]
data = await self.search(
collection_info=collection.name,
embedding=embedding,
Expand Down