Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ BASED_TEST_DB_URLS='postgresql://postgres:postgres@localhost:5432/postgres,mysql
- [x] CI/CD
- [x] Building and uploading packages to PyPi
- [x] Testing with multiple Python versions
- [ ] Database URL parsing and building
- [x] Database URL parsing and building
- [x] MySQL backend
- [x] Add comments and docstrings
- [x] Add lock for PostgreSQL in `force_rollback` mode and SQLite in both modes
Expand Down
2 changes: 1 addition & 1 deletion based/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.5.0post1"
__version__ = "0.6.0"

from based.backends import Session
from based.database import Database
Expand Down
13 changes: 11 additions & 2 deletions based/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@ class Backend:
_connected: bool = False
_connected_before: bool = False

def __init__(self, url: str, *, force_rollback: bool = False) -> None:
def __init__(
self,
url: typing.Optional[str] = None,
*,
host: typing.Optional[str] = None,
port: typing.Optional[str] = None,
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
database: typing.Optional[str] = None,
force_rollback: bool = False,
) -> None:
"""Details of this method should be implementation specific."""
_ = url
self._force_rollback = force_rollback

async def _connect(self) -> None:
Expand Down
28 changes: 25 additions & 3 deletions based/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,32 @@ class MySQL(Backend):
_force_rollback_connection: asyncmy.Connection
_dialect: Dialect

def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107
self._url = make_url(url)
def __init__( # noqa: D107
self,
url: typing.Optional[str] = None,
*,
host: typing.Optional[str] = None,
port: typing.Optional[str] = None,
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
database: typing.Optional[str] = None,
force_rollback: bool = False,
) -> None:
if url:
self._url = make_url(url)
else:
self._url = URL.create(
username=username,
password=password,
host=host,
port=port,
database=database,
drivername="asyncmy",
query={},
)

self._force_rollback = force_rollback
self._dialect = dialect() # type: ignore
self._dialect = dialect()

async def _connect(self) -> None:
self._pool = await asyncmy.create_pool(
Expand Down
37 changes: 34 additions & 3 deletions based/backends/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from psycopg import AsyncConnection
from psycopg_pool import AsyncConnectionPool
from sqlalchemy import URL, make_url
from sqlalchemy.dialects import postgresql
from sqlalchemy.engine.interfaces import Dialect

Expand All @@ -17,10 +18,40 @@ class PostgreSQL(Backend):
_force_rollback_connection: AsyncConnection
_dialect: Dialect

def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107
self._pool = AsyncConnectionPool(url, open=False)
def __init__( # noqa: D107
self,
url: typing.Optional[str] = None,
*,
host: typing.Optional[str] = None,
port: typing.Optional[str] = None,
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
database: typing.Optional[str] = None,
force_rollback: bool = False,
) -> None:
if url:
self._url = make_url(url)
else:
self._url = URL.create(
username=username,
password=password,
host=host,
port=port,
database=database,
drivername="psycopg",
query={},
)

conninfo = (
f"user={self._url.username} "
f"password={self._url.password} "
f"host={self._url.host} "
f"port={self._url.port} "
f"dbname={self._url.database}"
)
self._pool = AsyncConnectionPool(conninfo, open=False)
self._force_rollback = force_rollback
self._dialect = postgresql.dialect() # type: ignore
self._dialect = postgresql.dialect()

async def _connect(self) -> None:
await self._pool.open()
Expand Down
2 changes: 1 addition & 1 deletion based/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SQLite(Backend):
def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107
self._conn = connect(url, isolation_level=None)
self._force_rollback = force_rollback
self._dialect = sqlite.dialect() # type: ignore
self._dialect = sqlite.dialect()

async def _connect(self) -> None:
await self._conn
Expand Down
54 changes: 45 additions & 9 deletions based/database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from asyncio import Lock
from contextlib import asynccontextmanager
from types import TracebackType
from typing import AsyncGenerator, Optional, Type
from typing import AsyncGenerator, Literal, Optional, Type

from based.backends import Backend, Session

Expand All @@ -15,8 +15,14 @@ class Database:

def __init__(
self,
url: str,
url: Optional[str] = None,
*,
host: Optional[str] = None,
port: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: Optional[str] = None,
schema: Optional[Literal["postgresql", "mysql", "sqlite"]] = None,
force_rollback: bool = False,
use_lock: bool = False,
) -> None:
Expand All @@ -36,7 +42,20 @@ def __init__(
Args:
url:
Database URL should be a URL defined by RFC 1738, containing the correct
schema like `postgresql://user:password@host:port/database`.
schema like `postgresql://user:password@host:port/database`. Can be
omitted in favor of passing parameters separately.
username:
Database username.
password:
Database password.
host:
Database host.
port:
Database port.
database:
Database name.
schema:
Used database schema. Can be `postgresql` or `mysql`.
force_rollback:
If this flag is set to True, then all the queries to the database will
be made in one single transaction which will be rolled back when the
Expand All @@ -53,10 +72,11 @@ def __init__(
Can be raised when an invalid database URL is provided or the database
schema is not supported.
"""
url_parts = url.split("://")
if len(url_parts) != 2:
raise ValueError("Invalid database URL")
schema = url_parts[0]
if url is not None:
url_parts = url.split("://")
if len(url_parts) != 2:
raise ValueError("Invalid database URL")
schema = url_parts[0]

if use_lock and (force_rollback or schema == "sqlite"):
self._lock = Lock()
Expand All @@ -72,11 +92,27 @@ def __init__(
elif schema == "postgresql":
from based.backends.postgresql import PostgreSQL

self._backend = PostgreSQL(url, force_rollback=force_rollback)
self._backend = PostgreSQL(
url=url,
username=username,
password=password,
host=host,
port=port,
database=database,
force_rollback=force_rollback,
)
elif schema == "mysql":
from based.backends.mysql import MySQL

self._backend = MySQL(url, force_rollback=force_rollback)
self._backend = MySQL(
url=url,
username=username,
password=password,
host=host,
port=port,
database=database,
force_rollback=force_rollback,
)
else:
raise ValueError(f"Unknown database schema: {schema}")

Expand Down
21 changes: 21 additions & 0 deletions tests/test_mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sqlalchemy as sa

import based


async def test_mysql_url_building(database_url: str):
if not database_url.startswith("mysql"):
return

url = sa.make_url(database_url)

async with based.Database(
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database=url.database,
schema="mysql",
) as database:
async with database.session() as session:
await session.execute("SELECT 1;")
21 changes: 21 additions & 0 deletions tests/test_postgresql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import sqlalchemy as sa

import based


async def test_postgresql_url_building(database_url: str):
if not database_url.startswith("postgresql"):
return

url = sa.make_url(database_url)

async with based.Database(
username=url.username,
password=url.password,
host=url.host,
port=url.port,
database=url.database,
schema="postgresql",
) as database:
async with database.session() as session:
await session.execute("SELECT 1;")