diff --git a/pandas-stubs/io/parsers/readers.pyi b/pandas-stubs/io/parsers/readers.pyi index 691b56fea..3dbabb652 100644 --- a/pandas-stubs/io/parsers/readers.pyi +++ b/pandas-stubs/io/parsers/readers.pyi @@ -27,6 +27,7 @@ from pandas._typing import ( DtypeArg, DtypeBackend, FilePath, + HashableT, ListLikeHashable, ReadCsvBuffer, StorageOptions, @@ -44,7 +45,7 @@ def read_csv( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( @@ -108,7 +109,7 @@ def read_csv( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( @@ -172,7 +173,7 @@ def read_csv( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( @@ -236,7 +237,7 @@ def read_table( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( @@ -300,7 +301,7 @@ def read_table( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( @@ -364,7 +365,7 @@ def read_table( header: int | Sequence[int] | Literal["infer"] | None = ..., names: ListLikeHashable | None = ..., index_col: int | str | Sequence[str | int] | Literal[False] | None = ..., - usecols: UsecolsArgType = ..., + usecols: UsecolsArgType[HashableT] = ..., dtype: DtypeArg | defaultdict | None = ..., engine: CSVEngine | None = ..., converters: ( diff --git a/tests/test_io.py b/tests/test_io.py index 32df677ba..6bfdd5734 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -575,6 +575,11 @@ def test_read_csv(): DataFrame, ) + def cols(x: str) -> bool: + return x in ["a", "b"] + + pd.read_csv(path, usecols=cols) + def test_read_csv_iterator(): with ensure_clean() as path: @@ -727,6 +732,11 @@ def test_types_read_csv() -> None: pd.read_csv(path, names="abcd") # type: ignore[call-overload] # pyright: ignore[reportArgumentType] pd.read_csv(path, usecols="abcd") # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + def cols2(x: set[float]) -> bool: + return sum(x) < 1.0 + + pd.read_csv("file.csv", usecols=cols2) # type: ignore[type-var] # pyright: ignore[reportArgumentType] + tfr1 = pd.read_csv(path, nrows=2, iterator=True, chunksize=3) check(assert_type(tfr1, TextFileReader), TextFileReader) tfr1.close()