Skip to content

Commit 160cfce

Browse files
committed
God bless our pity souls
1 parent aead9a7 commit 160cfce

4 files changed

Lines changed: 137 additions & 140 deletions

File tree

src/api/api.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from pydantic import BaseModel
33
from fastapi.middleware.cors import CORSMiddleware
44

5-
from .models import CreateCamera, CreateZone
5+
from .models import *
66

77
import cv2
88
from fastapi.responses import StreamingResponse
@@ -224,13 +224,7 @@ async def get_camera(camera_id: int):
224224
async def update_camera(camera_id: int, updated_fields: Request):
225225
try:
226226
updated_fields = await updated_fields.json()
227-
228-
if not isinstance(updated_fields, dict):
229-
raise HTTPException(
230-
status_code=400,
231-
detail="Request body must be a JSON dict"
232-
)
233-
227+
234228
camera = self.db_manager.update_camera(camera_id, updated_fields)
235229

236230
if camera is None:
@@ -250,17 +244,9 @@ async def update_camera(camera_id: int, updated_fields: Request):
250244
)
251245

252246
@self.app.put("/zones/{zone_id}")
253-
async def update_zone(zone_id: int, updated_fields: Request):
247+
async def update_zone(zone_id: int, update: UpdateZone):
254248
try:
255-
updated_fields = await updated_fields.json()
256-
257-
if not isinstance(updated_fields, dict):
258-
raise HTTPException(
259-
status_code=400,
260-
detail="Request body must be a JSON dict"
261-
)
262-
263-
zone = self.db_manager.update_zone(zone_id, updated_fields)
249+
zone = self.db_manager.update_zone(zone_id, update)
264250

265251
if zone is None:
266252
raise HTTPException(

src/api/models.py

Lines changed: 40 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import List, Dict, TypedDict, Any
2-
from pydantic import BaseModel, field_validator
1+
from typing import List, Dict, TypedDict, Any, Optional, Literal
2+
from pydantic import BaseModel, field_validator, Field, ValidationError
33
import json
44

55
class CreateCamera(BaseModel):
@@ -15,7 +15,7 @@ class CreateCamera(BaseModel):
1515
@classmethod
1616
def validate_title(cls, title):
1717
if len(title) < 1 or len(title) > 200:
18-
raise ValueError(f"Invalid camera title: {title}")
18+
raise ValidationError(f"Invalid camera title: {title}")
1919
return title
2020

2121
@field_validator('source')
@@ -27,28 +27,28 @@ def validate_source(cls, source):
2727
@classmethod
2828
def validate_latitude(cls, latitude):
2929
if latitude > 90 or latitude < -90:
30-
raise ValueError(f"Invalid latitude value: {latitude}")
30+
raise ValidationError(f"Invalid latitude value: {latitude}")
3131
return latitude
3232

3333
@field_validator('longitude')
3434
@classmethod
3535
def validate_longitude(cls, longitude):
3636
if longitude > 180 or longitude < -180:
37-
raise ValueError(f"Invalid longitude value: {longitude}")
37+
raise ValidationError(f"Invalid longitude value: {longitude}")
3838
return longitude
3939

4040
@field_validator('image_width')
4141
@classmethod
4242
def validate_image_width(cls, image_width):
4343
if image_width <= 0:
44-
raise ValueError(f"Invalid image_width value: {image_width}")
44+
raise ValidationError(f"Invalid image_width value: {image_width}")
4545
return image_width
4646

4747
@field_validator('image_height')
4848
@classmethod
4949
def validate_image_height(cls, image_height):
5050
if image_height <= 0:
51-
raise ValueError(f"Invalid image_height value: {image_height}")
51+
raise ValidationError(f"Invalid image_height value: {image_height}")
5252
return image_height
5353

5454
@field_validator('calib')
@@ -57,99 +57,52 @@ def validate_calib(cls, calib):
5757
if calib is not None:
5858
try:
5959
json.dumps(calib)
60-
except (TypeError, ValueError) as e:
61-
raise ValueError(f"Invalid calibration data: {e}")
60+
except:
61+
raise ValidationError(f"Invalid calibration data")
6262
return calib
6363

6464
class Point(BaseModel):
65-
latitude: float
66-
longitude: float
67-
x: int
68-
y: int
65+
latitude: float = Field(ge=-90, le=90)
66+
longitude: float = Field(ge=-180, le=180)
67+
x: int = Field(ge=0)
68+
y: int = Field(ge=0)
6969

7070
def __eq__(self, other):
7171
if not isinstance(other, Point):
7272
return False
7373
return (self.x == other.x and self.y == other.y)
7474

75-
@field_validator('latitude')
76-
@classmethod
77-
def validate_latitude(cls, latitude):
78-
if latitude > 90 or latitude < -90:
79-
raise ValueError(f"Invalid latitude value: {latitude}")
80-
return latitude
81-
82-
@field_validator('longitude')
83-
@classmethod
84-
def validate_longitude(cls, longitude):
85-
if longitude > 180 or longitude < -180:
86-
raise ValueError(f"Invalid longitude value: {longitude}")
87-
return longitude
88-
89-
@field_validator('x')
90-
@classmethod
91-
def validate_x(cls, x):
92-
if x < 0:
93-
raise ValueError(f"Invalid x value: {x}")
94-
95-
return x
75+
class ZoneBase(BaseModel):
76+
camera_id: Optional[int] = Field(None, ge=1)
77+
zone_type: Optional[Literal['parallel', 'standard']] = None
78+
capacity: Optional[int] = Field(None, gt=0)
79+
pay: Optional[int] = Field(None, ge=0)
80+
points: Optional[List[Point]] = None
9681

97-
@field_validator('y')
82+
@field_validator('points')
9883
@classmethod
99-
def validate_y(cls, y):
100-
if y < 0:
101-
raise ValueError(f"Invalid y value: {y}")
84+
def validate_points(cls, v: Optional[List[Point]]) -> Optional[List[Point]]:
85+
if v is None:
86+
return v
87+
88+
if len(v) != 4:
89+
raise ValidationError(f"Invalid points count: {len(v)}. Must be exactly 4 points")
10290

103-
return y
91+
for lhs in range(0, 4):
92+
for rhs in range(lhs + 1, 4):
93+
if v[lhs] == v[rhs]:
94+
raise ValidationError(f"Degenerate rectangle")
95+
96+
return v
10497

105-
class CreateZone(BaseModel):
98+
# 3. Модели запросов через наследование
99+
class CreateZone(ZoneBase):
106100
camera_id: int
107-
zone_type: str
108-
capacity: int
109-
pay: int
101+
zone_type: Literal['parallel', 'standard']
102+
capacity: int = Field(gt=0)
103+
pay: int = Field(ge=0)
110104
points: List[Point]
111105

112-
@field_validator('camera_id')
113-
@classmethod
114-
def validate_camera_id(cls, camera_id):
115-
if camera_id <= 0:
116-
raise ValueError(f"Invalid camera_id value: {camera_id}")
117-
118-
return camera_id
119-
120-
@field_validator('zone_type')
121-
@classmethod
122-
def validate_zone_type(cls, zone_type):
123-
if zone_type not in ['parallel', 'standard']:
124-
raise ValueError(f"Invalid zone_type value: {zone_type}")
125-
126-
return zone_type
127-
128-
@field_validator('capacity')
129-
@classmethod
130-
def validate_capacity(cls, capacity):
131-
if capacity <= 0:
132-
raise ValueError(f"Invalid capacity value: {capacity}")
133-
134-
return capacity
135-
136-
@field_validator('pay')
137-
@classmethod
138-
def validate_pay(cls, pay):
139-
if pay < 0:
140-
raise ValueError(f"Invalid pay value: {pay}")
141-
142-
return pay
143-
144-
@field_validator('points')
145-
@classmethod
146-
def validate_points(cls, points):
147-
if len(points) != 4:
148-
raise ValueError(f"Invalid points count: {len(points)}")
149-
150-
for lhs in range(0, len(points)):
151-
for rhs in range(lhs + 1, len(points)):
152-
if points[lhs] == points[rhs]:
153-
raise ValueError(f"Degenerate rectangle")
154-
155-
return points
106+
class UpdateZone(ZoneBase):
107+
occupied: Optional[int] = Field(None, ge=0)
108+
confidence: Optional[float] = Field(None, ge=0, le=1)

src/db_manager/db_manager.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
2-
from sqlalchemy import create_engine, inspect, text, func, update, select
2+
from sqlalchemy import create_engine, inspect, text, func, update, select, delete
33
from sqlalchemy.orm import sessionmaker, Session, joinedload
44
from sqlalchemy.exc import SQLAlchemyError
55
import contextlib
66
from typing import Generator, Optional
7+
from fastapi import HTTPException
78

89
from .models import Base, Camera, ParkingZone, ParkingZonePoint, datetime, timezone
910

@@ -120,6 +121,31 @@ def _create_tables(self):
120121
print(f"Error creating tables: {e}")
121122
raise
122123

124+
def _update_points(self, zone_id, points):
125+
if not isinstance(points, list[dict]) or len(points) != 4:
126+
raise HTTPException(
127+
status_code=400,
128+
detail="Invalid points request format"
129+
)
130+
131+
with self.get_session() as session:
132+
stmt = delete(ParkingZonePoint).where(ParkingZonePoint.parking_zone_id == zone_id)
133+
session.execute(stmt)
134+
135+
for point in points:
136+
stmt = update(ParkingZonePoint).where(ParkingZonePoint.id == camera_id)
137+
138+
stmt = stmt.values(updated_fields)
139+
140+
session.execute(stmt)
141+
142+
session.commit()
143+
144+
camera = session.query(Camera).filter(Camera.id == camera_id).one_or_none()
145+
146+
return camera.serialize() if camera is not None else None
147+
148+
123149
@contextlib.contextmanager
124150
def get_session(self) -> Generator[Session, None, None]:
125151
"""Контекстный менеджер для получения сессии"""
@@ -277,6 +303,12 @@ def get_camera(self, camera_id):
277303
return camera.serialize() if camera is not None else None
278304

279305
def update_camera(self, camera_id, updated_fields):
306+
if not isinstance(updated_fields, dict):
307+
raise HTTPException(
308+
status_code=400,
309+
detail="Request body must be a JSON dict"
310+
)
311+
280312
with self.get_session() as session:
281313
stmt = update(Camera).where(Camera.id == camera_id)
282314

@@ -290,16 +322,38 @@ def update_camera(self, camera_id, updated_fields):
290322

291323
return camera.serialize() if camera is not None else None
292324

293-
def update_zone(self, zone_id, updated_fields):
325+
def update_zone(self, zone_id, update):
326+
if self.get_zone(zone_id) is None: return None
327+
294328
with self.get_session() as session:
329+
330+
if update.points is not None:
331+
session.execute(
332+
delete(ParkingZonePoint).where(ParkingZonePoint.parking_zone_id == zone_id)
333+
)
334+
335+
for point in update.points:
336+
session.add(
337+
ParkingZonePoint(
338+
parking_zone_id=zone_id,
339+
latitude=point.latitude,
340+
longitude=point.longitude,
341+
x=point.x,
342+
y=point.y
343+
)
344+
)
345+
295346
stmt = update(ParkingZone).where(ParkingZone.id == zone_id)
296-
if "capacity" in updated_fields.keys():
297-
updated_fields['parking_lots_count'] = updated_fields.pop('capacity')
347+
348+
update_dump = update.model_dump(
349+
exclude_none=True,
350+
exclude={'points'}
351+
)
298352

299353
stmt = stmt.values(
300-
updated_fields | {"updated_at": datetime.now(timezone.utc)}
301-
if "occupied" not in updated_fields else
302-
updated_fields | {"occupancy_updated_at": datetime.now(timezone.utc)})
354+
update_dump | {"updated_at": datetime.now(timezone.utc)}
355+
if "occupied" not in update_dump else
356+
update_dump | {"occupancy_updated_at": datetime.now(timezone.utc)})
303357

304358
session.execute(stmt)
305359

0 commit comments

Comments
 (0)