Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mp_api/client/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,10 @@ def _submit_requests( # noqa
# No splitting needed - get first page
total_data = {"data": []}
initial_criteria = copy(criteria)
if isinstance(
initial_criteria.get("_page"), int
) and not initial_criteria.get("_per_page"):
initial_criteria["_per_page"] = initial_criteria.get("_limit")
data, total_num_docs = self._submit_request_and_process(
url=url,
verify=True,
Expand Down Expand Up @@ -1438,6 +1442,9 @@ def _search(
# This method should be customized for each end point to give more user friendly,
# documented kwargs.

# If user specifies page, ensure only one chunk is returned
if isinstance(kwargs.get("_page"), int) and num_chunks is None:
num_chunks = 1
return self._get_all_documents(
kwargs,
all_fields=all_fields,
Expand Down
6 changes: 5 additions & 1 deletion mp_api/client/routes/materials/electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BaseElectrodeRester(BaseRester):
primary_key = "battery_id"
_exclude_search_fields: list[str] | None = None

def search( # pragma: ignore
def search(
self,
battery_ids: str | list[str] | None = None,
average_voltage: tuple[float, float] | None = None,
Expand All @@ -39,6 +39,8 @@ def search( # pragma: ignore
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
_page: int | None = None,
_sort_fields: str | None = None,
) -> list[InsertionElectrodeDoc | ConversionElectrodeDoc] | list[dict]:
"""Query using a variety of search criteria.

Expand Down Expand Up @@ -77,6 +79,8 @@ def search( # pragma: ignore
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in InsertionElectrodeDoc or ConversionElectrodeDoc to return data for.
Default is battery_id and last_updated if all_fields is False.
_page (int or None) : Page of the results to skip to.
_sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order.

Returns:
([InsertionElectrodeDoc or ConversionElectrodeDoc], [dict]) List of insertion/conversion electrode documents or dictionaries.
Expand Down
25 changes: 17 additions & 8 deletions mp_api/client/routes/materials/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def search( # noqa: D417
chunk_size: int = 1000,
all_fields: bool = True,
fields: list[str] | None = None,
_page: int | None = None,
_sort_fields: str | None = None,
**kwargs,
) -> list[SummaryDoc] | list[dict]:
"""Query core data using a variety of search criteria.
Expand Down Expand Up @@ -150,6 +152,8 @@ def search( # noqa: D417
all_fields (bool): Whether to return all fields in the document. Defaults to True.
fields (List[str]): List of fields in SummaryDoc to return data for.
Default is material_id if all_fields is False.
_page (int or None) : Page of the results to skip to.
_sort_fields (str or None) : Field to sort on. Including a leading "-" sign will reverse sort order.

Returns:
([SummaryDoc], [dict]) List of SummaryDoc documents or dictionaries.
Expand Down Expand Up @@ -181,6 +185,8 @@ def search( # noqa: D417
"weighted_surface_energy",
"weighted_work_function",
"shape_factor",
"_page",
"_sort_fields",
]

min_max_name_dict = {
Expand Down Expand Up @@ -284,14 +290,17 @@ def _csrc(x):
)

for param, value in user_settings.items():
if isinstance(value, (int, float)):
value = (value, value)
query_params.update(
{
f"{min_max_name_dict[param]}_min": value[0],
f"{min_max_name_dict[param]}_max": value[1],
}
)
if param.startswith("_"):
query_params[param] = value
else:
if isinstance(value, (int, float)):
value = (value, value)
query_params.update(
{
f"{min_max_name_dict[param]}_min": value[0],
f"{min_max_name_dict[param]}_max": value[1],
}
)

if material_ids:
if isinstance(material_ids, str):
Expand Down
30 changes: 30 additions & 0 deletions tests/client/materials/test_electrodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def conversion_rester():
"num_chunks",
"all_fields",
"fields",
"_page",
"_sort_fields",
]

sub_doc_fields: list = []
Expand Down Expand Up @@ -80,3 +82,31 @@ def test_conversion_client(conversion_rester):
},
sub_doc_fields=sub_doc_fields,
)


@pytest.mark.xfail(reason="sort requires API redeployment", strict=False)
@requires_api_key
def test_pagination_sort():
num_docs = 5
with ElectrodeRester() as rester:
results_page_1 = rester.search(_page=1, chunk_size=num_docs)
results_page_2 = rester.search(_page=2, chunk_size=num_docs)
assert all(
len(results) == num_docs for results in (results_page_1, results_page_2)
)
assert {doc.battery_id for doc in results_page_1}.intersection(
{doc.battery_id for doc in results_page_2}
) == set()

ascending_e_hull = rester.search(_page=1, _sort_fields="average_voltage")
descending_e_hull = rester.search(_page=1, _sort_fields="-average_voltage")

assert sorted(
range(num_docs), key=lambda idx: ascending_e_hull[idx].average_voltage
) == list(range(num_docs))

assert sorted(
range(num_docs),
key=lambda idx: descending_e_hull[idx].average_voltage,
reverse=True,
) == list(range(num_docs))
32 changes: 31 additions & 1 deletion tests/client/materials/test_summary.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import os
from ..conftest import client_search_testing, requires_api_key

import pytest
from emmet.core.summary import HasProps
from emmet.core.symmetry import CrystalSystem
import numpy as np
from pymatgen.analysis.magnetism import Ordering
import pytest

from mp_api.client.routes.materials.summary import SummaryRester
from mp_api.client.core.exceptions import MPRestWarning, MPRestError
Expand All @@ -16,6 +17,8 @@
"num_chunks",
"all_fields",
"fields",
"_page",
"_sort_fields",
]

alt_name_dict: dict = {
Expand Down Expand Up @@ -134,3 +137,30 @@ def test_warning_messages():

with pytest.raises(MPRestError, match="not a valid property"):
_ = search_method(num_elements=10, has_props=["apples"])


@requires_api_key
def test_pagination_sort():
num_docs = 5
with SummaryRester() as rester:
results_page_1 = rester.search(_page=1, chunk_size=num_docs)
results_page_2 = rester.search(_page=2, chunk_size=num_docs)
assert all(
len(results) == num_docs for results in (results_page_1, results_page_2)
)
assert {doc.material_id for doc in results_page_1}.intersection(
{doc.material_id for doc in results_page_2}
) == set()

ascending_e_hull = rester.search(_page=1, _sort_fields="energy_above_hull")
descending_e_hull = rester.search(_page=1, _sort_fields="-energy_above_hull")

assert sorted(
range(num_docs), key=lambda idx: ascending_e_hull[idx].energy_above_hull
) == list(range(num_docs))

assert sorted(
range(num_docs),
key=lambda idx: descending_e_hull[idx].energy_above_hull,
reverse=True,
) == list(range(num_docs))
Loading