'Typing an overloaded decorator wrapped in partial

I am trying to get the typing of an overloaded decorator right that gets wrapped by partial:

from functools import partial
from typing import Any, Callable, Optional, Union, overload


AnyCallable = Callable[..., Any]


class Wrapped:
    def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
        pass


@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
    ...


@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
    ...


def create_wrapped(
    foo: str,
    func: Optional[AnyCallable] = None,
    *,
    bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
    def wrapper(func_: AnyCallable) -> Wrapped:
        return Wrapped(func_, foo, bar)

    if func is None:
        return wrapper
    return wrapper(func)


baz = partial(create_wrapped, "baz")


@baz
def func_1() -> None:
    pass


@baz(bar=False)
def func_2() -> None:
    pass

The code is correct, but mypy gives

47: error: "Wrapped" not callable

which indicates that the actual argument types are lost when applying partial, since @baz(bar=False) should match the second overload as it's the same as @create_wrapped("baz", bar=False), which does work without an issue.

I'm not sure how else I could annotate this, in fact I couldn't come up with any way to make mypy not complain about this, even if I was fine with not having proper types for the decorator since in that case, I'd get an Untyped decorator makes function untyped error.



Solution 1:[1]

mypy does not currently correctly infer the type of a partially applied function: https://github.com/python/mypy/issues/1484.

You can work around it by casting the return of the partial call to a proper Protocol.

from functools import partial
from typing import Any, Callable, Optional, Protocol, Union, overload, cast


AnyCallable = Callable[..., Any]


class Wrapped:
    def __init__(self, func: AnyCallable, foo: str, bar: bool) -> None:
        pass


@overload
def create_wrapped(foo: str, func: AnyCallable) -> Wrapped:
    ...


@overload
def create_wrapped(foo: str, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
    ...


def create_wrapped(
    foo: str,
    func: Optional[AnyCallable] = None,
    *,
    bar: bool = True,
) -> Union[Wrapped, Callable[[AnyCallable], Wrapped]]:
    def wrapper(func_: AnyCallable) -> Wrapped:
        return Wrapped(func_, foo, bar)

    if func is None:
        return wrapper

    return wrapper(func)


class partial_create_wrapped(Protocol):
    @overload
    def __call__(self, *, bar: bool = ...) -> Callable[[AnyCallable], Wrapped]:
        ...

    @overload
    def __call__(self, func: AnyCallable) -> Wrapped:
        ...


baz = cast(partial_create_wrapped, partial(create_wrapped, "baz"))


@baz
def func_1() -> None:
    pass


@baz(bar=False)
def func_2() -> None:
    pass

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Paweł Rubin