Skip to content

Commit 3a75f5e

Browse files
authored
Merge pull request #8 from leorochael/25-class-metaprog-persistent
Suggested improvements by @leorochael
2 parents bbc6643 + ee418d7 commit 3a75f5e

File tree

4 files changed

+402
-0
lines changed

4 files changed

+402
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.db
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SQLite3 does not support parameterized table and field names,
2+
# for CREATE TABLE and PRAGMA so we must use Python string formatting.
3+
# Applying `check_identifier` to parameters prevents SQL injection.
4+
5+
import sqlite3
6+
from typing import NamedTuple
7+
8+
DEFAULT_DB_PATH = ':memory:'
9+
CONNECTION = None
10+
11+
SQL_TYPES = {
12+
int: 'INTEGER',
13+
str: 'TEXT',
14+
float: 'REAL',
15+
bytes: 'BLOB',
16+
}
17+
18+
19+
class NoConnection(Exception):
20+
"""Call connect() to open connection."""
21+
22+
23+
class SchemaMismatch(ValueError):
24+
"""The table schema doesn't match the class."""
25+
26+
def __init__(self, table_name):
27+
self.table_name = table_name
28+
29+
30+
class NoSuchRecord(LookupError):
31+
"""The given primary key does not exist."""
32+
33+
def __init__(self, pk):
34+
self.pk = pk
35+
36+
37+
class UnexpectedMultipleResults(Exception):
38+
"""Query returned more than 1 row."""
39+
40+
41+
class ColumnSchema(NamedTuple):
42+
name: str
43+
sql_type: str
44+
45+
46+
def check_identifier(name):
47+
if not name.isidentifier():
48+
raise ValueError(f'{name!r} is not an identifier')
49+
50+
51+
def connect(db_path=DEFAULT_DB_PATH):
52+
global CONNECTION
53+
CONNECTION = sqlite3.connect(db_path)
54+
return CONNECTION
55+
56+
57+
def get_connection():
58+
if CONNECTION is None:
59+
raise NoConnection()
60+
return CONNECTION
61+
62+
63+
def gen_columns_sql(fields):
64+
for name, py_type in fields.items():
65+
check_identifier(name)
66+
try:
67+
sql_type = SQL_TYPES[py_type]
68+
except KeyError as e:
69+
raise ValueError(f'type {py_type!r} is not supported') from e
70+
yield ColumnSchema(name, sql_type)
71+
72+
73+
def make_schema_sql(table_name, fields):
74+
check_identifier(table_name)
75+
pk = 'pk INTEGER PRIMARY KEY,'
76+
spcs = ' ' * 4
77+
columns = ',\n '.join(
78+
f'{field_name} {sql_type}'
79+
for field_name, sql_type in gen_columns_sql(fields)
80+
)
81+
return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)'
82+
83+
84+
def create_table(table_name, fields):
85+
con = get_connection()
86+
con.execute(make_schema_sql(table_name, fields))
87+
88+
89+
def read_columns_sql(table_name):
90+
con = get_connection()
91+
check_identifier(table_name)
92+
rows = con.execute(f'PRAGMA table_info({table_name!r})')
93+
# row fields: cid name type notnull dflt_value pk
94+
return [ColumnSchema(r[1], r[2]) for r in rows]
95+
96+
97+
def valid_table(table_name, fields):
98+
table_columns = read_columns_sql(table_name)
99+
return set(gen_columns_sql(fields)) <= set(table_columns)
100+
101+
102+
def ensure_table(table_name, fields):
103+
table_columns = read_columns_sql(table_name)
104+
if len(table_columns) == 0:
105+
create_table(table_name, fields)
106+
if not valid_table(table_name, fields):
107+
raise SchemaMismatch(table_name)
108+
109+
110+
def insert_record(table_name, fields):
111+
con = get_connection()
112+
check_identifier(table_name)
113+
placeholders = ', '.join(['?'] * len(fields))
114+
sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})'
115+
cursor = con.execute(sql, tuple(fields.values()))
116+
pk = cursor.lastrowid
117+
con.commit()
118+
cursor.close()
119+
return pk
120+
121+
122+
def fetch_record(table_name, pk):
123+
con = get_connection()
124+
check_identifier(table_name)
125+
sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2'
126+
result = list(con.execute(sql, (pk,)))
127+
if len(result) == 0:
128+
raise NoSuchRecord(pk)
129+
elif len(result) == 1:
130+
return result[0]
131+
else:
132+
raise UnexpectedMultipleResults()
133+
134+
135+
def update_record(table_name, pk, fields):
136+
check_identifier(table_name)
137+
con = get_connection()
138+
names = ', '.join(fields.keys())
139+
placeholders = ', '.join(['?'] * len(fields))
140+
values = tuple(fields.values()) + (pk,)
141+
sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?'
142+
con.execute(sql, values)
143+
con.commit()
144+
return sql, values
145+
146+
147+
def delete_record(table_name, pk):
148+
con = get_connection()
149+
check_identifier(table_name)
150+
sql = f'DELETE FROM {table_name} WHERE pk = ?'
151+
return con.execute(sql, (pk,))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
from dblib import gen_columns_sql, make_schema_sql, connect, read_columns_sql
6+
from dblib import ColumnSchema, insert_record, fetch_record, update_record
7+
from dblib import NoSuchRecord, delete_record, valid_table
8+
9+
10+
@pytest.fixture
11+
def create_movies_sql():
12+
sql = '''
13+
CREATE TABLE movies (
14+
pk INTEGER PRIMARY KEY,
15+
title TEXT,
16+
revenue REAL
17+
)
18+
'''
19+
return dedent(sql).strip()
20+
21+
22+
@pytest.mark.parametrize(
23+
'fields, expected',
24+
[
25+
(
26+
dict(title=str, awards=int),
27+
[('title', 'TEXT'), ('awards', 'INTEGER')],
28+
),
29+
(
30+
dict(picture=bytes, score=float),
31+
[('picture', 'BLOB'), ('score', 'REAL')],
32+
),
33+
],
34+
)
35+
def test_gen_columns_sql(fields, expected):
36+
result = list(gen_columns_sql(fields))
37+
assert result == expected
38+
39+
40+
def test_make_schema_sql(create_movies_sql):
41+
fields = dict(title=str, revenue=float)
42+
result = make_schema_sql('movies', fields)
43+
assert result == create_movies_sql
44+
45+
46+
def test_read_columns_sql(create_movies_sql):
47+
expected = [
48+
ColumnSchema(name='pk', sql_type='INTEGER'),
49+
ColumnSchema(name='title', sql_type='TEXT'),
50+
ColumnSchema(name='revenue', sql_type='REAL'),
51+
]
52+
with connect() as con:
53+
con.execute(create_movies_sql)
54+
result = read_columns_sql('movies')
55+
assert result == expected
56+
57+
58+
def test_read_columns_sql_no_such_table(create_movies_sql):
59+
with connect() as con:
60+
con.execute(create_movies_sql)
61+
result = read_columns_sql('no_such_table')
62+
assert result == []
63+
64+
65+
def test_insert_record(create_movies_sql):
66+
fields = dict(title='Frozen', revenue=1_290_000_000)
67+
with connect() as con:
68+
con.execute(create_movies_sql)
69+
for _ in range(3):
70+
result = insert_record('movies', fields)
71+
assert result == 3
72+
73+
74+
def test_fetch_record(create_movies_sql):
75+
fields = dict(title='Frozen', revenue=1_290_000_000)
76+
with connect() as con:
77+
con.execute(create_movies_sql)
78+
pk = insert_record('movies', fields)
79+
row = fetch_record('movies', pk)
80+
assert row == (1, 'Frozen', 1_290_000_000.0)
81+
82+
83+
def test_fetch_record_no_such_pk(create_movies_sql):
84+
with connect() as con:
85+
con.execute(create_movies_sql)
86+
with pytest.raises(NoSuchRecord) as e:
87+
fetch_record('movies', 42)
88+
assert e.value.pk == 42
89+
90+
91+
def test_update_record(create_movies_sql):
92+
fields = dict(title='Frozen', revenue=1_290_000_000)
93+
with connect() as con:
94+
con.execute(create_movies_sql)
95+
pk = insert_record('movies', fields)
96+
fields['revenue'] = 1_299_999_999
97+
sql, values = update_record('movies', pk, fields)
98+
row = fetch_record('movies', pk)
99+
assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?'
100+
assert values == ('Frozen', 1_299_999_999, 1)
101+
assert row == (1, 'Frozen', 1_299_999_999.0)
102+
103+
104+
def test_delete_record(create_movies_sql):
105+
fields = dict(title='Frozen', revenue=1_290_000_000)
106+
with connect() as con:
107+
con.execute(create_movies_sql)
108+
pk = insert_record('movies', fields)
109+
delete_record('movies', pk)
110+
with pytest.raises(NoSuchRecord) as e:
111+
fetch_record('movies', pk)
112+
assert e.value.pk == pk
113+
114+
115+
def test_persistent_valid_table(create_movies_sql):
116+
fields = dict(title=str, revenue=float)
117+
118+
with connect() as con:
119+
con.execute(create_movies_sql)
120+
con.commit()
121+
assert valid_table('movies', fields)
122+
123+
124+
def test_persistent_valid_table_false(create_movies_sql):
125+
# year field not in movies_sql
126+
fields = dict(title=str, revenue=float, year=int)
127+
128+
with connect() as con:
129+
con.execute(create_movies_sql)
130+
con.commit()
131+
assert not valid_table('movies', fields)

0 commit comments

Comments
 (0)