Issue
I have the following simple function:
def f1(y_true, y_pred):
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
According to the scikit-learn documentation, the arguments to f1_score
can have the following types:
y_true
: 1d array-like, or label indicator array / sparse matrixy_pred
: 1d array-like, or label indicator array / sparse matrix
and the output is of type:
- float or array of float, shape = [n_unique_labels]
How do I add type hints to this function so that mypy doesn't complain?
I tried variations of the following:
Array1D = NewType('Array1D', Union[np.ndarray, List[np.float64]])
def f1(y_true: Union[List[float], Array1D], y_pred: Union[List[float], Array1D]) -> Dict[str, Union[List[float], Array1D]]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
but that gave errors.
Solution
This is the approach I use to avoid similar mypy issues. It takes advantage of numpy typing introduced in 1.20. The ArrayLike
type covers List[float]
, so no need to worry about covering it explicitly.
Running mypy v0.971 with numpy v1.23.1 on this shows no issues.
from typing import List, Dict
import numpy as np
import numpy.typing as npt
import sklearn.metrics
def f1(y_true: npt.ArrayLike, y_pred: npt.ArrayLike) -> Dict[str, npt.ArrayLike]:
return {"f1": 100 * sklearn.metrics.f1_score(y_true, y_pred)}
y_true_list: List[float] = [1, 0, 1, 0]
y_pred_list: List[float] = [1, 0, 1, 1]
y_true_np: npt.ArrayLike = np.array(y_true_list)
y_pred_np: npt.ArrayLike = np.array(y_pred_list)
assert f1(y_true_list, y_pred_list) == f1(y_true_np, y_pred_np)
Answered By - Bryan Dannowitz
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.