python - Overload a method based on init variables - Stack Overflow

admin2025-04-17  3

How can I overload the get_data method below to return the correct type based on the init value of data_type instead of returning a union of both types?

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> WoodData | ConcreteData:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

I was thinking this could be done by specifying a generic for Foo. But I'm unsure on implementation details.

I'd prefer not to pass WoodData/ConcreteData directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood or concrete.

To illustrate that last point, I know I could add a generic that takes one of the two return types like so:

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> MY_RETURN_TYPE:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

But imagine I have tons of methods conditionally returning different types based on the value of data_type. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.

Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.

How can I overload the get_data method below to return the correct type based on the init value of data_type instead of returning a union of both types?

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> WoodData | ConcreteData:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

I was thinking this could be done by specifying a generic for Foo. But I'm unsure on implementation details.

I'd prefer not to pass WoodData/ConcreteData directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood or concrete.

To illustrate that last point, I know I could add a generic that takes one of the two return types like so:

from typing import Literal

DATA_TYPE = Literal["wood", "concrete"]

class WoodData: ...
class ConcreteData: ...

class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
    def __init__(self, data_type: DATA_TYPE) -> None:
        self.data_type = data_type

    def get_data(self) -> MY_RETURN_TYPE:
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()

But imagine I have tons of methods conditionally returning different types based on the value of data_type. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.

Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.

Share Improve this question edited Jan 30 at 21:08 Matt asked Jan 30 at 20:39 MattMatt 1,7255 gold badges31 silver badges63 bronze badges 16
  • Can you elaborate on your last paragraph? I'm not sure I follow – juanpa.arrivillaga Commented Jan 30 at 20:53
  • @juanpa.arrivillaga I'll make an edit – Matt Commented Jan 30 at 20:54
  • Are you looking to do Either = WoodData | ConcreteData? – JonSG Commented Jan 30 at 20:55
  • 1 @Matt that sounds like a way to go, why not like this? What doesn't work? – STerliakov Commented Jan 30 at 21:34
  • 1 Ah, that's fine then (but usually from __future__ import annotations is better than stringified annotations) – STerliakov Commented Jan 30 at 21:57
 |  Show 11 more comments

1 Answer 1

Reset to default 3

Ok, for this solution, you annotate self with the generic type you want, both mypy and pyright give similar outputs for reveal_type (i.e., it works with the base class but not the subclass):

from typing import Literal, overload, TypeVar


class WoodData: ...
class ConcreteData: ...

class Foo[T:(Literal['wood'], Literal['concrete'])]:
    data_type: T
    def __init__(self, data_type: T) -> None:
        self.data_type = data_type
    @overload
    def get_data(self: "Foo[Literal['wood']]") -> WoodData:
        ...
    @overload
    def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData:
        ...
    @overload
    def get_data(self) -> WoodData | ConcreteData:
        ...
    def get_data(self):
        if self.data_type == "wood":
            return WoodData()
        return ConcreteData()
    @overload
    def bar(self: "Foo[Literal['wood']]") -> int:
        ...
    @overload
    def bar(self: "Foo[Literal['concrete']]") -> str:
        ...
    @overload
    def bar(self) -> int | str:
        ...
    def bar(self):
        if self.data_type == "wood":
            return 42
        return "42"

reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData"
reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData"
reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int"
reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str"

class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]):
    pass
# works with inheritance too
reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData"
reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData"
reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int"
reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"

However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...

转载请注明原文地址:http://anycun.com/QandA/1744891326a89090.html