Skip to content
Merged

V0.1 #12

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 59 additions & 3 deletions internal_admin/admin/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from typing import Any

from sqlalchemy import Boolean, Date, DateTime
from sqlalchemy.inspection import inspect as sa_inspect
from sqlalchemy.orm import Session
from sqlalchemy.sql.sqltypes import TypeDecorator

from .model_admin import ModelAdmin

Expand Down Expand Up @@ -204,9 +206,44 @@ def get_choices(self, session: Session, model_class: type[Any]) -> list[tuple[An
if not hasattr(model_class, self.field_name):
return []

# This is simplified - in practice you'd need proper relationship introspection
# For now, return empty choices
return []
try:
relationship = self._find_relationship(model_class)
if relationship is None:
return []

related_model = relationship.mapper.class_
related_mapper = sa_inspect(related_model)
pk_attr = related_mapper.primary_key[0].key

label_attr = self.display_field if hasattr(related_model, self.display_field) else None
if label_attr is None:
for candidate in ("display_name", "name", "title", "username", "email"):
if hasattr(related_model, candidate):
label_attr = candidate
break

query = session.query(related_model)
if label_attr:
query = query.order_by(getattr(related_model, label_attr).asc())
else:
query = query.order_by(getattr(related_model, pk_attr).asc())

rows = query.limit(200).all()
choices = []
for row in rows:
value = getattr(row, pk_attr)
if label_attr:
label_value = getattr(row, label_attr, None)
display = str(label_value) if label_value not in (None, "") else str(value)
else:
display = str(row)
if display.startswith("<") and " object at " in display:
display = str(value)
choices.append((value, display))

return choices
except Exception:
return []

def apply_filter(self, query: Any, value: Any) -> Any:
"""Apply foreign key filter."""
Expand All @@ -216,8 +253,27 @@ def apply_filter(self, query: Any, value: Any) -> Any:
model_class = query.column_descriptions[0]['type']
field = getattr(model_class, self.field_name)

column = model_class.__table__.columns.get(self.field_name)
if column is not None:
column_type = type(column.type)
if isinstance(column.type, TypeDecorator):
column_type = type(column.type.impl)
try:
if column_type.__name__ in {"Integer", "BigInteger", "SmallInteger"}:
value = int(value)
except (TypeError, ValueError):
return query

return query.filter(field == value)

def _find_relationship(self, model_class: type[Any]) -> Any | None:
mapper = sa_inspect(model_class)
for relationship in mapper.relationships:
for local_column in relationship.local_columns:
if local_column.key == self.field_name:
return relationship
return None


class FilterManager:
"""
Expand Down
69 changes: 56 additions & 13 deletions internal_admin/admin/form_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from typing import Any

from sqlalchemy import Boolean, Column, Date, DateTime, Float, Integer, String, Text
from sqlalchemy.inspection import inspect as sa_inspect
from sqlalchemy.orm import Session
from sqlalchemy.sql.sqltypes import TypeDecorator

from ..registry import get_registry
from .model_admin import ModelAdmin


Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(self, model_admin: ModelAdmin) -> None:
self.model_admin = model_admin
self.model = model_admin.model
self._type_mapping = self._get_type_mapping()
self._foreign_key_choice_limit = 200

def generate_form_fields(self, session: Session, instance: Any | None = None) -> list[FormField]:
"""
Expand Down Expand Up @@ -198,24 +201,64 @@ def _get_foreign_key_choices(self, column: Column, session: Session) -> list[tup
Returns:
List of (value, label) tuples
"""
choices = [("", "-- Select --")]
related_model = self._get_related_model_for_column(column)
if related_model is None:
return []

# Get the referenced table and model
list(column.foreign_keys)[0]

# Find the model class for the referenced table
# This is a simplified approach - in practice, you might need
# a more sophisticated model registry lookup
try:
# Try to find model class by table name
# This requires models to be registered or discoverable
mapper = sa_inspect(related_model)
pk_attr = mapper.primary_key[0].key

# For now, skip foreign key choices - can be implemented later
# when we have better model discovery
return choices
label_attr = self._resolve_related_label_attr(related_model)
query = session.query(related_model)
if label_attr and hasattr(related_model, label_attr):
query = query.order_by(getattr(related_model, label_attr).asc())
else:
query = query.order_by(getattr(related_model, pk_attr).asc())

rows = query.limit(self._foreign_key_choice_limit).all()
return [
(getattr(row, pk_attr), self._get_related_display_value(row, label_attr, pk_attr))
for row in rows
]
except Exception:
return choices
return []

def _get_related_model_for_column(self, column: Column) -> type[Any] | None:
relationships = sa_inspect(self.model).relationships
for relationship in relationships:
if column in relationship.local_columns:
return relationship.mapper.class_

try:
foreign_key = next(iter(column.foreign_keys))
except StopIteration:
return None

referenced_table = foreign_key.column.table
for model_class in get_registry().get_registered_models().keys():
if getattr(model_class, "__table__", None) is referenced_table:
return model_class

return None

def _resolve_related_label_attr(self, related_model: type[Any]) -> str | None:
preferred = ("display_name", "name", "title", "username", "email")
for attr_name in preferred:
if hasattr(related_model, attr_name):
return attr_name
return None

def _get_related_display_value(self, row: Any, label_attr: str | None, pk_attr: str) -> str:
if label_attr:
value = getattr(row, label_attr, None)
if value not in (None, ""):
return str(value)

value = str(row)
if value.startswith("<") and " object at " in value:
return str(getattr(row, pk_attr))
return value

def validate_form_data(self, form_data: dict[str, Any]) -> dict[str, Any]:
"""
Expand Down
44 changes: 44 additions & 0 deletions internal_admin/admin/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from typing import Any

from sqlalchemy import or_
from sqlalchemy import Boolean, Date, DateTime, Float, Integer
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm.strategy_options import selectinload
from sqlalchemy.sql.sqltypes import TypeDecorator

from .model_admin import ModelAdmin

Expand Down Expand Up @@ -199,11 +201,17 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query:
continue

field = getattr(self.model, field_name)
column = self.model.__table__.columns.get(field_name)

# Skip empty values
if value is None or value == "":
continue

try:
value = self._coerce_filter_value(column, value)
except ValueError:
continue

# Handle different filter types
if isinstance(value, (list, tuple)):
# Multiple values - use IN clause
Expand All @@ -217,6 +225,42 @@ def _apply_filters(self, query: Query, filters: dict[str, Any] | None) -> Query:

return query

def _coerce_filter_value(self, column: Any, value: Any) -> Any:
if column is None:
return value

if isinstance(value, (list, tuple)):
return [self._coerce_filter_value(column, item) for item in value]

column_type = type(column.type)
if isinstance(column.type, TypeDecorator):
column_type = type(column.type.impl)

if column_type == Boolean:
if isinstance(value, bool):
return value
normalized = str(value).strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True
if normalized in {"false", "0", "no", "off"}:
return False
raise ValueError("Invalid boolean filter value")

if column_type == Integer:
return int(value)
if column_type == Float:
return float(value)
if column_type == Date:
if isinstance(value, str):
return Date().python_type.fromisoformat(value)
return value
if column_type == DateTime:
if isinstance(value, str):
return DateTime().python_type.fromisoformat(value.replace("T", " "))
return value

return value

def _apply_ordering(self, query: Query, ordering: list[str] | None) -> Query:
"""
Apply ordering to query.
Expand Down
2 changes: 1 addition & 1 deletion internal_admin/templates/admin/form.html
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ <h5 class="card-title mb-0">{{ "Basic Information" if is_create else "Edit Detai
{% if field.choices %}
{% for value, display in field.choices %}
<option value="{{ value }}"
{% if form_data.get(field.name, field.default_value) == value %}selected{% endif %}>
{% if (form_data.get(field.name, field.default_value)|string) == (value|string) %}selected{% endif %}>
{{ display }}
</option>
{% endfor %}
Expand Down
71 changes: 69 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine, Column, Integer, String, Boolean
from sqlalchemy import create_engine, Column, Integer, String, Boolean, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.orm import sessionmaker, Session, relationship

from internal_admin import AdminSite, AdminConfig, ModelAdmin
from internal_admin.auth.models import AdminUser
Expand All @@ -30,6 +30,10 @@ class TestUser(Base):
is_active = Column(Boolean, default=True)
is_superuser = Column(Boolean, default=False)

@property
def display_name(self) -> str:
return self.username or f"User {self.id}"


class TestModel(Base):
"""Simple test model for admin testing."""
Expand All @@ -41,13 +45,51 @@ class TestModel(Base):
is_active = Column(Boolean, default=True)


class TestCategory(Base):
"""Related model for foreign key tests."""
__tablename__ = "test_categories"

id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)

products = relationship("TestProduct", back_populates="category")

def __str__(self) -> str:
return self.name


class TestProduct(Base):
"""Model containing a foreign key for admin tests."""
__tablename__ = "test_products"

id = Column(Integer, primary_key=True)
name = Column(String(100), nullable=False)
category_id = Column(Integer, ForeignKey("test_categories.id"), nullable=False)
is_active = Column(Boolean, default=True)

category = relationship("TestCategory", back_populates="products")


class TestModelAdmin(ModelAdmin):
"""Test ModelAdmin configuration."""
list_display = ["id", "name", "is_active"]
search_fields = ["name", "description"]
list_filter = ["is_active"]


class TestCategoryAdmin(ModelAdmin):
"""Admin configuration for category model."""
list_display = ["id", "name"]
search_fields = ["name"]


class TestProductAdmin(ModelAdmin):
"""Admin configuration for product model."""
list_display = ["id", "name", "category_id", "is_active"]
search_fields = ["name"]
list_filter = ["category_id", "is_active"]


@pytest.fixture(scope="session")
def test_db() -> Generator[str, None, None]:
"""Create a temporary test database."""
Expand Down Expand Up @@ -87,6 +129,8 @@ def admin_site(admin_config: AdminConfig) -> AdminSite:
# Create fresh AdminSite after clearing registry
site = AdminSite(admin_config)
site.register(TestModel, TestModelAdmin)
site.register(TestCategory, TestCategoryAdmin)
site.register(TestProduct, TestProductAdmin)
return site


Expand Down Expand Up @@ -150,6 +194,29 @@ def test_objects(db_session: Session) -> list[TestModel]:
return objects


@pytest.fixture
def fk_objects(db_session: Session) -> dict[str, Any]:
"""Create related objects for foreign key tests."""
categories = [
TestCategory(name="Hardware"),
TestCategory(name="Software"),
]
db_session.add_all(categories)
db_session.flush()

products = [
TestProduct(name="Keyboard", category_id=categories[0].id, is_active=True),
TestProduct(name="IDE License", category_id=categories[1].id, is_active=True),
]
db_session.add_all(products)
db_session.commit()

return {
"categories": categories,
"products": products,
}


@pytest.fixture
def authenticated_client(client: TestClient, test_user: TestUser) -> TestClient:
"""Create authenticated test client."""
Expand Down
Loading
Loading