Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/ferro/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
get_origin,
)

from pydantic import BaseModel, Field as PydanticField
from pydantic import BaseModel
from pydantic import Field as PydanticField
from pydantic.fields import FieldInfo

from ._core import register_model_schema
Expand Down Expand Up @@ -188,7 +189,14 @@ def __new__(mcs, name, bases, namespace, **kwargs):
if isinstance(metadata.to, ForwardRef):
target_name = metadata.to.__forward_arg__

setattr(cls, field_name, ForwardDescriptor(field_name, target_name))
setattr(
cls,
field_name,
ForwardDescriptor(
target_model_name=target_name,
field_name=field_name,
),
)
else:
setattr(cls, field_name, None)

Expand All @@ -205,9 +213,9 @@ def __new__(mcs, name, bases, namespace, **kwargs):
if "properties" in schema:
for f_name, metadata in ferro_fields.items():
if f_name in schema["properties"]:
schema["properties"][f_name][
"primary_key"
] = metadata.primary_key
schema["properties"][f_name]["primary_key"] = (
metadata.primary_key
)
prop = schema["properties"][f_name]
is_int = prop.get("type") == "integer" or any(
item.get("type") == "integer"
Expand Down
24 changes: 13 additions & 11 deletions src/ferro/relations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def resolve_relationships():
target_model,
rel.related_name,
RelationshipDescriptor(
model_name, field_name, is_one_to_one=getattr(rel, "unique", False)
target_model_name=model_name,
field_name=field_name,
is_one_to_one=getattr(rel, "unique", False),
),
)
elif isinstance(rel, ManyToManyField):
Expand All @@ -68,8 +70,8 @@ def resolve_relationships():
_MODEL_REGISTRY_PY[model_name],
field_name,
RelationshipDescriptor(
target_model.__name__,
field_name,
target_model_name=target_model.__name__,
field_name=field_name,
is_m2m=True,
join_table=join_table,
source_col=source_col,
Expand All @@ -81,8 +83,8 @@ def resolve_relationships():
target_model,
rel.related_name,
RelationshipDescriptor(
model_name,
rel.related_name,
target_model_name=model_name,
field_name=rel.related_name,
is_m2m=True,
join_table=join_table,
source_col=target_col, # Reversed for the back side
Expand Down Expand Up @@ -119,12 +121,12 @@ def resolve_relationships():
if "properties" in schema:
for f_name, metadata in model_cls.ferro_fields.items():
if f_name in schema["properties"]:
schema["properties"][f_name][
"primary_key"
] = metadata.primary_key
schema["properties"][f_name][
"autoincrement"
] = metadata.autoincrement
schema["properties"][f_name]["primary_key"] = (
metadata.primary_key
)
schema["properties"][f_name]["autoincrement"] = (
metadata.autoincrement
)
schema["properties"][f_name]["unique"] = metadata.unique
schema["properties"][f_name]["index"] = metadata.index

Expand Down
44 changes: 20 additions & 24 deletions src/ferro/relations/descriptors.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
from typing import TYPE_CHECKING

from pydantic import BaseModel

if TYPE_CHECKING:
from ferro.models import Model

from ..state import _MODEL_REGISTRY_PY


class RelationshipDescriptor:
class RelationshipDescriptor(BaseModel):
"""Descriptor that returns either a Query object or a single object (for 1:1)."""

def __init__(
self,
target_model_name: str,
field_name: str,
is_one_to_one: bool = False,
is_m2m: bool = False,
join_table: str | None = None,
source_col: str | None = None,
target_col: str | None = None,
):
self.target_model_name = target_model_name
self.field_name = field_name
self.is_one_to_one = is_one_to_one
self.is_m2m = is_m2m
self.join_table = join_table
self.source_col = source_col
self.target_col = target_col
self._target_model = None
target_model_name: str
field_name: str
is_one_to_one: bool = False
is_m2m: bool = False
join_table: str | None = None
source_col: str | None = None
target_col: str | None = None
_target_model: Model | None = None

def __get__(self, instance, owner):
if instance is None:
Expand Down Expand Up @@ -70,13 +67,12 @@ def __get__(self, instance, owner):
return query


class ForwardDescriptor:
class ForwardDescriptor(BaseModel):
"""Descriptor that handles lazy loading of a related object."""

def __init__(self, field_name: str, target_model_name: str):
self.field_name = field_name
self.target_model_name = target_model_name
self._target_model = None
target_model_name: str
field_name: str
_target_model: Model | None = None

def __get__(self, instance, owner):
if instance is None:
Expand Down
43 changes: 42 additions & 1 deletion tests/test_auto_migrate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Annotated

import pytest
from pydantic import Field

import ferro
from ferro import Model
from pydantic import Field
from ferro.base import FerroField, ManyToManyField
from ferro.query import BackRef


class AutoMigratedUser(Model):
Expand Down Expand Up @@ -35,3 +40,39 @@ async def test_connect_without_auto_migrate():
# Manual call still works
await ferro.create_tables()
assert True


@pytest.mark.asyncio
async def test_m2m_join_table_created_during_auto_migrate():
"""Verify that the many-to-many join table is created when auto_migrate=True.
We clear registries, migrate a fresh in-memory DB, then use the M2M API; if the
join table were not created, .add() would fail. No second connection needed."""
from ferro import clear_registry, connect, reset_engine
from ferro.state import _JOIN_TABLE_REGISTRY, _MODEL_REGISTRY_PY, _PENDING_RELATIONS

reset_engine()
clear_registry()
_MODEL_REGISTRY_PY.clear()
_PENDING_RELATIONS.clear()
_JOIN_TABLE_REGISTRY.clear()

class Actor(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
name: str
movies: Annotated[list["Movie"], ManyToManyField(related_name="actors")] = None

class Movie(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
title: str
actors: BackRef[Actor] = None

await connect("sqlite::memory:", auto_migrate=True)

actor = await Actor.create(name="Alice")
movie = await Movie.create(title="Matrix")
await actor.movies.add(movie)

linked = await actor.movies.all()
assert len(linked) == 1
assert linked[0].id == movie.id
assert linked[0].title == "Matrix"
20 changes: 9 additions & 11 deletions tests/test_relationship_engine.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import pytest
from typing import Annotated, ForwardRef

import pytest

from ferro import (
Model,
BackRef,
FerroField,
Field,
reset_engine,
clear_registry,
ForeignKey,
BackRef,
Model,
clear_registry,
reset_engine,
)


Expand Down Expand Up @@ -150,9 +152,7 @@ def test_back_ref_via_annotated_field():
class UserAnnotated(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
username: str
posts: Annotated[
list["PostAnnotated"] | None, Field(back_ref=True)
] = None
posts: Annotated[list["PostAnnotated"] | None, Field(back_ref=True)] = None

class PostAnnotated(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
Expand All @@ -179,9 +179,7 @@ def test_back_ref_and_field_back_ref_raises():
class UserDouble(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
username: str
posts: BackRef[list["PostDouble"]] = Field(
default=None, back_ref=True
)
posts: BackRef[list["PostDouble"]] = Field(default=None, back_ref=True)

class PostDouble(Model):
id: Annotated[int | None, FerroField(primary_key=True)] = None
Expand Down
12 changes: 7 additions & 5 deletions tests/test_schema_constraints.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import pytest
import sqlite3
from typing import Annotated

import pytest

from ferro import (
Model,
connect,
BackRef,
FerroField,
ForeignKey,
BackRef,
reset_engine,
Model,
clear_registry,
connect,
reset_engine,
)


Expand Down