Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]!: make @external modifier optional in .vyi files #4178

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/functional/codegen/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,3 +695,34 @@ def test_call(a: address, b: {type_str}) -> {type_str}:
make_file("jsonabi.json", json.dumps(convert_v1_abi(abi)))
c3 = get_contract(code, input_bundle=input_bundle)
assert c3.test_call(c1.address, value) == value


def test_interface_function_without_visibility(make_input_bundle, get_contract):
interface_code = """
def foo() -> uint256:
...

@external
def bar() -> uint256:
...
"""

code = """
import a as FooInterface

implements: FooInterface

@external
def foo() -> uint256:
return 1

@external
def bar() -> uint256:
return 1
"""

input_bundle = make_input_bundle({"a.vyi": interface_code})

c = get_contract(code, input_bundle=input_bundle)

assert c.foo() == c.bar() == 1
78 changes: 78 additions & 0 deletions tests/functional/syntax/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,3 +484,81 @@ def baz():
"""

assert compiler.compile_code(code, input_bundle=input_bundle) is not None


invalid_visibility_code = [
"""
import foo as Foo
implements: Foo
@external
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
@internal
def foobar():
pass
""",
"""
import foo as Foo
implements: Foo
def foobar():
pass
""",
]


@pytest.mark.parametrize("code", invalid_visibility_code)
def test_internal_visibility_in_interface(make_input_bundle, code):
interface_code = """
@internal
def foobar():
...
"""

input_bundle = make_input_bundle({"foo.vyi": interface_code})

with pytest.raises(FunctionDeclarationException) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value._message == "Interface functions can only be marked as `@external`"


external_visibility_interface = [
"""
@external
def foobar():
...
def bar():
...
""",
"""
def foobar():
...
@external
def bar():
...
""",
]


@pytest.mark.parametrize("iface", external_visibility_interface)
def test_internal_implemenatation_of_external_interface(make_input_bundle, iface):
input_bundle = make_input_bundle({"foo.vyi": iface})

code = """
import foo as Foo
implements: Foo
@internal
def foobar():
pass
def bar():
pass
"""

with pytest.raises(InterfaceViolation) as e:
compiler.compile_code(code, input_bundle=input_bundle)

assert e.value.message == "Contract does not implement all interface functions: bar(), foobar()"
8 changes: 4 additions & 4 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self._imported_modules: dict[PurePath, vy_ast.VyperNode] = {}

# keep track of exported functions to prevent duplicate exports
self._exposed_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}
self._all_functions: dict[ContractFunctionT, vy_ast.VyperNode] = {}

self._events: list[EventT] = []

Expand Down Expand Up @@ -414,7 +414,7 @@ def visit_ImplementsDecl(self, node):
raise StructureException(msg, node.annotation, hint=hint)

# grab exposed functions
funcs = self._exposed_functions
funcs = {fn_t: node for fn_t, node in self._all_functions.items() if fn_t.is_external}
type_.validate_implements(node, funcs)

node._metadata["interface_type"] = type_
Expand Down Expand Up @@ -608,10 +608,10 @@ def _self_t(self):
def _add_exposed_function(self, func_t, node, relax=True):
# call this before self._self_t.typ.add_member() for exception raising
# priority
if not relax and (prev_decl := self._exposed_functions.get(func_t)) is not None:
if not relax and (prev_decl := self._all_functions.get(func_t)) is not None:
raise StructureException("already exported!", node, prev_decl=prev_decl)

self._exposed_functions[func_t] = node
self._all_functions[func_t] = node

def visit_VariableDecl(self, node):
# postcondition of VariableDecl.validate
Expand Down
38 changes: 29 additions & 9 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,23 @@ def from_vyi(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

if nonreentrant:
raise FunctionDeclarationException("`@nonreentrant` not allowed in interfaces", funcdef)
# TODO: refactor so parse_decorators returns the AST location
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
cyberthirst marked this conversation as resolved.
Show resolved Hide resolved
raise FunctionDeclarationException(
"`@nonreentrant` not allowed in interfaces", decorator
)

# it's redundant to specify visibility in vyi - always should be external
if function_visibility is None:
function_visibility = FunctionVisibility.EXTERNAL

if function_visibility != FunctionVisibility.EXTERNAL:
nonexternal = next(
d for d in funcdef.decorator_list if d.id in FunctionVisibility.values()
)
raise FunctionDeclarationException(
"Interface functions can only be marked as `@external`", nonexternal
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error message can be confusing since it could imply that payable is disallowed (which is not). @cyberthirst how about:

Interface functions' visibility can only be marked as `@external`

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, we've improved the reporting a bit, so the message refers to the corresponding source construct

@internal
@payable
def bar():
    ...

would yield:

vyper.exceptions.FunctionDeclarationException: Interface functions can only be marked as `@external`

  contract "tests/custom/i.vyi:1", function "bar", line 1:1 
  ---> 1 @internal
  --------^
       2 @payable

would you say that it's still confusing?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think error messages should be as precise as possible and thus I still think we should mention somehow that it's a visibility decorator.

)

if funcdef.name == "__init__":
raise FunctionDeclarationException("Constructors cannot appear in interfaces", funcdef)
Expand Down Expand Up @@ -381,6 +397,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
"""
function_visibility, state_mutability, nonreentrant = _parse_decorators(funcdef)

# it's redundant to specify internal visibility - it's implied by not being external
if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

positional_args, keyword_args = _parse_args(funcdef)

return_type = _parse_return_type(funcdef)
Expand Down Expand Up @@ -419,6 +439,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
raise FunctionDeclarationException(
"Constructor may not use default arguments", funcdef.args.defaults[0]
)
if nonreentrant:
decorator = next(d for d in funcdef.decorator_list if d.id == "nonreentrant")
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

return cls(
funcdef.name,
Expand Down Expand Up @@ -495,6 +519,8 @@ def implements(self, other: "ContractFunctionT") -> bool:
if not self.is_external: # pragma: nocover
raise CompilerPanic("unreachable!")

assert self.visibility == other.visibility

arguments, return_type = self._iface_sig
other_arguments, other_return_type = other._iface_sig

Expand Down Expand Up @@ -700,7 +726,7 @@ def _parse_return_type(funcdef: vy_ast.FunctionDef) -> Optional[VyperType]:

def _parse_decorators(
funcdef: vy_ast.FunctionDef,
) -> tuple[FunctionVisibility, StateMutability, bool]:
) -> tuple[Optional[FunctionVisibility], StateMutability, bool]:
function_visibility = None
state_mutability = None
nonreentrant_node = None
Expand All @@ -719,10 +745,6 @@ def _parse_decorators(
if nonreentrant_node is not None:
raise StructureException("nonreentrant decorator is already set", nonreentrant_node)

if funcdef.name == "__init__":
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorator)

nonreentrant_node = decorator

elif isinstance(decorator, vy_ast.Name):
Expand All @@ -733,6 +755,7 @@ def _parse_decorators(
decorator,
hint="only one visibility decorator is allowed per function",
)

function_visibility = FunctionVisibility(decorator.id)

elif StateMutability.is_valid_value(decorator.id):
Expand All @@ -755,9 +778,6 @@ def _parse_decorators(
else:
raise StructureException("Bad decorator syntax", decorator)

if function_visibility is None:
function_visibility = FunctionVisibility.INTERNAL

if state_mutability is None:
# default to nonpayable
state_mutability = StateMutability.NONPAYABLE
Expand Down
3 changes: 3 additions & 0 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def _ctor_modifiability_for_call(self, node: vy_ast.Call, modifiability: Modifia
def validate_implements(
self, node: vy_ast.ImplementsDecl, functions: dict[ContractFunctionT, vy_ast.VyperNode]
) -> None:
# only external functions can implement interfaces
fns_by_name = {fn_t.name: fn_t for fn_t in functions.keys()}

unimplemented = []
Expand All @@ -116,7 +117,9 @@ def _is_function_implemented(fn_name, fn_type):
return False

to_compare = fns_by_name[fn_name]
assert to_compare.is_external
assert isinstance(to_compare, ContractFunctionT)
assert isinstance(fn_type, ContractFunctionT)

return to_compare.implements(fn_type)

Expand Down
Loading