Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions vyper/tests/functional/syntax/exceptions/test_syntax_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,28 @@ def foo():
assert str(e.value) == expected_error.strip()


@pytest.mark.parametrize(
"bad_keyword",
[
"satticcall",
"staitccall",
"staticacll",
"staitcall",
],
)
def test_bad_staticcall_keyword_variants(bad_keyword):
bad_code = f"""
from ethereum.ercs import IERC20Detailed

def foo():
staticcall ERC20(msg.sender).transfer(msg.sender, {bad_keyword} IERC20Detailed(msg.sender).decimals())
""" # noqa
with pytest.raises(SyntaxException) as e:
compile_code(bad_code)

assert e.value.hint == "did you mean `staticcall`?"


@pytest.mark.parametrize(
"bad_literal",
[
Expand Down
72 changes: 66 additions & 6 deletions vyper/vyper/ast/parse.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast as python_ast
import copy
import pickle
import re
import tokenize
from decimal import Decimal
from functools import cached_property
Expand All @@ -20,6 +21,70 @@
python_ast.expr_context,
)

_STATICCALL_KEYWORD = "staticcall"


def _adjacent_transpositions(value: str) -> tuple[str, ...]:
ret = []
for i in range(len(value) - 1):
if value[i] == value[i + 1]:
continue

chars = list(value)
chars[i], chars[i + 1] = chars[i + 1], chars[i]
ret.append("".join(chars))

return tuple(ret)


_STATICCALL_LIKELY_ERRORS = ("staticcal", *_adjacent_transpositions(_STATICCALL_KEYWORD))


def _levenshtein_distance(source: str, target: str) -> int:
if source == target:
return 0
if not source:
return len(target)
if not target:
return len(source)

previous = list(range(len(target) + 1))
for i, source_ch in enumerate(source, start=1):
current = [i]
for j, target_ch in enumerate(target, start=1):
substitution_cost = 0 if source_ch == target_ch else 1
current.append(
min(
previous[j] + 1,
current[j - 1] + 1,
previous[j - 1] + substitution_cost,
)
)
previous = current

return previous[-1]


def _maybe_staticcall_hint(vyper_source: str, lineno: Optional[int]) -> Optional[str]:
if lineno is None:
return None

lines = vyper_source.splitlines()
if lineno < 1 or lineno > len(lines):
return None

likely_errors = _STATICCALL_LIKELY_ERRORS
for token in re.findall(r"[A-Za-z_][A-Za-z0-9_]*", lines[lineno - 1]):
if token == _STATICCALL_KEYWORD:
continue
if token in likely_errors:
return "did you mean `staticcall`?"
if abs(len(token) - len(_STATICCALL_KEYWORD)) <= 2:
if _levenshtein_distance(token, _STATICCALL_KEYWORD) <= 2:
return "did you mean `staticcall`?"

return None


def parse_to_ast(
vyper_source: str,
Expand Down Expand Up @@ -91,12 +156,7 @@ def _parse_to_ast(

new_e = SyntaxException(str(e), vyper_source, e.lineno, offset)

likely_errors = ("staticall", "staticcal")
tmp = str(new_e)
for s in likely_errors:
if s in tmp:
new_e._hint = "did you mean `staticcall`?"
break
new_e._hint = _maybe_staticcall_hint(vyper_source, e.lineno)

raise new_e from None

Expand Down