Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@ jobs:
timeout-minutes: 10

steps:
- name: Check out the repo
uses: actions/checkout@v4
- uses: actions/checkout@v4
with:
fetch-depth: 0
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}

- name: Setup backend
- name: Setup
id: setup
uses: ./.github/actions/setup

Expand Down
14 changes: 5 additions & 9 deletions sqlalchemy_type_annotations/run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import codegen


from codegen import Codebase
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
import subprocess
Expand Down Expand Up @@ -70,9 +72,7 @@ def run(codebase: Codebase):
continue

# Check for nullable=True
is_nullable = any(
x.name == "nullable" and x.value == "True" for x in db_column_call.args
)
is_nullable = any(x.name == "nullable" and x.value == "True" for x in db_column_call.args)

# Extract the first argument for the column type
first_argument = db_column_call.args[0].source or ""
Expand Down Expand Up @@ -101,9 +101,7 @@ def run(codebase: Codebase):

# Add necessary imports
if not cls.file.has_import("Mapped"):
cls.file.add_import_from_import_string(
"from sqlalchemy.orm import Mapped\n"
)
cls.file.add_import_from_import_string("from sqlalchemy.orm import Mapped\n")

if "Optional" in new_type and not cls.file.has_import("Optional"):
cls.file.add_import_from_import_string("from typing import Optional\n")
Expand All @@ -112,9 +110,7 @@ def run(codebase: Codebase):
cls.file.add_import_from_import_string("from decimal import Decimal\n")

if "datetime" in new_type and not cls.file.has_import("datetime"):
cls.file.add_import_from_import_string(
"from datetime import datetime\n"
)
cls.file.add_import_from_import_string("from datetime import datetime\n")

if class_modified:
classes_modified += 1
Expand Down