Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mp_api/client/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from .client import BaseRester
from .exceptions import MPRestError, MPRestWarning
from .settings import MAPIClientSettings
from .settings import MAPI_CLIENT_SETTINGS, MAPIClientSettings
161 changes: 90 additions & 71 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import gzip
import inspect
import itertools
import os
Expand All @@ -16,30 +17,32 @@
from functools import cache
from importlib import import_module
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from json import JSONDecodeError
from math import ceil
from typing import TYPE_CHECKING, ForwardRef, Optional, get_args
from urllib.parse import quote, urljoin
from urllib.parse import quote

import boto3
import requests
from botocore import UNSIGNED
from botocore.config import Config
from botocore.exceptions import ClientError
from emmet.core.utils import jsanitize
from pydantic import BaseModel, create_model
from requests.adapters import HTTPAdapter
from requests.exceptions import RequestException
from smart_open import open
from tqdm.auto import tqdm
from urllib3.util.retry import Retry

from mp_api.client.core.exceptions import MPRestError
from mp_api.client.core.settings import MAPIClientSettings
from mp_api.client.core.utils import load_json, validate_ids

try:
import boto3
from botocore import UNSIGNED
from botocore.config import Config
except ImportError:
boto3 = None
from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
from mp_api.client.core.utils import (
load_json,
validate_api_key,
validate_endpoint,
validate_ids,
)

try:
import flask
Expand All @@ -51,15 +54,14 @@

from pydantic.fields import FieldInfo

from mp_api.client.core.utils import LazyImport

try:
__version__ = version("mp_api")
except PackageNotFoundError: # pragma: no cover
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION")


SETTINGS = MAPIClientSettings() # type: ignore


class _DictLikeAccess(BaseModel):
"""Define a pydantic mix-in which permits dict-like access to model fields."""

Expand All @@ -82,7 +84,6 @@ class BaseRester:

suffix: str = ""
document_model: type[BaseModel] | None = None
supports_versions: bool = False
primary_key: str = "material_id"

def __init__(
Expand All @@ -96,7 +97,7 @@ def __init__(
use_document_model: bool = True,
timeout: int = 20,
headers: dict | None = None,
mute_progress_bars: bool = SETTINGS.MUTE_PROGRESS_BARS,
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
**kwargs,
):
"""Initialize the REST API helper class.
Expand Down Expand Up @@ -130,23 +131,17 @@ def __init__(
mute_progress_bars: Whether to disable progress bars.
**kwargs: access to legacy kwargs that may be in the process of being deprecated
"""
# TODO: think about how to migrate from PMG_MAPI_KEY
self.api_key = api_key or os.getenv("MP_API_KEY")
self.base_endpoint = self.endpoint = endpoint or os.getenv(
"MP_API_ENDPOINT", "https://api.materialsproject.org/"
)
self.api_key = validate_api_key(api_key)
self.base_endpoint = validate_endpoint(endpoint)
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)

self.debug = debug
self.include_user_agent = include_user_agent
self.use_document_model = use_document_model
self.timeout = timeout
self.headers = headers or {}
self.mute_progress_bars = mute_progress_bars
self.db_version = BaseRester._get_database_version(self.endpoint)

if self.suffix:
self.endpoint = urljoin(self.endpoint, self.suffix)
if not self.endpoint.endswith("/"):
self.endpoint += "/"
self.db_version = BaseRester._get_database_version(self.base_endpoint)

self._session = session
self._s3_client = s3_client
Expand All @@ -167,13 +162,6 @@ def session(self) -> requests.Session:

@property
def s3_client(self):
if boto3 is None:
raise MPRestError(
"boto3 not installed. To query charge density, "
"band structure, or density of states data first "
"install with: 'pip install boto3'"
)

if not self._s3_client:
self._s3_client = boto3.client(
"s3",
Expand All @@ -194,15 +182,14 @@ def _create_session(api_key, include_user_agent, headers):
user_agent = f"{mp_api_info} ({python_info} {platform_info})"
session.headers["user-agent"] = user_agent

settings = MAPIClientSettings() # type: ignore
max_retry_num = settings.MAX_RETRIES
max_retry_num = MAPI_CLIENT_SETTINGS.MAX_RETRIES
retry = Retry(
total=max_retry_num,
read=max_retry_num,
connect=max_retry_num,
respect_retry_after_header=True,
status_forcelist=[429, 504, 502], # rate limiting
backoff_factor=settings.BACKOFF_FACTOR,
backoff_factor=MAPI_CLIENT_SETTINGS.BACKOFF_FACTOR,
)
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
Expand Down Expand Up @@ -263,11 +250,7 @@ def _post_resource(
payload = jsanitize(body)

try:
url = self.endpoint
if suburl:
url = urljoin(self.endpoint, suburl)
if not url.endswith("/"):
url += "/"
url = validate_endpoint(self.endpoint, suffix=suburl)
response = self.session.post(url, json=payload, verify=True, params=params)

if response.status_code == 200:
Expand Down Expand Up @@ -331,11 +314,7 @@ def _patch_resource(
payload = jsanitize(body)

try:
url = self.endpoint
if suburl:
url = urljoin(self.endpoint, suburl)
if not url.endswith("/"):
url += "/"
url = validate_endpoint(self.endpoint, suffix=suburl)
response = self.session.patch(url, json=payload, verify=True, params=params)

if response.status_code == 200:
Expand Down Expand Up @@ -390,20 +369,31 @@ def _query_open_data(
Returns:
dict: MontyDecoded data
"""
decoder = decoder or load_json
try:
byio = BytesIO()
self.s3_client.download_fileobj(bucket, key, byio)
byio.seek(0)
if (file_data := byio.read()).startswith(b"\x1f\x8b"):
file_data = gzip.decompress(file_data)
byio.close()

file = open(
f"s3://{bucket}/{key}",
encoding="utf-8",
transport_params={"client": self.s3_client},
)
decoder = decoder or load_json

if "jsonl" in key:
decoded_data = [decoder(jline) for jline in file.read().splitlines()]
else:
decoded_data = decoder(file.read())
if not isinstance(decoded_data, list):
decoded_data = [decoded_data]
if "jsonl" in key:
decoded_data = [decoder(jline) for jline in file_data.splitlines()]
else:
decoded_data = decoder(file_data)
if not isinstance(decoded_data, list):
decoded_data = [decoded_data]

raise_error = not decoded_data or len(decoded_data) == 0

except ClientError:
# No such object exists
raise_error = True

if raise_error:
raise MPRestError(f"No object found: s3://{bucket}/{key}")

return decoded_data, len(decoded_data) # type: ignore

Expand Down Expand Up @@ -467,14 +457,9 @@ def _query_resource(
criteria["_fields"] = ",".join(fields)

try:
url = self.endpoint
if suburl:
url = urljoin(self.endpoint, suburl)
if not url.endswith("/"):
url += "/"
url = validate_endpoint(self.endpoint, suffix=suburl)

if query_s3:
db_version = self.db_version.replace(".", "-")
if "/" not in self.suffix:
suffix = self.suffix
elif self.suffix == "molecules/summary":
Expand All @@ -490,7 +475,7 @@ def _query_resource(
bucket_suffix, prefix = "parsed", "tasks_atomate2"
else:
bucket_suffix = "build"
prefix = f"collections/{db_version}/{suffix}"
prefix = f"collections/{self.db_version.replace('.', '-')}/{suffix}"

bucket = f"materialsproject-{bucket_suffix}"
paginator = self.s3_client.get_paginator("list_objects_v2")
Expand Down Expand Up @@ -618,15 +603,15 @@ def _submit_requests( # noqa

bare_url_len = len(url_string)
max_param_str_length = (
MAPIClientSettings().MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
MAPI_CLIENT_SETTINGS.MAX_HTTP_URL_LENGTH - bare_url_len # type: ignore
)

# Next, check if default number of parallel requests works.
# If not, make slice size the minimum number of param entries
# contained in any substring of length max_param_str_length.
param_length = len(criteria[parallel_param].split(","))
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1 # type: ignore
int(param_length / MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS) or 1 # type: ignore
)

url_param_string = quote(criteria[parallel_param])
Expand Down Expand Up @@ -907,14 +892,14 @@ def _multi_thread(
params_ind = 0

with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS # type: ignore
max_workers=MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS # type: ignore
) as executor:
# Get list of initial futures defined by max number of parallel requests
futures = set()

for params in itertools.islice(
params_gen,
MAPIClientSettings().NUM_PARALLEL_REQUESTS, # type: ignore
MAPI_CLIENT_SETTINGS.NUM_PARALLEL_REQUESTS, # type: ignore
):
future = executor.submit(
func,
Expand Down Expand Up @@ -1276,7 +1261,7 @@ def _get_all_documents(
for key, entry in query_params.items()
if isinstance(entry, str)
and len(entry.split(",")) > 0
and key not in MAPIClientSettings().QUERY_NO_PARALLEL # type: ignore
and key not in MAPI_CLIENT_SETTINGS.QUERY_NO_PARALLEL # type: ignore
),
key=lambda item: item[1],
reverse=True,
Expand Down Expand Up @@ -1351,3 +1336,37 @@ def __str__(self): # pragma: no cover
f"{self.__class__.__name__} connected to {self.endpoint}\n\n"
f"Available fields: {', '.join(self.available_fields)}\n\n"
)


class CoreRester(BaseRester):
"""Define a BaseRester with extra features for core resters.

Enables lazy importing / initialization of sub resters
provided in `_sub_resters`, which should be a map
of endpoints names to LazyImport objects.

"""

_sub_resters: dict[str, LazyImport] = {}

def __init__(self, **kwargs):
"""Ensure that sub resters are unset on re-init."""
super().__init__(**kwargs)
self.sub_resters = {k: v.copy() for k, v in self._sub_resters.items()}

def __getattr__(self, v: str):
if v in self.sub_resters:
if self.sub_resters[v]._obj is None:
self.sub_resters[v](
api_key=self.api_key,
endpoint=self.base_endpoint,
include_user_agent=self._include_user_agent,
session=self.session,
use_document_model=self.use_document_model,
headers=self.headers,
mute_progress_bars=self.mute_progress_bars,
)
return self.sub_resters[v]

def __dir__(self):
return dir(self.__class__) + list(self._sub_resters)
29 changes: 21 additions & 8 deletions mp_api/client/core/settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from multiprocessing import cpu_count
from typing import List

from emmet.core.settings import EmmetSettings
from pydantic import Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pymatgen.core import _load_pmg_settings
Expand All @@ -14,12 +13,9 @@
_MUTE_PROGRESS_BAR = PMG_SETTINGS.get("MPRESTER_MUTE_PROGRESS_BARS", False)
_MAX_HTTP_URL_LENGTH = PMG_SETTINGS.get("MPRESTER_MAX_HTTP_URL_LENGTH", 2000)
_MAX_LIST_LENGTH = min(PMG_SETTINGS.get("MPRESTER_MAX_LIST_LENGTH", 10000), 10000)
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"

try:
CPU_COUNT = cpu_count()
except NotImplementedError:
pass
_EMMET_SETTINGS = EmmetSettings()
_DEFAULT_ENDPOINT = "https://api.materialsproject.org/"


class MAPIClientSettings(BaseSettings):
Expand All @@ -32,7 +28,7 @@ class MAPIClientSettings(BaseSettings):
description="Directory with test files",
)

QUERY_NO_PARALLEL: List[str] = Field(
QUERY_NO_PARALLEL: list[str] = Field(
[
"elements",
"exclude_elements",
Expand Down Expand Up @@ -93,9 +89,26 @@ class MAPIClientSettings(BaseSettings):
_DEFAULT_ENDPOINT, description="The default API endpoint to use."
)

LTOL: float = Field(
_EMMET_SETTINGS.LTOL,
description="Fractional length tolerance for structure matching",
)

STOL: float = Field(
_EMMET_SETTINGS.STOL, description="Site tolerance for structure matching."
)

ANGLE_TOL: float = Field(
_EMMET_SETTINGS.ANGLE_TOL,
description="Angle tolerance for structure matching in degrees.",
)

model_config = SettingsConfigDict(env_prefix="MPRESTER_")

@field_validator("ENDPOINT", mode="before")
def _get_endpoint_from_env(cls, v: str | None) -> str:
"""Support setting endpoint via MP_API_ENDPOINT environment variable."""
return v or os.environ.get("MP_API_ENDPOINT") or _DEFAULT_ENDPOINT


MAPI_CLIENT_SETTINGS = MAPIClientSettings()
Loading