Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.0
current_version = 0.3.1
commit = True
tag = True

Expand Down
125 changes: 125 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Index, update, delete
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy_memory import MemorySession
import argparse
import time
import random
from faker import Faker

try:
from sqlalchemy_memory import create_memory_engine
except ImportError:
create_memory_engine = None

Base = declarative_base()
fake = Faker()
CATEGORIES = list("ABCDEFGHIJK")

class Item(Base):
__tablename__ = "items"

id = Column(Integer, primary_key=True)
name = Column(String)
active = Column(Boolean, index=True)
category = Column(String, index=True)

def generate_items(n):
for _ in range(n):
yield Item(
name=fake.name(),
active=random.choice([True, False]),
category=random.choice(CATEGORIES)
)

def generate_random_select_query():
clauses = []
if random.random() < 0.5:
clauses.append(Item.active == random.choice([True, False]))
if random.random() < 0.5 or not clauses:
subset = random.sample(CATEGORIES, random.randint(1, 4))
clauses.append(Item.category.in_(subset))
return select(Item).where(*clauses)

def inserts(Session, count):
insert_start = time.time()
with Session() as session:
session.add_all(generate_items(count))
session.commit()
insert_duration = time.time() - insert_start
print(f"Inserted {count} items in {insert_duration:.2f} seconds.")
return insert_duration

def selects(Session, count):
queries = [generate_random_select_query() for _ in range(count)]

query_start = time.time()
with Session() as session:
for stmt in queries:
list(session.execute(stmt).scalars())
query_duration = time.time() - query_start
print(f"Executed {count} select queries in {query_duration:.2f} seconds.")
return query_duration

def updates(Session, random_ids):
update_start = time.time()
with Session() as session:
for rid in random_ids:
stmt = update(Item).where(Item.id == rid).values(
name=fake.name(),
category=random.choice(CATEGORIES),
active=random.choice([True, False])
)
session.execute(stmt)
session.commit()
update_duration = time.time() - update_start
print(f"Executed {len(random_ids)} updates in {update_duration:.2f} seconds.")
return update_duration

def deletes(Session, random_ids):
delete_start = time.time()
with Session() as session:
for rid in random_ids:
stmt = delete(Item).where(Item.id == rid)
session.execute(stmt)
session.commit()
delete_duration = time.time() - delete_start
print(f"Deleted {len(random_ids)} items in {delete_duration:.2f} seconds.")
return delete_duration

def run_benchmark(db_type="sqlite", count=100_000):
print(f"Running benchmark: type={db_type}, count={count}")

if db_type == "sqlite":
engine = create_engine("sqlite:///:memory:", echo=False)
Session = sessionmaker(engine)
elif db_type == "memory":
engine = create_engine("memory://")
Session = sessionmaker(
engine,
class_=MemorySession,
expire_on_commit=False,
)
else:
raise ValueError("Invalid --type. Use 'sqlite' or 'memory'.")

Base.metadata.create_all(engine)

elapsed = inserts(Session, count)
elapsed += selects(Session, 500)

random_ids = random.sample(range(1, count + 1), 500)
elapsed += updates(Session, random_ids)

random_ids = random.sample(range(1, count + 1), 500)
elapsed += deletes(Session, random_ids)

print(f"Total runtime for {db_type}: {elapsed:.2f} seconds.")



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--type", choices=["sqlite", "memory"], required=True)
parser.add_argument("--count", type=int, default=10_000)
args = parser.parse_args()
run_benchmark(args.type, args.count)
29 changes: 29 additions & 0 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Benchmark Comparison (20,000 items)
===================================

This benchmark compares `sqlalchemy-memory` to `in-memory SQLite` using 20,000 inserted items and a series of 500 queries, updates, and deletions.

As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, delivering significantly faster query performance. While SQLite performs slightly better on update and delete operations, the overall runtime of `sqlalchemy-memory` remains substantially lower, making it a strong choice for prototyping and simulation.

.. list-table::
:header-rows: 1
:widths: 25 25 25

* - Operation
- SQLite (in-memory)
- sqlalchemy-memory
* - Insert
- 3.17 sec
- 2.70 sec
* - 500 Select Queries
- 26.37 sec
- 2.94 sec
* - 500 Updates
- 0.26 sec
- 1.12 sec
* - 500 Deletes
- 0.09 sec
- 0.90 sec
* - **Total Runtime**
- **29.89 sec**
- **7.66 sec**
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,5 @@ Quickstart: async example
query
update
delete
commit_rollback
commit_rollback
benchmarks
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sqlalchemy-memory"
version = "0.3.0"
version = "0.3.1"
dependencies = [
"sqlalchemy>=2.0,<3.0",
"sortedcontainers>=2.4.0"
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
"AsyncMemorySession",
]

__version__ = '0.3.0'
__version__ = '0.3.1'
48 changes: 36 additions & 12 deletions sqlalchemy_memory/base/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
from typing import Any, List
from sqlalchemy.sql import operators

from ..helpers.ordered_set import OrderedSet


class IndexManager:
__slots__ = ('hash_index', 'range_index', 'table_indexes', 'columns_mapping', )

def __init__(self):
self.hash_index = HashIndex()
self.range_index = RangeIndex()

self.table_indexes = {}
self.columns_mapping = {}


def get_indexes(self, obj):
"""
Retrieve index from object's table as dict: indexname => list of column name
Expand All @@ -21,18 +26,27 @@ def get_indexes(self, obj):
if tablename not in self.table_indexes:
self.table_indexes[tablename] = {}

pk_col_name = obj.__table__.primary_key.columns[0].name

for index in obj.__table__.indexes:
if len(index.expressions) > 1:
# Ignoring compound indexes for now ...
continue

if index.name == pk_col_name:
pk_col_name = None

self.table_indexes[tablename][index.name] = [
col.name
for col in index.expressions
]

if pk_col_name:
self.table_indexes[tablename][pk_col_name] = [pk_col_name]

return self.table_indexes[tablename]


def _column_to_index(self, tablename, colname):
"""
Get index name from tablename & column name
Expand All @@ -51,6 +65,7 @@ def _column_to_index(self, tablename, colname):

return self.columns_mapping[tablename][colname]


def _get_index_key(self, obj, columns):
if len(columns) == 1:
return getattr(obj, columns[0])
Expand All @@ -65,7 +80,7 @@ def on_insert(self, obj):

self.hash_index.add(tablename, indexname, value, obj)
self.range_index.add(tablename, indexname, value, obj)

def on_delete(self, obj):
tablename = obj.__tablename__
indexes = self.get_indexes(obj)
Expand Down Expand Up @@ -145,6 +160,7 @@ def query(self, collection, tablename, colname, operator, value):
in_range = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1])
return list(set(collection) - set(in_range))


def get_selectivity(self, tablename, colname, operator, value, total_count):
"""
Estimate selectivity: higher means worst filtering.
Expand Down Expand Up @@ -187,23 +203,24 @@ class HashIndex:
Maintains insertion order of objects.
"""

__slots__ = ('index',)

def __init__(self):
self.index = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.index = defaultdict(lambda: defaultdict(lambda: defaultdict(OrderedSet)))


def add(self, tablename: str, indexname: str, value: Any, obj: Any):
self.index[tablename][indexname][value].append(obj)
self.index[tablename][indexname][value].add(obj)


def remove(self, tablename: str, indexname: str, value: Any, obj: Any):
lst = self.index[tablename][indexname][value]
try:
lst.remove(obj)
if not lst:
del self.index[tablename][indexname][value]
except ValueError:
pass
s = self.index[tablename][indexname][value]
s.discard(obj)
if not s:
del self.index[tablename][indexname][value]

def query(self, tablename: str, indexname: str, value: Any) -> List[Any]:
return self.index[tablename][indexname].get(value, [])
return list(self.index[tablename][indexname].get(value, []))


class RangeIndex:
Expand All @@ -215,12 +232,19 @@ class RangeIndex:
index[tablename][indexname] = SortedDict { value: [obj1, obj2, ...] }
"""

__slots__ = ('index',)

def __init__(self):
self.index = defaultdict(lambda: defaultdict(SortedDict))

def add(self, tablename: str, indexname: str, value: Any, obj: Any):
self.index[tablename][indexname].setdefault(value, []).append(obj)
index = self.index[tablename][indexname]
if value in index:
index[value].append(obj)
else:
index[value] = [obj]


def remove(self, tablename: str, indexname: str, value: Any, obj: Any):
col = self.index[tablename][indexname]
if value in col:
Expand Down
3 changes: 1 addition & 2 deletions sqlalchemy_memory/base/pending_changes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ def rollback(self):

def add(self, obj, **kwargs):
tablename = obj.__tablename__
if not any(id(x) == id(obj) for x in self._to_add[tablename]):
self._to_add[tablename].append(obj)
self._to_add[tablename].append(obj)

def delete(self, obj):
tablename = obj.__tablename__
Expand Down
8 changes: 6 additions & 2 deletions sqlalchemy_memory/base/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def __init__(self, *args, **kwargs):
def add(self, obj, **kwargs):
self.pending_changes.add(obj, **kwargs)

def add_all(self, instances, **kwargs):
for instance in instances:
self.add(instance, **kwargs)

def delete(self, obj):
self.pending_changes.delete(obj)

Expand Down Expand Up @@ -159,7 +163,7 @@ def _handle_update(self, statement: Update, **kwargs):
pk_col_name = None
for obj in collection:
if pk_col_name is None:
pk_col_name = self.store._get_primary_key_name(obj)
pk_col_name = self.store._get_primary_key_name(obj.__table__)

pk_value = getattr(obj, pk_col_name)
self.update(tablename, pk_value, data)
Expand Down Expand Up @@ -188,7 +192,7 @@ def merge(self, instance, **kwargs):
Merge a possibly detached instance into the current session
"""

pk_name = self.store._get_primary_key_name(instance)
pk_name = self.store._get_primary_key_name(instance.__table__)
pk_value = getattr(instance, pk_name)
existing = self.store.get_by_primary_key(instance, pk_value)

Expand Down
Loading