Skip to content

Commit 31b0413

Browse files
authored
stubtest: Improve heuristics for determining whether global-namespace names are imported (#14270)
Stubtest currently has both false-positives and false-negatives when it comes to verifying constants in the global namespace of a module. This PR fixes the false positive by using `inspect.getsourcelines()` to dynamically retrieve the module source code. It then uses `symtable` to analyse that source code to gather a list of names which are known to be imported. The PR fixes the false negative by only using the `__module__` heuristic on objects which are callable. The vast majority of callable objects will be types or functions. For these objects, the `__module__` attribute will give a good indication of whether the object originates from another module or not; for other objects, it's less useful.
1 parent 2514610 commit 31b0413

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

mypy/stubtest.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import pkgutil
1717
import re
18+
import symtable
1819
import sys
1920
import traceback
2021
import types
@@ -283,6 +284,36 @@ def _verify_exported_names(
283284
)
284285

285286

287+
def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None:
288+
"""Retrieve the names in the global namespace which are known to be imported.
289+
290+
1). Use inspect to retrieve the source code of the module
291+
2). Use symtable to parse the source and retrieve names that are known to be imported
292+
from other modules.
293+
294+
If either of the above steps fails, return `None`.
295+
296+
Note that if a set of names is returned,
297+
it won't include names imported via `from foo import *` imports.
298+
"""
299+
try:
300+
source = inspect.getsource(runtime)
301+
except (OSError, TypeError, SyntaxError):
302+
return None
303+
304+
if not source.strip():
305+
# The source code for the module was an empty file,
306+
# no point in parsing it with symtable
307+
return frozenset()
308+
309+
try:
310+
module_symtable = symtable.symtable(source, runtime.__name__, "exec")
311+
except SyntaxError:
312+
return None
313+
314+
return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported())
315+
316+
286317
@verify.register(nodes.MypyFile)
287318
def verify_mypyfile(
288319
stub: nodes.MypyFile, runtime: MaybeMissing[types.ModuleType], object_path: list[str]
@@ -312,15 +343,22 @@ def verify_mypyfile(
312343
if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
313344
}
314345

346+
imported_symbols = _get_imported_symbol_names(runtime)
347+
315348
def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
316349
obj = getattr(r, attr)
317-
try:
318-
obj_mod = getattr(obj, "__module__", None)
319-
except Exception:
350+
if isinstance(obj, types.ModuleType):
320351
return False
321-
if obj_mod is not None:
322-
return bool(obj_mod == r.__name__)
323-
return not isinstance(obj, types.ModuleType)
352+
if callable(obj):
353+
try:
354+
obj_mod = getattr(obj, "__module__", None)
355+
except Exception:
356+
return False
357+
if obj_mod is not None:
358+
return bool(obj_mod == r.__name__)
359+
if imported_symbols is not None:
360+
return attr not in imported_symbols
361+
return True
324362

325363
runtime_public_contents = (
326364
runtime_all_as_set

mypy/test/teststubtest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,9 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]:
10821082
yield Case(stub="", runtime="import sys", error=None)
10831083
yield Case(stub="", runtime="def g(): ...", error="g")
10841084
yield Case(stub="", runtime="CONSTANT = 0", error="CONSTANT")
1085+
yield Case(stub="", runtime="import re; constant = re.compile('foo')", error="constant")
1086+
yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None)
1087+
yield Case(stub="", runtime="from string import ascii_letters", error=None)
10851088

10861089
@collect_cases
10871090
def test_non_public_1(self) -> Iterator[Case]:

0 commit comments

Comments
 (0)