Add multiclass model#16
Conversation
There was a problem hiding this comment.
same here -> verschieben in #18
| self.add_bg_noise = True | ||
|
|
||
| self.multiclass = "multiclass" in self.model_name | ||
| self.mask_channels = None # number of anomaly classes (background class not encluded) |
There was a problem hiding this comment.
rename: num_anomaly_classes
| self.model_params = get_model_configuration(model_name, anomaly_size[0], debug=False) | ||
|
|
||
| # (optional) fixed roi size - can also be set to None for variable roi size | ||
| self.fixed_roi_size = (64, 64) |
There was a problem hiding this comment.
hatten wir echt keine unterscheidung hier zwischen 3D / 2D? Danke für den fix!
There was a problem hiding this comment.
Ich hatte es so gelöst, dass man Tuple beliebiger Länge übergeben konnte und für 2D nur die ersten 2 Einträge beachtet (und für 3D die ersten 3). War nicht so schön, deswegen jetzt so.
| return dict(config) | ||
| raise TypeError("Model config must be a dataclass instance or mapping.") | ||
|
|
||
| def sync_model_mask_channels(self): |
There was a problem hiding this comment.
warum brauchen wir das? Nutze die Funktion set_model_param. Bitte nichtmehr die model_param selber setzen. es gibt dafür passende set-funktionen in der configuration.py
|
|
||
| if self._config.multiclass: | ||
| if save_folder_org_mask is None: | ||
| save_folder_org_mask = os.path.join(self._config.study_folder, "org_masks") |
There was a problem hiding this comment.
die org_masks und tgt_masks wollen wir jetzt auch bei binären tasks. Dann ist es einheitlicher und schadet nicht sie auch abzuspeichern. siehe comments bei #18
| i = 0 | ||
|
|
||
| while best < self._config.feedback_threshold: | ||
| if self._config.prior_sampling: |
| if self._config.prior_sampling: | ||
| syn_anomaly_sample = self._model.generate_synth_sample_prior(clamp_01=self._config.clamp01_output, out_hw=self._config.anomaly_size[1:]) | ||
| if self._config.multiclass: | ||
| # TODO: hier noch prior einbauen? |
There was a problem hiding this comment.
yes. versuch hier auch sowenig wie möglich vom bestehenden code zu ändern.
| self._anomaly_dataset.numpy_mode = True | ||
|
|
||
| for batch in tqdm(self._anomaly_dataset): | ||
| if self._config.multiclass: |
There was a problem hiding this comment.
versuch hier die unterscheidung zu vermeiden.
Pfleiderer-Adrian
left a comment
There was a problem hiding this comment.
hier brauchen wir noch paar änderungen.
Ich würde bei dem binären fall ebenfalls die masken (ori, transformed) mit abspeichern. So sparen wir uns einige/fast alle multiclass if-abfragen. Man müsste nurnoch bei der generierung von synth. anomalien (generate_synth_sample) bei den jeweiligen modellen anpassen. wobei das auch nicht nötig ist da wir ja ein sample übergeben (kann ja auch img+maske in einem sein).
| dtype=torch.float32, | ||
| ) | ||
|
|
||
| if self._config.multiclass: |
There was a problem hiding this comment.
die if abfrage raus. egal ob binär oder multi wir nehmen immer die org und tgt masks mit.
| self._config.num_anomaly_classes = int(max_class_val) | ||
| self._config.set_model_param("num_anomaly_classes", self._config.num_anomaly_classes) | ||
|
|
||
| self._anomaly_dataset = AnomalyDataset( |
There was a problem hiding this comment.
Hier unterscheiden. Also bei conditional mit tgt & org. Aber nicht mit dem multiclassen-flag sondern mit conditional. Somit können wir evtl dann in zukunft auch auf binäre usecases conditional startegien anwenden.
There was a problem hiding this comment.
nvm wir brauchen nie tgt & ori masken in einem load. Warum nicht ein switch im AnomalDataset einabauen? Dann den switch auf config.is_conditional setzen. Dann gibt es ja nach dem tgt oder ori zurück was du halt gerade brauchst. Somit laden wir bei binär einfach immer die ori. Bei deinem mutliclass beispiel dann: beim training ori und beim generieren tgt.
There was a problem hiding this comment.
wir brauchen hier keine if-abfrage. immer eine maske zurückgeben.
| os.makedirs(save_folder, exist_ok=True) | ||
|
|
||
|
|
||
| # use feedback system to generate similar anomalies |
| if self._config.prior_sampling: | ||
| syn_anomaly_sample = self._model.generate_synth_sample_prior(clamp_01=self._config.clamp01_output, out_hw=self._config.anomaly_size[1:]) | ||
| if self._config.multiclass: | ||
| if getattr(self._config, "prior_sampling", False): |
There was a problem hiding this comment.
hier ebenfalls. Abfragen vereinheitlichen. immer maske und img übergeben
| return x | ||
| return x, org_mask_out, tgt_mask_out, fname | ||
|
|
||
| return x, org_mask_out, tgt_mask_out |
There was a problem hiding this comment.
Lass und das alles vereinfachen bisher war rückgabe immer x oder x, filename; Jetzt brauchen wir masken. Also würde ich vorschlagen neue rückgabe x, mask oder x, mask, filename. Beim Tranining brauchst du nur x, ori_mask und beim generieren brauchst du nur x, tgt_mask. Wir brauchen zu keinem zeitpunkt beides oder? -> somit können wir mit einem einfachen if im AnomalyDataset beide fälle abdecken. Wenn wir uns generell egal ob condition oder nicht auf eine rückgabefrom einigen sparen wir uns in der pipeline viele verschachtelungen die unseren hybrid data generator unnötig komplex macht.
needs identical-offset changes to work properly