Skip to content

Fix annotations of str methods that accept regular expressions #1278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions pandas-stubs/core/strings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,17 @@ class StringMethods(
) -> _T_STR: ...
@overload
def split(
self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ...
self,
pat: str | re.Pattern[str] = ...,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*,
n: int = ...,
expand: Literal[True],
regex: bool = ...,
) -> _T_EXPANDING: ...
@overload
def split(
self,
pat: str = ...,
pat: str | re.Pattern[str] = ...,
*,
n: int = ...,
expand: Literal[False] = ...,
Expand Down Expand Up @@ -133,11 +138,15 @@ class StringMethods(
regex: bool = ...,
) -> _T_BOOL: ...
def match(
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
self,
pat: str | re.Pattern[str],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only one of the changed methods that is not documented as accepting a compiled regex but this is an oversight https://pandas.pydata.org/docs/reference/api/pandas.Series.str.match.html. See pandas-dev/pandas#61879

case: bool = ...,
flags: int = ...,
na: Any = ...,
) -> _T_BOOL: ...
def replace(
self,
pat: str,
pat: str | re.Pattern[str],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repl: str | Callable[[re.Match[str]], str],
n: int = ...,
case: bool | None = ...,
Expand Down Expand Up @@ -180,18 +189,26 @@ class StringMethods(
def count(self, pat: str, flags: int = ...) -> _T_INT: ...
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ...
def findall(self, pat: str | re.Pattern[str], flags: int = ...) -> _T_LIST_STR: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@overload
def extract(
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
self,
pat: str | re.Pattern[str],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flags: int = ...,
*,
expand: Literal[True] = ...,
) -> pd.DataFrame: ...
@overload
def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ...
def extract(
self, pat: str | re.Pattern[str], flags: int, expand: Literal[False]
) -> _T_OBJECT: ...
@overload
def extract(
self, pat: str, flags: int = ..., *, expand: Literal[False]
self, pat: str | re.Pattern[str], flags: int = ..., *, expand: Literal[False]
) -> _T_OBJECT: ...
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
def extractall(
self, pat: str | re.Pattern[str], flags: int = ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

) -> pd.DataFrame: ...
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ...
Expand All @@ -214,7 +231,11 @@ class StringMethods(
def isnumeric(self) -> _T_BOOL: ...
def isdecimal(self) -> _T_BOOL: ...
def fullmatch(
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
self,
pat: str | re.Pattern[str],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case: bool = ...,
flags: int = ...,
na: Any = ...,
) -> _T_BOOL: ...
def removeprefix(self, prefix: str) -> _T_STR: ...
def removesuffix(self, suffix: str) -> _T_STR: ...
47 changes: 47 additions & 0 deletions tests/test_string_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import pytest
from typing_extensions import assert_type

from tests import (
Expand Down Expand Up @@ -44,6 +45,7 @@ def test_string_accessors_boolean_series():
_check(assert_type(s.str.endswith("e"), "pd.Series[bool]"))
_check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"))
_check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"))
_check(assert_type(s.str.fullmatch(re.compile(r"apple")), "pd.Series[bool]"))
_check(assert_type(s.str.isalnum(), "pd.Series[bool]"))
_check(assert_type(s.str.isalpha(), "pd.Series[bool]"))
_check(assert_type(s.str.isdecimal(), "pd.Series[bool]"))
Expand All @@ -54,6 +56,7 @@ def test_string_accessors_boolean_series():
_check(assert_type(s.str.istitle(), "pd.Series[bool]"))
_check(assert_type(s.str.isupper(), "pd.Series[bool]"))
_check(assert_type(s.str.match("pp"), "pd.Series[bool]"))
_check(assert_type(s.str.match(re.compile(r"pp")), "pd.Series[bool]"))


def test_string_accessors_boolean_index():
Expand All @@ -72,6 +75,7 @@ def test_string_accessors_boolean_index():
_check(assert_type(idx.str.endswith("e"), np_ndarray_bool))
_check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool))
_check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool))
_check(assert_type(idx.str.fullmatch(re.compile(r"apple")), np_ndarray_bool))
_check(assert_type(idx.str.isalnum(), np_ndarray_bool))
_check(assert_type(idx.str.isalpha(), np_ndarray_bool))
_check(assert_type(idx.str.isdecimal(), np_ndarray_bool))
Expand All @@ -82,6 +86,7 @@ def test_string_accessors_boolean_index():
_check(assert_type(idx.str.istitle(), np_ndarray_bool))
_check(assert_type(idx.str.isupper(), np_ndarray_bool))
_check(assert_type(idx.str.match("pp"), np_ndarray_bool))
_check(assert_type(idx.str.match(re.compile(r"pp")), np_ndarray_bool))


def test_string_accessors_integer_series():
Expand All @@ -94,6 +99,10 @@ def test_string_accessors_integer_series():
_check(assert_type(s.str.count("pp"), "pd.Series[int]"))
_check(assert_type(s.str.len(), "pd.Series[int]"))

# unlike findall, find doesn't accept a compiled pattern
with pytest.raises(TypeError):
s.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


def test_string_accessors_integer_index():
idx = pd.Index(DATA)
Expand All @@ -105,6 +114,10 @@ def test_string_accessors_integer_index():
_check(assert_type(idx.str.count("pp"), "pd.Index[int]"))
_check(assert_type(idx.str.len(), "pd.Index[int]"))

# unlike findall, find doesn't accept a compiled pattern
with pytest.raises(TypeError):
idx.str.find(re.compile(r"p")) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]


def test_string_accessors_string_series():
s = pd.Series(DATA)
Expand All @@ -123,6 +136,9 @@ def test_string_accessors_string_series():
_check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"))
_check(assert_type(s.str.repeat(2), "pd.Series[str]"))
_check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"))
_check(
assert_type(s.str.replace(re.compile(r"a"), "X", regex=True), "pd.Series[str]")
)
_check(assert_type(s.str.rjust(80), "pd.Series[str]"))
_check(assert_type(s.str.rstrip(), "pd.Series[str]"))
_check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"))
Expand Down Expand Up @@ -158,6 +174,9 @@ def test_string_accessors_string_index():
_check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]"))
_check(assert_type(idx.str.repeat(2), "pd.Index[str]"))
_check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]"))
_check(
assert_type(idx.str.replace(re.compile(r"a"), "X", regex=True), "pd.Index[str]")
)
_check(assert_type(idx.str.rjust(80), "pd.Index[str]"))
_check(assert_type(idx.str.rstrip(), "pd.Index[str]"))
_check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]"))
Expand Down Expand Up @@ -190,29 +209,49 @@ def test_string_accessors_list_series():
s = pd.Series(DATA)
_check = functools.partial(check, klass=pd.Series, dtype=list)
_check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"))
_check(assert_type(s.str.findall(re.compile(r"pp")), "pd.Series[list[str]]"))
_check(assert_type(s.str.split("a"), "pd.Series[list[str]]"))
_check(assert_type(s.str.split(re.compile(r"a")), "pd.Series[list[str]]"))
# GH 194
_check(assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"))
_check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"))
_check(assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"))

# rsplit doesn't accept compiled pattern
# it doesn't raise at runtime but produces a nan
bad_rsplit_result = s.str.rsplit(
re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
)
assert bad_rsplit_result.isna().all()


def test_string_accessors_list_index():
idx = pd.Index(DATA)
_check = functools.partial(check, klass=pd.Index, dtype=list)
_check(assert_type(idx.str.findall("pp"), "pd.Index[list[str]]"))
_check(assert_type(idx.str.findall(re.compile(r"pp")), "pd.Index[list[str]]"))
_check(assert_type(idx.str.split("a"), "pd.Index[list[str]]"))
_check(assert_type(idx.str.split(re.compile(r"a")), "pd.Index[list[str]]"))
# GH 194
_check(assert_type(idx.str.split("a", expand=False), "pd.Index[list[str]]"))
_check(assert_type(idx.str.rsplit("a"), "pd.Index[list[str]]"))
_check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]"))

# rsplit doesn't accept compiled pattern
# it doesn't raise at runtime but produces a nan
bad_rsplit_result = idx.str.rsplit(
re.compile(r"a") # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
)
assert bad_rsplit_result.isna().all()


def test_string_accessors_expanding_series():
s = pd.Series(["a1", "b2", "c3"])
_check = functools.partial(check, klass=pd.DataFrame)
_check(assert_type(s.str.extract(r"([ab])?(\d)"), pd.DataFrame))
_check(assert_type(s.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame))
_check(assert_type(s.str.extractall(r"([ab])?(\d)"), pd.DataFrame))
_check(assert_type(s.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame))
_check(assert_type(s.str.get_dummies(), pd.DataFrame))
_check(assert_type(s.str.partition("p"), pd.DataFrame))
_check(assert_type(s.str.rpartition("p"), pd.DataFrame))
Expand All @@ -231,7 +270,15 @@ def test_string_accessors_expanding_index():

# These ones are the odd ones out?
check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
check(
assert_type(idx.str.extractall(re.compile(r"([ab])?(\d)")), pd.DataFrame),
pd.DataFrame,
)
check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
check(
assert_type(idx.str.extract(re.compile(r"([ab])?(\d)")), pd.DataFrame),
pd.DataFrame,
)


def test_series_overloads_partition():
Expand Down
Loading