Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ global-exclude *.pyc
global-exclude *.pyo
global-exclude __pycache__
global-exclude .git*
global-exclude .github
global-exclude .pytest_cache
global-exclude .coverage
global-exclude htmlcov
Expand Down
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo
- **DB API 2.0 Compliant**: Full compatibility with Python Database API 2.0 specification
- **PartiQL-based SQL Syntax**: Built on [PartiQL](https://partiql.org/tutorial.html) (SQL for semi-structured data), enabling seamless SQL querying of nested and hierarchical MongoDB documents
- **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax
- **MongoDB Aggregate Pipeline Support**: Execute native MongoDB aggregation pipelines using SQL-like syntax with `aggregate()` function
- **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect
- **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases
- **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax
Expand Down Expand Up @@ -80,6 +81,7 @@ pip install -e .
- [WHERE Clauses](#where-clauses)
- [Nested Field Support](#nested-field-support)
- [Sorting and Limiting](#sorting-and-limiting)
- [MongoDB Aggregate Function](#mongodb-aggregate-function)
- [INSERT Statements](#insert-statements)
- [UPDATE Statements](#update-statements)
- [DELETE Statements](#delete-statements)
Expand Down Expand Up @@ -235,6 +237,61 @@ Parameters are substituted into the MongoDB filter during execution, providing p
- **LIMIT**: `LIMIT 10`
- **Combined**: `ORDER BY created_at DESC LIMIT 5`

### MongoDB Aggregate Function

PyMongoSQL supports executing native MongoDB aggregation pipelines using SQL-like syntax with the `aggregate()` function. This allows you to leverage MongoDB's powerful aggregation framework while maintaining SQL-style query patterns.

**Syntax**

The `aggregate()` function accepts two parameters:
- **pipeline**: JSON string representing the MongoDB aggregation pipeline
- **options**: JSON string for aggregation options (optional, use '{}' for defaults)

**Qualified Aggregate (Collection-Specific)**

```python
cursor.execute(
"SELECT * FROM users.aggregate('[{\"$match\": {\"age\": {\"$gt\": 25}}}, {\"$group\": {\"_id\": \"$city\", \"count\": {\"$sum\": 1}}}]', '{}')"
)
results = cursor.fetchall()
```

**Unqualified Aggregate (Database-Level)**

```python
cursor.execute(
"SELECT * FROM aggregate('[{\"$match\": {\"status\": \"active\"}}]', '{\"allowDiskUse\": true}')"
)
results = cursor.fetchall()
```

**Post-Aggregation Filtering and Sorting**

You can apply WHERE, ORDER BY, and LIMIT clauses after aggregation:

```python
# Filter aggregation results
cursor.execute(
"SELECT * FROM users.aggregate('[{\"$group\": {\"_id\": \"$city\", \"total\": {\"$sum\": 1}}}]', '{}') WHERE total > 100"
)

# Sort and limit aggregation results
cursor.execute(
"SELECT * FROM products.aggregate('[{\"$match\": {\"category\": \"Electronics\"}}]', '{}') ORDER BY price DESC LIMIT 10"
)
```

**Projection Support**

```python
# Select specific fields from aggregation results
cursor.execute(
"SELECT _id, total FROM users.aggregate('[{\"$group\": {\"_id\": \"$city\", \"total\": {\"$sum\": 1}}}]', '{}')"
)
```

**Note**: The pipeline and options must be valid JSON strings enclosed in single quotes. Post-aggregation filtering (WHERE), sorting (ORDER BY), and limiting (LIMIT) are applied in Python after the aggregation executes on MongoDB.

### INSERT Statements

PyMongoSQL supports inserting documents into MongoDB collections using both PartiQL-style object literals and standard SQL INSERT VALUES syntax.
Expand Down
2 changes: 1 addition & 1 deletion pymongosql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from .connection import Connection

__version__: str = "0.3.3"
__version__: str = "0.3.4"

# Globals https://www.python.org/dev/peps/pep-0249/#globals
apilevel: str = "2.0"
Expand Down
165 changes: 163 additions & 2 deletions pymongosql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _replace_placeholders(self, obj: Any, parameters: Sequence[Any]) -> Any:
"""Recursively replace ? placeholders with parameter values in filter/projection dicts"""
return SQLHelper.replace_placeholders_generic(obj, parameters, "qmark")

def _execute_execution_plan(
def _execute_find_plan(
self,
execution_plan: QueryExecutionPlan,
connection: Any = None,
Expand Down Expand Up @@ -172,6 +172,163 @@ def _execute_execution_plan(
_logger.error(f"Unexpected error during command execution: {e}")
raise OperationalError(f"Command execution error: {e}")

def _execute_aggregate_plan(
self,
execution_plan: QueryExecutionPlan,
connection: Any = None,
parameters: Optional[Sequence[Any]] = None,
) -> Optional[Dict[str, Any]]:
"""Execute a QueryExecutionPlan with aggregate() call.

Args:
execution_plan: QueryExecutionPlan with aggregate_pipeline and aggregate_options
connection: Connection object (for database access)
parameters: Parameters for placeholder replacement

Returns:
Command result with aggregation results
"""
try:
import json

# Get database from connection
if not connection:
raise OperationalError("No connection provided")

db = connection.database

if not execution_plan.collection:
raise ProgrammingError("No collection specified in aggregate query")

# Parse pipeline and options from JSON strings
try:
pipeline = json.loads(execution_plan.aggregate_pipeline or "[]")
options = json.loads(execution_plan.aggregate_options or "{}")
except json.JSONDecodeError as e:
raise ProgrammingError(f"Invalid JSON in aggregate pipeline or options: {e}")

_logger.debug(f"Executing aggregate on collection {execution_plan.collection}")
_logger.debug(f"Pipeline: {pipeline}")
_logger.debug(f"Options: {options}")

# Get collection and call aggregate()
collection = db[execution_plan.collection]

# Execute aggregate with options
cursor = collection.aggregate(pipeline, **options)

# Convert cursor to list
results = list(cursor)

# Apply additional filters if specified (from WHERE clause)
if execution_plan.filter_stage:
_logger.debug(f"Applying additional filter: {execution_plan.filter_stage}")
# Would need to filter results in Python, as aggregate already ran
# For now, log that we're applying filters
results = self._filter_results(results, execution_plan.filter_stage)

# Apply sorting if specified
if execution_plan.sort_stage:
for sort_dict in reversed(execution_plan.sort_stage):
for field_name, direction in sort_dict.items():
reverse = direction == -1
results = sorted(results, key=lambda x: x.get(field_name), reverse=reverse)

# Apply skip and limit
if execution_plan.skip_stage:
results = results[execution_plan.skip_stage :]

if execution_plan.limit_stage:
results = results[: execution_plan.limit_stage]

# Apply projection if specified
if execution_plan.projection_stage:
results = self._apply_projection(results, execution_plan.projection_stage)

# Return in command result format
return {
"cursor": {"firstBatch": results},
"ok": 1,
}

except (ProgrammingError, OperationalError):
raise
except PyMongoError as e:
_logger.error(f"MongoDB aggregate execution failed: {e}")
raise DatabaseError(f"Aggregate execution failed: {e}")
except Exception as e:
_logger.error(f"Unexpected error during aggregate execution: {e}")
raise OperationalError(f"Aggregate execution error: {e}")

@staticmethod
def _filter_results(results: list, filter_conditions: dict) -> list:
"""Apply MongoDB filter conditions to Python results"""
# Basic filtering implementation
# This is a simplified version - can be enhanced with full MongoDB query operators
filtered = []
for doc in results:
if StandardQueryExecution._matches_filter(doc, filter_conditions):
filtered.append(doc)
return filtered

@staticmethod
def _matches_filter(doc: dict, filter_conditions: dict) -> bool:
"""Check if a document matches the filter conditions"""
for field, condition in filter_conditions.items():
if field == "$and":
return all(StandardQueryExecution._matches_filter(doc, cond) for cond in condition)
elif field == "$or":
return any(StandardQueryExecution._matches_filter(doc, cond) for cond in condition)
elif isinstance(condition, dict):
# Handle operators like $eq, $gt, etc.
for op, value in condition.items():
if op == "$eq":
if doc.get(field) != value:
return False
elif op == "$ne":
if doc.get(field) == value:
return False
elif op == "$gt":
if not (doc.get(field) > value):
return False
elif op == "$gte":
if not (doc.get(field) >= value):
return False
elif op == "$lt":
if not (doc.get(field) < value):
return False
elif op == "$lte":
if not (doc.get(field) <= value):
return False
else:
if doc.get(field) != condition:
return False
return True

@staticmethod
def _apply_projection(results: list, projection_stage: dict) -> list:
"""Apply projection to results"""
projected = []
include_fields = {k for k, v in projection_stage.items() if v == 1}
exclude_fields = {k for k, v in projection_stage.items() if v == 0}

for doc in results:
if include_fields:
# Include mode: only include specified fields
projected_doc = (
{"_id": doc.get("_id")} if "_id" in include_fields or "_id" not in projection_stage else {}
)
for field in include_fields:
if field != "_id" and field in doc:
projected_doc[field] = doc[field]
projected.append(projected_doc)
else:
# Exclude mode: exclude specified fields
projected_doc = {k: v for k, v in doc.items() if k not in exclude_fields}
projected.append(projected_doc)

return projected

def execute(
self,
context: ExecutionContext,
Expand All @@ -197,7 +354,11 @@ def execute(
# Parse the query
self._execution_plan = self._parse_sql(processed_query)

return self._execute_execution_plan(self._execution_plan, connection, processed_params)
# Route to appropriate execution plan handler
if hasattr(self._execution_plan, "is_aggregate_query") and self._execution_plan.is_aggregate_query:
return self._execute_aggregate_plan(self._execution_plan, connection, processed_params)
else:
return self._execute_find_plan(self._execution_plan, connection, processed_params)


class InsertExecution(ExecutionStrategy):
Expand Down
10 changes: 9 additions & 1 deletion pymongosql/sql/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,15 @@ def _build_query_plan(parse_result: "QueryParseResult") -> "QueryExecutionPlan":
parse_result.column_aliases
).sort(parse_result.sort_fields).limit(parse_result.limit_value).skip(parse_result.offset_value)

return builder.build()
# Set aggregate flags BEFORE building (needed for validation)
if hasattr(parse_result, "is_aggregate_query") and parse_result.is_aggregate_query:
builder._execution_plan.is_aggregate_query = True
builder._execution_plan.aggregate_pipeline = parse_result.aggregate_pipeline
builder._execution_plan.aggregate_options = parse_result.aggregate_options

# Now build and validate
plan = builder.build()
return plan

@staticmethod
def _build_insert_plan(parse_result: "InsertParseResult") -> "InsertExecutionPlan":
Expand Down
28 changes: 25 additions & 3 deletions pymongosql/sql/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ class QueryExecutionPlan(ExecutionPlan):
sort_stage: List[Dict[str, int]] = field(default_factory=list)
limit_stage: Optional[int] = None
skip_stage: Optional[int] = None
# Aggregate pipeline support
aggregate_pipeline: Optional[str] = None # JSON string representation of pipeline
aggregate_options: Optional[str] = None # JSON string representation of options
is_aggregate_query: bool = False # Flag indicating this is an aggregate() call

def to_dict(self) -> Dict[str, Any]:
"""Convert query plan to dictionary representation"""
return {
result = {
"collection": self.collection,
"filter": self.filter_stage,
"projection": self.projection_stage,
Expand All @@ -30,9 +34,22 @@ def to_dict(self) -> Dict[str, Any]:
"skip": self.skip_stage,
}

# Add aggregate-specific fields if present
if self.is_aggregate_query:
result["is_aggregate_query"] = True
result["aggregate_pipeline"] = self.aggregate_pipeline
result["aggregate_options"] = self.aggregate_options

return result

def validate(self) -> bool:
"""Validate the query plan"""
errors = self.validate_base()
# For aggregate queries, collection is optional (unqualified aggregate syntax)
# For regular queries, collection is required
if self.is_aggregate_query:
errors = []
else:
errors = self.validate_base()

if self.limit_stage is not None and (not isinstance(self.limit_stage, int) or self.limit_stage < 0):
errors.append("Limit must be a non-negative integer")
Expand All @@ -56,6 +73,9 @@ def copy(self) -> "QueryExecutionPlan":
sort_stage=self.sort_stage.copy(),
limit_stage=self.limit_stage,
skip_stage=self.skip_stage,
aggregate_pipeline=self.aggregate_pipeline,
aggregate_options=self.aggregate_options,
is_aggregate_query=self.is_aggregate_query,
)


Expand Down Expand Up @@ -217,7 +237,9 @@ def validate(self) -> bool:
"""Validate the current query plan"""
self._validation_errors.clear()

if not self._execution_plan.collection:
# For aggregate queries, collection is optional (unqualified aggregate syntax)
# For regular queries, collection is required
if not self._execution_plan.is_aggregate_query and not self._execution_plan.collection:
self._add_error("Collection name is required")

# Add more validation rules as needed
Expand Down
Loading