Skip to content

Commit d29bd71

Browse files
committed
remove redundant __init__ by adding ClassVar
1 parent 17d690f commit d29bd71

File tree

2 files changed

+41
-142
lines changed

2 files changed

+41
-142
lines changed

openml/tasks/functions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ def get_task(
424424
# Including class labels as part of task meta data handles
425425
# the case where data download was initially disabled
426426
if isinstance(task, (OpenMLClassificationTask, OpenMLLearningCurveTask)):
427+
assert (
428+
task.target_name is not None
429+
), "Supervised tasks must define a target feature before retrieving class labels."
427430
task.class_labels = dataset.retrieve_class_labels(task.target_name)
428431
# Clustering tasks do not have class labels
429432
# and do not offer download_split

openml/tasks/task.py

Lines changed: 38 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# License: BSD 3-Clause
2-
# TODO(eddbergman): Seems like a lot of the subclasses could just get away with setting
3-
# a `ClassVar` for whatever changes as their `__init__` defaults, less duplicated code.
42
from __future__ import annotations
53

64
import warnings
75
from abc import ABC
86
from enum import Enum
97
from pathlib import Path
10-
from typing import TYPE_CHECKING, Any, Sequence
8+
from typing import TYPE_CHECKING, Any, ClassVar, Sequence
119
from typing_extensions import TypedDict
1210

1311
import openml._api_calls
@@ -70,31 +68,45 @@ class OpenMLTask(OpenMLBase):
7068
Refers to the URL of the data splits used for the OpenML task.
7169
"""
7270

71+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 1
72+
7373
def __init__( # noqa: PLR0913
7474
self,
75-
task_id: int | None,
7675
task_type_id: TaskType,
7776
task_type: str,
7877
data_set_id: int,
79-
estimation_procedure_id: int = 1,
78+
task_id: int | None = None,
79+
estimation_procedure_id: int | None = None,
8080
estimation_procedure_type: str | None = None,
8181
estimation_parameters: dict[str, str] | None = None,
8282
evaluation_measure: str | None = None,
8383
data_splits_url: str | None = None,
84+
target_name: str | None = None,
8485
):
8586
self.task_id = int(task_id) if task_id is not None else None
8687
self.task_type_id = task_type_id
8788
self.task_type = task_type
8889
self.dataset_id = int(data_set_id)
90+
self.target_name = target_name
91+
resolved_estimation_procedure_id = self._resolve_estimation_procedure_id(
92+
estimation_procedure_id,
93+
)
8994
self.evaluation_measure = evaluation_measure
9095
self.estimation_procedure: _EstimationProcedure = {
9196
"type": estimation_procedure_type,
9297
"parameters": estimation_parameters,
9398
"data_splits_url": data_splits_url,
9499
}
95-
self.estimation_procedure_id = estimation_procedure_id
100+
self.estimation_procedure_id = resolved_estimation_procedure_id
96101
self.split: OpenMLSplit | None = None
97102

103+
def _resolve_estimation_procedure_id(self, estimation_procedure_id: int | None) -> int:
104+
return (
105+
estimation_procedure_id
106+
if estimation_procedure_id is not None
107+
else self.DEFAULT_ESTIMATION_PROCEDURE_ID
108+
)
109+
98110
@classmethod
99111
def _entity_letter(cls) -> str:
100112
return "t"
@@ -128,7 +140,8 @@ def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str]]]:
128140
if class_labels is not None:
129141
fields["# of Classes"] = len(class_labels)
130142

131-
if hasattr(self, "cost_matrix"):
143+
cost_matrix = getattr(self, "cost_matrix", None)
144+
if cost_matrix is not None:
132145
fields["Cost Matrix"] = "Available"
133146

134147
# determines the order in which the information will be printed
@@ -249,32 +262,43 @@ class OpenMLSupervisedTask(OpenMLTask, ABC):
249262
Refers to the unique identifier of task.
250263
"""
251264

265+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 1
266+
252267
def __init__( # noqa: PLR0913
253268
self,
254269
task_type_id: TaskType,
255270
task_type: str,
256271
data_set_id: int,
257272
target_name: str,
258-
estimation_procedure_id: int = 1,
273+
estimation_procedure_id: int | None = None,
259274
estimation_procedure_type: str | None = None,
260275
estimation_parameters: dict[str, str] | None = None,
261276
evaluation_measure: str | None = None,
262277
data_splits_url: str | None = None,
263278
task_id: int | None = None,
279+
class_labels: list[str] | None = None,
280+
cost_matrix: np.ndarray | None = None,
264281
):
282+
resolved_estimation_procedure_id = self._resolve_estimation_procedure_id(
283+
estimation_procedure_id,
284+
)
265285
super().__init__(
266286
task_id=task_id,
267287
task_type_id=task_type_id,
268288
task_type=task_type,
269289
data_set_id=data_set_id,
270-
estimation_procedure_id=estimation_procedure_id,
290+
estimation_procedure_id=resolved_estimation_procedure_id,
271291
estimation_procedure_type=estimation_procedure_type,
272292
estimation_parameters=estimation_parameters,
273293
evaluation_measure=evaluation_measure,
274294
data_splits_url=data_splits_url,
295+
target_name=target_name,
275296
)
276297

277-
self.target_name = target_name
298+
self.class_labels = class_labels
299+
self.cost_matrix = cost_matrix
300+
if cost_matrix is not None:
301+
raise NotImplementedError("Costmatrix")
278302

279303
def get_X_and_y(self) -> tuple[pd.DataFrame, pd.Series | pd.DataFrame | None]:
280304
"""Get data associated with the current task.
@@ -325,64 +349,13 @@ class OpenMLClassificationTask(OpenMLSupervisedTask):
325349
326350
Parameters
327351
----------
328-
task_type_id : TaskType
329-
ID of the Classification task type.
330-
task_type : str
331-
Name of the Classification task type.
332-
data_set_id : int
333-
ID of the OpenML dataset associated with the Classification task.
334-
target_name : str
335-
Name of the target variable.
336-
estimation_procedure_id : int, default=None
337-
ID of the estimation procedure for the Classification task.
338-
estimation_procedure_type : str, default=None
339-
Type of the estimation procedure.
340-
estimation_parameters : dict, default=None
341-
Estimation parameters for the Classification task.
342-
evaluation_measure : str, default=None
343-
Name of the evaluation measure.
344-
data_splits_url : str, default=None
345-
URL of the data splits for the Classification task.
346-
task_id : Union[int, None]
347-
ID of the Classification task (if it already exists on OpenML).
348352
class_labels : List of str, default=None
349353
A list of class labels (for classification tasks).
350354
cost_matrix : array, default=None
351355
A cost matrix (for classification tasks).
352356
"""
353357

354-
def __init__( # noqa: PLR0913
355-
self,
356-
task_type_id: TaskType,
357-
task_type: str,
358-
data_set_id: int,
359-
target_name: str,
360-
estimation_procedure_id: int = 1,
361-
estimation_procedure_type: str | None = None,
362-
estimation_parameters: dict[str, str] | None = None,
363-
evaluation_measure: str | None = None,
364-
data_splits_url: str | None = None,
365-
task_id: int | None = None,
366-
class_labels: list[str] | None = None,
367-
cost_matrix: np.ndarray | None = None,
368-
):
369-
super().__init__(
370-
task_id=task_id,
371-
task_type_id=task_type_id,
372-
task_type=task_type,
373-
data_set_id=data_set_id,
374-
estimation_procedure_id=estimation_procedure_id,
375-
estimation_procedure_type=estimation_procedure_type,
376-
estimation_parameters=estimation_parameters,
377-
evaluation_measure=evaluation_measure,
378-
target_name=target_name,
379-
data_splits_url=data_splits_url,
380-
)
381-
self.class_labels = class_labels
382-
self.cost_matrix = cost_matrix
383-
384-
if cost_matrix is not None:
385-
raise NotImplementedError("Costmatrix")
358+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 1
386359

387360

388361
class OpenMLRegressionTask(OpenMLSupervisedTask):
@@ -412,31 +385,7 @@ class OpenMLRegressionTask(OpenMLSupervisedTask):
412385
Evaluation measure used in the Regression task.
413386
"""
414387

415-
def __init__( # noqa: PLR0913
416-
self,
417-
task_type_id: TaskType,
418-
task_type: str,
419-
data_set_id: int,
420-
target_name: str,
421-
estimation_procedure_id: int = 7,
422-
estimation_procedure_type: str | None = None,
423-
estimation_parameters: dict[str, str] | None = None,
424-
data_splits_url: str | None = None,
425-
task_id: int | None = None,
426-
evaluation_measure: str | None = None,
427-
):
428-
super().__init__(
429-
task_id=task_id,
430-
task_type_id=task_type_id,
431-
task_type=task_type,
432-
data_set_id=data_set_id,
433-
estimation_procedure_id=estimation_procedure_id,
434-
estimation_procedure_type=estimation_procedure_type,
435-
estimation_parameters=estimation_parameters,
436-
evaluation_measure=evaluation_measure,
437-
target_name=target_name,
438-
data_splits_url=data_splits_url,
439-
)
388+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 7
440389

441390

442391
class OpenMLClusteringTask(OpenMLTask):
@@ -467,32 +416,7 @@ class OpenMLClusteringTask(OpenMLTask):
467416
feature set for the clustering task.
468417
"""
469418

470-
def __init__( # noqa: PLR0913
471-
self,
472-
task_type_id: TaskType,
473-
task_type: str,
474-
data_set_id: int,
475-
estimation_procedure_id: int = 17,
476-
task_id: int | None = None,
477-
estimation_procedure_type: str | None = None,
478-
estimation_parameters: dict[str, str] | None = None,
479-
data_splits_url: str | None = None,
480-
evaluation_measure: str | None = None,
481-
target_name: str | None = None,
482-
):
483-
super().__init__(
484-
task_id=task_id,
485-
task_type_id=task_type_id,
486-
task_type=task_type,
487-
data_set_id=data_set_id,
488-
evaluation_measure=evaluation_measure,
489-
estimation_procedure_id=estimation_procedure_id,
490-
estimation_procedure_type=estimation_procedure_type,
491-
estimation_parameters=estimation_parameters,
492-
data_splits_url=data_splits_url,
493-
)
494-
495-
self.target_name = target_name
419+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 17
496420

497421
def get_X(self) -> pd.DataFrame:
498422
"""Get data associated with the current task.
@@ -554,32 +478,4 @@ class OpenMLLearningCurveTask(OpenMLClassificationTask):
554478
Cost matrix for Learning Curve tasks.
555479
"""
556480

557-
def __init__( # noqa: PLR0913
558-
self,
559-
task_type_id: TaskType,
560-
task_type: str,
561-
data_set_id: int,
562-
target_name: str,
563-
estimation_procedure_id: int = 13,
564-
estimation_procedure_type: str | None = None,
565-
estimation_parameters: dict[str, str] | None = None,
566-
data_splits_url: str | None = None,
567-
task_id: int | None = None,
568-
evaluation_measure: str | None = None,
569-
class_labels: list[str] | None = None,
570-
cost_matrix: np.ndarray | None = None,
571-
):
572-
super().__init__(
573-
task_id=task_id,
574-
task_type_id=task_type_id,
575-
task_type=task_type,
576-
data_set_id=data_set_id,
577-
estimation_procedure_id=estimation_procedure_id,
578-
estimation_procedure_type=estimation_procedure_type,
579-
estimation_parameters=estimation_parameters,
580-
evaluation_measure=evaluation_measure,
581-
target_name=target_name,
582-
data_splits_url=data_splits_url,
583-
class_labels=class_labels,
584-
cost_matrix=cost_matrix,
585-
)
481+
DEFAULT_ESTIMATION_PROCEDURE_ID: ClassVar[int] = 13

0 commit comments

Comments
 (0)