diff --git a/.gitignore b/.gitignore index 3f321a6..bb96855 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Project specific folders +stubs/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index f32f3c4..a4d733f 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -119,19 +119,20 @@ class ScopeType(enum.StrEnum): # docstub: on +# TODO use `libcst.metadata.ScopeProvider` instead @dataclass(slots=True, frozen=True) class _Scope: """""" type: ScopeType - node: cst.CSTNode = None + node: cst.CSTNode | None = None @property - def has_self_or_cls(self): + def has_self_or_cls(self) -> bool: return self.type in {ScopeType.METHOD, ScopeType.CLASSMETHOD} @property - def is_method(self): + def is_method(self) -> bool: return self.type in { ScopeType.METHOD, ScopeType.CLASSMETHOD, @@ -139,10 +140,21 @@ def is_method(self): } @property - def is_class_init(self): + def is_class_init(self) -> bool: out = self.is_method and self.node.name.value == "__init__" return out + @property + def is_dataclass(self) -> bool: + if cstm.matches(self.node, cstm.ClassDef()): + # Determine if dataclass + decorators = cstm.findall(self.node, cstm.Decorator()) + is_dataclass = any( + cstm.findall(d, cstm.Name("dataclass")) for d in decorators + ) + return is_dataclass + return False + def _get_docstring_node(node): """Extract the node with the docstring from a definition. @@ -672,16 +684,27 @@ def leave_AnnAssign(self, original_node, updated_node): updated_node : cst.AnnAssign """ name = updated_node.target.value - is_type_alias = cstm.matches( - updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias")) - ) - is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__")) - # Remove value if not type alias or __all__ - if updated_node.value is not None and not is_type_alias and not is__all__: - updated_node = updated_node.with_changes( - value=None, equal=cst.MaybeSentinel.DEFAULT + if updated_node.value is not None: + is_type_alias = cstm.matches( + updated_node.annotation, cstm.Annotation(cstm.Name("TypeAlias")) ) + is__all__ = cstm.matches(updated_node.target, cstm.Name("__all__")) + is_dataclass = self._scope_stack[-1].is_dataclass + is_classvar = any( + cstm.findall(updated_node.annotation, cstm.Name("ClassVar")) + ) + + # Replace with ellipses if dataclass + if is_dataclass and not is_classvar: + updated_node = updated_node.with_changes( + value=cst.Ellipsis(), equal=cst.MaybeSentinel.DEFAULT + ) + # Remove value if not type alias or __all__ + elif not is_type_alias and not is__all__: + updated_node = updated_node.with_changes( + value=None, equal=cst.MaybeSentinel.DEFAULT + ) # Replace with type annotation from docstring, if available pytypes = self._pytypes_stack[-1] diff --git a/stubtest_allow.txt b/stubtest_allow.txt index de954fc..016a47f 100644 --- a/stubtest_allow.txt +++ b/stubtest_allow.txt @@ -2,6 +2,3 @@ docstub\._version\..* docstub\..*\.__match_args__$ docstub._cache.FuncSerializer.__type_params__ docstub._cli.main -docstub._config.Config.__init__ -docstub._docstrings.Annotation.__init__ -docstub._stubs._Scope.__init__ diff --git a/tests/test_stubs.py b/tests/test_stubs.py index cdc86f2..454f519 100644 --- a/tests/test_stubs.py +++ b/tests/test_stubs.py @@ -394,3 +394,39 @@ class Foo: # remove these empty lines from the result too result = dedent(result) assert expected == result + + @pytest.mark.parametrize("decorator", ["dataclass", "dataclasses.dataclass"]) + def test_dataclass(self, decorator): + source = dedent( + f""" + @{decorator} + class Foo: + a: float + b: int = 3 + c: str = None + _: KW_ONLY + d: dict[str, Any] = field(default_factory=dict) + e: InitVar[tuple] = tuple() + f: ClassVar + g: ClassVar[float] + h: Final[ClassVar[int]] = 1 + """ + ) + expected = dedent( + f""" + @{decorator} + class Foo: + a: float + b: int = ... + c: str = ... + _: KW_ONLY + d: dict[str, Any] = ... + e: InitVar[tuple] = ... + f: ClassVar + g: ClassVar[float] + h: Final[ClassVar[int]] + """ + ) + transformer = Py2StubTransformer() + result = transformer.python_to_stub(source) + assert expected == result