我正在尝试理解如何正确地使用panda和numpy的类型注解。我有一个DataFrame,其索引为dtype np.uint64
。我想编写一个函数,以如下方式返回此DataFrame的子集:
import numpy as np
import numpy.typing as npt
import pandas as pd
df = pd.DataFrame(dict(x=[1, 2, 3]), index=np.array([10, 20, 30], dtype="uint64"))
assert df.index.dtype == np.uint64
def get_NDArray(df: pd.DataFrame, key: npt.NDArray[np.uint64]):
df2 = df.loc[key]
reveal_type(df2)
return df2
对于注解key: npt.NDArray[np.uint64]
,它不起作用。在 pyright 中,df2
的推断类型是Series[Unkown]
,这是不正确的(应该是DataFrame
)。
error: Invalid index type "ndarray[Any, dtype[unsignedinteger[_64Bit]]]" for "_LocIndexerFrame"; expected type
"Union[slice, ndarray[Any, dtype[signedinteger[_64Bit]]], Index, List[int],
Series[int], Series[bool], ndarray[Any, dtype[bool_]], List[bool],
List[<nothing>], Tuple[Union[slice, ndarray[Any, dtype[signedinteger[_64Bit]]],
Index, List[int], Series[int], Series[bool], ndarray[Any, dtype[bool_]],
List[bool], List[<nothing>], Hashable], Union[List[<nothing>], slice,
Series[bool], Callable[..., Any]]]]"
我可以将key
的注解更改为key: np.ndarray
或key: npt.NDArray
,然后一切都正常工作,但我希望确保key
不是任意的np.ndarray
,而是np.ndarray
与dtype == 'np.uint64'
。我期望npt.NDArray[np.uint64]
正是应该允许这样做的工具,但它不起作用。我做错了什么吗?
1条答案
按热度按时间kokeuurv1#
查看
_LocIndexerFrame
的pandas-stubs
源代码(loc
属性返回的类型),可以看到__getitem__
方法接受(以及其他不相关的选项)IndexType
,这是np_ndarray_int64
并集的类型别名(以及其他选项),而np_ndarray_int64
恰好是numpy.typing.NDArray[np.int64]
的另一个类型别名。这对应于
mypy
的输出,其中提到了巨大的类型联合中的ndarray[Any, dtype[signedinteger[_64Bit]]]
类型。公平地说,因为这些包是如此的臃肿,正确的类型注解会导致到处都是巨大的畸形类型联合。所以挖掘它们或类型检查器输出它们并不那么容易。但那是另一回事了...
无论如何,您要使用的
NDArray[np.uint64]
不是NDArray[np.int64]
的子类型,因为np.uint64
不是np.int64
的子类型(并且ScalarType
是协变的)。DataFrame.loc
对象应使用有符号整数数组作为下标。我对panda了解不够,但我相信负整数作为
loc
的索引不是问题,所以这并不奇怪,现在,无论这是否意味着注解不完整(哦,天哪...),或者传递无符号整数类型实际上会导致一些意外的行为,我都不知道。这只是对您遇到的错误的解释。将
key
的注解更改为npt.NDArray[np.int64]
应该可以修复该错误。