Skip to content

Commit 4ff0a59

Browse files
committed
refactoring after reviewers feedback + type hints
1 parent 3a75f5e commit 4ff0a59

File tree

4 files changed

+142
-74
lines changed

4 files changed

+142
-74
lines changed

25-class-metaprog/persistent/dblib.py

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,10 @@
33
# Applying `check_identifier` to parameters prevents SQL injection.
44

55
import sqlite3
6-
from typing import NamedTuple
6+
from typing import NamedTuple, Optional, Iterator, Any
77

88
DEFAULT_DB_PATH = ':memory:'
9-
CONNECTION = None
10-
11-
SQL_TYPES = {
12-
int: 'INTEGER',
13-
str: 'TEXT',
14-
float: 'REAL',
15-
bytes: 'BLOB',
16-
}
9+
CONNECTION: Optional[sqlite3.Connection] = None
1710

1811

1912
class NoConnection(Exception):
@@ -38,29 +31,45 @@ class UnexpectedMultipleResults(Exception):
3831
"""Query returned more than 1 row."""
3932

4033

34+
SQLType = str
35+
36+
TypeMap = dict[type, SQLType]
37+
38+
SQL_TYPES: TypeMap = {
39+
int: 'INTEGER',
40+
str: 'TEXT',
41+
float: 'REAL',
42+
bytes: 'BLOB',
43+
}
44+
45+
4146
class ColumnSchema(NamedTuple):
4247
name: str
43-
sql_type: str
48+
sql_type: SQLType
4449

4550

46-
def check_identifier(name):
51+
FieldMap = dict[str, type]
52+
53+
54+
def check_identifier(name: str) -> None:
4755
if not name.isidentifier():
4856
raise ValueError(f'{name!r} is not an identifier')
4957

5058

51-
def connect(db_path=DEFAULT_DB_PATH):
59+
def connect(db_path: str = DEFAULT_DB_PATH) -> sqlite3.Connection:
5260
global CONNECTION
5361
CONNECTION = sqlite3.connect(db_path)
62+
CONNECTION.row_factory = sqlite3.Row
5463
return CONNECTION
5564

5665

57-
def get_connection():
66+
def get_connection() -> sqlite3.Connection:
5867
if CONNECTION is None:
5968
raise NoConnection()
6069
return CONNECTION
6170

6271

63-
def gen_columns_sql(fields):
72+
def gen_columns_sql(fields: FieldMap) -> Iterator[ColumnSchema]:
6473
for name, py_type in fields.items():
6574
check_identifier(name)
6675
try:
@@ -70,7 +79,7 @@ def gen_columns_sql(fields):
7079
yield ColumnSchema(name, sql_type)
7180

7281

73-
def make_schema_sql(table_name, fields):
82+
def make_schema_sql(table_name: str, fields: FieldMap) -> str:
7483
check_identifier(table_name)
7584
pk = 'pk INTEGER PRIMARY KEY,'
7685
spcs = ' ' * 4
@@ -81,47 +90,46 @@ def make_schema_sql(table_name, fields):
8190
return f'CREATE TABLE {table_name} (\n{spcs}{pk}\n{spcs}{columns}\n)'
8291

8392

84-
def create_table(table_name, fields):
93+
def create_table(table_name: str, fields: FieldMap) -> None:
8594
con = get_connection()
8695
con.execute(make_schema_sql(table_name, fields))
8796

8897

89-
def read_columns_sql(table_name):
90-
con = get_connection()
98+
def read_columns_sql(table_name: str) -> list[ColumnSchema]:
9199
check_identifier(table_name)
100+
con = get_connection()
92101
rows = con.execute(f'PRAGMA table_info({table_name!r})')
93-
# row fields: cid name type notnull dflt_value pk
94-
return [ColumnSchema(r[1], r[2]) for r in rows]
102+
return [ColumnSchema(r['name'], r['type']) for r in rows]
95103

96104

97-
def valid_table(table_name, fields):
105+
def valid_table(table_name: str, fields: FieldMap) -> bool:
98106
table_columns = read_columns_sql(table_name)
99107
return set(gen_columns_sql(fields)) <= set(table_columns)
100108

101109

102-
def ensure_table(table_name, fields):
110+
def ensure_table(table_name: str, fields: FieldMap) -> None:
103111
table_columns = read_columns_sql(table_name)
104112
if len(table_columns) == 0:
105113
create_table(table_name, fields)
106114
if not valid_table(table_name, fields):
107115
raise SchemaMismatch(table_name)
108116

109117

110-
def insert_record(table_name, fields):
111-
con = get_connection()
118+
def insert_record(table_name: str, data: dict[str, Any]) -> int:
112119
check_identifier(table_name)
113-
placeholders = ', '.join(['?'] * len(fields))
120+
con = get_connection()
121+
placeholders = ', '.join(['?'] * len(data))
114122
sql = f'INSERT INTO {table_name} VALUES (NULL, {placeholders})'
115-
cursor = con.execute(sql, tuple(fields.values()))
123+
cursor = con.execute(sql, tuple(data.values()))
116124
pk = cursor.lastrowid
117125
con.commit()
118126
cursor.close()
119127
return pk
120128

121129

122-
def fetch_record(table_name, pk):
123-
con = get_connection()
130+
def fetch_record(table_name: str, pk: int) -> sqlite3.Row:
124131
check_identifier(table_name)
132+
con = get_connection()
125133
sql = f'SELECT * FROM {table_name} WHERE pk = ? LIMIT 2'
126134
result = list(con.execute(sql, (pk,)))
127135
if len(result) == 0:
@@ -132,19 +140,21 @@ def fetch_record(table_name, pk):
132140
raise UnexpectedMultipleResults()
133141

134142

135-
def update_record(table_name, pk, fields):
143+
def update_record(
144+
table_name: str, pk: int, data: dict[str, Any]
145+
) -> tuple[str, tuple[Any, ...]]:
136146
check_identifier(table_name)
137147
con = get_connection()
138-
names = ', '.join(fields.keys())
139-
placeholders = ', '.join(['?'] * len(fields))
140-
values = tuple(fields.values()) + (pk,)
148+
names = ', '.join(data.keys())
149+
placeholders = ', '.join(['?'] * len(data))
150+
values = tuple(data.values()) + (pk,)
141151
sql = f'UPDATE {table_name} SET ({names}) = ({placeholders}) WHERE pk = ?'
142152
con.execute(sql, values)
143153
con.commit()
144154
return sql, values
145155

146156

147-
def delete_record(table_name, pk):
157+
def delete_record(table_name: str, pk: int) -> sqlite3.Cursor:
148158
con = get_connection()
149159
check_identifier(table_name)
150160
sql = f'DELETE FROM {table_name} WHERE pk = ?'

25-class-metaprog/persistent/dblib_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_fetch_record(create_movies_sql):
7777
con.execute(create_movies_sql)
7878
pk = insert_record('movies', fields)
7979
row = fetch_record('movies', pk)
80-
assert row == (1, 'Frozen', 1_290_000_000.0)
80+
assert tuple(row) == (1, 'Frozen', 1_290_000_000.0)
8181

8282

8383
def test_fetch_record_no_such_pk(create_movies_sql):
@@ -98,7 +98,7 @@ def test_update_record(create_movies_sql):
9898
row = fetch_record('movies', pk)
9999
assert sql == 'UPDATE movies SET (title, revenue) = (?, ?) WHERE pk = ?'
100100
assert values == ('Frozen', 1_299_999_999, 1)
101-
assert row == (1, 'Frozen', 1_299_999_999.0)
101+
assert tuple(row) == (1, 'Frozen', 1_299_999_999.0)
102102

103103

104104
def test_delete_record(create_movies_sql):

25-class-metaprog/persistent/persistlib.py

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,26 @@
44
>>> class Movie(Persistent):
55
... title: str
66
... year: int
7-
... boxmega: float
7+
... megabucks: float
88
99
Implemented behavior::
1010
1111
>>> Movie._connect() # doctest: +ELLIPSIS
1212
<sqlite3.Connection object at 0x...>
13-
>>> movie = Movie('The Godfather', 1972, 137)
13+
>>> movie = Movie(title='The Godfather', year=1972, megabucks=137)
1414
>>> movie.title
1515
'The Godfather'
16-
>>> movie.boxmega
16+
>>> movie.megabucks
1717
137.0
1818
19-
Instances always have a ``.pk`` attribute, but it is ``None`` until the
19+
Instances always have a ``._pk`` attribute, but it is ``None`` until the
2020
object is saved::
2121
22-
>>> movie.pk is None
22+
>>> movie._pk is None
2323
True
24-
>>> movie._persist()
25-
>>> movie.pk
24+
>>> movie._save()
25+
1
26+
>>> movie._pk
2627
1
2728
2829
Delete the in-memory ``movie``, and fetch the record from the database,
@@ -31,7 +32,7 @@
3132
>>> del movie
3233
>>> film = Movie[1]
3334
>>> film
34-
Movie('The Godfather', 1972, 137.0, pk=1)
35+
Movie(title='The Godfather', year=1972, megabucks=137.0, _pk=1)
3536
3637
By default, the table name is the class name lowercased, with an appended
3738
"s" for plural::
@@ -51,69 +52,89 @@ class declaration::
5152
5253
"""
5354

54-
from typing import get_type_hints
55+
from typing import Any, ClassVar, get_type_hints
5556

5657
import dblib as db
5758

5859

5960
class Field:
60-
def __init__(self, name, py_type):
61+
def __init__(self, name: str, py_type: type) -> None:
6162
self.name = name
6263
self.type = py_type
6364

64-
def __set__(self, instance, value):
65+
def __set__(self, instance: 'Persistent', value: Any) -> None:
6566
try:
6667
value = self.type(value)
67-
except TypeError as e:
68-
msg = f'{value!r} is not compatible with {self.name}:{self.type}.'
68+
except (TypeError, ValueError) as e:
69+
type_name = self.type.__name__
70+
msg = f'{value!r} is not compatible with {self.name}:{type_name}.'
6971
raise TypeError(msg) from e
7072
instance.__dict__[self.name] = value
7173

7274

7375
class Persistent:
74-
def __init_subclass__(
75-
cls, *, db_path=db.DEFAULT_DB_PATH, table='', **kwargs
76-
):
77-
super().__init_subclass__(**kwargs)
76+
_TABLE_NAME: ClassVar[str]
77+
_TABLE_READY: ClassVar[bool] = False
78+
79+
@classmethod
80+
def _fields(cls) -> dict[str, type]:
81+
return {
82+
name: py_type
83+
for name, py_type in get_type_hints(cls).items()
84+
if not name.startswith('_')
85+
}
86+
87+
def __init_subclass__(cls, *, table: str = '', **kwargs: dict):
88+
super().__init_subclass__(**kwargs) # type:ignore
7889
cls._TABLE_NAME = table if table else cls.__name__.lower() + 's'
79-
cls._TABLE_READY = False
80-
for name, py_type in get_type_hints(cls).items():
90+
for name, py_type in cls._fields().items():
8191
setattr(cls, name, Field(name, py_type))
8292

8393
@staticmethod
84-
def _connect(db_path=db.DEFAULT_DB_PATH):
94+
def _connect(db_path: str = db.DEFAULT_DB_PATH):
8595
return db.connect(db_path)
8696

8797
@classmethod
88-
def _ensure_table(cls):
98+
def _ensure_table(cls) -> str:
8999
if not cls._TABLE_READY:
90-
db.ensure_table(cls._TABLE_NAME, get_type_hints(cls))
100+
db.ensure_table(cls._TABLE_NAME, cls._fields())
91101
cls._TABLE_READY = True
92102
return cls._TABLE_NAME
93103

94-
def _fields(self):
104+
def __class_getitem__(cls, pk: int) -> 'Persistent':
105+
field_names = ['_pk'] + list(cls._fields())
106+
values = db.fetch_record(cls._TABLE_NAME, pk)
107+
return cls(**dict(zip(field_names, values)))
108+
109+
def _asdict(self) -> dict[str, Any]:
95110
return {
96111
name: getattr(self, name)
97112
for name, attr in self.__class__.__dict__.items()
98113
if isinstance(attr, Field)
99114
}
100115

101-
def __init__(self, *args, pk=None):
102-
for name, arg in zip(self._fields(), args):
116+
def __init__(self, *, _pk=None, **kwargs):
117+
field_names = self._asdict().keys()
118+
for name, arg in kwargs.items():
119+
if name not in field_names:
120+
msg = f'{self.__class__.__name__!r} has no attribute {name!r}'
121+
raise AttributeError(msg)
103122
setattr(self, name, arg)
104-
self.pk = pk
105-
106-
def __class_getitem__(cls, pk):
107-
return cls(*db.fetch_record(cls._TABLE_NAME, pk)[1:], pk=pk)
108-
109-
def __repr__(self):
110-
args = ', '.join(repr(value) for value in self._fields().values())
111-
pk = '' if self.pk is None else f', pk={self.pk}'
112-
return f'{self.__class__.__name__}({args}{pk})'
113-
114-
def _persist(self):
123+
self._pk = _pk
124+
125+
def __repr__(self) -> str:
126+
kwargs = ', '.join(
127+
f'{key}={value!r}' for key, value in self._asdict().items()
128+
)
129+
cls_name = self.__class__.__name__
130+
if self._pk is None:
131+
return f'{cls_name}({kwargs})'
132+
return f'{cls_name}({kwargs}, _pk={self._pk})'
133+
134+
def _save(self) -> int:
115135
table = self.__class__._ensure_table()
116-
if self.pk is None:
117-
self.pk = db.insert_record(table, self._fields())
136+
if self._pk is None:
137+
self._pk = db.insert_record(table, self._asdict())
118138
else:
119-
db.update_record(table, self.pk, self._fields())
139+
db.update_record(table, self._pk, self._asdict())
140+
return self._pk
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
4+
from persistlib import Persistent
5+
6+
7+
def test_field_descriptor_validation_type_error():
8+
class Cat(Persistent):
9+
name: str
10+
weight: float
11+
12+
with pytest.raises(TypeError) as e:
13+
felix = Cat(name='Felix', weight=None)
14+
15+
assert str(e.value) == 'None is not compatible with weight:float.'
16+
17+
18+
def test_field_descriptor_validation_value_error():
19+
class Cat(Persistent):
20+
name: str
21+
weight: float
22+
23+
with pytest.raises(TypeError) as e:
24+
felix = Cat(name='Felix', weight='half stone')
25+
26+
assert str(e.value) == "'half stone' is not compatible with weight:float."
27+
28+
29+
def test_constructor_attribute_error():
30+
class Cat(Persistent):
31+
name: str
32+
weight: float
33+
34+
with pytest.raises(AttributeError) as e:
35+
felix = Cat(name='Felix', weight=3.2, age=7)
36+
37+
assert str(e.value) == "'Cat' has no attribute 'age'"

0 commit comments

Comments
 (0)