Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: CI Tests

on:
push:
branches: [ main ]
branches: [ main, "*.*.*" ]
pull_request:
branches: [ main ]
workflow_call:
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# PyMongoSQL

[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Test](https://github.com/passren/PyMongoSQL/actions/workflows/ci.yml/badge.svg)](https://github.com/passren/PyMongoSQL/actions/workflows/ci.yml)
[![Code Style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![License: MIT](https://img.shields.io/badge/License-MIT-purple.svg)](https://github.com/passren/PyMongoSQL/blob/0.1.2/LICENSE)
[![Python Version](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
[![MongoDB](https://img.shields.io/badge/MongoDB-4.0+-green.svg)](https://www.mongodb.com/)
[![MongoDB](https://img.shields.io/badge/MongoDB-7.0+-green.svg)](https://www.mongodb.com/)

PyMongoSQL is a Python [DB API 2.0 (PEP 249)](https://www.python.org/dev/peps/pep-0249/) client for [MongoDB](https://www.mongodb.com/). It provides a familiar SQL interface to MongoDB, allowing developers to use SQL queries to interact with MongoDB collections.

Expand Down
46 changes: 41 additions & 5 deletions pymongosql/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ def __init__(
else:
# Just create the client without testing connection
self._client = MongoClient(**self._pymongo_params)
if self._database_name:
self._database = self._client[self._database_name]
# Initialize the database according to explicit parameter or client's default
self._init_database()

def _connect(self) -> None:
"""Establish connection to MongoDB"""
Expand All @@ -91,19 +91,55 @@ def _connect(self) -> None:
# Test connection
self._client.admin.command("ping")

# Set database if specified
if self._database_name:
self._database = self._client[self._database_name]
# Initialize the database according to explicit parameter or client's default
# This may raise OperationalError if no database could be determined; allow it to bubble up
self._init_database()

_logger.info(f"Successfully connected to MongoDB at {self._host}:{self._port}")

except OperationalError:
# Allow OperationalError (e.g., no database selected) to propagate unchanged
raise
except ConnectionFailure as e:
_logger.error(f"Failed to connect to MongoDB: {e}")
raise OperationalError(f"Could not connect to MongoDB: {e}")
except Exception as e:
_logger.error(f"Unexpected error during connection: {e}")
raise DatabaseError(f"Database connection error: {e}")

def _init_database(self) -> None:
"""Internal helper to initialize `self._database`.

Behavior:
- If `database` parameter was provided explicitly, use that database name.
- Otherwise, try to use the MongoClient's default database (from the URI path).
If no default is set, leave `self._database` as None.
"""
if self._client is None:
self._database = None
return

if self._database_name is not None:
# Explicit database parameter takes precedence
try:
self._database = self._client.get_database(self._database_name)
except Exception:
# Fallback to subscription style access
self._database = self._client[self._database_name]
else:
# No explicit database; try to get client's default
try:
self._database = self._client.get_default_database()
except Exception:
# PyMongo can raise various exceptions for missing database
self._database = None

# Enforce that a database must be selected
if self._database is None:
raise OperationalError(
"No database selected. Provide 'database' parameter or include a database in the URI path."
)

@property
def client(self) -> MongoClient:
"""Get the PyMongo client"""
Expand Down
63 changes: 32 additions & 31 deletions pymongosql/cursor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# -*- coding: utf-8 -*-
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, TypeVar

from pymongo.cursor import Cursor as MongoCursor
from pymongo.errors import PyMongoError

from .common import BaseCursor, CursorIterator
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
from .result_set import ResultSet
from .sql.builder import QueryPlan
from .sql.builder import ExecutionPlan
from .sql.parser import SQLParser

if TYPE_CHECKING:
Expand All @@ -31,7 +31,7 @@ def __init__(self, connection: "Connection", **kwargs) -> None:
self._kwargs = kwargs
self._result_set: Optional[ResultSet] = None
self._result_set_class = ResultSet
self._current_query_plan: Optional[QueryPlan] = None
self._current_execution_plan: Optional[ExecutionPlan] = None
self._mongo_cursor: Optional[MongoCursor] = None
self._is_closed = False

Expand Down Expand Up @@ -78,65 +78,66 @@ def _check_closed(self) -> None:
if self._is_closed:
raise ProgrammingError("Cursor is closed")

def _parse_sql(self, sql: str) -> QueryPlan:
"""Parse SQL statement and return QueryPlan"""
def _parse_sql(self, sql: str) -> ExecutionPlan:
"""Parse SQL statement and return ExecutionPlan"""
try:
parser = SQLParser(sql)
query_plan = parser.get_query_plan()
execution_plan = parser.get_execution_plan()

if not query_plan.validate():
if not execution_plan.validate():
raise SqlSyntaxError("Generated query plan is invalid")

return query_plan
return execution_plan

except SqlSyntaxError:
raise
except Exception as e:
_logger.error(f"SQL parsing failed: {e}")
raise SqlSyntaxError(f"Failed to parse SQL: {e}")

def _execute_query_plan(self, query_plan: QueryPlan) -> None:
"""Execute a QueryPlan against MongoDB using db.command"""
def _execute_execution_plan(self, execution_plan: ExecutionPlan) -> None:
"""Execute an ExecutionPlan against MongoDB using db.command"""
try:
# Get database
if not query_plan.collection:
if not execution_plan.collection:
raise ProgrammingError("No collection specified in query")

db = self.connection.database

# Build MongoDB find command
find_command = {"find": query_plan.collection, "filter": query_plan.filter_stage or {}}
find_command = {"find": execution_plan.collection, "filter": execution_plan.filter_stage or {}}

# Convert projection stage from alias mapping to MongoDB format
if query_plan.projection_stage:
# Convert {"field": "alias"} to {"field": 1} for MongoDB
find_command["projection"] = {field: 1 for field in query_plan.projection_stage.keys()}
# Apply projection if specified (already in MongoDB format)
if execution_plan.projection_stage:
find_command["projection"] = execution_plan.projection_stage

# Apply sort if specified
if query_plan.sort_stage:
if execution_plan.sort_stage:
sort_spec = {}
for sort_dict in query_plan.sort_stage:
for sort_dict in execution_plan.sort_stage:
for field, direction in sort_dict.items():
sort_spec[field] = direction
find_command["sort"] = sort_spec

# Apply skip if specified
if query_plan.skip_stage:
find_command["skip"] = query_plan.skip_stage
if execution_plan.skip_stage:
find_command["skip"] = execution_plan.skip_stage

# Apply limit if specified
if query_plan.limit_stage:
find_command["limit"] = query_plan.limit_stage
if execution_plan.limit_stage:
find_command["limit"] = execution_plan.limit_stage

_logger.debug(f"Executing MongoDB command: {find_command}")

# Execute find command directly
result = db.command(find_command)

# Create result set from command result
self._result_set = self._result_set_class(command_result=result, query_plan=query_plan, **self._kwargs)
self._result_set = self._result_set_class(
command_result=result, execution_plan=execution_plan, **self._kwargs
)

_logger.info(f"Query executed successfully on collection '{query_plan.collection}'")
_logger.info(f"Query executed successfully on collection '{execution_plan.collection}'")

except PyMongoError as e:
_logger.error(f"MongoDB command execution failed: {e}")
Expand All @@ -161,11 +162,11 @@ def execute(self: _T, operation: str, parameters: Optional[Dict[str, Any]] = Non
_logger.warning("Parameter substitution not yet implemented, ignoring parameters")

try:
# Parse SQL to QueryPlan
self._current_query_plan = self._parse_sql(operation)
# Parse SQL to ExecutionPlan
self._current_execution_plan = self._parse_sql(operation)

# Execute the query plan
self._execute_query_plan(self._current_query_plan)
# Execute the execution plan
self._execute_execution_plan(self._current_execution_plan)

return self

Expand Down Expand Up @@ -205,7 +206,7 @@ def flush(self) -> None:
# For now, this is a no-op
pass

def fetchone(self) -> Optional[Dict[str, Any]]:
def fetchone(self) -> Optional[Sequence[Any]]:
"""Fetch the next row from the result set"""
self._check_closed()

Expand All @@ -214,7 +215,7 @@ def fetchone(self) -> Optional[Dict[str, Any]]:

return self._result_set.fetchone()

def fetchmany(self, size: Optional[int] = None) -> List[Dict[str, Any]]:
def fetchmany(self, size: Optional[int] = None) -> List[Sequence[Any]]:
"""Fetch multiple rows from the result set"""
self._check_closed()

Expand All @@ -223,7 +224,7 @@ def fetchmany(self, size: Optional[int] = None) -> List[Dict[str, Any]]:

return self._result_set.fetchmany(size)

def fetchall(self) -> List[Dict[str, Any]]:
def fetchall(self) -> List[Sequence[Any]]:
"""Fetch all remaining rows from the result set"""
self._check_closed()

Expand Down
Loading