Issue
I have this function in pandas:
def rolling_pd(
dataf: pd.DataFrame,
groupby_cols: Union[str, list],
column: str,
function: str = "mean",
rolling_periods: int = 1,
shift_periods: int = 1,
*args,
**kwargs,
) -> pd.Series:
return dataf.groupby(groupby_cols)[column].transform(
lambda d: (
d.shift(shift_periods)
.rolling(rolling_periods, min_periods=1)
.agg(function, *args, **kwargs)
)
)
I want to do the same thing with polars, haven't maanged since I don't see a rolling method, can you help me do this translation?
Solution
Let us consider the following data.
import polars as pl
df = pl.DataFrame({
"group": ["a"] * 4 + ["b"] * 4,
"value": [1, 2, 3, 4, 5, 6, 7, 8]
})
>>> df
shape: (8, 2)
┌───────┬───────┐
│ group ┆ value │
│ --- ┆ --- │
│ str ┆ i64 │
╞═══════╪═══════╡
│ a ┆ 1 │
│ a ┆ 2 │
│ a ┆ 3 │
│ a ┆ 4 │
│ b ┆ 5 │
│ b ┆ 6 │
│ b ┆ 7 │
│ b ┆ 8 │
└───────┴───────┘
For a fixed aggregation function, one could use the corresponding pl.Expr.rolling_*
function. As an example (with rolling_periods = 2
and shift_periods = 1
), consider pl.Expr.rolling_mean
:
(
df_pl
.with_columns(
pl.col("value").shift(1).rolling_mean(window_size=2, min_periods=1).over("group")
)
)
Output.
shape: (8, 2)
┌───────┬───────┐
│ group ┆ value │
│ --- ┆ --- │
│ str ┆ f64 │
╞═══════╪═══════╡
│ a ┆ null │
│ a ┆ 1.0 │
│ a ┆ 1.5 │
│ a ┆ 2.5 │
│ b ┆ null │
│ b ┆ 5.0 │
│ b ┆ 5.5 │
│ b ┆ 6.5 │
└───────┴───────┘
However, this gets more complex when we want to apply an arbitrary aggregation function based on the corresponding name as in the pandas example provided in the question.
In this case, we can use pl.Expr.rolling
. Unfortunately, rolling
doesn't infer an index column so we have to create it manually.
Moreover, we can rely on Python's built-in getattr
to obtain the aggregation function in the pl.Expr
namespace by name.
def rolling_pl(
data: pl.DataFrame,
groupby_cols: Union[str, list],
column: str,
function: str = "mean",
rolling_periods: int = 1,
shift_periods: int = 1,
*args,
**kwargs,
) -> pd.Series:
return (
data
.with_columns(pl.int_range(0,pl.count()).cast(pl.Int64).alias('index'))
.rolling(
index_column="index",
period=f"{rolling_periods}i",
closed="both",
by=groupby_cols
)
.agg(
getattr(pl.col(column).shift(shift_periods), function)(*args, **kwargs)
)
.drop("index")
)
Answered By - Hericks
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.