Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ It accepts in input the following sources:
* direct `CREATE TABLE` sql statements
* sqlite file path
* postgres connection string
* mysql connection from the [mysql-connector-python](https://github.com/mysql/mysql-connector-python) library

## Installation

Expand Down Expand Up @@ -67,13 +68,28 @@ from sqlmodelgen import gen_code_from_postgres
code = gen_code_from_postgres('postgres://USER:PASSWORD@HOST:PORT/DBNAME')
```

### Generating from MYSQL

The separate `mysql` extension is required, it can be installed with `pip install sqlmodelgen[mysql]`.

```python
import mysql.connector
from sqlmodelgen import gen_code_from_mysql

# instantiate your connection
conn = mysql.connector.connect(host='YOURHOST', port=3306, user='YOURUSER', password='YOURPASSWORD')

code = gen_code_from_mysql(conn, 'YOURDBNAME')
```

### Relationships

`sqlmodelgen` allows to build relationships by passing the argument `generate_relationships=True` to the functions:

* `gen_code_from_sql`
* `gen_code_from_sqlite`
* `gen_code_from_postgres`
* `gen_code_from_mysql`

In such case `sqlmodelgen` is going to generate relationships between classes based on the foreign keys retrieved.
The following example
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sqlmodelgen"
version = "0.0.14"
version = "0.0.15"
description = "Generate SQLModel code from SQL"
license = {file = "LICENSE"}
readme = "README.md"
Expand All @@ -16,13 +16,18 @@ classifiers = [
"Topic :: Software Development :: Code Generators",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12"
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14"
]

[project.optional-dependencies]
postgres = [
"psycopg[binary]>=3.2.6",
]
mysql = [
"mysql-connector-python>=9.5.0",
]

[tool.uv]
dev-dependencies = [
Expand Down
7 changes: 6 additions & 1 deletion src/sqlmodelgen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
gen_code_from_sqlite,
)

from .utils.dependency_checker import check_postgres_deps
from .utils.dependency_checker import check_postgres_deps, check_mysql_deps

if check_postgres_deps():
from .sqlmodelgen import (
gen_code_from_postgres
)

if check_mysql_deps():
from .sqlmodelgen import (
gen_code_from_mysql
)
212 changes: 212 additions & 0 deletions src/sqlmodelgen/ir/mysql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from typing import Iterator

from mysql.connector.connection_cext import CMySQLConnection
from mysql.connector.cursor_cext import CMySQLCursor

from sqlmodelgen.ir.ir import (
ColIR,
TableIR,
SchemaIR,
FKIR
)
from sqlmodelgen.ir.query import ColQueryData, ContraintsData, ir_build

class MySQLCollector:

def __init__(
self,
cnx: CMySQLConnection,
dbname: str,
):
self.cnx = cnx
self.dbname = dbname


def collect_table_names(self) -> Iterator[str]:
cur = self.cnx.cursor()
yield from collect_tables(cur, self.dbname)


def collect_columns(self, table_name: str) -> Iterator[ColQueryData]:
cur = self.cnx.cursor()
yield from collect_columns(cur, self.dbname, table_name)


def collect_constraints(self) -> ContraintsData:
cur = self.cnx.cursor()

uniques = collect_uniques(cur, self.dbname)
primary_keys = collect_primary_keys(cur, self.dbname)
foreign_keys = collect_foreign_keys(cur, self.dbname)

return ContraintsData(
uniques=uniques,
primary_keys=primary_keys,
foreign_keys=foreign_keys,
)


def collect_mysql_ir(cnx: CMySQLConnection, dbname: str) -> SchemaIR:
return ir_build(collector=MySQLCollector(
cnx=cnx,
dbname=dbname
))


def collect_columns(
cur: CMySQLCursor,
schema_name: str,
table_name: str,
) -> Iterator[ColQueryData]:
cur.execute(f'''SELECT
COLUMN_NAME,
ORDINAL_POSITION,
IS_NULLABLE,
DATA_TYPE,
COLUMN_TYPE
FROM
information_schema.COLUMNS
WHERE
TABLE_SCHEMA = '{schema_name}'
AND TABLE_NAME = '{table_name}'
ORDER BY
TABLE_NAME,
ORDINAL_POSITION;''')

for col_name, ord_pos, is_nullable, data_type, col_type in cur.fetchall():
yield ColQueryData(
name=col_name,
data_type=data_type,
is_nullable=is_nullable=='YES'
)


def collect_tables(cur: CMySQLCursor, schema_name: str) -> Iterator[str]:
cur.execute(f'''SELECT
TABLE_NAME
FROM
information_schema.TABLES
WHERE
TABLE_SCHEMA = '{schema_name}'
ORDER BY
TABLE_NAME;''')

for elem in cur.fetchall():
yield elem[0]


def collect_uniques(
cur: CMySQLCursor,
schema_name: str,
):
cur.execute('''SELECT
tc.TABLE_SCHEMA,
tc.TABLE_NAME,
tc.CONSTRAINT_NAME,
kcu.COLUMN_NAME
FROM
information_schema.TABLE_CONSTRAINTS tc
JOIN
information_schema.KEY_COLUMN_USAGE kcu
ON tc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME
AND tc.TABLE_SCHEMA = kcu.TABLE_SCHEMA
AND tc.TABLE_NAME = kcu.TABLE_NAME
WHERE
tc.CONSTRAINT_TYPE = 'UNIQUE'
ORDER BY
tc.TABLE_SCHEMA,
tc.TABLE_NAME,
kcu.ORDINAL_POSITION;''')

result: dict[str, set[str]] = dict()

for table_schema, table_name, constraint_name, column_name in cur.fetchall():
if table_schema != schema_name:
continue

if table_name not in result.keys():
result[table_name] = set()

result[table_name].add(column_name)

return result


def collect_primary_keys(
cur: CMySQLCursor,
schema_name: str,
) -> dict[str, set[str]]:
cur.execute(f'''SELECT
TABLE_NAME,
COLUMN_NAME
FROM
information_schema.KEY_COLUMN_USAGE
WHERE
CONSTRAINT_NAME = 'PRIMARY'
AND TABLE_SCHEMA = '{schema_name}'
ORDER BY
TABLE_SCHEMA,
TABLE_NAME,
ORDINAL_POSITION;''')

result: dict[str, set[str]] = dict()

for table_name, col_name in cur.fetchall():
if table_name not in result.keys():
result[table_name] = set()

result[table_name].add(col_name)

return result


def collect_foreign_keys(
cur: CMySQLCursor,
schema_name: str,
) -> dict[str, dict[str, FKIR]]:
cur.execute(f'''SELECT
kcu.TABLE_SCHEMA,
kcu.TABLE_NAME,
kcu.CONSTRAINT_NAME,
kcu.COLUMN_NAME,
kcu.REFERENCED_TABLE_SCHEMA,
kcu.REFERENCED_TABLE_NAME,
kcu.REFERENCED_COLUMN_NAME
FROM
information_schema.KEY_COLUMN_USAGE kcu
JOIN
information_schema.TABLE_CONSTRAINTS tc
ON kcu.CONSTRAINT_NAME = tc.CONSTRAINT_NAME
AND kcu.TABLE_SCHEMA = tc.TABLE_SCHEMA
AND kcu.TABLE_NAME = tc.TABLE_NAME
WHERE
tc.CONSTRAINT_TYPE = 'FOREIGN KEY'
AND kcu.TABLE_SCHEMA = '{schema_name}'
AND kcu.REFERENCED_TABLE_SCHEMA = '{schema_name}'
ORDER BY
kcu.TABLE_SCHEMA,
kcu.TABLE_NAME,
kcu.ORDINAL_POSITION;''')

result: dict[str, dict[str, FKIR]] = dict()

for (
table_schema,
table_name,
constraint_name,
column_name,
referenced_table_schema,
referenced_table_name,
referenced_column_name,
) in cur.fetchall():
table_fks = result.get(table_name)
if table_fks is None:
table_fks = dict()
result[table_name] = table_fks

table_fks[column_name] = FKIR(
target_table=referenced_table_name,
target_column=referenced_column_name
)

return result
89 changes: 89 additions & 0 deletions src/sqlmodelgen/ir/query/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass
from typing import Iterable, Iterator, Protocol

from sqlmodelgen.ir.ir import (
ColIR,
TableIR,
SchemaIR,
FKIR
)

@dataclass
class ColQueryData:
name: str
data_type: str
is_nullable: bool = True

@dataclass
class ContraintsData:
uniques: dict[str, set[str]]
primary_keys: dict[str, set[str]]
foreign_keys: dict[str, dict[str, FKIR]]

def is_unique(self, table_name: str, column_name: str) -> bool:
return column_name in self.uniques.get(table_name, set())

def is_primary_key(self, table_name: str, column_name: str) -> bool:
return column_name in self.primary_keys.get(table_name, set())

def get_foreign_key(self, table_name: str, column_name: str) -> FKIR | None:
table_fks = self.foreign_keys.get(table_name)

if table_fks is None:
return None

return table_fks.get(column_name)

class QCollector(Protocol):
'''
a protocol for collection of stuff from sql, that is which a sql collector shall satisfy
'''

def collect_table_names(self) -> Iterator[str]:
pass

def collect_columns(self, table_name: str) -> Iterator[ColQueryData]:
pass

def collect_constraints(self) -> ContraintsData:
pass


def ir_build(collector: QCollector) -> SchemaIR:
constraints = collector.collect_constraints()

tables_names = list(collector.collect_table_names())

table_irs: list[TableIR] = list()
for table_name in tables_names:
cols_data = collector.collect_columns(table_name)

table_irs.append(TableIR(
name=table_name,
col_irs=list(build_cols_ir(
cols_data=cols_data,
table_name=table_name,
constraints=constraints,
))
))

return SchemaIR(
table_irs=table_irs
)

def build_cols_ir(
cols_data: Iterable[ColQueryData],
table_name: str,
constraints: ContraintsData
) -> Iterator[ColIR]:
for col_data in cols_data:
# TODO: a lot of ORs here for new constraints coming from structure, no?
# in theory what arrives from the constraints should have priority I guess
yield ColIR(
name=col_data.name,
data_type=col_data.data_type,
primary_key=constraints.is_primary_key(table_name, col_data.name),
not_null=not col_data.is_nullable, # TODO: handle this into a bool
unique=constraints.is_unique(table_name, col_data.name),
foreign_key=constraints.get_foreign_key(table_name, col_data.name)
)
Loading
Loading