Skip to content

Commit 042dd83

Browse files
authored
Merge pull request #37 from CogStack/embeddings
Add embedding creation for MedCAT and HF NER models
2 parents 5c83f01 + a869a83 commit 042dd83

15 files changed

Lines changed: 1027 additions & 137 deletions

app/api/routers/invocation.py

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import hashlib
88
import logging
99
import pandas as pd
10+
from fastapi.encoders import jsonable_encoder
11+
1012
import app.api.globals as cms_globals
1113

1214
from typing import Dict, List, Union, Iterator, Any
1315
from collections import defaultdict
1416
from io import BytesIO
15-
from starlette.status import HTTP_400_BAD_REQUEST
17+
from starlette.status import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
1618
from typing_extensions import Annotated
1719
from fastapi import APIRouter, Depends, Body, UploadFile, File, Request, Query, Response
1820
from fastapi.responses import StreamingResponse, PlainTextResponse, JSONResponse
@@ -22,7 +24,7 @@
2224
TextWithAnnotations,
2325
TextWithPublicKey,
2426
TextStreamItem,
25-
Tags,
27+
Tags, OpenAIEmbeddingsRequest, OpenAIEmbeddingsResponse,
2628
)
2729
from app.model_services.base import AbstractModelService
2830
from app.utils import get_settings, load_pydantic_object_from_dict
@@ -43,6 +45,7 @@
4345
PATH_PROCESS_BULK_FILE = "/process_bulk_file"
4446
PATH_REDACT = "/redact"
4547
PATH_REDACT_WITH_ENCRYPTION = "/redact_with_encryption"
48+
PATH_OPENAI_EMBEDDINGS = "/v1/embeddings"
4649

4750
router = APIRouter()
4851
config = get_settings()
@@ -355,6 +358,93 @@ def get_redacted_text_with_encryption(
355358
return JSONResponse(content=content)
356359

357360

361+
@router.post(
362+
PATH_OPENAI_EMBEDDINGS,
363+
tags=[Tags.OpenAICompatible.name],
364+
response_model=None,
365+
dependencies=[Depends(cms_globals.props.current_active_user)],
366+
description="Create embeddings based on text(s), similar to OpenAI's /v1/embeddings endpoint",
367+
)
368+
def embed_texts(
369+
request: Request,
370+
request_data: Annotated[OpenAIEmbeddingsRequest, Body(
371+
description="Text(s) to be embedded", media_type="application/json"
372+
)],
373+
tracking_id: Union[str, None] = Depends(validate_tracking_id),
374+
model_service: AbstractModelService = Depends(cms_globals.model_service_dep)
375+
) -> JSONResponse:
376+
"""
377+
Embeds text or a list of texts, mimicking OpenAI's /v1/embeddings endpoint.
378+
379+
Args:
380+
request (Request): The request object.
381+
request_data (OpenAIEmbeddingsRequest): The request data containing model and input text(s).
382+
tracking_id (Union[str, None]): An optional tracking ID of the requested task.
383+
model_service (AbstractModelService): The model service dependency.
384+
385+
Returns:
386+
JSONResponse: A response containing the embeddings of the text(s).
387+
"""
388+
tracking_id = tracking_id or str(uuid.uuid4())
389+
390+
if not hasattr(model_service, "create_embeddings"):
391+
error_response = {
392+
"error": {
393+
"message": "Model does not support embeddings",
394+
"type": "invalid_request_error",
395+
"param": "model",
396+
"code": "model_not_supported",
397+
}
398+
}
399+
return JSONResponse(
400+
content=error_response,
401+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
402+
headers={"x-cms-tracking-id": tracking_id},
403+
)
404+
405+
input_text = request_data.input
406+
model = model_service.model_name if request_data.model != model_service.model_name else request_data.model
407+
408+
if isinstance(input_text, str):
409+
input_texts = [input_text]
410+
else:
411+
input_texts = input_text
412+
413+
try:
414+
embeddings_data = []
415+
416+
for i, embedding in enumerate(model_service.create_embeddings(input_texts)):
417+
embeddings_data.append({
418+
"object": "embedding",
419+
"embedding": embedding,
420+
"index": i,
421+
})
422+
423+
response = OpenAIEmbeddingsResponse(object="list", data=embeddings_data, model=model)
424+
425+
return JSONResponse(
426+
content=jsonable_encoder(response),
427+
headers={"x-cms-tracking-id": tracking_id},
428+
)
429+
430+
except Exception as e:
431+
logger.error("Failed to create embeddings")
432+
logger.exception(e)
433+
error_response = {
434+
"error": {
435+
"message": f"Failed to create embeddings: {str(e)}",
436+
"type": "server_error",
437+
"code": "internal_error",
438+
}
439+
}
440+
return JSONResponse(
441+
content=error_response,
442+
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
443+
headers={"x-cms-tracking-id": tracking_id},
444+
)
445+
446+
447+
358448
def _send_annotation_num_metric(annotation_num: int, handler: str) -> None:
359449
cms_doc_annotations.labels(handler=handler).observe(annotation_num)
360450

app/api/routers/stream.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from starlette.types import Receive, Scope, Send
1212
from starlette.background import BackgroundTask
1313
from fastapi import APIRouter, Depends, Request, Response, WebSocket, WebSocketException
14-
from pydantic import ValidationError
14+
from pydantic import ValidationError, BaseModel
1515
from app.domain import Tags, TextStreamItem
1616
from app.model_services.base import AbstractModelService
1717
from app.utils import get_settings
@@ -20,7 +20,6 @@
2020

2121
PATH_STREAM_PROCESS = "/process"
2222
PATH_WS = "/ws"
23-
PATH_GENERATE= "/generate"
2423

2524
router = APIRouter()
2625
config = get_settings()
@@ -57,6 +56,22 @@ async def get_entities_stream_from_jsonlines_stream(
5756
return _LocalStreamingResponse(annotation_stream, media_type="application/x-ndjson; charset=utf-8")
5857

5958

59+
@router.get(
60+
PATH_WS,
61+
tags=[Tags.Annotations.name],
62+
dependencies=[Depends(cms_globals.props.current_active_user)],
63+
description="WebSocket info endpoint for real-time NER entity extraction. Use ws://host:port/stream/ws to establish an actual WebSocket connection.",
64+
include_in_schema=True,
65+
)
66+
async def get_inline_annotations_from_websocket_info() -> "_WebSocketInfo":
67+
"""
68+
Information about the WebSocket endpoint for real-time NER entity extraction.
69+
70+
This endpoint provides documentation for the WebSocket connection available at the same path.
71+
Connect to ws://host:port/stream/ws and send texts to retrieve annotated results.
72+
"""
73+
return _WebSocketInfo()
74+
6075
@router.websocket(PATH_WS)
6176
# @limiter.limit(config.PROCESS_BULK_RATE_LIMIT) # Not supported yet
6277
async def get_inline_annotations_from_websocket(
@@ -189,6 +204,28 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
189204
await self.background()
190205

191206

207+
class _WebSocketInfo(BaseModel):
208+
message: str = "WebSocket endpoint for real-time NER entity extraction"
209+
example: str = """<form action="" onsubmit="send_doc(event)">
210+
<input type="text" id="cms-input" autocomplete="off"/>
211+
<button>Send</button>
212+
</form>
213+
<ul id="cms-output"></ul>
214+
<script>
215+
var ws = new WebSocket("ws://localhost:8000/stream/ws");
216+
ws.onmessage = function(event) {
217+
document.getElementById("cms-output").appendChild(
218+
Object.assign(document.createElement('li'), { textContent: event.data })
219+
);
220+
};
221+
function send_doc(event) {
222+
ws.send(document.getElementById("cms-input").value);
223+
event.preventDefault();
224+
};
225+
</script>"""
226+
protocol: str = "WebSocket"
227+
228+
192229
async def _annotation_async_gen(request: Request, model_service: AbstractModelService) -> AsyncGenerator:
193230
try:
194231
buffer = ""

app/model_services/huggingface_llm_model.py

Lines changed: 89 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from app import __version__ as app_version
1616
from app.exception import ConfigurationException
1717
from app.model_services.base import AbstractModelService
18-
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
18+
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer
1919
from app.domain import ModelCard, ModelType, Annotation, Device
2020
from app.config import Settings
2121
from app.utils import (
@@ -62,6 +62,7 @@ def __init__(
6262
self._multi_label_threshold = 0.5
6363
self._text_generator = ThreadPoolExecutor(max_workers=50)
6464
self.model_name = model_name or "HuggingFace LLM model"
65+
self.is_4bit_quantised = False
6566

6667
@property
6768
def model(self) -> PreTrainedModel:
@@ -206,6 +207,8 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N
206207
self._model.to(get_settings().DEVICE)
207208
if self._enable_trainer:
208209
self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self)
210+
self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self)
211+
self.is_4bit_quantised = load_in_4bit
209212

210213
def info(self) -> ModelCard:
211214
"""
@@ -396,29 +399,47 @@ def create_embeddings(
396399

397400
self.model.eval()
398401

399-
inputs = self.tokenizer(
400-
text,
401-
add_special_tokens=False,
402-
return_tensors="pt",
403-
padding=True,
404-
truncation=True,
405-
)
406-
407-
inputs.to(self.model.device)
408-
409-
with torch.no_grad():
410-
outputs = self.model(**inputs, output_hidden_states=True)
402+
texts = [text] if isinstance(text, str) else text
403+
all_embeddings = []
404+
405+
for txt in texts:
406+
inputs = self.tokenizer(txt, add_special_tokens=False, truncation=False, padding=False)
407+
input_ids = inputs["input_ids"]
408+
attention_mask = inputs["attention_mask"]
409+
window_size = max(self.model.config.max_position_embeddings - 2, 1)
410+
stride = window_size
411+
chunk_embeddings = []
412+
413+
for start in range(0, len(input_ids), stride):
414+
end = min(start + window_size, len(input_ids))
415+
chunk_inputs = {
416+
"input_ids": torch.tensor(
417+
[input_ids[start:end]], dtype=torch.long
418+
).to(self.model.device),
419+
"attention_mask": torch.tensor(
420+
[attention_mask[start:end]], dtype=torch.long
421+
).to(self.model.device),
422+
}
423+
424+
with torch.no_grad():
425+
outputs = self.model(**chunk_inputs, output_hidden_states=True)
426+
427+
last_hidden_state = outputs.hidden_states[-1]
428+
chunk_attention_mask = chunk_inputs["attention_mask"]
429+
masked_hidden_states = last_hidden_state * chunk_attention_mask.unsqueeze(-1)
430+
sum_hidden_states = masked_hidden_states.sum(dim=1)
431+
num_tokens = chunk_attention_mask.sum(dim=1, keepdim=True)
432+
chunk_embedding = sum_hidden_states / num_tokens
433+
chunk_embeddings.append(chunk_embedding)
434+
435+
if end >= len(input_ids):
436+
break
411437

412-
last_hidden_state = outputs.hidden_states[-1]
413-
attention_mask = inputs["attention_mask"]
414-
masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1)
415-
sum_hidden_states = masked_hidden_states.sum(dim=1)
416-
num_tokens = attention_mask.sum(dim=1, keepdim=True)
417-
embeddings = sum_hidden_states / num_tokens
418-
l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1)
438+
final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True)
439+
l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1)
440+
all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0])
419441

420-
results = l2_normalised.cpu().numpy().tolist()
421-
return results[0] if isinstance(text, str) else results
442+
return all_embeddings[0] if isinstance(text, str) else all_embeddings
422443

423444
def train_supervised(
424445
self,
@@ -465,3 +486,49 @@ def train_supervised(
465486
synchronised,
466487
**hyperparams,
467488
)
489+
490+
def train_unsupervised(
491+
self,
492+
data_file: TextIO,
493+
epochs: int,
494+
log_frequency: int,
495+
training_id: str,
496+
input_file_name: str,
497+
raw_data_files: Optional[List[TextIO]] = None,
498+
description: Optional[str] = None,
499+
synchronised: bool = False,
500+
**hyperparams: Dict[str, Any],
501+
) -> Tuple[bool, str, str]:
502+
"""
503+
Initiates unsupervised training on the model.
504+
505+
Args:
506+
data_file (TextIO): The file containing a JSON list of texts.
507+
epochs (int): The number of training epochs.
508+
log_frequency (int): The number of epochs after which training metrics will be logged.
509+
training_id (str): A unique identifier for the training process.
510+
input_file_name (str): The name of the input file to be logged.
511+
raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None.
512+
description (Optional[str]): The description of the training or change logs. Defaults to empty.
513+
synchronised (bool): Whether to wait for the training to complete.
514+
**hyperparams (Dict[str, Any]): Additional hyperparameters for training.
515+
516+
Returns:
517+
Tuple[bool, str, str]: A tuple with the first element indicating success or failure.
518+
519+
Raises:
520+
ConfigurationException: If the unsupervised trainer is not enabled.
521+
"""
522+
if self._unsupervised_trainer is None:
523+
raise ConfigurationException("The unsupervised trainer is not enabled")
524+
return self._unsupervised_trainer.train(
525+
data_file,
526+
epochs,
527+
log_frequency,
528+
training_id,
529+
input_file_name,
530+
raw_data_files,
531+
description,
532+
synchronised,
533+
**hyperparams,
534+
)

0 commit comments

Comments
 (0)