Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/openjd/model/_format_strings/_dyn_constrained_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

from typing import Any, Callable, Optional, Pattern, Union

from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler, ValidationInfo
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
import re

from .._types import ModelParsingContextInterface


class DynamicConstrainedStr(str):
"""Constrained string type for interfacing with Pydantic.
The maximum string length can be dynamically defined at runtime.

The parsing context, a subclass of ModelParsingContextInterface,
is required to construct a DynamicConstrainedStr or subclass.
This enables the FormatString to handle the supported expression types
based on the Open Job Description revision version and extensions.

Note: Does *not* run model validation when constructed.
"""

Expand All @@ -23,14 +30,17 @@ class DynamicConstrainedStr(str):
# ================================
# Reference: https://pydantic-docs.helpmanual.io/usage/types/#custom-data-types

def __new__(cls, value: str, *, context: ModelParsingContextInterface):
return super().__new__(cls, value)

@classmethod
def _get_max_length(cls) -> Optional[int]:
if callable(cls._max_length):
return cls._max_length()
return cls._max_length

@classmethod
def _validate(cls, value: str) -> Any:
def _validate(cls, value: str, info: ValidationInfo) -> Any:
if not isinstance(value, str):
raise ValueError("String required")

Expand All @@ -46,13 +56,20 @@ def _validate(cls, value: str) -> Any:
pattern: str = cls._regex if isinstance(cls._regex, str) else cls._regex.pattern
raise ValueError(f"String does not match the required pattern: {pattern}")

return cls(value)
if type(value) is cls:
return value
else:
if info.context is None:
raise ValueError(
f"Internal parsing error: No parsing context was provided during model validation for the DynamicConstrainedStr subclass {cls.__name__}."
)
return cls(value, context=info.context)

@classmethod
def __get_pydantic_core_schema__(
cls, source_type: type[Any], handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(cls._validate)
return core_schema.with_info_plain_validator_function(cls._validate)

@classmethod
def __get_pydantic_json_schema__(
Expand Down
8 changes: 4 additions & 4 deletions src/openjd/model/_format_strings/_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from .._errors import ExpressionError
from .._symbol_table import SymbolTable
from ._nodes import Node
from ._parser import Parser
from ._parser import parse_format_string_expr
from .._types import ModelParsingContextInterface


class InterpolationExpression:
expr: str
_expression_tree: Node

def __init__(self, expr: str) -> None:
def __init__(self, expr: str, *, context: ModelParsingContextInterface) -> None:
"""Constructor.

Raises:
Expand All @@ -24,10 +25,9 @@ def __init__(self, expr: str) -> None:
expr (str): The expression
"""
self.expr = expr
parser = Parser()

# Raises: ExpressionError, TokenError
self._expresion_tree = parser.parse(expr)
self._expresion_tree = parse_format_string_expr(expr, context=context)

def validate_symbol_refs(self, *, symbols: set[str]) -> None:
"""Check whether this expression can be evaluated correctly given a set of symbol names.
Expand Down
21 changes: 13 additions & 8 deletions src/openjd/model/_format_strings/_format_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .._symbol_table import SymbolTable
from ._dyn_constrained_str import DynamicConstrainedStr
from ._expression import InterpolationExpression
from .._types import ModelParsingContextInterface


@dataclass
Expand All @@ -30,12 +31,11 @@ def __init__(self, *, string: str, start: int, end: int, expr: str = "", details
)
super().__init__(msg)

def __str__(self) -> str:
return self.args[0]


class FormatString(DynamicConstrainedStr):
def __init__(self, value: str):
_processed_list: list[Union[str, ExpressionInfo]]

def __new__(cls, value: str, *, context: ModelParsingContextInterface):
"""
Instantiate a FormatString from a given string.

Expand All @@ -55,8 +55,9 @@ def __init__(self, value: str):
------
FormatStringError: if the original string is nonvalid.
"""
# Note: str is constructed in __new__, so don't call super __init__
self._processed_list: list[Union[str, ExpressionInfo]] = self._preprocess()
self = super().__new__(cls, value, context=context)
self._processed_list = self._preprocess(context=context)
return self

@property
def original_value(self) -> str:
Expand Down Expand Up @@ -125,7 +126,9 @@ def resolve(self, *, symtab: SymbolTable) -> str:

return "".join(resolved_list)

def _preprocess(self) -> list[Union[str, ExpressionInfo]]:
def _preprocess(
self, *, context: ModelParsingContextInterface
) -> list[Union[str, ExpressionInfo]]:
"""
Scans through the original string to find all interpolation expressions inside of {{ }}.
Also, validates the content of each interpolation expression inside of {{ }}.
Expand Down Expand Up @@ -187,7 +190,9 @@ def _preprocess(self) -> list[Union[str, ExpressionInfo]]:

expression_info = ExpressionInfo(braces_start, braces_end)
try:
expr = InterpolationExpression(self[expression_start:expression_end])
expr = InterpolationExpression(
self[expression_start:expression_end], context=context
)
except (ExpressionError, TokenError) as exc:
raise FormatStringError(
string=self.original_value,
Expand Down
21 changes: 19 additions & 2 deletions src/openjd/model/_format_strings/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,30 @@
from .._tokenstream import Token, TokenStream, TokenType
from ._nodes import FullNameNode, Node
from ._tokens import DotToken, NameToken
from .._types import ModelParsingContextInterface

_tokens: dict[TokenType, Type[Token]] = {TokenType.NAME: NameToken, TokenType.DOT: DotToken}


class Parser:
def parse_format_string_expr(expr: str, *, context: ModelParsingContextInterface) -> Node:
"""Generate an expression tree for the given string interpolation expression.

Args:
expr (str): A string interpolation expression

Raises:
ExpressionError: If the given expression does not adhere to the grammar.
TokenError: If the given expression contains nonvalid or unexpected tokens.

Returns:
Node: Root of the expression tree.
"""
return FormatStringExprParser_v2023_09().parse(expr)


class FormatStringExprParser_v2023_09:
"""
Parser used to build an AST of the currently supported operations.
Parser used to build an AST of format strings for the 2023-09 specification.
"""

def parse(self, expr: str) -> Node:
Expand Down
Loading