Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions akd_ext/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
RepositorySearchToolOutputSchema,
RepositorySearchToolConfig,
)
from .eie import (
STACItemStatsTool,
STACItemStatsToolConfig,
STACItemStatsInputSchema,
STACItemStatsOutputSchema,
ItemStats,
)

__all__ = [
"DummyTool",
Expand All @@ -38,4 +45,9 @@
"RepositorySearchToolInputSchema",
"RepositorySearchToolOutputSchema",
"RepositorySearchToolConfig",
"STACItemStatsTool",
"STACItemStatsToolConfig",
"STACItemStatsInputSchema",
"STACItemStatsOutputSchema",
"ItemStats",
]
17 changes: 17 additions & 0 deletions akd_ext/tools/eie/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""EIE-specific tools for akd_ext."""

from .stac_stats import (
STACItemStatsTool,
STACItemStatsToolConfig,
STACItemStatsInputSchema,
STACItemStatsOutputSchema,
ItemStats,
)

__all__ = [
"STACItemStatsTool",
"STACItemStatsToolConfig",
"STACItemStatsInputSchema",
"STACItemStatsOutputSchema",
"ItemStats",
]
123 changes: 123 additions & 0 deletions akd_ext/tools/eie/stac_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Tool to calculate statistics for STAC items.

"""

from pydantic import BaseModel, Field
from akd._base import InputSchema, OutputSchema
from akd.tools import BaseTool, BaseToolConfig
from akd_ext.mcp import mcp_tool
from .utils import fetch_statistics_batch


class StacItemInfo(BaseModel):
"""Info about a STAC item including its COG asset URL."""

id: str = Field(description="Item ID")
collection: str | None = Field(default=None, description="Collection ID")
datetime: str | None = Field(default=None, description="Item datetime")
asset_url: str | None = Field(default=None, description="URL to the COG asset")


class PlaceResult(BaseModel):
"""Result from place resolution."""

place: str | None = Field(description="Resolved place name")
bbox: list[float] | None = Field(description="Bounding box [west, south, east, north]")
geometry: dict | None = Field(default=None, description="GeoJSON geometry (Polygon/MultiPolygon) for the place")
error: str | None = Field(default=None, description="Error message if resolution failed")


class StacSearchResult(BaseModel):
"""Result from STAC search."""

item_ids: list[str] = Field(default_factory=list, description="Found item IDs")
items: list[StacItemInfo] = Field(default_factory=list, description="Item details with COG asset URLs")
count: int = Field(default=0, description="Total number of items found")
error: str | None = Field(default=None, description="Error message if search failed")


class STACItemStatsInputSchema(InputSchema):
"""Input schema for the STACItemStatsTool."""

stac_result: StacSearchResult | None = Field(
default=None,
description="Result from stac_search tool with items and asset URLs",
)
place_result: PlaceResult | None = Field(
default=None,
description="Result from get_place tool with bbox and geometry",
)


class ItemStats(BaseModel):
"""Statistics for a single COG item."""

id: str | None = Field(default=None, description="Item ID")
datetime: str | None = Field(default=None, description="Item datetime")
statistics: dict = Field(default_factory=dict, description="Per-band statistics")
error: str | None = Field(default=None, description="Error if this item failed")


class STACItemStatsOutputSchema(OutputSchema):
"""Output schema for the STACItemStatsTool."""

items: list[ItemStats] = Field(default_factory=list, description="Statistics for each item")
error: str | None = Field(default=None, description="Error message if request failed")


class STACItemStatsToolConfig(BaseToolConfig):
"""Config for the STACItemStatsTool."""

raster_api_url: str = Field(default="https://dev.openveda.cloud/api/raster", description="Raster API URL")


@mcp_tool
class STACItemStatsTool(BaseTool[STACItemStatsInputSchema, STACItemStatsOutputSchema]):
"""
Compute statistics for STAC items over a bounding box.
"""

input_schema = STACItemStatsInputSchema
output_schema = STACItemStatsOutputSchema
config_schema = STACItemStatsToolConfig

async def _arun(self, params: STACItemStatsInputSchema) -> STACItemStatsOutputSchema:
"""Compute statistics for STAC items over a bounding box."""
try:
if not params.stac_result or not params.stac_result.items:
return STACItemStatsOutputSchema(items=[], error="No items from stac_search. Run stac_search first.")
if not params.place_result or not params.place_result.geometry:
return STACItemStatsOutputSchema(items=[], error="No geometry from get_place. Run get_place first.")

# Convert state items to batch format
batch_items = [
{"url": it.asset_url, "id": it.id, "datetime": it.datetime}
for it in params.stac_result.items
if it.asset_url
]

if not batch_items:
return STACItemStatsOutputSchema(items=[], error="No items with asset URLs found.")

results = fetch_statistics_batch(
items=batch_items,
geometry=params.place_result.geometry,
dst_crs="+proj=cea",
raster_api_url=self.config.raster_api_url,
)

item_stats = [
ItemStats(
id=r.get("id"),
datetime=r.get("datetime"),
statistics=r.get("statistics", {}),
error=r.get("error"),
)
for r in results
]

return STACItemStatsOutputSchema(items=item_stats)

except Exception as e:
return STACItemStatsOutputSchema(items=[], error=str(e))
106 changes: 106 additions & 0 deletions akd_ext/tools/eie/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Utility functions for EIE tools."""

import httpx
import concurrent.futures


def fetch_statistics_batch(
items: list[dict],
geometry: dict,
dst_crs: str = "+proj=cea",
raster_api_url: str | None = None,
timeout: float = 60.0,
) -> list[dict]:
"""Fetch raster statistics for multiple COGs in parallel.

Args:
items: List of dicts with 'url' and optionally 'datetime', 'id'
geometry: GeoJSON geometry (Polygon or MultiPolygon) to clip the raster
dst_crs: Destination CRS for area-weighted stats
raster_api_url: Base URL for the VEDA raster API
timeout: HTTP request timeout in seconds

Returns:
List of dicts with 'url', 'datetime', 'statistics', 'error' for each item
"""
if not items:
return []

def fetch_one(item: dict) -> dict:
url = item.get("url")
item_id = item.get("id")
result = fetch_statistics(
url=url,
geometry=geometry,
dst_crs=dst_crs,
raster_api_url=raster_api_url,
timeout=30.0, # Reduced timeout per item
)
return {
"url": url,
"id": item_id,
"datetime": item.get("datetime"),
"statistics": result.get("statistics", {}),
"error": result.get("error"),
}

# Fetch in parallel (max 5 concurrent)
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
results = list(executor.map(fetch_one, items))

return results


def fetch_statistics(
url: str,
geometry: dict,
dst_crs: str = "+proj=cea",
raster_api_url: str | None = None,
timeout: float = 60.0,
) -> dict:
"""Fetch raster statistics from the VEDA raster API.

Args:
url: URL to the COG file (S3 or HTTP)
geometry: GeoJSON geometry (Polygon or MultiPolygon) to clip the raster
dst_crs: Destination CRS for area-weighted stats (default: Equal Area)
raster_api_url: Base URL for the VEDA raster API
timeout: HTTP request timeout in seconds

Returns:
Dict with 'statistics' (per-band stats) and optional 'error'
"""
if not geometry:
return {"statistics": {}, "error": "geometry is required"}

endpoint = f"{raster_api_url.rstrip('/')}/cog/statistics"

geojson_feature = {
"type": "Feature",
"properties": {},
"geometry": geometry,
}

try:
with httpx.Client(timeout=timeout) as client:
response = client.post(
endpoint,
params={"url": url, "dst_crs": dst_crs},
json=geojson_feature,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
data = response.json()

# Parse response - VEDA returns {"properties": {"statistics": {...}}}
properties = data.get("properties", data)
stats_data = properties.get("statistics", properties)

return {"statistics": stats_data, "error": None}

except httpx.TimeoutException:
return {"statistics": {}, "error": f"Request timed out after {timeout}s"}
except httpx.HTTPStatusError as e:
return {"statistics": {}, "error": f"HTTP {e.response.status_code}: {e.response.text[:200]}"}
except Exception as e:
return {"statistics": {}, "error": str(e)}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dev = [
"pytest-cov>=6.0.0",
"pre-commit>=4.2.0",
"PyGithub>=2.1.1",
"httpx>=0.27.0"
]

[build-system]
Expand Down