Issue
I'd like to type hint that a pandas dataframe must have a datetimeindex. I was hoping there might be some way to do this with protocols but looks like no. Something in the spirit of this:
class TSFrame(Protocol):
index: pd.DatetimeIndex
def test(df: TSFrame):
# Do stuff with df.index.methods_supported_by_dtidx_only
pass
nontsdf = pd.DataFrame()
tsdf = pd.DataFrame(index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf) # goal is for my interpreter to complain here
test(tsdf) # and not complain here
My interpreter instead complains in both cases. Confusingly, if I create an analogous test on a generic class but where the type hint is int, it complains in neither case.
class IntWanted(Protocol):
var: int
class TestClass:
def __init__(self, var: Any) -> None:
self.var = var
def foo(a: IntWanted) -> int:
return a.var
good = TestClass(1)
bad = TestClass("x")
foo(good)
foo(bad)
Other ways I can think of to treat these timeseries dataframes:
- Subclass dataframe and add validation that the index is a datetimeindex. Convert every df I have into an instance of this class and type hint that class everywhere. Would that solve the problem of mypy knowing its index has the attributes of a dtidx though? I think not. E.g.
class TSFrame(pd.DataFrame):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.index, pd.DatetimeIndex)
- Make an entirely new object that has a df as one attribute, defined an index that is a datetimeindex created from the input df's index so it's known by mypy what type that index has. This feels heavy.
@attr.s(auto_attribs=True)
class TSFrame:
df: pd.DataFrame
def __attrs_post_init__(self):
assert isinstance(self.index, pd.DatetimeIndex)
@property
def index(self) -> pd.DatetimeIndex:
return pd.DatetimeIndex(self.df.index)
Ideas appreciated.
Solution
You can use pandera
(and pandas-stub
) to do pretty much whatever you want.
pip install pandera[mypy]
- Create a
mypy.ini
file with:
[mypy]
plugins = pandera.mypy
demo.py
import pandera as pa
import pandas as pd
import numpy as np
from pandera.typing import Index, DataFrame, Series
class TSFrame(pa.DataFrameModel):
idx: Index[pa.Timestamp] = pa.Field(check_name=False)
@pa.check_types # at runtime
def test(df: DataFrame[TSFrame]): # at compile time
pass
nontsdf = pd.DataFrame()
tsdf = DataFrame[TSFrame](index=pd.DatetimeIndex(pd.date_range("2022-01-01", "2022-01-02")))
test(nontsdf)
test(tsdf)
Usage:
[...]$ mypy demo.py
demo1.py:14: error: Argument 1 to "test" has incompatible type "pandas.core.frame.DataFrame"; expected "pandera.typing.pandas.DataFrame[TSFrame]" [arg-type]
Found 1 error in 1 file (checked 1 source file)
[...]$ python demo.py
...
pandera.errors.SchemaError: error in check_types decorator of function 'test': expected series 'None' to have type datetime64[ns], got int64
More information:
Answered By - Corralien
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.