Skip to content

Add direct support for dataclasses #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Project specific folders
stubs/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
47 changes: 35 additions & 12 deletions src/docstub/_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,30 +119,42 @@ class ScopeType(enum.StrEnum):
# docstub: on


# TODO use `libcst.metadata.ScopeProvider` instead
@dataclass(slots=True, frozen=True)
class _Scope:
""""""

type: ScopeType
node: cst.CSTNode = None
node: cst.CSTNode | None = None

@property
def has_self_or_cls(self):
def has_self_or_cls(self) -> bool:
return self.type in {ScopeType.METHOD, ScopeType.CLASSMETHOD}

@property
def is_method(self):
def is_method(self) -> bool:
return self.type in {
ScopeType.METHOD,
ScopeType.CLASSMETHOD,
ScopeType.STATICMETHOD,
}

@property
def is_class_init(self):
def is_class_init(self) -> bool:
out = self.is_method and self.node.name.value == "__init__"
return out

@property
def is_dataclass(self) -> bool:
if cstm.matches(self.node, cstm.ClassDef()):
# Determine if dataclass
decorators = cstm.findall(self.node, cstm.Decorator())
is_dataclass = any(
cstm.findall(d, cstm.Name("dataclass")) for d in decorators
)
return is_dataclass
return False


def _get_docstring_node(node):
"""Extract the node with the docstring from a definition.
Expand Down Expand Up @@ -672,16 +684,27 @@ def leave_AnnAssign(self, original_node, updated_node):
updated_node : cst.AnnAssign
"""
name = updated_node.target.value
is_type_alias = cstm.matches(
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
)
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))

# Remove value if not type alias or __all__
if updated_node.value is not None and not is_type_alias and not is__all__:
updated_node = updated_node.with_changes(
value=None, equal=cst.MaybeSentinel.DEFAULT
if updated_node.value is not None:
is_type_alias = cstm.matches(
updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias"))
)
is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__"))
is_dataclass = self._scope_stack[-1].is_dataclass
is_classvar = any(
cstm.findall(updated_node.annotation, cstm.Name("ClassVar"))
)

# Replace with ellipses if dataclass
if is_dataclass and not is_classvar:
updated_node = updated_node.with_changes(
value=cst.Ellipsis(), equal=cst.MaybeSentinel.DEFAULT
)
# Remove value if not type alias or __all__
elif not is_type_alias and not is__all__:
updated_node = updated_node.with_changes(
value=None, equal=cst.MaybeSentinel.DEFAULT
)

# Replace with type annotation from docstring, if available
pytypes = self._pytypes_stack[-1]
Expand Down
3 changes: 0 additions & 3 deletions stubtest_allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,3 @@ docstub\._version\..*
docstub\..*\.__match_args__$
docstub._cache.FuncSerializer.__type_params__
docstub._cli.main
docstub._config.Config.__init__
docstub._docstrings.Annotation.__init__
docstub._stubs._Scope.__init__
36 changes: 36 additions & 0 deletions tests/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,3 +394,39 @@ class Foo:
# remove these empty lines from the result too
result = dedent(result)
assert expected == result

@pytest.mark.parametrize("decorator", ["dataclass", "dataclasses.dataclass"])
def test_dataclass(self, decorator):
source = dedent(
f"""
@{decorator}
class Foo:
a: float
b: int = 3
c: str = None
_: KW_ONLY
d: dict[str, Any] = field(default_factory=dict)
e: InitVar[tuple] = tuple()
f: ClassVar
g: ClassVar[float]
h: Final[ClassVar[int]] = 1
"""
)
expected = dedent(
f"""
@{decorator}
class Foo:
a: float
b: int = ...
c: str = ...
_: KW_ONLY
d: dict[str, Any] = ...
e: InitVar[tuple] = ...
f: ClassVar
g: ClassVar[float]
h: Final[ClassVar[int]]
"""
)
transformer = Py2StubTransformer()
result = transformer.python_to_stub(source)
assert expected == result