Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions datacommons_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from datacommons_client.endpoints.node import NodeEndpoint
from datacommons_client.endpoints.observation import ObservationEndpoint
from datacommons_client.endpoints.resolve import ResolveEndpoint
from datacommons_client.utils.context import use_api_key

__all__ = [
"DataCommonsClient",
"API",
"NodeEndpoint",
"ObservationEndpoint",
"ResolveEndpoint",
"use_api_key",
]
10 changes: 9 additions & 1 deletion datacommons_client/endpoints/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
from typing import Any, Dict, Optional

from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
from datacommons_client.utils.request_handling import check_instance_is_valid
from datacommons_client.utils.request_handling import post_request
from datacommons_client.utils.request_handling import resolve_instance_url
Expand Down Expand Up @@ -94,9 +95,16 @@ def post(self,

url = (self.base_url if endpoint is None else f"{self.base_url}/{endpoint}")

headers = self.headers
ctx_api_key = _API_KEY_CONTEXT_VAR.get()
if ctx_api_key:
# Copy headers to avoid mutating the shared client state
headers = self.headers.copy()
headers["X-API-Key"] = ctx_api_key

return post_request(url=url,
payload=payload,
headers=self.headers,
headers=headers,
all_pages=all_pages,
next_token=next_token)

Expand Down
8 changes: 6 additions & 2 deletions datacommons_client/endpoints/node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
import contextvars
from functools import partial
from functools import wraps
from typing import Literal, Optional
Expand Down Expand Up @@ -447,10 +448,13 @@ def _fetch_place_relationships(
)

# Use a thread pool to fetch ancestry graphs in parallel for each input entity
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
futures = [
executor.submit(build_graph_map, root=dcid, fetch_fn=fetch_fn)
for dcid in place_dcids
executor.submit(ctx.run,
build_graph_map,
root=dcid,
fetch_fn=fetch_fn) for dcid in place_dcids
]
# Gather ancestry maps and postprocess into flat or nested form
for future in futures:
Expand Down
52 changes: 52 additions & 0 deletions datacommons_client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import pytest

from datacommons_client import use_api_key
from datacommons_client.client import DataCommonsClient
from datacommons_client.endpoints.base import API
from datacommons_client.endpoints.node import NodeEndpoint
Expand Down Expand Up @@ -419,3 +420,54 @@ def test_client_end_to_end_surface_header_propagation_observation(
assert headers is not None
assert headers.get("x-surface") == "datagemma"
assert headers.get("X-API-Key") == "test_key"


@patch("datacommons_client.endpoints.base.post_request")
def test_use_api_key_with_observation_fetch(mock_post_request):
"""Test use_api_key override for observation fetches (non-threaded)."""

# Setup client with default key
client = DataCommonsClient(api_key="default-key")

# Configure mock to return valid response structure
mock_post_request.return_value = {"byVariable": {}, "facets": {}}

# Default usage
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
mock_post_request.assert_called()
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"

# Context override
with use_api_key("context-key"):
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "context-key"

# Back to default
client.observation.fetch(variable_dcids="sv1", entity_dcids=["geo1"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"


@patch("datacommons_client.endpoints.base.post_request")
def test_use_api_key_with_node_fetch_place_ancestors(mock_post_request):
"""Test use_api_key propagation for node graph methods (threaded)."""

client = DataCommonsClient(api_key="default-key")

# Configure mock. fetch_place_ancestors expects a dict response or list of nodes.
# NodeResponse.data is a dict.
mock_post_request.return_value = {"data": {}}

# Default usage
client.node.fetch_place_ancestors(place_dcids=["geoId/06"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "default-key"

# Context override
with use_api_key("context-key"):
# Use a different DCID to avoid hitting fetch_relationship_lru cache
client.node.fetch_place_ancestors(place_dcids=["geoId/07"])
_, kwargs = mock_post_request.call_args
assert kwargs["headers"]["X-API-Key"] == "context-key"
43 changes: 43 additions & 0 deletions datacommons_client/tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datacommons_client.utils.context import _API_KEY_CONTEXT_VAR
from datacommons_client.utils.context import use_api_key


def test_use_api_key_sets_var():
"""Test that use_api_key sets the context variable."""
assert _API_KEY_CONTEXT_VAR.get() is None
with use_api_key("test-key"):
assert _API_KEY_CONTEXT_VAR.get() == "test-key"
assert _API_KEY_CONTEXT_VAR.get() is None


def test_use_api_key_nested():
"""Test nested usage of use_api_key."""
with use_api_key("outer"):
assert _API_KEY_CONTEXT_VAR.get() == "outer"
with use_api_key("inner"):
assert _API_KEY_CONTEXT_VAR.get() == "inner"
assert _API_KEY_CONTEXT_VAR.get() == "outer"
assert _API_KEY_CONTEXT_VAR.get() is None


def test_use_api_key_none():
"""Test that use_api_key with None/empty does not set the variable."""
assert _API_KEY_CONTEXT_VAR.get() is None
with use_api_key(None):
assert _API_KEY_CONTEXT_VAR.get() is None
with use_api_key(""):
assert _API_KEY_CONTEXT_VAR.get() is None
56 changes: 56 additions & 0 deletions datacommons_client/utils/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator

_API_KEY_CONTEXT_VAR: ContextVar[str | None] = ContextVar("api_key",
default=None)


@contextmanager
def use_api_key(api_key: str | None) -> Generator[None, None, None]:
"""Context manager to set the API key for the current execution context.

If api_key is None or empty, this context manager does nothing, allowing
the underlying client to use its default API key.

Args:
api_key: The API key to use. If None or empty, no change is made.

Example:
from datacommons_client import use_api_key
# ...
client = DataCommonsClient(api_key="default-key")

# Uses "default-key"
client.observation.fetch(...)

with use_api_key("temp-key"):
# Uses "temp-key"
client.observation.fetch(...)

# Back to "default-key"
client.observation.fetch(...)
"""
if not api_key:
yield
return

token = _API_KEY_CONTEXT_VAR.set(api_key)
try:
yield
finally:
_API_KEY_CONTEXT_VAR.reset(token)
4 changes: 3 additions & 1 deletion datacommons_client/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
import contextvars
from functools import lru_cache
from typing import Callable, Literal, Optional, TypeAlias

Expand Down Expand Up @@ -108,6 +109,7 @@ def build_graph_map(

original_root = root

ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
queue = deque([root])

Expand All @@ -119,7 +121,7 @@ def build_graph_map(
# Check if the node has already been visited or is in progress
if dcid not in visited and dcid not in in_progress:
# Submit the fetch task
in_progress[dcid] = executor.submit(fetch_fn, dcid=dcid)
in_progress[dcid] = executor.submit(ctx.run, fetch_fn, dcid=dcid)

# Check if any futures are still in progress
if not in_progress:
Expand Down