'Typing an overloaded decorator wrapped in partial
I am trying to get the typing of an overloaded decorator right that gets wrapped by partial:
from functools import partial
from typing import Any, Callable, Optional, Union, overload
AnyCallable = Callable[..., Any]
class Wrapped:
def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
pass
@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
...
@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
def create_wrapped(
foo: str,
func: Optional[AnyCallable] = None,
*,
bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
def wrapper(func_: AnyCallable) -> Wrapped:
return Wrapped(func_, foo, bar)
if func is None:
return wrapper
return wrapper(func)
baz = partial(create_wrapped, "baz")
@baz
def func_1() -> None:
pass
@baz(bar=False)
def func_2() -> None:
pass
The code is correct, but mypy gives
47: error: "Wrapped" not callable
which indicates that the actual argument types are lost when applying partial, since @baz(bar=False) should match the second overload as it's the same as @create_wrapped("baz", bar=False), which does work without an issue.
I'm not sure how else I could annotate this, in fact I couldn't come up with any way to make mypy not complain about this, even if I was fine with not having proper types for the decorator since in that case, I'd get an Untyped decorator makes function untyped error.
Solution 1:[1]
mypy does not currently correctly infer the type of a partially applied function: https://github.com/python/mypy/issues/1484.
You can work around it by casting the return of the partial call to a proper Protocol.
from functools import partial
from typing import Any, Callable, Optional, Protocol, Union, overload, cast
AnyCallable = Callable[..., Any]
class Wrapped:
def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
pass
@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
...
@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
def create_wrapped(
foo: str,
func: Optional[AnyCallable] = None,
*,
bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
def wrapper(func_: AnyCallable) -> Wrapped:
return Wrapped(func_, foo, bar)
if func is None:
return wrapper
return wrapper(func)
class partial_create_wrapped(Protocol):
@overload
def __call__(self, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
...
@overload
def __call__(self, func: AnyCallable) -> Wrapped:
...
baz = cast(partial_create_wrapped, partial(create_wrapped, "baz"))
@baz
def func_1() -> None:
pass
@baz(bar=False)
def func_2() -> None:
pass
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Paweł Rubin |
