Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 64 additions & 53 deletions vyper/vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tokenize
from decimal import Decimal
from functools import cached_property
from typing import Optional
from typing import Optional, TypeVar, cast

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import PreParser
Expand All @@ -20,6 +20,9 @@
python_ast.expr_context,
)

_AstT = TypeVar("_AstT", bound=python_ast.AST)
_DocstringNodeT = TypeVar("_DocstringNodeT", python_ast.Module, python_ast.FunctionDef)


def parse_to_ast(
vyper_source: str,
Expand Down Expand Up @@ -73,13 +76,13 @@ def _parse_to_ast(
"""
if "\x00" in vyper_source:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
pre_parser = PreParser(is_interface)
pre_parser: PreParser = PreParser(is_interface)
pre_parser.parse(vyper_source)

try:
py_ast = python_ast.parse(pre_parser.reformatted_code)
py_ast: python_ast.Module = python_ast.parse(pre_parser.reformatted_code)
except SyntaxError as e:
offset = e.offset
offset: int | None = e.offset
if offset is not None:
# SyntaxError offset is 1-based, not 0-based (see:
# https://docs.python.org/3/library/exceptions.html#SyntaxError.offset)
Expand All @@ -89,10 +92,10 @@ def _parse_to_ast(
if e.lineno is not None: # help mypy
offset += pre_parser.adjustments.get((e.lineno, offset), 0)

new_e = SyntaxException(str(e), vyper_source, e.lineno, offset)
new_e: SyntaxException = SyntaxException(str(e), vyper_source, e.lineno, offset)

likely_errors = ("staticall", "staticcal")
tmp = str(new_e)
likely_errors: tuple[str, ...] = ("staticall", "staticcal")
tmp: str = str(new_e)
for s in likely_errors:
if s in tmp:
new_e._hint = "did you mean `staticcall`?"
Expand All @@ -117,7 +120,7 @@ def _parse_to_ast(
assert len(pre_parser.hex_string_locations) == 0

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)
module: vy_ast.VyperNode = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint
module.is_interface = is_interface

Expand Down Expand Up @@ -153,21 +156,24 @@ def annotate_python_ast(
-------
The annotated and optimized AST.
"""
visitor = AnnotatingVisitor(
visitor: AnnotatingVisitor = AnnotatingVisitor(
vyper_source, pre_parser, source_id, module_path=module_path, resolved_path=resolved_path
)
visitor.visit(parsed_ast)

return parsed_ast


def _deepcopy_ast(ast_node: python_ast.AST):
def _deepcopy_ast(ast_node: _AstT) -> _AstT:
# pickle roundtrip is faster than copy.deepcopy() here.
return pickle.loads(pickle.dumps(ast_node))


class AnnotatingVisitor(python_ast.NodeTransformer):
_source_code: str
_source_id: int
_module_path: Optional[str]
_resolved_path: Optional[str]
_pre_parser: PreParser
_parents: list[python_ast.AST]

Expand All @@ -178,7 +184,7 @@ def __init__(
source_id: int,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
):
) -> None:
self._source_id = source_id
self._module_path = module_path
self._resolved_path = resolved_path
Expand All @@ -189,20 +195,20 @@ def __init__(
self.counter: int = 0

@cached_property
def source_lines(self):
def source_lines(self) -> list[str]:
return self._source_code.splitlines(keepends=True)

@cached_property
def line_offsets(self):
ofst = 0
def line_offsets(self) -> dict[int, int]:
ofst: int = 0
# ensure line_offsets has at least 1 entry for 0-line source
ret = {1: ofst}
ret: dict[int, int] = {1: ofst}
for lineno, line in enumerate(self.source_lines):
ret[lineno + 1] = ofst
ofst += len(line)
return ret

def generic_visit(self, node):
def generic_visit(self, node: _AstT) -> _AstT:
"""
Adds location info to all python ast nodes and replaces python ast nodes
that are singletons with a copy so that the location info will be unique,
Expand All @@ -222,11 +228,11 @@ def generic_visit(self, node):
# https://github.com/python/cpython/blob/62729d79206014886f5d/Lib/ast.py#L228
for field in LINE_INFO_FIELDS:
if len(self._parents) > 0:
parent = self._parents[-1]
val = getattr(node, field, None)
parent: python_ast.AST = self._parents[-1]
val: int | None = cast(int | None, getattr(node, field, None))
if val is None:
# try to get the field from the parent
val = getattr(parent, field)
val = cast(int, getattr(parent, field))
setattr(node, field, val)
else:
assert hasattr(node, field), node
Expand All @@ -238,37 +244,37 @@ def generic_visit(self, node):
self.counter += 1
node.ast_type = node.__class__.__name__

adjustments = self._pre_parser.adjustments
adjustments: dict[tuple[int, int], int] = self._pre_parser.adjustments

adj = adjustments.get((node.lineno, node.col_offset), 0)
adj: int = adjustments.get((node.lineno, node.col_offset), 0)
node.col_offset += adj

adj = adjustments.get((node.end_lineno, node.end_col_offset), 0)
node.end_col_offset += adj

start_pos = self.line_offsets[node.lineno] + node.col_offset
end_pos = self.line_offsets[node.end_lineno] + node.end_col_offset
start_pos: int = self.line_offsets[node.lineno] + node.col_offset
end_pos: int = self.line_offsets[node.end_lineno] + node.end_col_offset

node.src = f"{start_pos}:{end_pos-start_pos}:{self._source_id}"
node.node_source_code = self._source_code[start_pos:end_pos]

# keep track of the current path thru the AST
self._parents.append(node)
try:
node = super().generic_visit(node)
node = cast(_AstT, super().generic_visit(node))
finally:
self._parents.pop()

return node

def _visit_docstring(self, node):
def _visit_docstring(self, node: _DocstringNodeT) -> _DocstringNodeT:
"""
Move a node docstring from body to `doc_string` and annotate it as `DocStr`.
"""
self.generic_visit(node)

if node.body:
n = node.body[0]
n: python_ast.stmt = node.body[0]
if (
isinstance(n, python_ast.Expr)
and isinstance(n.value, python_ast.Constant)
Expand All @@ -281,7 +287,7 @@ def _visit_docstring(self, node):

return node

def visit_Module(self, node):
def visit_Module(self, node: python_ast.Module) -> python_ast.Module:
node.lineno = 1
node.col_offset = 0
node.end_lineno = max(1, len(self.source_lines))
Expand All @@ -299,10 +305,10 @@ def visit_Module(self, node):
node.source_id = self._source_id
return self._visit_docstring(node)

def visit_FunctionDef(self, node):
def visit_FunctionDef(self, node: python_ast.FunctionDef) -> python_ast.FunctionDef:
return self._visit_docstring(node)

def visit_ClassDef(self, node):
def visit_ClassDef(self, node: python_ast.ClassDef) -> python_ast.ClassDef:
"""
Convert the `ClassDef` node into a Vyper-specific node type.

Expand All @@ -315,13 +321,13 @@ def visit_ClassDef(self, node):
node.ast_type = self._pre_parser.keyword_translations[(node.lineno, node.col_offset)]
return node

def visit_For(self, node):
def visit_For(self, node: python_ast.For) -> python_ast.For:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new return annotation -> python_ast.For is inconsistent with the function's actual return: return self.generic_visit(node) on the last line returns python_ast.AST per the new generic_visit annotation on line 211, so mypy will flag this as an incompatible return type. Consider parameterizing generic_visit with a TypeVar (e.g. _AstT) so subclass methods can return their narrow types, or cast(...) the result here.

Severity: low


🤖 Was this useful? React with 👍 or 👎

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR Author Agent

Fixed in c6d510a by parameterizing generic_visit with _AstT and casting the super().generic_visit(...) result back to that type. Validation: python -m py_compile vyper/vyper/ast/parse.py; structural annotation check passed. (python -m mypy ... is unavailable here because mypy is not installed.)

"""
Visit a For node, splicing in the loop variable annotation provided by
the pre-parser
"""
key = (node.lineno, node.col_offset)
annotation_tokens = self._pre_parser.for_loop_annotations.pop(key)
key: tuple[int, int] = (node.lineno, node.col_offset)
annotation_tokens: list[tokenize.TokenInfo] = self._pre_parser.for_loop_annotations.pop(key)

if not annotation_tokens:
# a common case for people migrating to 0.4.0, provide a more
Expand Down Expand Up @@ -350,19 +356,21 @@ def visit_For(self, node):
# in a bit, but for now lets us keep the line/col offset, and
# *also* gives us a valid AST. it doesn't matter what the dummy
# target name is, since it gets removed in a few lines.
annotation_str = tokenize.untokenize(annotation_tokens)
annotation_str: str = tokenize.untokenize(annotation_tokens)
annotation_str = "dummy_target:" + annotation_str

try:
fake_node = python_ast.parse(annotation_str).body[0]
fake_stmt: python_ast.stmt = python_ast.parse(annotation_str).body[0]
assert isinstance(fake_stmt, python_ast.AnnAssign)
# do we need to fix location info here?
fake_node = _deepcopy_ast(fake_node)
fake_node: python_ast.AnnAssign = _deepcopy_ast(fake_stmt)
except SyntaxError as e:
raise SyntaxException(
"invalid type annotation", self._source_code, node.lineno, node.col_offset
) from e
# block things like `for x: uint256 = 5 in ...`
if (value_node := fake_node.value) is not None:
value_node: python_ast.expr | None = fake_node.value
if value_node is not None:
raise SyntaxException(
"invalid type annotation",
self._source_code,
Expand All @@ -377,7 +385,7 @@ def visit_For(self, node):

return self.generic_visit(node)

def visit_Expr(self, node):
def visit_Expr(self, node: python_ast.Expr) -> python_ast.expr | python_ast.Expr:
"""
Convert the `Yield` node into a Vyper-specific node type.

Expand All @@ -393,34 +401,35 @@ def visit_Expr(self, node):
if isinstance(node.value, python_ast.Yield):
# CMC 2024-03-03 consider unremoving this from the enclosing Expr
node = node.value
key = (node.lineno, node.col_offset)
key: tuple[int, int] = (node.lineno, node.col_offset)
node.ast_type = self._pre_parser.keyword_translations[key]

return node

def visit_Await(self, node):
start_pos = node.lineno, node.col_offset
def visit_Await(self, node: python_ast.Await) -> python_ast.Await:
start_pos: tuple[int, int] = (node.lineno, node.col_offset)
self.generic_visit(node)
node.ast_type = self._pre_parser.keyword_translations[start_pos]
return node

def visit_Call(self, node):
def visit_Call(self, node: python_ast.Call) -> python_ast.Call:
# Convert structs declared as `Dict` node for vyper < 0.4.0 to kwargs
if len(node.args) == 1 and isinstance(node.args[0], python_ast.Dict):
msg = "Instantiating a struct using a dictionary is deprecated "
msg: str = "Instantiating a struct using a dictionary is deprecated "
msg += "as of v0.4.0 and will be disallowed in a future release. "
msg += "Use kwargs instead e.g. Foo(a=1, b=2)"

# add full_source_code so that str(VyperException(msg, node)) works
node.full_source_code = self._source_code
vyper_warn(Deprecation(msg, node))

dict_ = node.args[0]
kw_list = []
dict_: python_ast.Dict = node.args[0]
kw_list: list[python_ast.keyword] = []

assert len(dict_.keys) == len(dict_.values)
for key, value in zip(dict_.keys, dict_.values):
replacement_kw_node = python_ast.keyword(key.id, value)
assert isinstance(key, python_ast.Name)
replacement_kw_node: python_ast.keyword = python_ast.keyword(key.id, value)
# set locations
for attr in LINE_INFO_FIELDS:
setattr(replacement_kw_node, attr, getattr(key, attr))
Expand All @@ -433,7 +442,7 @@ def visit_Call(self, node):

return node

def visit_Constant(self, node):
def visit_Constant(self, node: python_ast.Constant) -> python_ast.Constant:
"""
Handle `Constant` when using Python >=3.8

Expand All @@ -448,7 +457,7 @@ def visit_Constant(self, node):
if node.value is None or isinstance(node.value, bool):
node.ast_type = "NameConstant"
elif isinstance(node.value, str):
key = (node.lineno, node.col_offset)
key: tuple[int, int] = (node.lineno, node.col_offset)
if key in self._pre_parser.hex_string_locations:
if len(node.value) % 2 != 0:
raise SyntaxException(
Expand All @@ -475,7 +484,7 @@ def visit_Constant(self, node):

return node

def visit_Num(self, node):
def visit_Num(self, node: python_ast.Constant) -> python_ast.Constant:
"""
Adjust numeric node class based on the value type.

Expand All @@ -486,7 +495,7 @@ def visit_Num(self, node):
"""
# modify vyper AST type according to the format of the literal value
self.generic_visit(node)
value = node.node_source_code
value: str = node.node_source_code

# ignore underscores in numeric literals (PEP 515)
value = value.replace("_", "")
Expand All @@ -505,7 +514,7 @@ def visit_Num(self, node):

elif value.lower()[:2] == "0b":
node.ast_type = "Bytes"
mod = (len(value) - 2) % 8
mod: int = (len(value) - 2) % 8
if mod:
raise SyntaxException(
f"Bit notation requires a multiple of 8 bits. {8-mod} bit(s) are missing.",
Expand All @@ -527,16 +536,18 @@ def visit_Num(self, node):

return node

def visit_UnaryOp(self, node):
def visit_UnaryOp(self, node: python_ast.UnaryOp) -> python_ast.expr:
"""
Adjust operand value and discard unary operations, where possible.

This is done so that negative decimal literals are accurately represented.
"""
self.generic_visit(node)

is_sub = isinstance(node.op, python_ast.USub)
is_num = hasattr(node.operand, "value") and isinstance(node.operand.value, (int, Decimal))
is_sub: bool = isinstance(node.op, python_ast.USub)
is_num: bool = hasattr(node.operand, "value") and isinstance(
node.operand.value, (int, Decimal)
)
if is_sub and is_num:
node.operand.value = 0 - node.operand.value
node.operand.col_offset = node.col_offset
Expand Down