Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions codecarbon/core/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check_auth(self):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def get_list_organizations(self):
Expand All @@ -100,7 +100,7 @@ def get_list_organizations(self):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def check_organization_exists(self, organization_name: str):
Expand Down Expand Up @@ -131,7 +131,7 @@ def create_organization(self, organization: OrganizationCreate):
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_organization(self, organization_id):
Expand All @@ -143,7 +143,7 @@ def get_organization(self, organization_id):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def update_organization(self, organization: OrganizationCreate):
Expand All @@ -156,7 +156,7 @@ def update_organization(self, organization: OrganizationCreate):
r = requests.patch(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def list_projects_from_organization(self, organization_id):
Expand All @@ -168,7 +168,7 @@ def list_projects_from_organization(self, organization_id):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def create_project(self, project: ProjectCreate):
Expand All @@ -181,7 +181,7 @@ def create_project(self, project: ProjectCreate):
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_project(self, project_id):
Expand All @@ -193,7 +193,7 @@ def get_project(self, project_id):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def add_emission(self, carbon_emission: dict):
Expand Down Expand Up @@ -235,11 +235,11 @@ def add_emission(self, carbon_emission: dict):
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return False
r.raise_for_status()
logger.debug(f"ApiClient - Successful upload emission {payload} to {url}")
except Exception as e:
logger.error(e, exc_info=True)
return False
raise
return True

def _create_run(self, experiment_id: str):
Expand All @@ -251,7 +251,7 @@ def _create_run(self, experiment_id: str):
logger.error(
"ApiClient FATAL The ApiClient._create_run() needs an experiment_id !"
)
return None
raise ValueError("ApiClient._create_run() needs an experiment_id")
try:
run = RunCreate(
timestamp=get_datetime_with_timezone(),
Expand All @@ -277,7 +277,7 @@ def _create_run(self, experiment_id: str):
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
self.run_id = r.json()["id"]
logger.info(
"ApiClient Successfully registered your run on the API.\n\n"
Expand All @@ -290,8 +290,10 @@ def _create_run(self, experiment_id: str):
f"Failed to connect to API, please check the configuration. {e}",
exc_info=False,
)
raise
except Exception as e:
logger.error(e, exc_info=True)
raise

def list_experiments_from_project(self, project_id: str):
"""
Expand All @@ -302,7 +304,7 @@ def list_experiments_from_project(self, project_id: str):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return []
r.raise_for_status()
return r.json()

def set_experiment(self, experiment_id: str):
Expand All @@ -322,7 +324,7 @@ def add_experiment(self, experiment: ExperimentCreate):
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_experiment(self, experiment_id):
Expand All @@ -334,7 +336,7 @@ def get_experiment(self, experiment_id):
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def _log_error(self, url, payload, response):
Expand Down
20 changes: 12 additions & 8 deletions codecarbon/emissions_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,18 @@ def _init_output_methods(self, *, api_key: str = None):
self._output_handlers.append(HTTPOutput(self._emissions_endpoint))

if self._save_to_api:
cc_api__out = CodeCarbonAPIOutput(
endpoint_url=self._api_endpoint,
experiment_id=self._experiment_id,
api_key=api_key,
conf=self._conf,
)
self.run_id = cc_api__out.run_id
self._output_handlers.append(cc_api__out)
try:
cc_api__out = CodeCarbonAPIOutput(
endpoint_url=self._api_endpoint,
experiment_id=self._experiment_id,
api_key=api_key,
conf=self._conf,
)
self.run_id = cc_api__out.run_id
self._output_handlers.append(cc_api__out)
except Exception as e:
logger.error(e, exc_info=True)
self.run_id = uuid.uuid4()
else:
self.run_id = uuid.uuid4()

Expand Down
171 changes: 171 additions & 0 deletions tests/test_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from uuid import uuid4

import requests
import requests_mock

from codecarbon.core.api_client import ApiClient
Expand Down Expand Up @@ -106,3 +107,173 @@ def test_call_api(self):
tracking_mode="Machine",
)
assert api.add_emission(dataclasses.asdict(carbon_emission))

def test_create_run_error_raises(self):
"""Test that _create_run raises HTTPError on server error."""
with requests_mock.Mocker() as m:
m.post(
"http://test.com/runs",
status_code=500,
)
with self.assertRaises(requests.exceptions.HTTPError):
ApiClient(
experiment_id="experiment_id",
endpoint_url="http://test.com",
api_key="Toto",
conf=conf,
)

def test_create_run_connection_error_raises(self):
"""Test that _create_run raises ConnectionError when API is unreachable."""
with requests_mock.Mocker() as m:
m.post(
"http://test.com/runs",
exc=requests.exceptions.ConnectionError("API unreachable"),
)
with self.assertRaises(requests.exceptions.ConnectionError):
ApiClient(
experiment_id="experiment_id",
endpoint_url="http://test.com",
api_key="Toto",
conf=conf,
)

def test_add_emission_error_raises(self):
"""Test that add_emission raises HTTPError on server error."""
with requests_mock.Mocker() as m:
m.post("http://test.com/runs", json={"id": "run-id"}, status_code=201)
api = ApiClient(
experiment_id="experiment_id",
endpoint_url="http://test.com",
api_key="Toto",
conf=conf,
)
with requests_mock.Mocker() as m:
m.post("http://test.com/emissions", status_code=500)
carbon_emission = EmissionsData(
timestamp="222",
project_name="",
run_id=uuid4(),
experiment_id="test",
duration=1.5,
emissions=2.0,
emissions_rate=2.0,
cpu_energy=2,
gpu_energy=0,
ram_energy=1,
cpu_power=3.0,
gpu_power=0,
ram_power=0.15,
energy_consumed=3.0,
water_consumed=0.0,
country_name="Groland",
country_iso_code="GRD",
region="EU",
on_cloud="N",
cloud_provider="",
cloud_region="",
os="Linux",
python_version="3.8.0",
codecarbon_version="2.1.3",
gpu_count=4,
gpu_model="NVIDIA",
cpu_count=12,
cpu_model="Intel",
longitude=-7.6174,
latitude=33.5822,
ram_total_size=83948.22,
tracking_mode="Machine",
)
with self.assertRaises(requests.exceptions.HTTPError):
api.add_emission(dataclasses.asdict(carbon_emission))

def test_check_auth_error_raises(self):
"""Test that check_auth raises HTTPError on server error."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/auth/check", status_code=401)
with self.assertRaises(requests.exceptions.HTTPError):
api.check_auth()

def test_get_list_organizations_error_raises(self):
"""Test that get_list_organizations raises HTTPError on server error."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/organizations", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.get_list_organizations()

def test_get_organization_error_raises(self):
"""Test that get_organization raises HTTPError on server error."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/organizations/org-id", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.get_organization("org-id")

def test_list_projects_error_raises(self):
"""Test that list_projects_from_organization raises HTTPError."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/organizations/org-id/projects", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.list_projects_from_organization("org-id")

def test_get_project_error_raises(self):
"""Test that get_project raises HTTPError on server error."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/projects/proj-id", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.get_project("proj-id")

def test_list_experiments_error_raises(self):
"""Test that list_experiments_from_project raises HTTPError."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/projects/proj-id/experiments", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.list_experiments_from_project("proj-id")

def test_get_experiment_error_raises(self):
"""Test that get_experiment raises HTTPError on server error."""
api = ApiClient(
endpoint_url="http://test.com",
api_key="Toto",
conf=None,
create_run_automatically=False,
)
with requests_mock.Mocker() as m:
m.get("http://test.com/experiments/exp-id", status_code=500)
with self.assertRaises(requests.exceptions.HTTPError):
api.get_experiment("exp-id")