Skip to content

Cache results of static analysis #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

import builtins
import collections.abc
import json
import logging
import re
import typing
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from pathlib import Path

import libcst as cst

from ._utils import accumulate_qualname, module_name_from_path
from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -260,6 +261,38 @@ def common_known_imports():


class TypeCollector(cst.CSTVisitor):
"""Collect types from a given Python file.

Examples
--------
>>> types = TypeCollector.collect(__file__)
>>> types[f"{__name__}.TypeCollector"]
<KnownImport 'from docstub._analysis import TypeCollector'>
"""

class ImportSerializer:
"""Implements the `FuncSerializer` protocol to cache `TypeCollector.collect`."""

suffix = ".json"
encoding = "utf-8"

def hash_args(self, path: Path) -> str:
"""Compute a unique hash from the path passed to `TypeCollector.collect`."""
key = pyfile_checksum(path)
return key

def serialize(self, data: dict[str, KnownImport]) -> bytes:
"""Serialize results from `TypeCollector.collect`."""
primitives = {qualname: asdict(imp) for qualname, imp in data.items()}
raw = json.dumps(primitives, separators=(",", ":")).encode(self.encoding)
return raw

def deserialize(self, raw: bytes) -> dict[str, KnownImport]:
"""Deserialize results from `TypeCollector.collect`."""
primitives = json.loads(raw.decode(self.encoding))
data = {qualname: KnownImport(**kw) for qualname, kw in primitives.items()}
return data

@classmethod
def collect(cls, file):
"""Collect importable type annotations in given file.
Expand Down
141 changes: 141 additions & 0 deletions src/docstub/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import logging
from functools import cached_property
from typing import Protocol

logger = logging.getLogger(__name__)


CACHEDIR_TAG_CONTENT = """\
Signature: 8a477f597d28d172789f06886806bc55\
# This file is a cache directory tag automatically created by docstub.\n"
# For information about cache directory tags see https://bford.info/cachedir/\n"
"""


def _directory_size(path):
"""Estimate total size of a directory's content in bytes.

Parameters
----------
path : Path

Returns
-------
total_bytes : int
Total size of all objects in bytes.
"""
if not path.is_dir():
msg = f"{path} doesn't exist, can't determine size"
raise FileNotFoundError(msg)
files = path.rglob("*")
total_bytes = sum(f.stat().st_size for f in files)
return total_bytes


def create_cache(path):
"""Create a cache directory.

Parameters
----------
path : Path
Directory of the cache. The directory and it's parents will be created if it
doesn't exist yet.
"""
path.mkdir(parents=True, exist_ok=True)
cachdir_tag_path = path / "CACHEDIR.TAG"
cachdir_tag_content = (
"Signature: 8a477f597d28d172789f06886806bc55\n"
"# This file is a cache directory tag automatically created by docstub.\n"
"# For information about cache directory tags see https://bford.info/cachedir/\n"
)
if not cachdir_tag_path.is_file():
with open(cachdir_tag_path, "w") as fp:
fp.write(cachdir_tag_content)

gitignore_path = path / ".gitignore"
gitignore_content = (
"# This file is a cache directory tag automatically created by docstub.\n" "*\n"
)
if not gitignore_path.is_file():
with open(gitignore_path, "w") as fp:
fp.write(gitignore_content)


class FuncSerializer[T](Protocol):
"""Defines an interface to serialize and deserialize results of a function.

This interface is used by `FileCache` to cache results of a

Attributes
----------
suffix :
A suffix corresponding to the format of the serialized data, e.g. ".json".
"""

suffix: str

def hash_args(self, *args, **kwargs) -> str:
"""Compute a unique hash from the arguments passed to a function."""

def serialize(self, data: T) -> bytes:
"""Serialize results of a function from `T` to bytes."""

def deserialize(self, raw: bytes) -> T:
"""Deserialize results of a function from bytes back to `T`."""


class FileCache:
"""Cache results from a function call as a files on disk.

This class can cache results of a function to the disk. A unique key is
generated from the arguments to the function, and the result is cached
inside a file named after this key.
"""

def __init__(self, *, func, serializer, cache_dir, name):
"""
Parameters
----------
func : callable
The function whose output shall be cached.
serializer : FuncSerializer
An interface that matches the given `func`. It must implement the
`FileCachIO` protocol.
cache_dir : Path
The directory of the cache.
name : str
A unique name to separate parallel caches inside `cache_dir`.
"""
self.func = func
self.serializer = serializer
self._cache_dir = cache_dir
self.name = name

@cached_property
def named_cache_dir(self):
"""Path to the named subdirectory inside the cache.

Warns when cache size exceeds 512 MiB.
"""
cache_dir = self._cache_dir
create_cache(cache_dir)
if _directory_size(cache_dir) > 512 * 1024**2:
logger.warning("cache size at %r exceeds 512 MiB", cache_dir)
_named_cache_dir = cache_dir / self.name
_named_cache_dir.mkdir(parents=True, exist_ok=True)
return _named_cache_dir

def __call__(self, *args, **kwargs):
"""Call the wrapped `func` and cache each result in a file."""
key = self.serializer.hash_args(*args, **kwargs)
entry_path = self.named_cache_dir / f"{key}{self.serializer.suffix}"
if entry_path.is_file():
with entry_path.open("rb") as fp:
raw = fp.read()
data = self.serializer.deserialize(raw)
else:
data = self.func(*args, **kwargs)
raw = self.serializer.serialize(data)
with entry_path.open("xb") as fp:
fp.write(raw)
return data
67 changes: 58 additions & 9 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import sys
import time
from contextlib import contextmanager
from pathlib import Path

import click
Expand All @@ -10,6 +12,7 @@
TypeCollector,
common_known_imports,
)
from ._cache import FileCache
from ._config import Config
from ._stubs import Py2StubTransformer, walk_source, walk_source_and_targets
from ._version import __version__
Expand All @@ -26,7 +29,7 @@ def _load_configuration(config_path=None):

Returns
-------
config : dict[str, Any]
config : ~.Config
"""
config = Config.from_toml(Config.DEFAULT_CONFIG_PATH)

Expand Down Expand Up @@ -65,6 +68,58 @@ def _setup_logging(*, verbose):
)


def _build_import_map(config, source_dir):
"""Build a map of known imports.

Parameters
----------
config : ~.Config
source_dir : Path

Returns
-------
imports : dict[str, ~.KnownImport]
"""
known_imports = common_known_imports()

collect_cached_types = FileCache(
func=TypeCollector.collect,
serializer=TypeCollector.ImportSerializer(),
cache_dir=Path.cwd() / ".docstub_cache",
name=f"{__version__}/collected_types",
)
for source_path in walk_source(source_dir):
logger.info("collecting types in %s", source_path)
known_imports_in_source = collect_cached_types(source_path)
known_imports.update(known_imports_in_source)

known_imports.update(KnownImport.many_from_config(config.known_imports))

return known_imports


@contextmanager
def report_execution_time():
start = time.time()
try:
yield
finally:
stop = time.time()
total_seconds = stop - start

hours, remainder = divmod(total_seconds, 3600)
minutes, seconds = divmod(remainder, 60)

formated_duration = f"{seconds:.3f} s"
if minutes:
formated_duration = f"{minutes} min {formated_duration}"
if hours:
formated_duration = f"{hours} h {formated_duration}"

click.echo()
click.echo(f"Finished in {formated_duration}")


@click.command()
@click.version_option(__version__)
@click.argument("source_dir", type=click.Path(exists=True, file_okay=False))
Expand All @@ -82,19 +137,13 @@ def _setup_logging(*, verbose):
)
@click.option("-v", "--verbose", count=True, help="Log more details.")
@click.help_option("-h", "--help")
@report_execution_time()
def main(source_dir, out_dir, config_path, verbose):
_setup_logging(verbose=verbose)

source_dir = Path(source_dir)
config = _load_configuration(config_path)

# Build map of known imports
known_imports = common_known_imports()
for source_path in walk_source(source_dir):
logger.info("collecting types in %s", source_path)
known_imports_in_source = TypeCollector.collect(source_path)
known_imports.update(known_imports_in_source)
known_imports.update(KnownImport.many_from_config(config.known_imports))
known_imports = _build_import_map(config, source_dir)

inspector = StaticInspector(
source_pkgs=[source_dir.parent.resolve()], known_imports=known_imports
Expand Down
23 changes: 23 additions & 0 deletions src/docstub/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import lru_cache
from pathlib import Path
from textwrap import indent
from zlib import crc32

import click

Expand Down Expand Up @@ -105,6 +106,28 @@ def module_name_from_path(path):
return name


def pyfile_checksum(path):
"""Compute a unique key for a Python file.

The key takes into account the given `path`, the relative position if the
file is part of a Python package and the file's content.

Parameters
----------
path : Path

Returns
-------
key : str
"""
module_name = module_name_from_path(path).encode()
absolute_path = str(path.resolve()).encode()
with open(path, "rb") as fp:
content = fp.read()
key = crc32(content + module_name + absolute_path)
return key


@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
class ContextFormatter:
"""Format messages in context of a location in a file.
Expand Down
Loading