|
15 | 15 | from app import __version__ as app_version |
16 | 16 | from app.exception import ConfigurationException |
17 | 17 | from app.model_services.base import AbstractModelService |
18 | | -from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer |
| 18 | +from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer, HuggingFaceLlmUnsupervisedTrainer |
19 | 19 | from app.domain import ModelCard, ModelType, Annotation, Device |
20 | 20 | from app.config import Settings |
21 | 21 | from app.utils import ( |
@@ -62,6 +62,7 @@ def __init__( |
62 | 62 | self._multi_label_threshold = 0.5 |
63 | 63 | self._text_generator = ThreadPoolExecutor(max_workers=50) |
64 | 64 | self.model_name = model_name or "HuggingFace LLM model" |
| 65 | + self.is_4bit_quantised = False |
65 | 66 |
|
66 | 67 | @property |
67 | 68 | def model(self) -> PreTrainedModel: |
@@ -206,6 +207,8 @@ def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> N |
206 | 207 | self._model.to(get_settings().DEVICE) |
207 | 208 | if self._enable_trainer: |
208 | 209 | self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self) |
| 210 | + self._unsupervised_trainer = HuggingFaceLlmUnsupervisedTrainer(self) |
| 211 | + self.is_4bit_quantised = load_in_4bit |
209 | 212 |
|
210 | 213 | def info(self) -> ModelCard: |
211 | 214 | """ |
@@ -396,29 +399,47 @@ def create_embeddings( |
396 | 399 |
|
397 | 400 | self.model.eval() |
398 | 401 |
|
399 | | - inputs = self.tokenizer( |
400 | | - text, |
401 | | - add_special_tokens=False, |
402 | | - return_tensors="pt", |
403 | | - padding=True, |
404 | | - truncation=True, |
405 | | - ) |
406 | | - |
407 | | - inputs.to(self.model.device) |
408 | | - |
409 | | - with torch.no_grad(): |
410 | | - outputs = self.model(**inputs, output_hidden_states=True) |
| 402 | + texts = [text] if isinstance(text, str) else text |
| 403 | + all_embeddings = [] |
| 404 | + |
| 405 | + for txt in texts: |
| 406 | + inputs = self.tokenizer(txt, add_special_tokens=False, truncation=False, padding=False) |
| 407 | + input_ids = inputs["input_ids"] |
| 408 | + attention_mask = inputs["attention_mask"] |
| 409 | + window_size = max(self.model.config.max_position_embeddings - 2, 1) |
| 410 | + stride = window_size |
| 411 | + chunk_embeddings = [] |
| 412 | + |
| 413 | + for start in range(0, len(input_ids), stride): |
| 414 | + end = min(start + window_size, len(input_ids)) |
| 415 | + chunk_inputs = { |
| 416 | + "input_ids": torch.tensor( |
| 417 | + [input_ids[start:end]], dtype=torch.long |
| 418 | + ).to(self.model.device), |
| 419 | + "attention_mask": torch.tensor( |
| 420 | + [attention_mask[start:end]], dtype=torch.long |
| 421 | + ).to(self.model.device), |
| 422 | + } |
| 423 | + |
| 424 | + with torch.no_grad(): |
| 425 | + outputs = self.model(**chunk_inputs, output_hidden_states=True) |
| 426 | + |
| 427 | + last_hidden_state = outputs.hidden_states[-1] |
| 428 | + chunk_attention_mask = chunk_inputs["attention_mask"] |
| 429 | + masked_hidden_states = last_hidden_state * chunk_attention_mask.unsqueeze(-1) |
| 430 | + sum_hidden_states = masked_hidden_states.sum(dim=1) |
| 431 | + num_tokens = chunk_attention_mask.sum(dim=1, keepdim=True) |
| 432 | + chunk_embedding = sum_hidden_states / num_tokens |
| 433 | + chunk_embeddings.append(chunk_embedding) |
| 434 | + |
| 435 | + if end >= len(input_ids): |
| 436 | + break |
411 | 437 |
|
412 | | - last_hidden_state = outputs.hidden_states[-1] |
413 | | - attention_mask = inputs["attention_mask"] |
414 | | - masked_hidden_states = last_hidden_state * attention_mask.unsqueeze(-1) |
415 | | - sum_hidden_states = masked_hidden_states.sum(dim=1) |
416 | | - num_tokens = attention_mask.sum(dim=1, keepdim=True) |
417 | | - embeddings = sum_hidden_states / num_tokens |
418 | | - l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
| 438 | + final_embedding = torch.mean(torch.cat(chunk_embeddings, dim=0), dim=0, keepdim=True) |
| 439 | + l2_normalised = torch.nn.functional.normalize(final_embedding, p=2, dim=1) |
| 440 | + all_embeddings.append(l2_normalised.cpu().numpy().tolist()[0]) |
419 | 441 |
|
420 | | - results = l2_normalised.cpu().numpy().tolist() |
421 | | - return results[0] if isinstance(text, str) else results |
| 442 | + return all_embeddings[0] if isinstance(text, str) else all_embeddings |
422 | 443 |
|
423 | 444 | def train_supervised( |
424 | 445 | self, |
@@ -465,3 +486,49 @@ def train_supervised( |
465 | 486 | synchronised, |
466 | 487 | **hyperparams, |
467 | 488 | ) |
| 489 | + |
| 490 | + def train_unsupervised( |
| 491 | + self, |
| 492 | + data_file: TextIO, |
| 493 | + epochs: int, |
| 494 | + log_frequency: int, |
| 495 | + training_id: str, |
| 496 | + input_file_name: str, |
| 497 | + raw_data_files: Optional[List[TextIO]] = None, |
| 498 | + description: Optional[str] = None, |
| 499 | + synchronised: bool = False, |
| 500 | + **hyperparams: Dict[str, Any], |
| 501 | + ) -> Tuple[bool, str, str]: |
| 502 | + """ |
| 503 | + Initiates unsupervised training on the model. |
| 504 | +
|
| 505 | + Args: |
| 506 | + data_file (TextIO): The file containing a JSON list of texts. |
| 507 | + epochs (int): The number of training epochs. |
| 508 | + log_frequency (int): The number of epochs after which training metrics will be logged. |
| 509 | + training_id (str): A unique identifier for the training process. |
| 510 | + input_file_name (str): The name of the input file to be logged. |
| 511 | + raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None. |
| 512 | + description (Optional[str]): The description of the training or change logs. Defaults to empty. |
| 513 | + synchronised (bool): Whether to wait for the training to complete. |
| 514 | + **hyperparams (Dict[str, Any]): Additional hyperparameters for training. |
| 515 | +
|
| 516 | + Returns: |
| 517 | + Tuple[bool, str, str]: A tuple with the first element indicating success or failure. |
| 518 | +
|
| 519 | + Raises: |
| 520 | + ConfigurationException: If the unsupervised trainer is not enabled. |
| 521 | + """ |
| 522 | + if self._unsupervised_trainer is None: |
| 523 | + raise ConfigurationException("The unsupervised trainer is not enabled") |
| 524 | + return self._unsupervised_trainer.train( |
| 525 | + data_file, |
| 526 | + epochs, |
| 527 | + log_frequency, |
| 528 | + training_id, |
| 529 | + input_file_name, |
| 530 | + raw_data_files, |
| 531 | + description, |
| 532 | + synchronised, |
| 533 | + **hyperparams, |
| 534 | + ) |
0 commit comments