How can I overload the get_data
method below to return the correct type based on the init value of data_type
instead of returning a union of both types?
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> WoodData | ConcreteData:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
I was thinking this could be done by specifying a generic for Foo
. But I'm unsure on implementation details.
I'd prefer not to pass WoodData
/ConcreteData
directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood
or concrete
.
To illustrate that last point, I know I could add a generic that takes one of the two return types like so:
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> MY_RETURN_TYPE:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
But imagine I have tons of methods conditionally returning different types based on the value of data_type
. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.
Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.
How can I overload the get_data
method below to return the correct type based on the init value of data_type
instead of returning a union of both types?
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> WoodData | ConcreteData:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
I was thinking this could be done by specifying a generic for Foo
. But I'm unsure on implementation details.
I'd prefer not to pass WoodData
/ConcreteData
directly as a generic. This is because I have many methods returning conditional data types depending on whether the init var is wood
or concrete
.
To illustrate that last point, I know I could add a generic that takes one of the two return types like so:
from typing import Literal
DATA_TYPE = Literal["wood", "concrete"]
class WoodData: ...
class ConcreteData: ...
class Foo[MY_RETURN_TYPE: WoodData | ConcreteData]:
def __init__(self, data_type: DATA_TYPE) -> None:
self.data_type = data_type
def get_data(self) -> MY_RETURN_TYPE:
if self.data_type == "wood":
return WoodData()
return ConcreteData()
But imagine I have tons of methods conditionally returning different types based on the value of data_type
. I don't want to specify each of these as generics. I'd rather overload the methods on the class and have return types accurately inferred.
Lastly, I know I could split this into two separate sub classes, but it would be nice to keep them as one class if possible.
Ok, for this solution, you annotate self
with the generic type you want, both mypy
and pyright
give similar outputs for reveal_type
(i.e., it works with the base class but not the subclass):
from typing import Literal, overload, TypeVar
class WoodData: ...
class ConcreteData: ...
class Foo[T:(Literal['wood'], Literal['concrete'])]:
data_type: T
def __init__(self, data_type: T) -> None:
self.data_type = data_type
@overload
def get_data(self: "Foo[Literal['wood']]") -> WoodData:
...
@overload
def get_data(self: "Foo[Literal['concrete']]") -> ConcreteData:
...
@overload
def get_data(self) -> WoodData | ConcreteData:
...
def get_data(self):
if self.data_type == "wood":
return WoodData()
return ConcreteData()
@overload
def bar(self: "Foo[Literal['wood']]") -> int:
...
@overload
def bar(self: "Foo[Literal['concrete']]") -> str:
...
@overload
def bar(self) -> int | str:
...
def bar(self):
if self.data_type == "wood":
return 42
return "42"
reveal_type(Foo('wood').get_data()) # main.py:32: note: Revealed type is "__main__.WoodData"
reveal_type(Foo('concrete').get_data()) # main.py:33: note: Revealed type is "__main__.ConcreteData"
reveal_type(Foo('wood').bar()) # main.py:34: note: Revealed type is "builtins.int"
reveal_type(Foo('concrete').bar()) # main.py:35: note: Revealed type is "builtins.str"
class Bar[T:(Literal['wood'], Literal['concrete'])](Foo[T]):
pass
# works with inheritance too
reveal_type(Bar('wood').get_data()) # main.py:41: note: Revealed type is "__main__.WoodData"
reveal_type(Bar('concrete').get_data()) # main.py:41: note: Revealed type is "__main__.ConcreteData"
reveal_type(Bar('wood').bar()) # main.py:41: note: Revealed type is "builtins.int"
reveal_type(Bar('concrete').bar()) # main.py:41: note: Revealed type is "builtins.str"
However, mypy won't type check the body of the implementation, and pyright seems to be reporting erroneous errors for the body...
Either = WoodData | ConcreteData
? – JonSG Commented Jan 30 at 20:55from __future__ import annotations
is better than stringified annotations) – STerliakov Commented Jan 30 at 21:57