Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ version = { source = "file", path = "tortoise/__init__.py" }
excludes = ["./**/.git", "./**/.*_cache", "examples"]
include = ["CHANGELOG.rst", "LICENSE", "README.rst"]

[tool.uv.sources]
pypika-tortoise = { git = "https://github.com/seladb/pypika-tortoise", branch = "add-functions" }

[tool.mypy]
pretty = true
exclude = ["docs"]
Expand Down
100 changes: 98 additions & 2 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,23 @@
Tournament,
)
from tortoise.contrib import test
from tortoise.contrib.test.condition import In, NotEQ
from tortoise.contrib.test import requireCapability
from tortoise.contrib.test.condition import In, NotEQ, NotIn
from tortoise.expressions import Case, F, Q, When
from tortoise.functions import Coalesce, Count, Length, Lower, Max, Trim, Upper
from tortoise.functions import (
Coalesce,
Count,
Length,
Lower,
LPad,
LTrim,
Max,
Replace,
RPad,
RTrim,
Trim,
Upper,
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -379,6 +393,88 @@ async def test_filter_by_aggregation_field_trim(db):
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" 1 ", "1")}


@pytest.mark.asyncio
@pytest.mark.parametrize(
["name", "trim_chars", "trimmed_name"],
[
("xxxhellox", "x", "hello"),
("ababhelloab", "ab", "hello"),
],
)
async def test_filter_by_trim_with_chars(db, name, trim_chars, trimmed_name):
await Tournament.create(name=name)
tournaments = await Tournament.annotate(trimmed_name=Trim("name", trim_chars)).filter(
trimmed_name=trimmed_name
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(name, trimmed_name)}


@pytest.mark.asyncio
async def test_filter_by_ltrim(db):
await Tournament.create(name=" hello ")
tournaments = await Tournament.annotate(trimmed_name=LTrim("name")).filter(
trimmed_name="hello "
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" hello ", "hello ")}


@pytest.mark.asyncio
async def test_filter_by_rtrim(db):
await Tournament.create(name=" hello ")
tournaments = await Tournament.annotate(trimmed_name=RTrim("name")).filter(
trimmed_name=" hello"
)

assert len(tournaments) == 1
assert {(t.name, t.trimmed_name) for t in tournaments} == {(" hello ", " hello")}


@requireCapability(dialect=NotIn("sqlite"))
@pytest.mark.asyncio
async def test_lpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=LPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"xxxxmy world", "xxxxxxxhello"}


@requireCapability(dialect=NotIn("sqlite"))
@pytest.mark.asyncio
async def test_rpad(db):
await Tournament.create(name="hello")
await Tournament.create(name="my world")
tournaments = await Tournament.annotate(pad_name=RPad("name", 12, "x"))
result = set(tournament.pad_name for tournament in tournaments)
assert result == {"my worldxxxx", "helloxxxxxxx"}


@pytest.mark.asyncio
async def test_replace(db):
await Tournament.create(name="Tournament A")
await Tournament.create(name="Tournament B")
tournaments = await Tournament.annotate(replaced_name=Replace("name", "Tournament", "Contest"))
result = {t.replaced_name for t in tournaments}
assert result == {"Contest A", "Contest B"}


@pytest.mark.asyncio
async def test_filter_by_replace(db):
await Tournament.create(name="1st Tournament")
await Tournament.create(name="2nd Tournament")
await Tournament.create(name="3rd Place")

tournaments = await Tournament.annotate(
replaced_name=Replace("name", "Tournament", "Contest")
).filter(replaced_name="1st Contest")
assert len(tournaments) == 1
assert {(t.name, t.replaced_name) for t in tournaments} == {("1st Tournament", "1st Contest")}


@test.requireCapability(dialect=NotEQ("mssql"))
@pytest.mark.asyncio
async def test_filter_by_aggregation_field_length(db):
Expand Down
92 changes: 91 additions & 1 deletion tortoise/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Any

from pypika_tortoise import SqlContext, functions
from pypika_tortoise.terms import Term

from tortoise.expressions import Aggregate, Function
from tortoise.expressions import Aggregate, CombinedExpression, F, Function

##############################################################################
# Standard functions
Expand All @@ -16,6 +19,93 @@ class Trim(Function):

database_func = functions.Trim

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
trim_chars: str = " ",
*default_values: Any,
) -> None:
super().__init__(field, trim_chars, *default_values)

database_func = functions.Trim


class LTrim(Function):
"""
Trims whitespace from the left side of text.

:samp:`LTrim("{FIELD_NAME}")`
"""

database_func = functions.LTrim


class RTrim(Function):
"""
Trims whitespace from the right side of text.

:samp:`RTrim("{FIELD_NAME}")`
"""

database_func = functions.RTrim


class LPad(Function):
"""
Pads the left side of a string with a specified character to reach a certain length.

:samp:`LPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
length: int,
fill_text: str = " ",
*default_values: Any,
) -> None:
super().__init__(field, length, fill_text, *default_values)

database_func = functions.LPad


class RPad(Function):
"""
Pads the right side of a string with a specified character to reach a certain length.

:samp:`RPad("{FIELD_NAME}", length, fill_text)`
"""

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
length: int,
fill_text: str = " ",
*default_values: Any,
) -> None:
super().__init__(field, length, fill_text, *default_values)

database_func = functions.RPad


class Replace(Function):
"""
Replaces all occurrences of a search string with a replacement string.

:samp:`Replace("{FIELD_NAME}", "search", "replacement")`
"""

def __init__(
self,
field: str | F | CombinedExpression | Function | Term,
search: str,
replacement: str,
*default_values: Any,
) -> None:
super().__init__(field, search, replacement, *default_values)

database_func = functions.Replace


class Length(Function):
"""
Expand Down
8 changes: 2 additions & 6 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading