Skip to content

Commit 6278813

Browse files
Tasks from sever incorrectly uses default estimation procedure ID (#1395)
1 parent 483f467 commit 6278813

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

openml/tasks/functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,7 @@ def _create_task_from_xml(xml: str) -> OpenMLTask:
492492
"data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"],
493493
"evaluation_measure": evaluation_measures,
494494
}
495+
# TODO: add OpenMLClusteringTask?
495496
if task_type in (
496497
TaskType.SUPERVISED_CLASSIFICATION,
497498
TaskType.SUPERVISED_REGRESSION,
@@ -508,6 +509,10 @@ def _create_task_from_xml(xml: str) -> OpenMLTask:
508509
common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][
509510
"oml:estimation_procedure"
510511
]["oml:type"]
512+
common_kwargs["estimation_procedure_id"] = int(
513+
inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"]
514+
)
515+
511516
common_kwargs["estimation_parameters"] = estimation_parameters
512517
common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"]
513518
common_kwargs["data_splits_url"] = inputs["estimation_procedure"][

tests/test_tasks/test_classification_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def setUp(self, n_levels: int = 1):
1515
super().setUp()
1616
self.task_id = 119 # diabetes
1717
self.task_type = TaskType.SUPERVISED_CLASSIFICATION
18-
self.estimation_procedure = 1
18+
self.estimation_procedure = 5
1919

2020
def test_get_X_and_Y(self):
2121
X, Y = super().test_get_X_and_Y()
@@ -30,7 +30,8 @@ def test_download_task(self):
3030
assert task.task_id == self.task_id
3131
assert task.task_type_id == TaskType.SUPERVISED_CLASSIFICATION
3232
assert task.dataset_id == 20
33+
assert task.estimation_procedure_id == self.estimation_procedure
3334

3435
def test_class_labels(self):
3536
task = get_task(self.task_id)
36-
assert task.class_labels == ["tested_negative", "tested_positive"]
37+
assert task.class_labels == ["tested_negative", "tested_positive"]

tests/test_tasks/test_regression_task.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ class OpenMLRegressionTaskTest(OpenMLSupervisedTaskTest):
1818

1919
def setUp(self, n_levels: int = 1):
2020
super().setUp()
21-
21+
self.estimation_procedure = 9
2222
task_meta_data = {
2323
"task_type": TaskType.SUPERVISED_REGRESSION,
2424
"dataset_id": 105, # wisconsin
25-
"estimation_procedure_id": 7,
25+
"estimation_procedure_id": self.estimation_procedure, # non default value to test estimation procedure id
2626
"target_name": "time",
2727
}
2828
_task_id = check_task_existence(**task_meta_data)
@@ -46,7 +46,7 @@ def setUp(self, n_levels: int = 1):
4646
raise Exception(repr(e))
4747
self.task_id = task_id
4848
self.task_type = TaskType.SUPERVISED_REGRESSION
49-
self.estimation_procedure = 7
49+
5050

5151
def test_get_X_and_Y(self):
5252
X, Y = super().test_get_X_and_Y()
@@ -61,3 +61,4 @@ def test_download_task(self):
6161
assert task.task_id == self.task_id
6262
assert task.task_type_id == TaskType.SUPERVISED_REGRESSION
6363
assert task.dataset_id == 105
64+
assert task.estimation_procedure_id == self.estimation_procedure

0 commit comments

Comments
 (0)