From 93bcd38791de466da6e8d453535f90b1369d7641 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Tue, 22 Jul 2025 22:05:09 +0200 Subject: [PATCH 1/5] Collect imports as scoped types "from imports" are collected as types within the scope of the containing module. E.g. "from pathlib import Path" will allow using "Path" inside the modules scope. "import" without "from" are collected as scoped type prefixes. Meaning if a module contains "import scipy as sp", "sp." can be used as a valid type prefix in that module. --- docs/user_guide.md | 8 +- examples/example_pkg-stubs/__init__.pyi | 3 + examples/example_pkg-stubs/_basic.pyi | 12 +- examples/example_pkg/__init__.py | 4 + examples/example_pkg/_basic.py | 17 +++ pyproject.toml | 10 +- src/docstub/_analysis.py | 150 +++++++++++++++++++----- src/docstub/_cli.py | 34 +++--- src/docstub/_docstrings.py | 8 +- src/docstub/_stubs.py | 2 +- src/docstub/_utils.py | 2 +- tests/test_analysis.py | 86 ++++++++++++-- 12 files changed, 262 insertions(+), 74 deletions(-) diff --git a/docs/user_guide.md b/docs/user_guide.md index ac3a5da..1f46e37 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -111,8 +111,9 @@ To translate a type from a docstring into a valid type annotation, docstub needs Out of the box, docstub will know about builtin types such as `int` or `bool` that don't need an import, and types in `typing`, `collections.abc` from Python's standard library. It will source these from the Python environment it is installed in. In addition to that, docstub will collect all types in the package directory you are running it on. +This also includes imported types, which you can then use within the scope of the module that imports them. -However, if you want to use types from third-party libraries you can tell docstub about them in a configuration file. +However, you can also tell docstub directly about external types in a configuration file. Docstub will look for a `pyproject.toml` or `docstub.toml` in the current working directory. Or, you can point docstub at TOML file(s) explicitly using the `--config` option. In these configuration file(s) you can declare external types directly with @@ -134,8 +135,9 @@ ski = "skimage" which will enable any type that is prefixed with `ski.` or `sklearn.tree.`, e.g. `ski.transform.AffineTransform` or `sklearn.tree.DecisionTreeClassifier`. -In both of these cases, docstub doesn't check that these types actually exist. -Testing the generated stubs with a type checker is recommended. +> [!IMPORTANT] +> Docstub doesn't check that types actually exist or if a symbol is a valid type. +> We always recommend validating the generated stubs with a full type checker! > [!TIP] > Docstub currently collects types statically. diff --git a/examples/example_pkg-stubs/__init__.pyi b/examples/example_pkg-stubs/__init__.pyi index 5a5bef1..c39e273 100644 --- a/examples/example_pkg-stubs/__init__.pyi +++ b/examples/example_pkg-stubs/__init__.pyi @@ -10,3 +10,6 @@ __all__ = [ class CustomException(Exception): pass + +class AnotherType: + pass diff --git a/examples/example_pkg-stubs/_basic.pyi b/examples/example_pkg-stubs/_basic.pyi index 7cdc627..65cf3c0 100644 --- a/examples/example_pkg-stubs/_basic.pyi +++ b/examples/example_pkg-stubs/_basic.pyi @@ -1,15 +1,16 @@ # File generated with docstub -import configparser import logging from collections.abc import Sequence +from configparser import ConfigParser +from configparser import ConfigParser as Cfg from typing import Any, Literal, Self, Union from _typeshed import Incomplete -from . import CustomException +from . import AnotherType, CustomException -logger: Incomplete +logger: logging.Logger __all__ = [ "func_empty", @@ -39,6 +40,7 @@ def func_use_from_elsewhere( a3: ExampleClass.NestedClass, a4: ExampleClass.NestedClass, ) -> tuple[CustomException, ExampleClass.NestedClass]: ... +def func_use_from_import(a1: AnotherType, a2: Cfg) -> None: ... class ExampleClass: @@ -58,6 +60,6 @@ class ExampleClass: @some_property.setter def some_property(self, value: str) -> None: ... @classmethod - def method_returning_cls(cls, config: configparser.ConfigParser) -> Self: ... + def method_returning_cls(cls, config: ConfigParser) -> Self: ... @classmethod - def method_returning_cls2(cls, config: configparser.ConfigParser) -> Self: ... + def method_returning_cls2(cls, config: ConfigParser) -> Self: ... diff --git a/examples/example_pkg/__init__.py b/examples/example_pkg/__init__.py index ac61e3d..f32d938 100644 --- a/examples/example_pkg/__init__.py +++ b/examples/example_pkg/__init__.py @@ -11,3 +11,7 @@ class CustomException(Exception): pass + + +class AnotherType: + pass diff --git a/examples/example_pkg/_basic.py b/examples/example_pkg/_basic.py index 4f12dd0..58d4842 100644 --- a/examples/example_pkg/_basic.py +++ b/examples/example_pkg/_basic.py @@ -1,12 +1,19 @@ """Basic docstring examples. Docstrings, including module-level ones, are stripped. + +Attributes +---------- +logger : logging.Logger """ # Existing imports are preserved import logging +from configparser import ConfigParser as Cfg # noqa: F401 from typing import Literal +from . import AnotherType # noqa: F401 + # Assign-statements are preserved logger = logging.getLogger(__name__) # Inline comments are stripped @@ -88,6 +95,16 @@ def func_use_from_elsewhere(a1, a2, a3, a4): """ +def func_use_from_import(a1, a2): + """Check using symbols made available in this module with from imports. + + Parameters + ---------- + a1 : AnotherType + a2 : Cfg + """ + + class ExampleClass: """Dummy. diff --git a/pyproject.toml b/pyproject.toml index a0824dc..834e9f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,14 +119,8 @@ run.source = ["docstub"] ".*maintenance.*" = "Maintenance" -[tool.docstub.types] -Path = "pathlib" - -[tool.docstub.type_prefixes] -re = "re" -cst = "libcst" -lark = "lark" -numpydoc = "numpydoc" +[tool.docstub.type_nicknames] +Path = "pathlib.Path" [tool.mypy] diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index a5ff9f3..146c5cf 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -197,10 +197,14 @@ def __post_init__(self): def __repr__(self) -> str: if self.builtin_name: - info = f"{self.target} (builtin)" + kwargs = f"builtin_name={self.builtin_name!r}" else: - info = f"{self.format_import()!r}" - out = f"<{type(self).__name__} {info}>" + kwargs = ( + f"import_path={self.import_path!r}, " + f"import_name={self.import_name!r}, " + f"import_alias={self.import_alias!r}" + ) + out = f"{type(self).__name__}({kwargs})" return out def __str__(self) -> str: @@ -291,14 +295,30 @@ def common_known_types(): return known_imports +@dataclass(slots=True, kw_only=True) +class TypeCollectionResult: + types: dict[str, KnownImport] + type_prefixes: dict[str, KnownImport] + + @classmethod + def serialize(cls, result): + pass + + @classmethod + def deserialize(cls, result): + pass + + class TypeCollector(cst.CSTVisitor): """Collect types from a given Python file. Examples -------- - >>> types = TypeCollector.collect(__file__) + >>> types, prefixes = TypeCollector.collect(__file__) >>> types[f"{__name__}.TypeCollector"] + >>> prefixes["logging"] + """ class ImportSerializer: @@ -312,17 +332,45 @@ def hash_args(self, path: Path) -> str: key = pyfile_checksum(path) return key - def serialize(self, data: dict[str, KnownImport]) -> bytes: - """Serialize results from `TypeCollector.collect`.""" - primitives = {qualname: asdict(imp) for qualname, imp in data.items()} - raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding) + def serialize(self, data): + """Serialize results from `TypeCollector.collect`. + + Parameters + ---------- + data : tuple[dict[str, KnownImport], dict[str, KnownImport]] + + Returns + ------- + raw : bytes + """ + primitives = {} + for name, table in zip(["types", "type_prefixes"], data, strict=False): + primitives[name] = {key: asdict(imp) for key, imp in table.items()} + raw = json.dumps(primitives, separators=(",", ":"), indent=1).encode( + self.encoding + ) return raw - def deserialize(self, raw: bytes) -> dict[str, KnownImport]: - """Deserialize results from `TypeCollector.collect`.""" + def deserialize(self, raw): + """Deserialize results from `TypeCollector.collect`. + + Parameters + ---------- + raw : bytes + + Returns + ------- + types : dict[str, KnownImport] + type_prefixes : dict[str, KnownImport] + """ primitives = json.loads(raw.decode(self.encoding)) - data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()} - return data + + def deserialize_table(table): + return {key: KnownImport(**kw) for key, kw in table.items()} + + types = deserialize_table(primitives["types"]) + type_prefixes = deserialize_table(primitives["type_prefixes"]) + return types, type_prefixes @classmethod def collect(cls, file): @@ -334,7 +382,8 @@ def collect(cls, file): Returns ------- - collected : dict[str, KnownImport] + types : dict[str, KnownImport] + type_prefixes : dict[str, KnownImport] """ file = Path(file) with file.open("r") as fo: @@ -343,7 +392,7 @@ def collect(cls, file): tree = cst.parse_module(source) collector = cls(module_name=module_name_from_path(file)) tree.visit(collector) - return collector.known_imports + return collector.types, collector.type_prefixes def __init__(self, *, module_name): """Initialize type collector. @@ -354,7 +403,8 @@ def __init__(self, *, module_name): """ self.module_name = module_name self._stack = [] - self.known_imports = {} + self.types = {} + self.type_prefixes = {} def visit_ClassDef(self, node: cst.ClassDef) -> bool: self._stack.append(node.name.value) @@ -388,6 +438,44 @@ def visit_AnnAssign(self, node: cst.AnnAssign) -> bool: self._collect_type_annotation(stack) return False + def visit_ImportFrom(self, node: cst.ImportFrom) -> bool: + """Collect "from import" targets as usable types.""" + if cstm.matches(node.names, cstm.ImportStar()): + return False + + if node.module: + from_names = cstm.findall(node.module, cstm.Name()) + from_names = [n.value for n in from_names] + else: + from_names = [] + + for import_alias in node.names: + asname = import_alias.evaluated_alias + name = import_alias.evaluated_name + + if not node.relative: + key = ".".join([*from_names, name]) + known_import = KnownImport( + import_path=".".join(from_names), import_name=name + ) + self.types[key] = known_import + + scoped_import = KnownImport(builtin_name=asname or name) + self.types[f"{self.module_name}:{asname or name}"] = scoped_import + + return False + + def visit_Import(self, node: cst.Import) -> bool: + for import_alias in node.names: + asname = import_alias.evaluated_alias + name = import_alias.evaluated_name + target = asname or name + + known_import = KnownImport(builtin_name=asname or name) + self.type_prefixes[f"{self.module_name}:{target}"] = known_import + + return False + def _collect_type_annotation(self, stack): """Collect an importable type annotation. @@ -398,7 +486,7 @@ def _collect_type_annotation(self, stack): """ qualname = ".".join([self.module_name, *stack]) known_import = KnownImport(import_path=self.module_name, import_name=stack[0]) - self.known_imports[qualname] = known_import + self.types[qualname] = known_import class TypeMatcher: @@ -411,7 +499,7 @@ class TypeMatcher: type_nicknames : dict[str, str] successful_queries : int unknown_qualnames : list - current_module : Path | None + current_file : Path | None Examples -------- @@ -435,13 +523,14 @@ def __init__( type_prefixes : dict[str, KnownImport] type_nicknames : dict[str, str] """ - self.types = types or common_known_types() + + self.types = common_known_types() | (types or {}) self.type_prefixes = type_prefixes or {} self.type_nicknames = type_nicknames or {} self.successful_queries = 0 self.unknown_qualnames = [] - self.current_module = None + self.current_file = None def match(self, search_name): """Search for a known annotation name. @@ -459,6 +548,8 @@ def match(self, search_name): type_name = None type_origin = None + module = module_name_from_path(self.current_file) if self.current_file else None + if search_name.startswith("~."): # Sphinx like matching with abbreviated name pattern = search_name.replace(".", r"\.") @@ -466,7 +557,10 @@ def match(self, search_name): regex = re.compile(pattern + "$") # Might be slow, but works for now matches = { - key: value for key, value in self.types.items() if regex.match(key) + key: value + for key, value in self.types.items() + if regex.match(key) + if ":" not in key } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] @@ -475,7 +569,7 @@ def match(self, search_name): logger.warning( "%r in %s matches multiple types %r, using %r", search_name, - self.current_module or "", + self.current_file or "", matches.keys(), shortest_key, ) @@ -486,17 +580,16 @@ def match(self, search_name): logger.debug( "couldn't match %r in %s", search_name, - self.current_module or "", + self.current_file or "", ) # Replace alias search_name = self.type_nicknames.get(search_name, search_name) - if type_origin is None and self.current_module: - # Try scope of current module - module_name = module_name_from_path(self.current_module) - try_qualname = f"{module_name}.{search_name}" - type_origin = self.types.get(try_qualname) + if type_origin is None and module: + # Look for matching type in current module + type_origin = self.types.get(f"{module}:{search_name}") + type_origin = self.types.get(f"{module}.{search_name}", type_origin) if type_origin: type_name = search_name @@ -507,7 +600,8 @@ def match(self, search_name): if type_origin is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - type_origin = self.type_prefixes.get(partial_qualname) + type_origin = self.type_prefixes.get(f"{module}:{partial_qualname}") + type_origin = self.type_prefixes.get(partial_qualname, type_origin) if type_origin: type_name = search_name break diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 3bcbf76..f2ce941 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -21,7 +21,7 @@ walk_source_and_targets, ) from ._stubs import Py2StubTransformer, try_format_stub -from ._utils import ErrorReporter, GroupedErrorReporter +from ._utils import ErrorReporter, GroupedErrorReporter, module_name_from_path from ._version import __version__ logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ def _setup_logging(*, verbose): ) -def _collect_types(root_path, *, ignore=()): +def _collect_type_info(root_path, *, ignore=()): """Collect types. Parameters @@ -94,22 +94,29 @@ def _collect_types(root_path, *, ignore=()): Returns ------- types : dict[str, ~.KnownImport] + type_prefixes : dict[str, ~.KnownImport] """ types = common_known_types() + type_prefixes = {} - collect_cached_types = FileCache( - func=TypeCollector.collect, - serializer=TypeCollector.ImportSerializer(), - cache_dir=Path.cwd() / ".docstub_cache", - name=f"{__version__}/collected_types", - ) if root_path.is_dir(): for source_path in walk_python_package(root_path, ignore=ignore): + + module = module_name_from_path(source_path) + module = module.replace(".", "/") + collect_cached_types = FileCache( + func=TypeCollector.collect, + serializer=TypeCollector.ImportSerializer(), + cache_dir=Path.cwd() / ".docstub_cache", + name=f"{__version__}/{module}", + ) + logger.info("collecting types in %s", source_path) - types_in_source = collect_cached_types(source_path) - types.update(types_in_source) + types_in_file, prefixes_in_file = collect_cached_types(source_path) + types.update(types_in_file) + type_prefixes.update(prefixes_in_file) - return types + return types, type_prefixes @contextmanager @@ -228,14 +235,13 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve config = _load_configuration(config_paths) config = config.merge(Config(ignore_files=list(ignore))) - types = common_known_types() - types |= _collect_types(root_path, ignore=config.ignore_files) + types, type_prefixes = _collect_type_info(root_path, ignore=config.ignore_files) types |= { type_name: KnownImport(import_path=module, import_name=type_name) for type_name, module in config.types.items() } - type_prefixes = { + type_prefixes |= { prefix: ( KnownImport(import_name=module, import_alias=prefix) if module != prefix diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index 1591cd4..b1ba2ac 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -555,11 +555,11 @@ def _uncombine_numpydoc_params(params): Parameters ---------- - params : list[numpydoc.docsrape.Parameter] + params : list[npds.Parameter] Yields ------ - param : numpydoc.docscrape.Parameter + param : npds.Parameter """ for param in params: if "," in param.name: @@ -791,11 +791,11 @@ def _handle_missing_whitespace(self, param): Parameters ---------- - param : numpydoc.docscrape.Parameter + param : npds.Parameter Returns ------- - param : numpydoc.docscrape.Parameter + param : npds.Parameter """ if ":" in param.name and param.type == "": msg = "Possibly missing whitespace between parameter and colon in docstring" diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index e64d49c..ab27a3f 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -318,7 +318,7 @@ def current_source(self, value): # TODO pass current_source directly when using the transformer / matcher # instead of assigning it here! if self.transformer is not None and self.transformer.matcher is not None: - self.transformer.matcher.current_module = value + self.transformer.matcher.current_file = value @property def is_inside_function_def(self): diff --git a/src/docstub/_utils.py b/src/docstub/_utils.py index bbd55bd..462a646 100644 --- a/src/docstub/_utils.py +++ b/src/docstub/_utils.py @@ -65,7 +65,7 @@ def escape_qualname(name): return qualname -@lru_cache(maxsize=10) +@lru_cache(maxsize=100) def module_name_from_path(path): """Find the full name of a module within its package from its file path. diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 3188495..729201d 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -49,7 +49,6 @@ def _module_factory(src, module_name): class Test_TypeCollector: - def test_classes(self, module_factory): module_path = module_factory( src=dedent( @@ -61,14 +60,15 @@ class NestedClass: ), module_name="sub.module", ) - imports = TypeCollector.collect(file=module_path) - assert len(imports) == 2 - assert imports["sub.module.TopLevelClass"] == KnownImport( + types, prefixes = TypeCollector.collect(file=module_path) + assert prefixes == {} + assert len(types) == 2 + assert types["sub.module.TopLevelClass"] == KnownImport( import_path="sub.module", import_name="TopLevelClass" ) # The import for the nested class should still use only the top-level # class as an import target - assert imports["sub.module.TopLevelClass.NestedClass"] == KnownImport( + assert types["sub.module.TopLevelClass.NestedClass"] == KnownImport( import_path="sub.module", import_name="TopLevelClass" ) @@ -77,9 +77,10 @@ class NestedClass: ) def test_type_alias(self, module_factory, src): module_path = module_factory(src=src, module_name="sub.module") - imports = TypeCollector.collect(file=module_path) - assert len(imports) == 1 - assert imports == { + types, prefixes = TypeCollector.collect(file=module_path) + assert prefixes == {} + assert len(types) == 1 + assert types == { "sub.module.alias_name": KnownImport( import_path="sub.module", import_name="alias_name" ) @@ -97,8 +98,73 @@ def test_type_alias(self, module_factory, src): ) def test_ignores_assigns(self, module_factory, src): module_path = module_factory(src=src, module_name="sub.module") - imports = TypeCollector.collect(file=module_path) - assert len(imports) == 0 + types, prefixes = TypeCollector.collect(file=module_path) + assert prefixes == {} + assert len(types) == 0 + + def test_from_import(self, module_factory): + src = dedent( + """ + from calendar import gregorian + from calendar.gregorian import August as Aug, December + """ + ) + + module_path = module_factory(src=src, module_name="sub.module") + types, prefixes = TypeCollector.collect(file=module_path) + + assert prefixes == {} + assert types == { + "calendar.gregorian": KnownImport( + import_path="calendar", import_name="gregorian" + ), + "calendar.gregorian.August": KnownImport( + import_path="calendar.gregorian", import_name="August" + ), + "calendar.gregorian.December": KnownImport( + import_path="calendar.gregorian", import_name="December" + ), + "sub.module:gregorian": KnownImport(builtin_name="gregorian"), + "sub.module:Aug": KnownImport(builtin_name="Aug"), + "sub.module:December": KnownImport(builtin_name="December"), + } + + def test_relative_import(self, module_factory): + src = dedent( + """ + from . import January + from .. import August as Aug, December + from ..calendar import September + """ + ) + module_path = module_factory(src=src, module_name="sub.module") + types, prefixes = TypeCollector.collect(file=module_path) + assert prefixes == {} + assert types == { + "sub.module:January": KnownImport(builtin_name="January"), + "sub.module:Aug": KnownImport(builtin_name="Aug"), + "sub.module:December": KnownImport(builtin_name="December"), + "sub.module:September": KnownImport(builtin_name="September"), + } + + def test_imports(self, module_factory): + src = dedent( + """ + import calendar + import drinks as dr + import calendar.gregorian as greg + """ + ) + + module_path = module_factory(src=src, module_name="sub.module") + types, prefixes = TypeCollector.collect(file=module_path) + assert types == {} + assert len(prefixes) == 3 + assert prefixes == { + "sub.module:calendar": KnownImport(builtin_name="calendar"), + "sub.module:dr": KnownImport(builtin_name="dr"), + "sub.module:greg": KnownImport(builtin_name="greg"), + } class Test_TypeMatcher: From 41ff30cf05b546154dfbdde1ac622d8fddf48f5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 23 Jul 2025 15:45:48 +0200 Subject: [PATCH 2/5] Rename `KnownImport` to `PyImport` And also improve the class representation and various other things. --- src/docstub/_analysis.py | 292 +++++++++++++++++-------------------- src/docstub/_cli.py | 12 +- src/docstub/_docstrings.py | 34 ++--- src/docstub/_stubs.py | 10 +- tests/test_analysis.py | 88 ++++++----- tests/test_docstrings.py | 46 ++---- 6 files changed, 217 insertions(+), 265 deletions(-) diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index 146c5cf..d2bdba9 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -50,30 +50,38 @@ def _shared_leading_qualname(*qualnames): @dataclass(slots=True, frozen=True) -class KnownImport: - """Import information associated with a single known type annotation. +class PyImport: + """Information to construct an import statement for any Python object. Attributes ---------- - import_path + from_ : Dotted names after "from". - import_name + import_ : Dotted names after "import". - import_alias + as_ : Name (without ".") after "as". - builtin_name - Names an object that's builtin and doesn't need an import. + implicit : + Describes an object that doesn't need an import statement and is + implicitly available. This may be a builtin or an object that is known + to be available in a given scope. E.g. it may have already been + imported. Examples -------- - >>> KnownImport(import_path="numpy", import_name="uint8", import_alias="ui8") - + >>> str(PyImport(from_="numpy", import_="uint8", as_="ui8")) + 'from numpy import uint8 as ui8' + + >>> str(PyImport(implicit="int")) + Traceback (most recent call last): + ... + RuntimeError: cannot import implicit object: 'int' """ - import_name: str | None = None - import_path: str | None = None - import_alias: str | None = None - builtin_name: str | None = None + import_: str | None = None + from_: str | None = None + as_: str | None = None + implicit: str | None = None @classmethod @cache @@ -85,125 +93,77 @@ def typeshed_Incomplete(cls): Returns ------- - import : KnownImport + import : PyImport The import corresponding to ``from _typeshed import Incomplete``. References ---------- .. [1] https://typing.readthedocs.io/en/latest/guides/writing_stubs.html#incomplete-stubs """ - import_ = cls(import_path="_typeshed", import_name="Incomplete") + import_ = cls(from_="_typeshed", import_="Incomplete") return import_ - @classmethod - def one_from_config(cls, name, *, info): - """Create one KnownImport from the configuration format. - - Parameters - ---------- - name : str - info : dict[{"from", "import", "as", "is_builtin"}, str] - - Returns - ------- - TypeImport : Self - """ - assert not (info.keys() - {"from", "import", "as", "is_builtin"}) - - if info.get("is_builtin"): - known_import = cls(builtin_name=name) - else: - import_name = name - if "import" in info: - import_name = info["import"] - - known_import = cls( - import_name=import_name, - import_path=info.get("from"), - import_alias=info.get("as"), - ) - if not name.startswith(known_import.target): - raise ValueError( - f"{name!r} doesn't start with {known_import.target!r}", - ) - - return known_import - - @classmethod - def many_from_config(cls, mapping): - """Create many KnownImports from the configuration format. - - Parameters - ---------- - mapping : dict[str, dict[{"from", "import", "as", "is_builtin"}, str]] - - Returns - ------- - known_imports : dict[str, Self] - """ - known_imports = { - name: cls.one_from_config(name, info=info) for name, info in mapping.items() - } - return known_imports - def format_import(self, relative_to=None): - if self.builtin_name: - msg = "cannot import builtin" + if self.implicit: + msg = f"cannot import implicit object: {self.implicit!r}" raise RuntimeError(msg) - out = f"import {self.import_name}" + out = f"import {self.import_}" - import_path = self.import_path + import_path = self.from_ if import_path: if relative_to: shared = _shared_leading_qualname(relative_to, import_path) if shared == import_path: import_path = "." else: - import_path = self.import_path.replace(shared, "") + import_path = self.from_.replace(shared, "") out = f"from {import_path} {out}" - if self.import_alias: - out = f"{out} as {self.import_alias}" + if self.as_: + out = f"{out} as {self.as_}" return out @property def target(self) -> str: - if self.import_alias: - out = self.import_alias - elif self.import_name: - out = self.import_name - elif self.builtin_name: - out = self.builtin_name + if self.as_: + out = self.as_ + elif self.import_: + out = self.import_ + elif self.implicit: + # Account for scoped form "some_module_scope:target" + out = self.implicit.split(":")[-1] else: raise RuntimeError("cannot determine import target") return out @property def has_import(self): - return self.builtin_name is None + return self.implicit is None def __post_init__(self): - if self.builtin_name is not None: + if self.implicit is not None: if ( - self.import_name is not None - or self.import_alias is not None - or self.import_path is not None + self.import_ is not None + or self.as_ is not None + or self.from_ is not None ): - raise ValueError("builtin cannot contain import information") - elif self.import_name is None: - raise ValueError("non builtin must at least define an `import_name`") - if self.import_alias is not None and "." in self.import_alias: - raise ValueError("`import_alias` can't contain a '.'") + raise ValueError("implicit import cannot contain import information") + elif self.import_ is None: + raise ValueError("must set at least one parameter: `import_` or `implicit`") + if self.as_ is not None and "." in self.as_: + raise ValueError("parameter `as_` can't contain a '.'") def __repr__(self) -> str: - if self.builtin_name: - kwargs = f"builtin_name={self.builtin_name!r}" + if self.implicit: + kwargs = f"implicit={self.implicit!r}" else: - kwargs = ( - f"import_path={self.import_path!r}, " - f"import_name={self.import_name!r}, " - f"import_alias={self.import_alias!r}" - ) + kwargs = [ + f"from_={self.from_!r}" if self.from_ else None, + f"import_={self.import_!r}" if self.import_ else None, + f"as_={self.as_!r}" if self.as_ else None, + ] + kwargs = [arg for arg in kwargs if arg is not None] + kwargs = ", ".join(kwargs) out = f"{type(self).__name__}({kwargs})" return out @@ -231,27 +191,37 @@ def _is_type(value): def _builtin_types(): - """Return known imports for all builtins (in the current runtime). + """Return known imports for all builtins in the current runtime. Returns ------- - known_imports : dict[str, KnownImport] + types : dict[str, PyImport] """ known_builtins = set(dir(builtins)) - known_imports = {} + types = {} for name in known_builtins: if name.startswith("_"): continue value = getattr(builtins, name) if not _is_type(value): continue - known_imports[name] = KnownImport(builtin_name=name) + types[name] = PyImport(implicit=name) - return known_imports + return types def _runtime_types_in_module(module_name): + """Return types of a module in the current runtime. + + Parameters + ---------- + module_name : str + + Returns + ------- + types : dict[str, PyImport] + """ module = importlib.import_module(module_name) types = {} for name in module.__all__: @@ -261,44 +231,44 @@ def _runtime_types_in_module(module_name): if not _is_type(value): continue - import_ = KnownImport(import_path=module_name, import_name=name) - types[name] = import_ - types[f"{module_name}.{name}"] = import_ + py_import = PyImport(from_=module_name, import_=name) + types[name] = py_import + types[f"{module_name}.{name}"] = py_import return types def common_known_types(): - """Return known imports for commonly supported types. + """Return commonly supported types. This includes builtin types, and types from the `typing` or `collections.abc` module. Returns ------- - known_imports : dict[str, KnownImport] + py_imports : dict[str, PyImport] Examples -------- >>> types = common_known_types() >>> types["str"] - + PyImport(implicit='str') >>> types["Iterable"] - + PyImport(from_='collections.abc', import_='Iterable') >>> types["collections.abc.Iterable"] - + PyImport(from_='collections.abc', import_='Iterable') """ - known_imports = _builtin_types() - known_imports |= _runtime_types_in_module("typing") + types = _builtin_types() + types |= _runtime_types_in_module("typing") # Overrides containers from typing - known_imports |= _runtime_types_in_module("collections.abc") - return known_imports + types |= _runtime_types_in_module("collections.abc") + return types @dataclass(slots=True, kw_only=True) class TypeCollectionResult: - types: dict[str, KnownImport] - type_prefixes: dict[str, KnownImport] + types: dict[str, PyImport] + type_prefixes: dict[str, PyImport] @classmethod def serialize(cls, result): @@ -316,9 +286,13 @@ class TypeCollector(cst.CSTVisitor): -------- >>> types, prefixes = TypeCollector.collect(__file__) >>> types[f"{__name__}.TypeCollector"] - - >>> prefixes["logging"] - + PyImport(from_='docstub._analysis', import_='TypeCollector') + + >>> from pathlib import Path + >>> from docstub._utils import module_name_from_path + >>> module = module_name_from_path(Path(__file__)) + >>> prefixes[f"{module}:logging"] + PyImport(implicit='...:logging') """ class ImportSerializer: @@ -337,7 +311,7 @@ def serialize(self, data): Parameters ---------- - data : tuple[dict[str, KnownImport], dict[str, KnownImport]] + data : tuple[dict[str, PyImport], dict[str, PyImport]] Returns ------- @@ -360,13 +334,13 @@ def deserialize(self, raw): Returns ------- - types : dict[str, KnownImport] - type_prefixes : dict[str, KnownImport] + types : dict[str, PyImport] + type_prefixes : dict[str, PyImport] """ primitives = json.loads(raw.decode(self.encoding)) def deserialize_table(table): - return {key: KnownImport(**kw) for key, kw in table.items()} + return {key: PyImport(**kw) for key, kw in table.items()} types = deserialize_table(primitives["types"]) type_prefixes = deserialize_table(primitives["type_prefixes"]) @@ -382,8 +356,8 @@ def collect(cls, file): Returns ------- - types : dict[str, KnownImport] - type_prefixes : dict[str, KnownImport] + types : dict[str, PyImport] + type_prefixes : dict[str, PyImport] """ file = Path(file) with file.open("r") as fo: @@ -455,13 +429,12 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> bool: if not node.relative: key = ".".join([*from_names, name]) - known_import = KnownImport( - import_path=".".join(from_names), import_name=name - ) - self.types[key] = known_import + py_import = PyImport(from_=".".join(from_names), import_=name) + self.types[key] = py_import - scoped_import = KnownImport(builtin_name=asname or name) - self.types[f"{self.module_name}:{asname or name}"] = scoped_import + scoped_key = f"{self.module_name}:{asname or name}" + scoped_import = PyImport(implicit=scoped_key) + self.types[scoped_key] = scoped_import return False @@ -469,10 +442,9 @@ def visit_Import(self, node: cst.Import) -> bool: for import_alias in node.names: asname = import_alias.evaluated_alias name = import_alias.evaluated_name - target = asname or name - - known_import = KnownImport(builtin_name=asname or name) - self.type_prefixes[f"{self.module_name}:{target}"] = known_import + scoped_key = f"{self.module_name}:{asname or name}" + py_import = PyImport(implicit=scoped_key) + self.type_prefixes[scoped_key] = py_import return False @@ -485,8 +457,8 @@ def _collect_type_annotation(self, stack): A list of names that form the path to the collected type. """ qualname = ".".join([self.module_name, *stack]) - known_import = KnownImport(import_path=self.module_name, import_name=stack[0]) - self.types[qualname] = known_import + py_import = PyImport(from_=self.module_name, import_=stack[0]) + self.types[qualname] = py_import class TypeMatcher: @@ -494,8 +466,8 @@ class TypeMatcher: Attributes ---------- - types : dict[str, KnownImport] - type_prefixes : dict[str, KnownImport] + types : dict[str, PyImport] + type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] successful_queries : int unknown_qualnames : list @@ -506,7 +478,7 @@ class TypeMatcher: >>> from docstub._analysis import TypeMatcher, common_known_types >>> db = TypeMatcher() >>> db.match("Any") - ('Any', ) + ('Any', PyImport(from_='typing', import_='Any')) """ def __init__( @@ -519,8 +491,8 @@ def __init__( """ Parameters ---------- - types : dict[str, KnownImport] - type_prefixes : dict[str, KnownImport] + types : dict[str, PyImport] + type_prefixes : dict[str, PyImport] type_nicknames : dict[str, str] """ @@ -543,10 +515,10 @@ def match(self, search_name): Returns ------- type_name : str | None - type_origin : KnownImport | None + py_import : PyImport | None """ type_name = None - type_origin = None + py_import = None module = module_name_from_path(self.current_file) if self.current_file else None @@ -564,7 +536,7 @@ def match(self, search_name): } if len(matches) > 1: shortest_key = sorted(matches.keys(), key=lambda x: len(x))[0] - type_origin = matches[shortest_key] + py_import = matches[shortest_key] type_name = shortest_key logger.warning( "%r in %s matches multiple types %r, using %r", @@ -574,7 +546,7 @@ def match(self, search_name): shortest_key, ) elif len(matches) == 1: - type_name, type_origin = matches.popitem() + type_name, py_import = matches.popitem() else: search_name = search_name[2:] logger.debug( @@ -586,38 +558,38 @@ def match(self, search_name): # Replace alias search_name = self.type_nicknames.get(search_name, search_name) - if type_origin is None and module: + if py_import is None and module: # Look for matching type in current module - type_origin = self.types.get(f"{module}:{search_name}") - type_origin = self.types.get(f"{module}.{search_name}", type_origin) - if type_origin: + py_import = self.types.get(f"{module}:{search_name}") + py_import = self.types.get(f"{module}.{search_name}", py_import) + if py_import: type_name = search_name - if type_origin is None and search_name in self.types: + if py_import is None and search_name in self.types: type_name = search_name - type_origin = self.types[search_name] + py_import = self.types[search_name] - if type_origin is None: + if py_import is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') for partial_qualname in reversed(accumulate_qualname(search_name)): - type_origin = self.type_prefixes.get(f"{module}:{partial_qualname}") - type_origin = self.type_prefixes.get(partial_qualname, type_origin) - if type_origin: + py_import = self.type_prefixes.get(f"{module}:{partial_qualname}") + py_import = self.type_prefixes.get(partial_qualname, py_import) + if py_import: type_name = search_name break if ( - type_origin is not None + py_import is not None and type_name is not None - and type_name != type_origin.target - and not type_name.startswith(type_origin.target) + and type_name != py_import.target + and not type_name.startswith(py_import.target) ): # Ensure that the annotation matches the import target - type_name = type_name[type_name.find(type_origin.target) :] + type_name = type_name[type_name.find(py_import.target) :] if type_name is not None: self.successful_queries += 1 else: self.unknown_qualnames.append(search_name) - return type_name, type_origin + return type_name, py_import diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index f2ce941..96328bc 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -8,7 +8,7 @@ import click from ._analysis import ( - KnownImport, + PyImport, TypeCollector, TypeMatcher, common_known_types, @@ -93,8 +93,8 @@ def _collect_type_info(root_path, *, ignore=()): Returns ------- - types : dict[str, ~.KnownImport] - type_prefixes : dict[str, ~.KnownImport] + types : dict[str, PyImport] + type_prefixes : dict[str, PyImport] """ types = common_known_types() type_prefixes = {} @@ -237,15 +237,15 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve types, type_prefixes = _collect_type_info(root_path, ignore=config.ignore_files) types |= { - type_name: KnownImport(import_path=module, import_name=type_name) + type_name: PyImport(from_=module, import_=type_name) for type_name, module in config.types.items() } type_prefixes |= { prefix: ( - KnownImport(import_name=module, import_alias=prefix) + PyImport(import_=module, as_=prefix) if module != prefix - else KnownImport(import_name=prefix) + else PyImport(import_=prefix) ) for prefix, module in config.type_prefixes.items() } diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index b1ba2ac..cf76085 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -15,7 +15,7 @@ # It should be possible to transform docstrings without matching to valid # types and imports. I think that could very well be done at a higher level, # e.g. in the stubs module. -from ._analysis import KnownImport, TypeMatcher +from ._analysis import PyImport, TypeMatcher from ._utils import DocstubError, ErrorReporter, escape_qualname logger = logging.getLogger(__name__) @@ -60,14 +60,14 @@ class Annotation: """Python-ready type annotation with attached import information.""" value: str - imports: frozenset[KnownImport] = field(default_factory=frozenset) + imports: frozenset[PyImport] = field(default_factory=frozenset) def __post_init__(self): object.__setattr__(self, "imports", frozenset(self.imports)) if "~" in self.value: raise ValueError(f"unexpected '~' in annotation value: {self.value}") for import_ in self.imports: - if not isinstance(import_, KnownImport): + if not isinstance(import_, PyImport): raise TypeError(f"unexpected type {type(import_)} in `imports`") def __str__(self) -> str: @@ -133,7 +133,7 @@ def as_generator(cls, *, yield_types, receive_types=(), return_types=()): value = f"{value}, {return_annotation.value}" value = f"Generator[{value}]" - imports |= {KnownImport(import_path="collections.abc", import_name="Generator")} + imports |= {PyImport(from_="collections.abc", import_="Generator")} generator = cls(value=value, imports=imports) return generator @@ -165,7 +165,7 @@ def _aggregate_annotations(*types): Returns ------- values : list[str] - imports : set[~.KnownImport] + imports : set[PyImport] """ values = [] imports = set() @@ -176,7 +176,7 @@ def _aggregate_annotations(*types): FallbackAnnotation = Annotation( - value="Incomplete", imports=frozenset([KnownImport.typeshed_Incomplete()]) + value="Incomplete", imports=frozenset([PyImport.typeshed_Incomplete()]) ) @@ -394,9 +394,9 @@ def natlang_literal(self, tree): ) if self.matcher is not None: - _, known_import = self.matcher.match("Literal") - if known_import: - self._collected_imports.add(known_import) + _, py_import = self.matcher.match("Literal") + if py_import: + self._collected_imports.add(py_import) return out def natlang_container(self, tree): @@ -524,13 +524,13 @@ def _match_import(self, qualname, *, meta): Possibly modified or normalized qualname. """ if self.matcher is not None: - annotation_name, known_import = self.matcher.match(qualname) + annotation_name, py_import = self.matcher.match(qualname) else: annotation_name = None - known_import = None + py_import = None - if known_import and known_import.has_import: - self._collected_imports.add(known_import) + if py_import and py_import.has_import: + self._collected_imports.add(py_import) if annotation_name: matched_qualname = annotation_name @@ -538,10 +538,10 @@ def _match_import(self, qualname, *, meta): # Unknown qualname, alias to `Incomplete` self._unknown_qualnames.append((qualname, meta.start_pos, meta.end_pos)) matched_qualname = escape_qualname(qualname) - any_alias = KnownImport( - import_path="_typeshed", - import_name="Incomplete", - import_alias=matched_qualname, + any_alias = PyImport( + from_="_typeshed", + import_="Incomplete", + as_=matched_qualname, ) self._collected_imports.add(any_alias) return matched_qualname diff --git a/src/docstub/_stubs.py b/src/docstub/_stubs.py index ab27a3f..6edaef8 100644 --- a/src/docstub/_stubs.py +++ b/src/docstub/_stubs.py @@ -14,7 +14,7 @@ import libcst as cst import libcst.matchers as cstm -from ._analysis import KnownImport +from ._analysis import PyImport from ._docstrings import DocstringAnnotations, DoctypeTransformer from ._utils import ErrorReporter, module_name_from_path @@ -571,7 +571,7 @@ def leave_Param(self, original_node, updated_node): # Potentially use "Incomplete" except for first param in (class)methods elif not is_self_or_cls and updated_node.annotation is None: node_changes["annotation"] = self._Annotation_Incomplete - import_ = KnownImport.typeshed_Incomplete() + import_ = PyImport.typeshed_Incomplete() self._required_imports.add(import_) if node_changes: @@ -756,7 +756,7 @@ def leave_Module(self, original_node, updated_node): if self.current_source: current_module = module_name_from_path(self.current_source) required_imports = [ - imp for imp in required_imports if imp.import_path != current_module + imp for imp in required_imports if imp.from_ != current_module ] import_nodes = self._parse_imports( required_imports, current_module=current_module @@ -818,7 +818,7 @@ def _parse_imports(imports, *, current_module=None): Parameters ---------- - imports : set[~.KnownImport] + imports : set[PyImport] current_module : str, optional Returns @@ -912,7 +912,7 @@ def _create_annotated_assign(self, *, name, trailing_semicolon=False): self._required_imports |= pytype.imports else: annotation = self._Annotation_Incomplete - self._required_imports.add(KnownImport.typeshed_Incomplete()) + self._required_imports.add(PyImport.typeshed_Incomplete()) semicolon = ( cst.Semicolon(whitespace_after=cst.SimpleWhitespace(" ")) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 729201d..145f46e 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -3,7 +3,7 @@ import pytest from docstub._analysis import ( - KnownImport, + PyImport, TypeCollector, TypeMatcher, ) @@ -12,7 +12,7 @@ class Test_KnownImport: def test_dot_in_alias(self): with pytest.raises(ValueError, match=r".*can't contain a '\.'"): - KnownImport(import_name="foo.bar.baz", import_alias="bar.baz") + PyImport(import_="foo.bar.baz", as_="bar.baz") @pytest.fixture @@ -63,13 +63,13 @@ class NestedClass: types, prefixes = TypeCollector.collect(file=module_path) assert prefixes == {} assert len(types) == 2 - assert types["sub.module.TopLevelClass"] == KnownImport( - import_path="sub.module", import_name="TopLevelClass" + assert types["sub.module.TopLevelClass"] == PyImport( + from_="sub.module", import_="TopLevelClass" ) # The import for the nested class should still use only the top-level # class as an import target - assert types["sub.module.TopLevelClass.NestedClass"] == KnownImport( - import_path="sub.module", import_name="TopLevelClass" + assert types["sub.module.TopLevelClass.NestedClass"] == PyImport( + from_="sub.module", import_="TopLevelClass" ) @pytest.mark.parametrize( @@ -81,9 +81,7 @@ def test_type_alias(self, module_factory, src): assert prefixes == {} assert len(types) == 1 assert types == { - "sub.module.alias_name": KnownImport( - import_path="sub.module", import_name="alias_name" - ) + "sub.module.alias_name": PyImport(from_="sub.module", import_="alias_name") } @pytest.mark.parametrize( @@ -115,18 +113,16 @@ def test_from_import(self, module_factory): assert prefixes == {} assert types == { - "calendar.gregorian": KnownImport( - import_path="calendar", import_name="gregorian" + "calendar.gregorian": PyImport(from_="calendar", import_="gregorian"), + "calendar.gregorian.August": PyImport( + from_="calendar.gregorian", import_="August" ), - "calendar.gregorian.August": KnownImport( - import_path="calendar.gregorian", import_name="August" + "calendar.gregorian.December": PyImport( + from_="calendar.gregorian", import_="December" ), - "calendar.gregorian.December": KnownImport( - import_path="calendar.gregorian", import_name="December" - ), - "sub.module:gregorian": KnownImport(builtin_name="gregorian"), - "sub.module:Aug": KnownImport(builtin_name="Aug"), - "sub.module:December": KnownImport(builtin_name="December"), + "sub.module:gregorian": PyImport(implicit="sub.module:gregorian"), + "sub.module:Aug": PyImport(implicit="sub.module:Aug"), + "sub.module:December": PyImport(implicit="sub.module:December"), } def test_relative_import(self, module_factory): @@ -141,10 +137,10 @@ def test_relative_import(self, module_factory): types, prefixes = TypeCollector.collect(file=module_path) assert prefixes == {} assert types == { - "sub.module:January": KnownImport(builtin_name="January"), - "sub.module:Aug": KnownImport(builtin_name="Aug"), - "sub.module:December": KnownImport(builtin_name="December"), - "sub.module:September": KnownImport(builtin_name="September"), + "sub.module:January": PyImport(implicit="sub.module:January"), + "sub.module:Aug": PyImport(implicit="sub.module:Aug"), + "sub.module:December": PyImport(implicit="sub.module:December"), + "sub.module:September": PyImport(implicit="sub.module:September"), } def test_imports(self, module_factory): @@ -161,24 +157,24 @@ def test_imports(self, module_factory): assert types == {} assert len(prefixes) == 3 assert prefixes == { - "sub.module:calendar": KnownImport(builtin_name="calendar"), - "sub.module:dr": KnownImport(builtin_name="dr"), - "sub.module:greg": KnownImport(builtin_name="greg"), + "sub.module:calendar": PyImport(implicit="sub.module:calendar"), + "sub.module:dr": PyImport(implicit="sub.module:dr"), + "sub.module:greg": PyImport(implicit="sub.module:greg"), } class Test_TypeMatcher: type_prefixes = { # noqa: RUF012 - "np": KnownImport(import_name="numpy", import_alias="np"), - "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), + "np": PyImport(import_="numpy", as_="np"), + "foo.bar.Baz": PyImport(from_="foo.bar", import_="Baz"), } types = { # noqa: RUF012 - "dict": KnownImport(builtin_name="dict"), - "foo.bar": KnownImport(import_path="foo", import_name="bar"), - "foo.bar.Baz": KnownImport(import_path="foo.bar", import_name="Baz"), - "foo.bar.Baz.Bix": KnownImport(import_path="foo.bar", import_name="Baz"), - "foo.bar.Baz.Qux": KnownImport(import_path="foo", import_name="bar"), + "dict": PyImport(implicit="dict"), + "foo.bar": PyImport(from_="foo", import_="bar"), + "foo.bar.Baz": PyImport(from_="foo.bar", import_="Baz"), + "foo.bar.Baz.Bix": PyImport(from_="foo.bar", import_="Baz"), + "foo.bar.Baz.Qux": PyImport(from_="foo", import_="bar"), } # fmt: off @@ -212,16 +208,16 @@ class Test_TypeMatcher: def test_query_types(self, search_name, expected_name, expected_origin): db = TypeMatcher(types=self.types.copy()) - type_name, type_origin = db.match(search_name) + type_name, py_import = db.match(search_name) if expected_name is None and expected_origin is None: assert expected_name is type_name - assert expected_origin is type_origin + assert expected_origin is py_import else: assert type_name is not None - assert type_origin is not None - assert str(type_origin) == expected_origin - assert type_name.startswith(type_origin.target) + assert py_import is not None + assert str(py_import) == expected_origin + assert type_name.startswith(py_import.target) assert type_name == expected_name # fmt: on @@ -240,16 +236,16 @@ def test_query_types(self, search_name, expected_name, expected_origin): def test_query_prefix(self, search_name, expected_name, expected_origin): db = TypeMatcher(type_prefixes=self.type_prefixes.copy()) - type_name, type_origin = db.match(search_name) + type_name, py_import = db.match(search_name) if expected_name is None and expected_origin is None: assert expected_name is type_name - assert expected_origin is type_origin + assert expected_origin is py_import else: assert type_name is not None - assert type_origin is not None - assert str(type_origin) == expected_origin - assert type_name.startswith(type_origin.target) + assert py_import is not None + assert str(py_import) == expected_origin + assert type_name.startswith(py_import.target) assert type_name == expected_name # fmt: on @@ -264,8 +260,8 @@ def test_query_prefix(self, search_name, expected_name, expected_origin): ) def test_common_known_types(self, search_name, import_path): matcher = TypeMatcher() - type_name, type_origin = matcher.match(search_name) + type_name, py_import = matcher.match(search_name) assert type_name == search_name.split(".")[-1] - assert type_origin is not None - assert type_origin.import_path == import_path + assert py_import is not None + assert py_import.from_ == import_path diff --git a/tests/test_docstrings.py b/tests/test_docstrings.py index 057fa6d..bade1be 100644 --- a/tests/test_docstrings.py +++ b/tests/test_docstrings.py @@ -3,7 +3,7 @@ import lark import pytest -from docstub._analysis import KnownImport +from docstub._analysis import PyImport from docstub._docstrings import Annotation, DocstringAnnotations, DoctypeTransformer @@ -11,20 +11,18 @@ class Test_Annotation: def test_str(self): annotation = Annotation( value="Path", - imports=frozenset({KnownImport(import_name="Path", import_path="pathlib")}), + imports=frozenset({PyImport(import_="Path", from_="pathlib")}), ) assert str(annotation) == annotation.value def test_as_return_tuple(self): path_anno = Annotation( value="Path", - imports=frozenset({KnownImport(import_name="Path", import_path="pathlib")}), + imports=frozenset({PyImport(import_="Path", from_="pathlib")}), ) sequence_anno = Annotation( value="Sequence", - imports=frozenset( - {KnownImport(import_name="Sequence", import_path="collections.abc")} - ), + imports=frozenset({PyImport(import_="Sequence", from_="collections.abc")}), ) return_annotation = Annotation.many_as_tuple([path_anno, sequence_anno]) assert return_annotation.value == "tuple[Path, Sequence]" @@ -229,9 +227,7 @@ def test_unknown_name(self): annotation, unknown_names = transformer.doctype_to_annotation("a") assert annotation.value == "a" assert annotation.imports == { - KnownImport( - import_name="Incomplete", import_path="_typeshed", import_alias="a" - ) + PyImport(import_="Incomplete", from_="_typeshed", as_="a") } assert unknown_names == [("a", 0, 1)] @@ -241,9 +237,7 @@ def test_unknown_qualname(self): annotation, unknown_names = transformer.doctype_to_annotation("a.b") assert annotation.value == "a_b" assert annotation.imports == { - KnownImport( - import_name="Incomplete", import_path="_typeshed", import_alias="a_b" - ) + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b") } assert unknown_names == [("a.b", 0, 3)] @@ -253,12 +247,8 @@ def test_multiple_unknown_names(self): annotation, unknown_names = transformer.doctype_to_annotation("a.b of c") assert annotation.value == "a_b[c]" assert annotation.imports == { - KnownImport( - import_name="Incomplete", import_path="_typeshed", import_alias="a_b" - ), - KnownImport( - import_name="Incomplete", import_path="_typeshed", import_alias="c" - ), + PyImport(import_="Incomplete", from_="_typeshed", as_="a_b"), + PyImport(import_="Incomplete", from_="_typeshed", as_="c"), } assert unknown_names == [("a.b", 0, 3), ("c", 7, 8)] @@ -294,9 +284,7 @@ def test_parameters(self, doctype, expected): assert len(annotations.parameters) == 2 assert annotations.parameters["a"].value == expected assert annotations.parameters["b"].value == "Incomplete" - assert annotations.parameters["b"].imports == { - KnownImport.typeshed_Incomplete() - } + assert annotations.parameters["b"].imports == {PyImport.typeshed_Incomplete()} @pytest.mark.parametrize( ("doctypes", "expected"), @@ -333,7 +321,7 @@ def test_yields(self, caplog): assert annotations.returns is not None assert annotations.returns.value == "Generator[tuple[int, str]]" assert annotations.returns.imports == { - KnownImport(import_path="collections.abc", import_name="Generator") + PyImport(from_="collections.abc", import_="Generator") } def test_receives(self, caplog): @@ -358,7 +346,7 @@ def test_receives(self, caplog): == "Generator[tuple[int, str], tuple[float, bytes]]" ) assert annotations.returns.imports == { - KnownImport(import_path="collections.abc", import_name="Generator") + PyImport(from_="collections.abc", import_="Generator") } def test_full_generator(self, caplog): @@ -386,7 +374,7 @@ def test_full_generator(self, caplog): "Generator[tuple[int, str], tuple[float, bytes], bool]" ) assert annotations.returns.imports == { - KnownImport(import_path="collections.abc", import_name="Generator") + PyImport(from_="collections.abc", import_="Generator") } def test_yields_and_returns(self, caplog): @@ -407,7 +395,7 @@ def test_yields_and_returns(self, caplog): assert annotations.returns is not None assert annotations.returns.value == ("Generator[tuple[int, str], None, bool]") assert annotations.returns.imports == { - KnownImport(import_path="collections.abc", import_name="Generator") + PyImport(from_="collections.abc", import_="Generator") } def test_duplicate_parameters(self, caplog): @@ -491,9 +479,5 @@ def test_combined_numpydoc_params(self): assert annotations.parameters["d"].value == "Incomplete" assert annotations.parameters["e"].value == "Incomplete" - assert annotations.parameters["d"].imports == { - KnownImport.typeshed_Incomplete() - } - assert annotations.parameters["e"].imports == { - KnownImport.typeshed_Incomplete() - } + assert annotations.parameters["d"].imports == {PyImport.typeshed_Incomplete()} + assert annotations.parameters["e"].imports == {PyImport.typeshed_Incomplete()} From e244b2e606c79156a06da4aff64aeda0874daf6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 23 Jul 2025 16:01:33 +0200 Subject: [PATCH 3/5] Test scoped matching --- tests/test_analysis.py | 54 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 145f46e..9f91f8c 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -265,3 +265,57 @@ def test_common_known_types(self, search_name, import_path): assert type_name == search_name.split(".")[-1] assert py_import is not None assert py_import.from_ == import_path + + def test_scoped_types(self, module_factory): + types = { + "sub.module:January": PyImport(implicit="sub.module:January"), + } + matcher = TypeMatcher(types=types) + + # Shouldn't match because the current module isn't set + type_name, py_import = matcher.match("January") + assert type_name is None + assert py_import is None + + # Set current module to something that doesn't match scope + module_path = module_factory(src="", module_name="other.module") + matcher.current_file = module_path + # Still shouldn't match because the current module doesn't match the scope + type_name, py_import = matcher.match("January") + assert type_name is None + assert py_import is None + + # Set current module to match the scope + module_path = module_factory(src="", module_name="sub.module") + matcher.current_file = module_path + # Now we should find the type + type_name, py_import = matcher.match("January") + assert type_name == "January" + assert py_import == PyImport(implicit="sub.module:January") + + def test_scoped_type_prefix(self, module_factory): + type_prefixes = { + "sub.module:cal": PyImport(implicit="sub.module:cal"), + } + matcher = TypeMatcher(type_prefixes=type_prefixes) + + # Shouldn't match because the current module isn't set + type_name, py_import = matcher.match("cal.January") + assert type_name is None + assert py_import is None + + # Set current module to something that doesn't match scope + module_path = module_factory(src="", module_name="other.module") + matcher.current_file = module_path + # Still shouldn't match because the current module doesn't match the scope + type_name, py_import = matcher.match("cal.January") + assert type_name is None + assert py_import is None + + # Set current module to match the scope + module_path = module_factory(src="", module_name="sub.module") + matcher.current_file = module_path + # Now we should find the prefix + type_name, py_import = matcher.match("cal.January") + assert type_name == "cal.January" + assert py_import == PyImport(implicit="sub.module:cal") From d1a84732559f00414f6a0decdf459598325db15d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 23 Jul 2025 16:25:58 +0200 Subject: [PATCH 4/5] Avoid collecting scoped import for "Literal" `py_import` can already be imported manually in the module's namespace. In which case `py_import` cannot be formatted as an import and shouldn't be collected. --- src/docstub/_docstrings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/docstub/_docstrings.py b/src/docstub/_docstrings.py index cf76085..24289fb 100644 --- a/src/docstub/_docstrings.py +++ b/src/docstub/_docstrings.py @@ -395,7 +395,7 @@ def natlang_literal(self, tree): if self.matcher is not None: _, py_import = self.matcher.match("Literal") - if py_import: + if py_import.has_import: self._collected_imports.add(py_import) return out From 945b59e50606a658188053d0f19505794671b22f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars=20Gr=C3=BCter?= Date: Wed, 23 Jul 2025 17:33:26 +0200 Subject: [PATCH 5/5] Handle nested type nicknames --- pyproject.toml | 4 +-- src/docstub/_analysis.py | 68 ++++++++++++++++++++++++++++------------ src/docstub/_cli.py | 3 ++ tests/test_analysis.py | 34 ++++++++++++++++++++ 4 files changed, 87 insertions(+), 22 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 834e9f9..57a0855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,8 +119,8 @@ run.source = ["docstub"] ".*maintenance.*" = "Maintenance" -[tool.docstub.type_nicknames] -Path = "pathlib.Path" +[tool.docstub.types] +Path = "pathlib" [tool.mypy] diff --git a/src/docstub/_analysis.py b/src/docstub/_analysis.py index d2bdba9..090ba9b 100644 --- a/src/docstub/_analysis.py +++ b/src/docstub/_analysis.py @@ -504,12 +504,39 @@ def __init__( self.current_file = None - def match(self, search_name): + def _resolve_nickname(self, name): + """Return intended name if `name` is a nickname. + + Parameters + ---------- + name : str + + Returns + ------- + resolved : str + """ + original = name + resolved = name + for _ in range(1000): + name = self.type_nicknames.get(name) + if name is None: + break + resolved = name + else: + logger.warning( + "reached limit while resolving nicknames for %r in %s, using %r", + original, + self.current_file or "", + resolved, + ) + return resolved + + def match(self, search): """Search for a known annotation name. Parameters ---------- - search_name : str + search : str current_module : Path, optional Returns @@ -517,14 +544,17 @@ def match(self, search_name): type_name : str | None py_import : PyImport | None """ + original_search = search type_name = None py_import = None module = module_name_from_path(self.current_file) if self.current_file else None - if search_name.startswith("~."): + search = self._resolve_nickname(search) + + if search.startswith("~."): # Sphinx like matching with abbreviated name - pattern = search_name.replace(".", r"\.") + pattern = search.replace(".", r"\.") pattern = pattern.replace("~", ".*") regex = re.compile(pattern + "$") # Might be slow, but works for now @@ -539,8 +569,9 @@ def match(self, search_name): py_import = matches[shortest_key] type_name = shortest_key logger.warning( - "%r in %s matches multiple types %r, using %r", - search_name, + "%r (original %r) in %s matches multiple types %r, using %r", + search, + original_search, self.current_file or "", matches.keys(), shortest_key, @@ -548,34 +579,31 @@ def match(self, search_name): elif len(matches) == 1: type_name, py_import = matches.popitem() else: - search_name = search_name[2:] + search = search[2:] logger.debug( "couldn't match %r in %s", - search_name, + search, self.current_file or "", ) - # Replace alias - search_name = self.type_nicknames.get(search_name, search_name) - if py_import is None and module: # Look for matching type in current module - py_import = self.types.get(f"{module}:{search_name}") - py_import = self.types.get(f"{module}.{search_name}", py_import) + py_import = self.types.get(f"{module}:{search}") + py_import = self.types.get(f"{module}.{search}", py_import) if py_import: - type_name = search_name + type_name = search - if py_import is None and search_name in self.types: - type_name = search_name - py_import = self.types[search_name] + if py_import is None and search in self.types: + type_name = search + py_import = self.types[search] if py_import is None: # Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a') - for partial_qualname in reversed(accumulate_qualname(search_name)): + for partial_qualname in reversed(accumulate_qualname(search)): py_import = self.type_prefixes.get(f"{module}:{partial_qualname}") py_import = self.type_prefixes.get(partial_qualname, py_import) if py_import: - type_name = search_name + type_name = search break if ( @@ -590,6 +618,6 @@ def match(self, search_name): if type_name is not None: self.successful_queries += 1 else: - self.unknown_qualnames.append(search_name) + self.unknown_qualnames.append(search) return type_name, py_import diff --git a/src/docstub/_cli.py b/src/docstub/_cli.py index 96328bc..0820c42 100644 --- a/src/docstub/_cli.py +++ b/src/docstub/_cli.py @@ -236,11 +236,14 @@ def run(root_path, out_dir, config_paths, ignore, group_errors, allow_errors, ve config = config.merge(Config(ignore_files=list(ignore))) types, type_prefixes = _collect_type_info(root_path, ignore=config.ignore_files) + + # Add declared types from configuration types |= { type_name: PyImport(from_=module, import_=type_name) for type_name, module in config.types.items() } + # Add declared type prefixes from configuration type_prefixes |= { prefix: ( PyImport(import_=module, as_=prefix) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 9f91f8c..1050178 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -319,3 +319,37 @@ def test_scoped_type_prefix(self, module_factory): type_name, py_import = matcher.match("cal.January") assert type_name == "cal.January" assert py_import == PyImport(implicit="sub.module:cal") + + def test_nested_nicknames(self, caplog): + types = { + "Foo": PyImport(implicit="Foo"), + "Bar": PyImport(implicit="Bar"), + } + type_nicknames = { + "Foo": "~.Baz", + "~.Baz": "B.i.k", + "B.i.k": "Bar", + } + matcher = TypeMatcher(types=types, type_nicknames=type_nicknames) + + type_name, py_import = matcher.match("Foo") + assert type_name == "Bar" + assert py_import == PyImport(implicit="Bar") + + def test_nickname_infinite_loop(self, caplog): + types = { + "Foo": PyImport(implicit="Foo"), + "Bar": PyImport(implicit="Bar"), + } + type_nicknames = { + "Foo": "Bar", + "Bar": "Foo", + } + matcher = TypeMatcher(types=types, type_nicknames=type_nicknames) + + type_name, py_import = matcher.match("Foo") + assert len(caplog.records) == 1 + assert "reached limit while resolving nicknames" in caplog.text + + assert type_name == "Foo" + assert py_import == PyImport(implicit="Foo")