Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: feature
packages:
- "@typespec/http-client-python"
---

Add support for xml paging
10 changes: 9 additions & 1 deletion packages/http-client-python/emitter/src/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,15 @@ function getWireNameFromPropertySegments(
if (segments[0].kind === "property") {
return segments
.filter((s) => s.kind === "property")
.map((s) => s.serializationOptions.json?.name ?? "")
.map((s) => {
if (s.serializationOptions.json) {
return s.serializationOptions.json.name;
}
if (s.serializationOptions.xml) {
return s.serializationOptions.xml.name;
}
return "";
})
.join(".");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def has_continuation_token(self) -> bool:
def next_variable_name(self) -> str:
return "_continuation_token" if self.has_continuation_token else "next_link"

@property
def is_xml_paging(self) -> bool:
try:
return self.responses[0].item_type.xml_metadata is not None
except KeyError:
return False

def _get_attr_name(self, wire_name: str) -> str:
response_type = self.responses[0].type
if not response_type:
Expand Down Expand Up @@ -176,6 +183,9 @@ def imports(self, async_mode: bool, **kwargs: Any) -> FileImport:
file_import.merge(self.item_type.imports(**kwargs))
if self.default_error_deserialization(serialize_namespace) or self.need_deserialize:
file_import.add_submodule_import(relative_path, "_deserialize", ImportType.LOCAL)
if self.is_xml_paging:
file_import.add_submodule_import("xml.etree", "ElementTree", ImportType.STDLIB, alias="ET")
file_import.add_submodule_import(relative_path, "_convert_element", ImportType.LOCAL)
return file_import


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1372,10 +1372,13 @@ def _prepare_request_callback(self, builder: PagingOperationType) -> list[str]:
def _function_def(self) -> str:
return "def"

def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: # pylint: disable=too-many-statements
def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: # pylint: disable=too-many-statements,too-many-branches
retval = [f"{'async ' if self.async_mode else ''}def extract_data(pipeline_response):"]
response = builder.responses[0]
deserialized = "pipeline_response.http_response.json()"
if builder.is_xml_paging:
deserialized = "ET.fromstring(pipeline_response.http_response.text())"
else:
deserialized = "pipeline_response.http_response.json()"
if self.code_model.options["models-mode"] == "msrest":
suffix = ".http_response" if hasattr(builder, "initial_operation") else ""
deserialize_type = response.serialization_type(serialize_namespace=self.serialize_namespace)
Expand All @@ -1395,6 +1398,10 @@ def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: #
item_name = builder.item_name
if self.code_model.options["models-mode"] == "msrest":
access = f".{item_name}"
elif builder.is_xml_paging:
# For XML, use .find() to navigate the element tree
item_name_array = item_name.split(".")
access = "".join([f'.find("{i}")' for i in item_name_array])
else:
item_name_array = item_name.split(".")
access = (
Expand All @@ -1412,11 +1419,17 @@ def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: #
retval.append(" if cls:")
retval.append(" list_of_elem = cls(list_of_elem) # type: ignore")

cont_token_expr: Optional[str] = None # For XML, we need to extract find() result first
if builder.has_continuation_token:
location = builder.continuation_token.get("output", {}).get("location")
wire_name = builder.continuation_token.get("output", {}).get("wireName") or ""
if location == "header":
cont_token_property = f'pipeline_response.http_response.headers.get("{wire_name}") or None'
elif builder.is_xml_paging:
wire_name_array = wire_name.split(".")
wire_name_call = "".join([f'.find("{i}")' for i in wire_name_array])
cont_token_expr = f"deserialized{wire_name_call}"
cont_token_property = "_cont_token_elem.text if _cont_token_elem is not None else None"
else:
wire_name_array = wire_name.split(".")
wire_name_call = (
Expand All @@ -1429,6 +1442,11 @@ def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: #
cont_token_property = "None"
elif self.code_model.options["models-mode"] == "msrest":
cont_token_property = f"deserialized.{next_link_name} or None"
elif builder.is_xml_paging:
next_link_name_array = next_link_name.split(".")
access = "".join([f'.find("{i}")' for i in next_link_name_array])
cont_token_expr = f"deserialized{access}"
cont_token_property = "_cont_token_elem.text if _cont_token_elem is not None else None"
elif builder.next_link_is_nested:
next_link_name_array = next_link_name.split(".")
access = (
Expand All @@ -1439,6 +1457,8 @@ def _extract_data_callback(self, builder: PagingOperationType) -> list[str]: #
else:
cont_token_property = f'deserialized.get("{next_link_name}") or None'
list_type = "AsyncList" if self.async_mode else "iter"
if cont_token_expr:
retval.append(f" _cont_token_elem = {cont_token_expr}")
retval.append(f" return {cont_token_property}, {list_type}(list_of_elem)")
return retval

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ async def test_request_header_nested_response_body(client: PageableClient):
async def test_list_without_continuation(client: PageableClient):
result = [p async for p in client.page_size.list_without_continuation()]
assert_result(result)


@pytest.mark.asyncio
async def test_xml_pagination_list_with_next_link(client: PageableClient):
result = [p async for p in client.xml_pagination.list_with_next_link()]
assert_result(result)
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ def test_request_header_nested_response_body(client: PageableClient):
def test_list_without_continuation(client: PageableClient):
result = list(client.page_size.list_without_continuation())
assert_result(result)


def test_xml_pagination_list_with_next_link(client: PageableClient):
result = list(client.xml_pagination.list_with_next_link())
assert_result(result)
Loading