Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions openml/tasks/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def _create_task_from_xml(xml: str) -> OpenMLTask:
"data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"],
"evaluation_measure": evaluation_measures,
}
# TODO: add OpenMLClusteringTask?
if task_type in (
TaskType.SUPERVISED_CLASSIFICATION,
TaskType.SUPERVISED_REGRESSION,
Expand All @@ -508,6 +509,10 @@ def _create_task_from_xml(xml: str) -> OpenMLTask:
common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][
"oml:estimation_procedure"
]["oml:type"]
common_kwargs["estimation_procedure_id"] = int(
inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"]
)

common_kwargs["estimation_parameters"] = estimation_parameters
common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"]
common_kwargs["data_splits_url"] = inputs["estimation_procedure"][
Expand Down
5 changes: 3 additions & 2 deletions tests/test_tasks/test_classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setUp(self, n_levels: int = 1):
super().setUp()
self.task_id = 119 # diabetes
self.task_type = TaskType.SUPERVISED_CLASSIFICATION
self.estimation_procedure = 1
self.estimation_procedure = 5

def test_get_X_and_Y(self):
X, Y = super().test_get_X_and_Y()
Expand All @@ -30,7 +30,8 @@ def test_download_task(self):
assert task.task_id == self.task_id
assert task.task_type_id == TaskType.SUPERVISED_CLASSIFICATION
assert task.dataset_id == 20
assert task.estimation_procedure_id == self.estimation_procedure

def test_class_labels(self):
task = get_task(self.task_id)
assert task.class_labels == ["tested_negative", "tested_positive"]
assert task.class_labels == ["tested_negative", "tested_positive"]
7 changes: 4 additions & 3 deletions tests/test_tasks/test_regression_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class OpenMLRegressionTaskTest(OpenMLSupervisedTaskTest):

def setUp(self, n_levels: int = 1):
super().setUp()

self.estimation_procedure = 9
task_meta_data = {
"task_type": TaskType.SUPERVISED_REGRESSION,
"dataset_id": 105, # wisconsin
"estimation_procedure_id": 7,
"estimation_procedure_id": self.estimation_procedure, # non default value to test estimation procedure id
"target_name": "time",
}
_task_id = check_task_existence(**task_meta_data)
Expand All @@ -46,7 +46,7 @@ def setUp(self, n_levels: int = 1):
raise Exception(repr(e))
self.task_id = task_id
self.task_type = TaskType.SUPERVISED_REGRESSION
self.estimation_procedure = 7


def test_get_X_and_Y(self):
X, Y = super().test_get_X_and_Y()
Expand All @@ -61,3 +61,4 @@ def test_download_task(self):
assert task.task_id == self.task_id
assert task.task_type_id == TaskType.SUPERVISED_REGRESSION
assert task.dataset_id == 105
assert task.estimation_procedure_id == self.estimation_procedure
Loading