From 5efb067bc38c7ec3b411b3ae79bb3ae569997574 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Fri, 6 Oct 2023 16:58:32 +0000 Subject: [PATCH] Validators refactor --- docs/python.rst | 16 +- openapi_spec_validator/__init__.py | 8 + openapi_spec_validator/__main__.py | 27 +- openapi_spec_validator/exceptions.py | 6 +- openapi_spec_validator/schemas/__init__.py | 10 + openapi_spec_validator/shortcuts.py | 44 +- openapi_spec_validator/validation/__init__.py | 67 ++- openapi_spec_validator/validation/caches.py | 65 +++ .../validation/decorators.py | 57 ++- openapi_spec_validator/validation/finders.py | 23 + openapi_spec_validator/validation/keywords.py | 424 ++++++++++++++++++ openapi_spec_validator/validation/proxies.py | 57 ++- .../validation/registries.py | 22 + openapi_spec_validator/validation/types.py | 5 + .../validation/validators.py | 421 +++++------------ tests/integration/conftest.py | 19 - tests/integration/test_main.py | 2 +- tests/integration/test_shortcuts.py | 29 +- .../integration/validation/test_exceptions.py | 66 ++- .../integration/validation/test_validators.py | 51 ++- 20 files changed, 941 insertions(+), 478 deletions(-) create mode 100644 openapi_spec_validator/validation/caches.py create mode 100644 openapi_spec_validator/validation/finders.py create mode 100644 openapi_spec_validator/validation/keywords.py create mode 100644 openapi_spec_validator/validation/registries.py create mode 100644 openapi_spec_validator/validation/types.py diff --git a/docs/python.rst b/docs/python.rst index bb13a26..f4b0d56 100644 --- a/docs/python.rst +++ b/docs/python.rst @@ -36,22 +36,22 @@ You can also validate spec from url: In order to explicitly validate a: -* Swagger / OpenAPI 2.0 spec, import ``openapi_v2_spec_validator`` -* OpenAPI 3.0 spec, import ``openapi_v30_spec_validator`` -* OpenAPI 3.1 spec, import ``openapi_v31_spec_validator`` +* Swagger / OpenAPI 2.0 spec, import ``OpenAPIV2SpecValidator`` +* OpenAPI 3.0 spec, import ``OpenAPIV30SpecValidator`` +* OpenAPI 3.1 spec, import ``OpenAPIV31SpecValidator`` -and pass the validator to ``validate_spec`` or ``validate_spec_url`` function: +and pass the validator class to ``validate_spec`` or ``validate_spec_url`` function: .. code:: python - validate_spec(spec_dict, validator=openapi_v31_spec_validator) + validate_spec(spec_dict, cls=OpenAPIV31SpecValidator) -You can also explicitly import ``openapi_v3_spec_validator`` which is a shortcut to the latest v3 release. +You can also explicitly import ``OpenAPIV3SpecValidator`` which is a shortcut to the latest v3 release. If you want to iterate through validation errors: .. code:: python - from openapi_spec_validator import openapi_v3_spec_validator + from openapi_spec_validator import OpenAPIV31SpecValidator - errors_iterator = openapi_v3_spec_validator.iter_errors(spec) + errors_iterator = OpenAPIV31SpecValidator(spec).iter_errors() diff --git a/openapi_spec_validator/__init__.py b/openapi_spec_validator/__init__.py index 105a2b1..b968e78 100644 --- a/openapi_spec_validator/__init__.py +++ b/openapi_spec_validator/__init__.py @@ -1,6 +1,10 @@ """OpenAPI spec validator module.""" from openapi_spec_validator.shortcuts import validate_spec from openapi_spec_validator.shortcuts import validate_spec_url +from openapi_spec_validator.validation import OpenAPIV2SpecValidator +from openapi_spec_validator.validation import OpenAPIV3SpecValidator +from openapi_spec_validator.validation import OpenAPIV30SpecValidator +from openapi_spec_validator.validation import OpenAPIV31SpecValidator from openapi_spec_validator.validation import openapi_v2_spec_validator from openapi_spec_validator.validation import openapi_v3_spec_validator from openapi_spec_validator.validation import openapi_v30_spec_validator @@ -17,6 +21,10 @@ "openapi_v3_spec_validator", "openapi_v30_spec_validator", "openapi_v31_spec_validator", + "OpenAPIV2SpecValidator", + "OpenAPIV3SpecValidator", + "OpenAPIV30SpecValidator", + "OpenAPIV31SpecValidator", "validate_spec", "validate_spec_url", ] diff --git a/openapi_spec_validator/__main__.py b/openapi_spec_validator/__main__.py index c058b76..8a1f26a 100644 --- a/openapi_spec_validator/__main__.py +++ b/openapi_spec_validator/__main__.py @@ -9,10 +9,10 @@ from openapi_spec_validator.readers import read_from_filename from openapi_spec_validator.readers import read_from_stdin -from openapi_spec_validator.validation import openapi_spec_validator_proxy -from openapi_spec_validator.validation import openapi_v2_spec_validator -from openapi_spec_validator.validation import openapi_v30_spec_validator -from openapi_spec_validator.validation import openapi_v31_spec_validator +from openapi_spec_validator.shortcuts import get_validator_cls +from openapi_spec_validator.validation import OpenAPIV2SpecValidator +from openapi_spec_validator.validation import OpenAPIV30SpecValidator +from openapi_spec_validator.validation import OpenAPIV31SpecValidator logger = logging.getLogger(__name__) logging.basicConfig( @@ -91,19 +91,22 @@ def main(args: Optional[Sequence[str]] = None) -> None: # choose the validator validators = { - "detect": openapi_spec_validator_proxy, - "2.0": openapi_v2_spec_validator, - "3.0": openapi_v30_spec_validator, - "3.1": openapi_v31_spec_validator, + "2.0": OpenAPIV2SpecValidator, + "3.0": OpenAPIV30SpecValidator, + "3.1": OpenAPIV31SpecValidator, # backward compatibility - "3.0.0": openapi_v30_spec_validator, - "3.1.0": openapi_v31_spec_validator, + "3.0.0": OpenAPIV30SpecValidator, + "3.1.0": OpenAPIV31SpecValidator, } - validator = validators[args_parsed.schema] + if args_parsed.schema == "detect": + validator_cls = get_validator_cls(spec) + else: + validator_cls = validators[args_parsed.schema] + validator = validator_cls(spec, base_uri=base_uri) # validate try: - validator.validate(spec, base_uri=base_uri) + validator.validate() except ValidationError as exc: print_validationerror(filename, exc, args_parsed.errors) sys.exit(1) diff --git a/openapi_spec_validator/exceptions.py b/openapi_spec_validator/exceptions.py index 6a62f4e..bcfc17a 100644 --- a/openapi_spec_validator/exceptions.py +++ b/openapi_spec_validator/exceptions.py @@ -1,2 +1,6 @@ -class OpenAPISpecValidatorError(Exception): +class OpenAPIError(Exception): + pass + + +class OpenAPISpecValidatorError(OpenAPIError): pass diff --git a/openapi_spec_validator/schemas/__init__.py b/openapi_spec_validator/schemas/__init__.py index ec1b287..8141788 100644 --- a/openapi_spec_validator/schemas/__init__.py +++ b/openapi_spec_validator/schemas/__init__.py @@ -1,6 +1,8 @@ """OpenAIP spec validator schemas module.""" from functools import partial +from jsonschema.validators import Draft4Validator +from jsonschema.validators import Draft202012Validator from lazy_object_proxy import Proxy from openapi_spec_validator.schemas.utils import get_schema_content @@ -17,3 +19,11 @@ # alias to the latest v3 version schema_v3 = schema_v31 + +get_openapi_v2_schema_validator = partial(Draft4Validator, schema_v2) +get_openapi_v30_schema_validator = partial(Draft4Validator, schema_v30) +get_openapi_v31_schema_validator = partial(Draft202012Validator, schema_v31) + +openapi_v2_schema_validator = Proxy(get_openapi_v2_schema_validator) +openapi_v30_schema_validator = Proxy(get_openapi_v30_schema_validator) +openapi_v31_schema_validator = Proxy(get_openapi_v31_schema_validator) diff --git a/openapi_spec_validator/shortcuts.py b/openapi_spec_validator/shortcuts.py index 121411f..77ee0d8 100644 --- a/openapi_spec_validator/shortcuts.py +++ b/openapi_spec_validator/shortcuts.py @@ -1,27 +1,55 @@ """OpenAPI spec validator shortcuts module.""" -from typing import Any -from typing import Hashable +import warnings from typing import Mapping from typing import Optional +from typing import Type from jsonschema_spec.handlers import all_urls_handler +from jsonschema_spec.typing import Schema -from openapi_spec_validator.validation import openapi_spec_validator_proxy +from openapi_spec_validator.validation import OpenAPIV2SpecValidator +from openapi_spec_validator.validation import OpenAPIV30SpecValidator +from openapi_spec_validator.validation import OpenAPIV31SpecValidator +from openapi_spec_validator.validation.finders import SpecFinder +from openapi_spec_validator.validation.finders import SpecVersion from openapi_spec_validator.validation.protocols import SupportsValidation +from openapi_spec_validator.validation.types import SpecValidatorType +from openapi_spec_validator.validation.validators import SpecValidator + +SPECS: Mapping[SpecVersion, SpecValidatorType] = { + SpecVersion("swagger", "2.0"): OpenAPIV2SpecValidator, + SpecVersion("openapi", "3.0"): OpenAPIV30SpecValidator, + SpecVersion("openapi", "3.1"): OpenAPIV31SpecValidator, +} + + +def get_validator_cls(spec: Schema) -> SpecValidatorType: + return SpecFinder(SPECS).find(spec) def validate_spec( - spec: Mapping[Hashable, Any], + spec: Schema, base_uri: str = "", - validator: SupportsValidation = openapi_spec_validator_proxy, + validator: Optional[SupportsValidation] = None, + cls: Optional[SpecValidatorType] = None, spec_url: Optional[str] = None, ) -> None: - return validator.validate(spec, base_uri=base_uri, spec_url=spec_url) + if validator is not None: + warnings.warn( + "validator parameter is deprecated. Use cls instead.", + DeprecationWarning, + ) + return validator.validate(spec, base_uri=base_uri, spec_url=spec_url) + if cls is None: + cls = get_validator_cls(spec) + v = cls(spec) + return v.validate() def validate_spec_url( spec_url: str, - validator: SupportsValidation = openapi_spec_validator_proxy, + validator: Optional[SupportsValidation] = None, + cls: Optional[Type[SpecValidator]] = None, ) -> None: spec = all_urls_handler(spec_url) - return validator.validate(spec, base_uri=spec_url) + return validate_spec(spec, base_uri=spec_url, validator=validator, cls=cls) diff --git a/openapi_spec_validator/validation/__init__.py b/openapi_spec_validator/validation/__init__.py index a889b96..3450616 100644 --- a/openapi_spec_validator/validation/__init__.py +++ b/openapi_spec_validator/validation/__init__.py @@ -1,19 +1,12 @@ -from functools import partial - -from jsonschema.validators import Draft4Validator -from jsonschema.validators import Draft202012Validator -from jsonschema_spec.handlers import default_handlers -from lazy_object_proxy import Proxy -from openapi_schema_validator import oas30_format_checker -from openapi_schema_validator import oas31_format_checker -from openapi_schema_validator.validators import OAS30Validator -from openapi_schema_validator.validators import OAS31Validator - -from openapi_spec_validator.schemas import schema_v2 -from openapi_spec_validator.schemas import schema_v30 -from openapi_spec_validator.schemas import schema_v31 from openapi_spec_validator.validation.proxies import DetectValidatorProxy -from openapi_spec_validator.validation.validators import SpecValidator +from openapi_spec_validator.validation.proxies import SpecValidatorProxy +from openapi_spec_validator.validation.validators import OpenAPIV2SpecValidator +from openapi_spec_validator.validation.validators import ( + OpenAPIV30SpecValidator, +) +from openapi_spec_validator.validation.validators import ( + OpenAPIV31SpecValidator, +) __all__ = [ "openapi_v2_spec_validator", @@ -21,46 +14,36 @@ "openapi_v30_spec_validator", "openapi_v31_spec_validator", "openapi_spec_validator_proxy", + "OpenAPIV2SpecValidator", + "OpenAPIV3SpecValidator", + "OpenAPIV30SpecValidator", + "OpenAPIV31SpecValidator", ] # v2.0 spec -get_openapi_v2_schema_validator = partial(Draft4Validator, schema_v2) -openapi_v2_schema_validator = Proxy(get_openapi_v2_schema_validator) -get_openapi_v2_spec_validator = partial( - SpecValidator, - openapi_v2_schema_validator, - OAS30Validator, - oas30_format_checker, - resolver_handlers=default_handlers, +openapi_v2_spec_validator = SpecValidatorProxy( + OpenAPIV2SpecValidator, + deprecated="openapi_v2_spec_validator", + use="OpenAPIV2SpecValidator", ) -openapi_v2_spec_validator = Proxy(get_openapi_v2_spec_validator) # v3.0 spec -get_openapi_v30_schema_validator = partial(Draft4Validator, schema_v30) -openapi_v30_schema_validator = Proxy(get_openapi_v30_schema_validator) -get_openapi_v30_spec_validator = partial( - SpecValidator, - openapi_v30_schema_validator, - OAS30Validator, - oas30_format_checker, - resolver_handlers=default_handlers, +openapi_v30_spec_validator = SpecValidatorProxy( + OpenAPIV30SpecValidator, + deprecated="openapi_v30_spec_validator", + use="OpenAPIV30SpecValidator", ) -openapi_v30_spec_validator = Proxy(get_openapi_v30_spec_validator) # v3.1 spec -get_openapi_v31_schema_validator = partial(Draft202012Validator, schema_v31) -openapi_v31_schema_validator = Proxy(get_openapi_v31_schema_validator) -get_openapi_v31_spec_validator = partial( - SpecValidator, - openapi_v31_schema_validator, - OAS31Validator, - oas31_format_checker, - resolver_handlers=default_handlers, +openapi_v31_spec_validator = SpecValidatorProxy( + OpenAPIV31SpecValidator, + deprecated="openapi_v31_spec_validator", + use="OpenAPIV31SpecValidator", ) -openapi_v31_spec_validator = Proxy(get_openapi_v31_spec_validator) # alias to the latest v3 version openapi_v3_spec_validator = openapi_v31_spec_validator +OpenAPIV3SpecValidator = OpenAPIV31SpecValidator # detect version spec openapi_spec_validator_proxy = DetectValidatorProxy( diff --git a/openapi_spec_validator/validation/caches.py b/openapi_spec_validator/validation/caches.py new file mode 100644 index 0000000..acc6b36 --- /dev/null +++ b/openapi_spec_validator/validation/caches.py @@ -0,0 +1,65 @@ +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import TypeVar + +T = TypeVar("T") + + +class CachedIterable(Iterable[T], Generic[T]): + """ + A cache-implementing wrapper for an iterator. + Note that this is class is `Iterable[T]` rather than `Iterator[T]`. + It should not be iterated by his own. + """ + + cache: List[T] + iter: Iterator[T] + completed: bool + + def __init__(self, it: Iterator[T]): + self.iter = iter(it) + self.cache = list() + self.completed = False + + def __iter__(self) -> Iterator[T]: + return CachedIterator(self) + + def __next__(self) -> T: + try: + item = next(self.iter) + except StopIteration: + self.completed = True + raise + else: + self.cache.append(item) + return item + + def __del__(self) -> None: + del self.cache + + +class CachedIterator(Iterator[T], Generic[T]): + """ + A cache-using wrapper for an iterator. + This class is only constructed by `CachedIterable` and cannot be used without it. + """ + + parent: CachedIterable[T] + position: int + + def __init__(self, parent: CachedIterable[T]): + self.parent = parent + self.position = 0 + + def __next__(self) -> T: + if self.position < len(self.parent.cache): + item = self.parent.cache[self.position] + elif self.parent.completed: + raise StopIteration + else: + item = next(self.parent) + + self.position += 1 + return item diff --git a/openapi_spec_validator/validation/decorators.py b/openapi_spec_validator/validation/decorators.py index 988b3b8..191c035 100644 --- a/openapi_spec_validator/validation/decorators.py +++ b/openapi_spec_validator/validation/decorators.py @@ -3,27 +3,54 @@ from functools import wraps from typing import Any from typing import Callable +from typing import Iterable from typing import Iterator -from typing import Type +from typing import TypeVar from jsonschema.exceptions import ValidationError +from openapi_spec_validator.validation.caches import CachedIterable +from openapi_spec_validator.validation.exceptions import OpenAPIValidationError + +Args = TypeVar("Args") +T = TypeVar("T") + log = logging.getLogger(__name__) -class ValidationErrorWrapper: - def __init__(self, error_class: Type[ValidationError]): - self.error_class = error_class +def wraps_errors( + func: Callable[..., Any] +) -> Callable[..., Iterator[ValidationError]]: + @wraps(func) + def wrapper(*args: Any, **kwds: Any) -> Iterator[ValidationError]: + errors = func(*args, **kwds) + for err in errors: + if not isinstance(err, OpenAPIValidationError): + # wrap other exceptions with library specific version + yield OpenAPIValidationError.create_from(err) + else: + yield err + + return wrapper + + +def wraps_cached_iter( + func: Callable[[Args], Iterator[T]] +) -> Callable[[Args], CachedIterable[T]]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> CachedIterable[T]: + result = func(*args, **kwargs) + return CachedIterable(result) + + return wrapper + - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - @wraps(f) - def wrapper(*args: Any, **kwds: Any) -> Iterator[ValidationError]: - errors = f(*args, **kwds) - for err in errors: - if not isinstance(err, self.error_class): - # wrap other exceptions with library specific version - yield self.error_class.create_from(err) - else: - yield err +def unwraps_iter( + func: Callable[[Args], Iterable[T]] +) -> Callable[[Args], Iterator[T]]: + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Iterator[T]: + result = func(*args, **kwargs) + return iter(result) - return wrapper + return wrapper diff --git a/openapi_spec_validator/validation/finders.py b/openapi_spec_validator/validation/finders.py new file mode 100644 index 0000000..74d2573 --- /dev/null +++ b/openapi_spec_validator/validation/finders.py @@ -0,0 +1,23 @@ +from typing import Mapping +from typing import NamedTuple + +from jsonschema_spec.typing import Schema + +from openapi_spec_validator.validation.exceptions import ValidatorDetectError +from openapi_spec_validator.validation.types import SpecValidatorType + + +class SpecVersion(NamedTuple): + name: str + version: str + + +class SpecFinder: + def __init__(self, specs: Mapping[SpecVersion, SpecValidatorType]) -> None: + self.specs = specs + + def find(self, spec: Schema) -> SpecValidatorType: + for v, classes in self.specs.items(): + if v.name in spec and spec[v.name].startswith(v.version): + return classes + raise ValidatorDetectError("Spec schema version not detected") diff --git a/openapi_spec_validator/validation/keywords.py b/openapi_spec_validator/validation/keywords.py new file mode 100644 index 0000000..125d3ad --- /dev/null +++ b/openapi_spec_validator/validation/keywords.py @@ -0,0 +1,424 @@ +import string +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterator +from typing import List +from typing import Optional +from typing import cast + +from jsonschema._format import FormatChecker +from jsonschema.exceptions import ValidationError +from jsonschema.protocols import Validator +from jsonschema_spec.paths import SchemaPath +from openapi_schema_validator import oas30_format_checker +from openapi_schema_validator import oas31_format_checker +from openapi_schema_validator.validators import OAS30Validator +from openapi_schema_validator.validators import OAS31Validator + +from openapi_spec_validator.validation.exceptions import ( + DuplicateOperationIDError, +) +from openapi_spec_validator.validation.exceptions import ExtraParametersError +from openapi_spec_validator.validation.exceptions import ( + ParameterDuplicateError, +) +from openapi_spec_validator.validation.exceptions import ( + UnresolvableParameterError, +) + +if TYPE_CHECKING: + from openapi_spec_validator.validation.registries import ( + KeywordValidatorRegistry, + ) + + +class KeywordValidator: + def __init__(self, registry: "KeywordValidatorRegistry"): + self.registry = registry + + +class ValueValidator(KeywordValidator): + value_validator_cls: Validator = NotImplemented + value_validator_format_checker: FormatChecker = NotImplemented + + def __call__( + self, schema: SchemaPath, value: Any + ) -> Iterator[ValidationError]: + with schema.resolve() as resolved: + value_validator = self.value_validator_cls( + resolved.contents, + _resolver=resolved.resolver, + format_checker=self.value_validator_format_checker, + ) + yield from value_validator.iter_errors(value) + + +class OpenAPIV30ValueValidator(ValueValidator): + value_validator_cls = OAS30Validator + value_validator_format_checker = oas30_format_checker + + +class OpenAPIV31ValueValidator(ValueValidator): + value_validator_cls = OAS31Validator + value_validator_format_checker = oas31_format_checker + + +class SchemaValidator(KeywordValidator): + def __init__(self, registry: "KeywordValidatorRegistry"): + super().__init__(registry) + + self.schema_ids_registry: Optional[List[int]] = [] + + @property + def default_validator(self) -> ValueValidator: + return cast(ValueValidator, self.registry["default"]) + + def __call__( + self, schema: SchemaPath, require_properties: bool = True + ) -> Iterator[ValidationError]: + if not hasattr(schema.content(), "__getitem__"): + return + + assert self.schema_ids_registry is not None + schema_id = id(schema.content()) + if schema_id in self.schema_ids_registry: + return + self.schema_ids_registry.append(schema_id) + + nested_properties = [] + if "allOf" in schema: + all_of = schema / "allOf" + for inner_schema in all_of: + yield from self( + inner_schema, + require_properties=False, + ) + if "properties" not in inner_schema: + continue + inner_schema_props = inner_schema / "properties" + inner_schema_props_keys = inner_schema_props.keys() + nested_properties += list(inner_schema_props_keys) + + if "anyOf" in schema: + any_of = schema / "anyOf" + for inner_schema in any_of: + yield from self( + inner_schema, + require_properties=False, + ) + + if "oneOf" in schema: + one_of = schema / "oneOf" + for inner_schema in one_of: + yield from self( + inner_schema, + require_properties=False, + ) + + if "not" in schema: + not_schema = schema / "not" + yield from self( + not_schema, + require_properties=False, + ) + + if "items" in schema: + array_schema = schema / "items" + yield from self( + array_schema, + require_properties=False, + ) + + if "properties" in schema: + props = schema / "properties" + for _, prop_schema in props.items(): + yield from self( + prop_schema, + require_properties=False, + ) + + required = schema.getkey("required", []) + properties = schema.get("properties", {}).keys() + if "allOf" in schema: + extra_properties = list( + set(required) - set(properties) - set(nested_properties) + ) + else: + extra_properties = list(set(required) - set(properties)) + + if extra_properties and require_properties: + yield ExtraParametersError( + f"Required list has not defined properties: {extra_properties}" + ) + + if "default" in schema: + default = schema["default"] + nullable = schema.get("nullable", False) + if default is not None or nullable is not True: + yield from self.default_validator(schema, default) + + +class SchemasValidator(KeywordValidator): + @property + def schema_validator(self) -> SchemaValidator: + return cast(SchemaValidator, self.registry["schema"]) + + def __call__(self, schemas: SchemaPath) -> Iterator[ValidationError]: + for _, schema in schemas.items(): + yield from self.schema_validator(schema) + + +class ParameterValidator(KeywordValidator): + @property + def schema_validator(self) -> SchemaValidator: + return cast(SchemaValidator, self.registry["schema"]) + + def __call__(self, parameter: SchemaPath) -> Iterator[ValidationError]: + if "schema" in parameter: + schema = parameter / "schema" + yield from self.schema_validator(schema) + + +class OpenAPIV2ParameterValidator(ParameterValidator): + @property + def default_validator(self) -> ValueValidator: + return cast(ValueValidator, self.registry["default"]) + + def __call__(self, parameter: SchemaPath) -> Iterator[ValidationError]: + yield from super().__call__(parameter) + + if "default" in parameter: + # only possible in swagger 2.0 + default = parameter.getkey("default") + if default is not None: + yield from self.default_validator(parameter, default) + + +class ParametersValidator(KeywordValidator): + @property + def parameter_validator(self) -> ParameterValidator: + return cast(ParameterValidator, self.registry["parameter"]) + + def __call__(self, parameters: SchemaPath) -> Iterator[ValidationError]: + seen = set() + for parameter in parameters: + yield from self.parameter_validator(parameter) + + key = (parameter["name"], parameter["in"]) + if key in seen: + yield ParameterDuplicateError( + f"Duplicate parameter `{parameter['name']}`" + ) + seen.add(key) + + +class MediaTypeValidator(KeywordValidator): + @property + def schema_validator(self) -> SchemaValidator: + return cast(SchemaValidator, self.registry["schema"]) + + def __call__( + self, mimetype: str, media_type: SchemaPath + ) -> Iterator[ValidationError]: + if "schema" in media_type: + schema = media_type / "schema" + yield from self.schema_validator(schema) + + +class ContentValidator(KeywordValidator): + @property + def media_type_validator(self) -> MediaTypeValidator: + return cast(MediaTypeValidator, self.registry["mediaType"]) + + def __call__(self, content: SchemaPath) -> Iterator[ValidationError]: + for mimetype, media_type in content.items(): + yield from self.media_type_validator(mimetype, media_type) + + +class ResponseValidator(KeywordValidator): + def __call__( + self, response_code: str, response: SchemaPath + ) -> Iterator[ValidationError]: + raise NotImplementedError + + +class OpenAPIV2ResponseValidator(ResponseValidator): + @property + def schema_validator(self) -> SchemaValidator: + return cast(SchemaValidator, self.registry["schema"]) + + def __call__( + self, response_code: str, response: SchemaPath + ) -> Iterator[ValidationError]: + # openapi 2 + if "schema" in response: + schema = response / "schema" + yield from self.schema_validator(schema) + + +class OpenAPIV3ResponseValidator(ResponseValidator): + @property + def content_validator(self) -> ContentValidator: + return cast(ContentValidator, self.registry["content"]) + + def __call__( + self, response_code: str, response: SchemaPath + ) -> Iterator[ValidationError]: + # openapi 3 + if "content" in response: + content = response / "content" + yield from self.content_validator(content) + + +class ResponsesValidator(KeywordValidator): + @property + def response_validator(self) -> ResponseValidator: + return cast(ResponseValidator, self.registry["response"]) + + def __call__(self, responses: SchemaPath) -> Iterator[ValidationError]: + for response_code, response in responses.items(): + yield from self.response_validator(response_code, response) + + +class OperationValidator(KeywordValidator): + def __init__(self, registry: "KeywordValidatorRegistry"): + super().__init__(registry) + + self.operation_ids_registry: Optional[List[str]] = [] + + @property + def responses_validator(self) -> ResponsesValidator: + return cast(ResponsesValidator, self.registry["responses"]) + + @property + def parameters_validator(self) -> ParametersValidator: + return cast(ParametersValidator, self.registry["parameters"]) + + def __call__( + self, + url: str, + name: str, + operation: SchemaPath, + path_parameters: Optional[SchemaPath], + ) -> Iterator[ValidationError]: + assert self.operation_ids_registry is not None + + operation_id = operation.getkey("operationId") + if ( + operation_id is not None + and operation_id in self.operation_ids_registry + ): + yield DuplicateOperationIDError( + f"Operation ID '{operation_id}' for '{name}' in '{url}' is not unique" + ) + self.operation_ids_registry.append(operation_id) + + if "responses" in operation: + responses = operation / "responses" + yield from self.responses_validator(responses) + + names = [] + + parameters = None + if "parameters" in operation: + parameters = operation / "parameters" + yield from self.parameters_validator(parameters) + names += list(self._get_path_param_names(parameters)) + + if path_parameters is not None: + names += list(self._get_path_param_names(path_parameters)) + + all_params = list(set(names)) + + for path in self._get_path_params_from_url(url): + if path not in all_params: + yield UnresolvableParameterError( + "Path parameter '{}' for '{}' operation in '{}' " + "was not resolved".format(path, name, url) + ) + return + + def _get_path_param_names(self, params: SchemaPath) -> Iterator[str]: + for param in params: + if param["in"] == "path": + yield param["name"] + + def _get_path_params_from_url(self, url: str) -> Iterator[str]: + formatter = string.Formatter() + path_params = [item[1] for item in formatter.parse(url)] + return filter(None, path_params) + + +class PathValidator(KeywordValidator): + OPERATIONS = [ + "get", + "put", + "post", + "delete", + "options", + "head", + "patch", + "trace", + ] + + @property + def parameters_validator(self) -> ParametersValidator: + return cast(ParametersValidator, self.registry["parameters"]) + + @property + def operation_validator(self) -> OperationValidator: + return cast(OperationValidator, self.registry["operation"]) + + def __call__( + self, url: str, path_item: SchemaPath + ) -> Iterator[ValidationError]: + parameters = None + if "parameters" in path_item: + parameters = path_item / "parameters" + yield from self.parameters_validator(parameters) + + for field_name, operation in path_item.items(): + if field_name not in self.OPERATIONS: + continue + + yield from self.operation_validator( + url, field_name, operation, parameters + ) + + +class PathsValidator(KeywordValidator): + @property + def path_validator(self) -> PathValidator: + return cast(PathValidator, self.registry["path"]) + + def __call__(self, paths: SchemaPath) -> Iterator[ValidationError]: + for url, path_item in paths.items(): + yield from self.path_validator(url, path_item) + + +class ComponentsValidator(KeywordValidator): + @property + def schemas_validator(self) -> SchemasValidator: + return cast(SchemasValidator, self.registry["schemas"]) + + def __call__(self, components: SchemaPath) -> Iterator[ValidationError]: + schemas = components.get("schemas", {}) + yield from self.schemas_validator(schemas) + + +class RootValidator(KeywordValidator): + @property + def paths_validator(self) -> PathsValidator: + return cast(PathsValidator, self.registry["paths"]) + + @property + def components_validator(self) -> ComponentsValidator: + return cast(ComponentsValidator, self.registry["components"]) + + def __call__(self, spec: SchemaPath) -> Iterator[ValidationError]: + if "paths" in spec: + paths = spec / "paths" + yield from self.paths_validator(paths) + if "components" in spec: + components = spec / "components" + yield from self.components_validator(components) diff --git a/openapi_spec_validator/validation/proxies.py b/openapi_spec_validator/validation/proxies.py index 372c6bf..1ab7185 100644 --- a/openapi_spec_validator/validation/proxies.py +++ b/openapi_spec_validator/validation/proxies.py @@ -1,4 +1,5 @@ """OpenAPI spec validator validation proxies module.""" +import warnings from typing import Any from typing import Hashable from typing import Iterator @@ -6,16 +7,62 @@ from typing import Optional from typing import Tuple +from jsonschema.exceptions import ValidationError +from jsonschema_spec.typing import Schema + from openapi_spec_validator.validation.exceptions import OpenAPIValidationError from openapi_spec_validator.validation.exceptions import ValidatorDetectError -from openapi_spec_validator.validation.validators import SpecValidator +from openapi_spec_validator.validation.types import SpecValidatorType + + +class SpecValidatorProxy: + def __init__( + self, + cls: SpecValidatorType, + deprecated: str = "SpecValidator", + use: Optional[str] = None, + ): + self.cls = cls + + self.deprecated = deprecated + self.use = use or self.cls.__name__ + + def validate( + self, + schema: Schema, + base_uri: str = "", + spec_url: Optional[str] = None, + ) -> None: + for err in self.iter_errors( + schema, + base_uri=base_uri, + spec_url=spec_url, + ): + raise err + + def is_valid(self, schema: Schema) -> bool: + error = next(self.iter_errors(schema), None) + return error is None + + def iter_errors( + self, + schema: Schema, + base_uri: str = "", + spec_url: Optional[str] = None, + ) -> Iterator[ValidationError]: + warnings.warn( + f"{self.deprecated} is deprecated. Use {self.use} instead.", + DeprecationWarning, + ) + validator = self.cls(schema, base_uri=base_uri, spec_url=spec_url) + return validator.iter_errors() class DetectValidatorProxy: - def __init__(self, choices: Mapping[Tuple[str, str], SpecValidator]): + def __init__(self, choices: Mapping[Tuple[str, str], SpecValidatorProxy]): self.choices = choices - def detect(self, instance: Mapping[Hashable, Any]) -> SpecValidator: + def detect(self, instance: Mapping[Hashable, Any]) -> SpecValidatorProxy: for (key, value), validator in self.choices.items(): if key in instance and instance[key].startswith(value): return validator @@ -44,6 +91,10 @@ def iter_errors( base_uri: str = "", spec_url: Optional[str] = None, ) -> Iterator[OpenAPIValidationError]: + warnings.warn( + "openapi_spec_validator_proxy is deprecated.", + DeprecationWarning, + ) validator = self.detect(instance) yield from validator.iter_errors( instance, base_uri=base_uri, spec_url=spec_url diff --git a/openapi_spec_validator/validation/registries.py b/openapi_spec_validator/validation/registries.py new file mode 100644 index 0000000..b9ddc5e --- /dev/null +++ b/openapi_spec_validator/validation/registries.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from typing import DefaultDict +from typing import Mapping +from typing import Type + +from openapi_spec_validator.validation.keywords import KeywordValidator + + +class KeywordValidatorRegistry(DefaultDict[str, KeywordValidator]): + def __init__( + self, keyword_validators: Mapping[str, Type[KeywordValidator]] + ): + super().__init__() + self.keyword_validators = keyword_validators + + def __missing__(self, keyword: str) -> KeywordValidator: + if keyword not in self.keyword_validators: + raise KeyError(keyword) + cls = self.keyword_validators[keyword] + self[keyword] = cls(self) + return self[keyword] diff --git a/openapi_spec_validator/validation/types.py b/openapi_spec_validator/validation/types.py new file mode 100644 index 0000000..90d83ba --- /dev/null +++ b/openapi_spec_validator/validation/types.py @@ -0,0 +1,5 @@ +from typing import Type + +from openapi_spec_validator.validation.validators import SpecValidator + +SpecValidatorType = Type[SpecValidator] diff --git a/openapi_spec_validator/validation/validators.py b/openapi_spec_validator/validation/validators.py index e7aa299..a789d16 100644 --- a/openapi_spec_validator/validation/validators.py +++ b/openapi_spec_validator/validation/validators.py @@ -1,352 +1,153 @@ """OpenAPI spec validator validation validators module.""" import logging -import string import warnings -from typing import Any +from functools import lru_cache from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import Type +from typing import cast -from jsonschema._format import FormatChecker from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema_spec.handlers import default_handlers from jsonschema_spec.paths import SchemaPath -from jsonschema_spec.typing import ResolverHandlers from jsonschema_spec.typing import Schema -from openapi_spec_validator.validation.decorators import ValidationErrorWrapper -from openapi_spec_validator.validation.exceptions import ( - DuplicateOperationIDError, -) -from openapi_spec_validator.validation.exceptions import ExtraParametersError -from openapi_spec_validator.validation.exceptions import OpenAPIValidationError -from openapi_spec_validator.validation.exceptions import ( - ParameterDuplicateError, -) -from openapi_spec_validator.validation.exceptions import ( - UnresolvableParameterError, +from openapi_spec_validator.schemas import openapi_v2_schema_validator +from openapi_spec_validator.schemas import openapi_v30_schema_validator +from openapi_spec_validator.schemas import openapi_v31_schema_validator +from openapi_spec_validator.validation import keywords +from openapi_spec_validator.validation.decorators import unwraps_iter +from openapi_spec_validator.validation.decorators import wraps_cached_iter +from openapi_spec_validator.validation.decorators import wraps_errors +from openapi_spec_validator.validation.registries import ( + KeywordValidatorRegistry, ) log = logging.getLogger(__name__) -wraps_errors = ValidationErrorWrapper(OpenAPIValidationError) - class SpecValidator: - OPERATIONS = [ - "get", - "put", - "post", - "delete", - "options", - "head", - "patch", - "trace", - ] + resolver_handlers = default_handlers + keyword_validators: Mapping[str, Type[keywords.KeywordValidator]] = { + "__root__": keywords.RootValidator, + } + root_keywords: List[str] = [] + schema_validator: Validator = NotImplemented def __init__( - self, - schema_validator: Validator, - value_validator_class: Type[Validator], - value_validator_format_checker: FormatChecker, - resolver_handlers: ResolverHandlers = default_handlers, - ): - self.schema_validator = schema_validator - self.value_validator_class = value_validator_class - self.value_validator_format_checker = value_validator_format_checker - self.resolver_handlers = resolver_handlers - - self.operation_ids_registry: Optional[List[str]] = None - self.schema_ids_registry: Optional[List[int]] = None - - def validate( self, schema: Schema, base_uri: str = "", spec_url: Optional[str] = None, ) -> None: - for err in self.iter_errors( - schema, - base_uri=base_uri, - spec_url=spec_url, - ): - raise err - - def is_valid(self, schema: Schema) -> bool: - error = next(self.iter_errors(schema), None) - return error is None - - @wraps_errors - def iter_errors( - self, - schema: Schema, - base_uri: str = "", - spec_url: Optional[str] = None, - ) -> Iterator[ValidationError]: + self.schema = schema if spec_url is not None: warnings.warn( "spec_url parameter is deprecated. " "Use base_uri instead.", DeprecationWarning, ) base_uri = spec_url + self.base_uri = base_uri - self.operation_ids_registry = [] - self.schema_ids_registry = [] - - yield from self.schema_validator.iter_errors(schema) - - spec = SchemaPath.from_dict( - schema, - base_uri=base_uri, + self.spec = SchemaPath.from_dict( + self.schema, + base_uri=self.base_uri, handlers=self.resolver_handlers, ) - if "paths" in spec: - paths = spec / "paths" - yield from self._iter_paths_errors(paths) - - if "components" in spec: - components = spec / "components" - yield from self._iter_components_errors(components) - - def _iter_paths_errors( - self, paths: SchemaPath - ) -> Iterator[ValidationError]: - for url, path_item in paths.items(): - yield from self._iter_path_errors(url, path_item) - - def _iter_path_errors( - self, url: str, path_item: SchemaPath - ) -> Iterator[ValidationError]: - parameters = None - if "parameters" in path_item: - parameters = path_item / "parameters" - yield from self._iter_parameters_errors(parameters) - - for field_name, operation in path_item.items(): - if field_name not in self.OPERATIONS: - continue - - yield from self._iter_operation_errors( - url, field_name, operation, parameters - ) - - def _iter_operation_errors( - self, - url: str, - name: str, - operation: SchemaPath, - path_parameters: Optional[SchemaPath], - ) -> Iterator[ValidationError]: - assert self.operation_ids_registry is not None - - operation_id = operation.getkey("operationId") - if ( - operation_id is not None - and operation_id in self.operation_ids_registry - ): - yield DuplicateOperationIDError( - f"Operation ID '{operation_id}' for '{name}' in '{url}' is not unique" - ) - self.operation_ids_registry.append(operation_id) - - if "responses" in operation: - responses = operation / "responses" - yield from self._iter_responses_errors(responses) - - names = [] - - parameters = None - if "parameters" in operation: - parameters = operation / "parameters" - yield from self._iter_parameters_errors(parameters) - names += list(self._get_path_param_names(parameters)) - - if path_parameters is not None: - names += list(self._get_path_param_names(path_parameters)) - - all_params = list(set(names)) - for path in self._get_path_params_from_url(url): - if path not in all_params: - yield UnresolvableParameterError( - "Path parameter '{}' for '{}' operation in '{}' " - "was not resolved".format(path, name, url) - ) - return - - def _iter_responses_errors( - self, responses: SchemaPath - ) -> Iterator[ValidationError]: - for response_code, response in responses.items(): - yield from self._iter_response_errors(response_code, response) - - def _iter_response_errors( - self, response_code: str, response: SchemaPath - ) -> Iterator[ValidationError]: - # openapi 2 - if "schema" in response: - schema = response / "schema" - yield from self._iter_schema_errors(schema) - # openapi 3 - if "content" in response: - content = response / "content" - yield from self._iter_content_errors(content) - - def _iter_content_errors( - self, content: SchemaPath - ) -> Iterator[ValidationError]: - for mimetype, media_type in content.items(): - yield from self._iter_media_type_errors(mimetype, media_type) - - def _iter_media_type_errors( - self, mimetype: str, media_type: SchemaPath - ) -> Iterator[ValidationError]: - if "schema" in media_type: - schema = media_type / "schema" - yield from self._iter_schema_errors(schema) - - def _get_path_param_names(self, params: SchemaPath) -> Iterator[str]: - for param in params: - if param["in"] == "path": - yield param["name"] - - def _get_path_params_from_url(self, url: str) -> Iterator[str]: - formatter = string.Formatter() - path_params = [item[1] for item in formatter.parse(url)] - return filter(None, path_params) - - def _iter_parameters_errors( - self, parameters: SchemaPath - ) -> Iterator[ValidationError]: - seen = set() - for parameter in parameters: - yield from self._iter_parameter_errors(parameter) - - key = (parameter["name"], parameter["in"]) - if key in seen: - yield ParameterDuplicateError( - f"Duplicate parameter `{parameter['name']}`" - ) - seen.add(key) - - def _iter_parameter_errors( - self, parameter: SchemaPath - ) -> Iterator[ValidationError]: - if "schema" in parameter: - schema = parameter / "schema" - yield from self._iter_schema_errors(schema) - - if "default" in parameter: - # only possible in swagger 2.0 - default = parameter.getkey("default") - if default is not None: - yield from self._iter_value_errors(parameter, default) - - def _iter_value_errors( - self, schema: SchemaPath, value: Any - ) -> Iterator[ValidationError]: - with schema.resolve() as resolved: - validator = self.value_validator_class( - resolved.contents, - _resolver=resolved.resolver, - format_checker=self.value_validator_format_checker, - ) - yield from validator.iter_errors(value) - - def _iter_schema_errors( - self, schema: SchemaPath, require_properties: bool = True - ) -> Iterator[ValidationError]: - if not hasattr(schema.content(), "__getitem__"): - return - - assert self.schema_ids_registry is not None - schema_id = id(schema.content()) - if schema_id in self.schema_ids_registry: - return - self.schema_ids_registry.append(schema_id) - - nested_properties = [] - if "allOf" in schema: - all_of = schema / "allOf" - for inner_schema in all_of: - yield from self._iter_schema_errors( - inner_schema, - require_properties=False, - ) - if "properties" not in inner_schema: - continue - inner_schema_props = inner_schema / "properties" - inner_schema_props_keys = inner_schema_props.keys() - nested_properties += list(inner_schema_props_keys) - - if "anyOf" in schema: - any_of = schema / "anyOf" - for inner_schema in any_of: - yield from self._iter_schema_errors( - inner_schema, - require_properties=False, - ) - - if "oneOf" in schema: - one_of = schema / "oneOf" - for inner_schema in one_of: - yield from self._iter_schema_errors( - inner_schema, - require_properties=False, - ) - - if "not" in schema: - not_schema = schema / "not" - yield from self._iter_schema_errors( - not_schema, - require_properties=False, - ) - - if "items" in schema: - array_schema = schema / "items" - yield from self._iter_schema_errors( - array_schema, - require_properties=False, - ) - - if "properties" in schema: - props = schema / "properties" - for _, prop_schema in props.items(): - yield from self._iter_schema_errors( - prop_schema, - require_properties=False, - ) + self.keyword_validators_registry = KeywordValidatorRegistry( + self.keyword_validators + ) - required = schema.getkey("required", []) - properties = schema.get("properties", {}).keys() - if "allOf" in schema: - extra_properties = list( - set(required) - set(properties) - set(nested_properties) - ) - else: - extra_properties = list(set(required) - set(properties)) + def validate(self) -> None: + for err in self.iter_errors(): + raise err - if extra_properties and require_properties: - yield ExtraParametersError( - f"Required list has not defined properties: {extra_properties}" - ) + def is_valid(self) -> bool: + error = next(self.iter_errors(), None) + return error is None - if "default" in schema: - default = schema["default"] - nullable = schema.get("nullable", False) - if default is not None or nullable is not True: - yield from self._iter_value_errors(schema, default) + @property + def root_validator(self) -> keywords.RootValidator: + return cast( + keywords.RootValidator, + self.keyword_validators_registry["__root__"], + ) - def _iter_components_errors( - self, components: SchemaPath - ) -> Iterator[ValidationError]: - schemas = components.get("schemas", {}) - yield from self._iter_schemas_errors(schemas) + @unwraps_iter + @lru_cache(maxsize=None) + @wraps_cached_iter + @wraps_errors + def iter_errors(self) -> Iterator[ValidationError]: + yield from self.schema_validator.iter_errors(self.schema) - def _iter_schemas_errors( - self, schemas: SchemaPath - ) -> Iterator[ValidationError]: - for _, schema in schemas.items(): - yield from self._iter_schema_errors(schema) + spec = SchemaPath.from_dict( + self.schema, + base_uri=self.base_uri, + handlers=self.resolver_handlers, + ) + yield from self.root_validator(spec) + + +class OpenAPIV2SpecValidator(SpecValidator): + schema_validator = openapi_v2_schema_validator + keyword_validators = { + "__root__": keywords.RootValidator, + "components": keywords.ComponentsValidator, + "default": keywords.OpenAPIV30ValueValidator, + "operation": keywords.OperationValidator, + "parameter": keywords.OpenAPIV2ParameterValidator, + "parameters": keywords.ParametersValidator, + "paths": keywords.PathsValidator, + "path": keywords.PathValidator, + "response": keywords.OpenAPIV2ResponseValidator, + "responses": keywords.ResponsesValidator, + "schema": keywords.SchemaValidator, + "schemas": keywords.SchemasValidator, + } + root_keywords = ["paths", "components"] + + +class OpenAPIV30SpecValidator(SpecValidator): + schema_validator = openapi_v30_schema_validator + keyword_validators = { + "__root__": keywords.RootValidator, + "components": keywords.ComponentsValidator, + "content": keywords.ContentValidator, + "default": keywords.OpenAPIV30ValueValidator, + "mediaType": keywords.MediaTypeValidator, + "operation": keywords.OperationValidator, + "parameter": keywords.ParameterValidator, + "parameters": keywords.ParametersValidator, + "paths": keywords.PathsValidator, + "path": keywords.PathValidator, + "response": keywords.OpenAPIV3ResponseValidator, + "responses": keywords.ResponsesValidator, + "schema": keywords.SchemaValidator, + "schemas": keywords.SchemasValidator, + } + root_keywords = ["paths", "components"] + + +class OpenAPIV31SpecValidator(SpecValidator): + schema_validator = openapi_v31_schema_validator + keyword_validators = { + "__root__": keywords.RootValidator, + "components": keywords.ComponentsValidator, + "content": keywords.ContentValidator, + "default": keywords.OpenAPIV31ValueValidator, + "mediaType": keywords.MediaTypeValidator, + "operation": keywords.OperationValidator, + "parameter": keywords.ParameterValidator, + "parameters": keywords.ParametersValidator, + "paths": keywords.PathsValidator, + "path": keywords.PathValidator, + "response": keywords.OpenAPIV3ResponseValidator, + "responses": keywords.ResponsesValidator, + "schema": keywords.SchemaValidator, + "schemas": keywords.SchemasValidator, + } + root_keywords = ["paths", "components"] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 2657e76..4f3cc08 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,10 +6,6 @@ from jsonschema_spec.handlers.file import FilePathHandler from jsonschema_spec.handlers.urllib import UrllibHandler -from openapi_spec_validator import openapi_v2_spec_validator -from openapi_spec_validator import openapi_v30_spec_validator -from openapi_spec_validator import openapi_v31_spec_validator - def spec_file_url(spec_file, schema="file"): directory = path.abspath(path.dirname(__file__)) @@ -40,18 +36,3 @@ def factory(): spec_from_file=spec_from_file, spec_from_url=spec_from_url, ) - - -@pytest.fixture -def validator_v2(): - return openapi_v2_spec_validator - - -@pytest.fixture -def validator_v30(): - return openapi_v30_spec_validator - - -@pytest.fixture -def validator_v31(): - return openapi_v31_spec_validator diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 73a8d55..1527a83 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -150,7 +150,7 @@ def test_validation_error(capsys): @mock.patch( - "openapi_spec_validator.__main__.openapi_v30_spec_validator.validate", + "openapi_spec_validator.__main__.OpenAPIV30SpecValidator.validate", side_effect=Exception, ) def test_unknown_error(m_validate, capsys): diff --git a/tests/integration/test_shortcuts.py b/tests/integration/test_shortcuts.py index 7c69ce1..37ebded 100644 --- a/tests/integration/test_shortcuts.py +++ b/tests/integration/test_shortcuts.py @@ -1,5 +1,7 @@ import pytest +from openapi_spec_validator import OpenAPIV2SpecValidator +from openapi_spec_validator import OpenAPIV30SpecValidator from openapi_spec_validator import openapi_v2_spec_validator from openapi_spec_validator import openapi_v30_spec_validator from openapi_spec_validator import validate_spec @@ -40,10 +42,11 @@ def local_test_suite_file_path(self, test_file): def test_valid(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) - spec_url = factory.spec_file_url(spec_path) validate_spec(spec) - validate_spec(spec, validator=openapi_v2_spec_validator) + validate_spec(spec, cls=OpenAPIV2SpecValidator) + with pytest.warns(DeprecationWarning): + validate_spec(spec, validator=openapi_v2_spec_validator) @pytest.mark.parametrize( "spec_file", @@ -56,7 +59,10 @@ def test_falied(self, factory, spec_file): spec = factory.spec_from_file(spec_path) with pytest.raises(OpenAPIValidationError): - validate_spec(spec, validator=openapi_v2_spec_validator) + validate_spec(spec, cls=OpenAPIV2SpecValidator) + with pytest.warns(DeprecationWarning): + with pytest.raises(OpenAPIValidationError): + validate_spec(spec, validator=openapi_v2_spec_validator) class TestLocalValidatev30Spec: @@ -78,7 +84,9 @@ def test_valid(self, factory, spec_file): validate_spec(spec) validate_spec(spec, spec_url=spec_url) - validate_spec(spec, validator=openapi_v30_spec_validator) + validate_spec(spec, cls=OpenAPIV30SpecValidator) + with pytest.warns(DeprecationWarning): + validate_spec(spec, validator=openapi_v30_spec_validator) @pytest.mark.parametrize( "spec_file", @@ -91,7 +99,10 @@ def test_falied(self, factory, spec_file): spec = factory.spec_from_file(spec_path) with pytest.raises(OpenAPIValidationError): - validate_spec(spec, validator=openapi_v30_spec_validator) + validate_spec(spec, cls=OpenAPIV30SpecValidator) + with pytest.warns(DeprecationWarning): + with pytest.raises(OpenAPIValidationError): + validate_spec(spec, validator=openapi_v30_spec_validator) @pytest.mark.network @@ -118,7 +129,9 @@ def test_valid(self, spec_file): spec_url = self.remote_test_suite_file_path(spec_file) validate_spec_url(spec_url) - validate_spec_url(spec_url, validator=openapi_v2_spec_validator) + validate_spec_url(spec_url, cls=OpenAPIV2SpecValidator) + with pytest.warns(DeprecationWarning): + validate_spec_url(spec_url, validator=openapi_v2_spec_validator) @pytest.mark.network @@ -145,4 +158,6 @@ def test_valid(self, spec_file): spec_url = self.remote_test_suite_file_path(spec_file) validate_spec_url(spec_url) - validate_spec_url(spec_url, validator=openapi_v30_spec_validator) + validate_spec_url(spec_url, cls=OpenAPIV30SpecValidator) + with pytest.warns(DeprecationWarning): + validate_spec_url(spec_url, validator=openapi_v30_spec_validator) diff --git a/tests/integration/validation/test_exceptions.py b/tests/integration/validation/test_exceptions.py index 129e0f1..687f85a 100644 --- a/tests/integration/validation/test_exceptions.py +++ b/tests/integration/validation/test_exceptions.py @@ -1,3 +1,5 @@ +from openapi_spec_validator import OpenAPIV2SpecValidator +from openapi_spec_validator import OpenAPIV30SpecValidator from openapi_spec_validator.validation.exceptions import ( DuplicateOperationIDError, ) @@ -9,10 +11,10 @@ class TestSpecValidatorIterErrors: - def test_empty(self, validator_v30): + def test_empty(self): spec = {} - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list[0].__class__ == OpenAPIValidationError @@ -22,20 +24,20 @@ def test_empty(self, validator_v30): assert errors_list[2].__class__ == OpenAPIValidationError assert errors_list[2].message == "'paths' is a required property" - def test_info_empty(self, validator_v30): + def test_info_empty(self): spec = { "openapi": "3.0.0", "info": {}, "paths": {}, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list[0].__class__ == OpenAPIValidationError assert errors_list[0].message == "'title' is a required property" - def test_minimalistic(self, validator_v30): + def test_minimalistic(self): spec = { "openapi": "3.0.0", "info": { @@ -45,12 +47,12 @@ def test_minimalistic(self, validator_v30): "paths": {}, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_same_parameters_names(self, validator_v30): + def test_same_parameters_names(self): spec = { "openapi": "3.0.0", "info": { @@ -80,12 +82,12 @@ def test_same_parameters_names(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_same_operation_ids(self, validator_v30): + def test_same_operation_ids(self): spec = { "openapi": "3.0.0", "info": { @@ -124,14 +126,14 @@ def test_same_operation_ids(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert len(errors_list) == 2 assert errors_list[0].__class__ == DuplicateOperationIDError assert errors_list[1].__class__ == DuplicateOperationIDError - def test_allow_allof_required_no_properties(self, validator_v30): + def test_allow_allof_required_no_properties(self): spec = { "openapi": "3.0.0", "info": { @@ -157,13 +159,11 @@ def test_allow_allof_required_no_properties(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_allow_allof_when_required_is_linked_to_the_parent_object( - self, validator_v30 - ): + def test_allow_allof_when_required_is_linked_to_the_parent_object(self): spec = { "openapi": "3.0.1", "info": { @@ -198,11 +198,11 @@ def test_allow_allof_when_required_is_linked_to_the_parent_object( }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_extra_parameters_in_required(self, validator_v30): + def test_extra_parameters_in_required(self): spec = { "openapi": "3.0.0", "info": { @@ -222,7 +222,7 @@ def test_extra_parameters_in_required(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list[0].__class__ == ExtraParametersError @@ -230,7 +230,7 @@ def test_extra_parameters_in_required(self, validator_v30): "Required list has not defined properties: ['testparam1']" ) - def test_undocumented_parameter(self, validator_v30): + def test_undocumented_parameter(self): spec = { "openapi": "3.0.0", "info": { @@ -260,7 +260,7 @@ def test_undocumented_parameter(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list[0].__class__ == UnresolvableParameterError @@ -269,7 +269,7 @@ def test_undocumented_parameter(self, validator_v30): "'/test/{param1}/{param2}' was not resolved" ) - def test_default_value_wrong_type(self, validator_v30): + def test_default_value_wrong_type(self): spec = { "openapi": "3.0.0", "info": { @@ -287,7 +287,7 @@ def test_default_value_wrong_type(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert len(errors_list) == 1 @@ -296,7 +296,7 @@ def test_default_value_wrong_type(self, validator_v30): "'invaldtype' is not of type 'integer'" ) - def test_parameter_default_value_wrong_type(self, validator_v30): + def test_parameter_default_value_wrong_type(self): spec = { "openapi": "3.0.0", "info": { @@ -327,7 +327,7 @@ def test_parameter_default_value_wrong_type(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert len(errors_list) == 1 @@ -336,7 +336,7 @@ def test_parameter_default_value_wrong_type(self, validator_v30): "'invaldtype' is not of type 'integer'" ) - def test_parameter_default_value_wrong_type_swagger(self, validator_v2): + def test_parameter_default_value_wrong_type_swagger(self): spec = { "swagger": "2.0", "info": { @@ -365,7 +365,7 @@ def test_parameter_default_value_wrong_type_swagger(self, validator_v2): }, } - errors = validator_v2.iter_errors(spec) + errors = OpenAPIV2SpecValidator(spec).iter_errors() errors_list = list(errors) assert len(errors_list) == 1 @@ -374,7 +374,7 @@ def test_parameter_default_value_wrong_type_swagger(self, validator_v2): "'invaldtype' is not of type 'integer'" ) - def test_parameter_default_value_with_reference(self, validator_v30): + def test_parameter_default_value_with_reference(self): spec = { "openapi": "3.0.0", "info": { @@ -415,12 +415,12 @@ def test_parameter_default_value_with_reference(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_parameter_custom_format_checker_not_found(self, validator_v30): + def test_parameter_custom_format_checker_not_found(self): spec = { "openapi": "3.0.0", "info": { @@ -451,14 +451,12 @@ def test_parameter_custom_format_checker_not_found(self, validator_v30): }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert errors_list == [] - def test_parameter_default_value_custom_format_invalid( - self, validator_v30 - ): + def test_parameter_default_value_custom_format_invalid(self): from openapi_schema_validator import oas30_format_checker @oas30_format_checker.checks("custom") @@ -495,7 +493,7 @@ def validate(to_validate) -> bool: }, } - errors = validator_v30.iter_errors(spec) + errors = OpenAPIV30SpecValidator(spec).iter_errors() errors_list = list(errors) assert len(errors_list) == 1 diff --git a/tests/integration/validation/test_validators.py b/tests/integration/validation/test_validators.py index aef830b..0ff61c5 100644 --- a/tests/integration/validation/test_validators.py +++ b/tests/integration/validation/test_validators.py @@ -1,6 +1,9 @@ import pytest from referencing.exceptions import Unresolvable +from openapi_spec_validator import OpenAPIV2SpecValidator +from openapi_spec_validator import OpenAPIV30SpecValidator +from openapi_spec_validator import OpenAPIV31SpecValidator from openapi_spec_validator.validation.exceptions import OpenAPIValidationError @@ -16,12 +19,15 @@ def local_test_suite_file_path(self, test_file): "petstore.yaml", ], ) - def test_valid(self, factory, validator_v2, spec_file): + def test_valid(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) + validator = OpenAPIV2SpecValidator(spec, base_uri=spec_url) - return validator_v2.validate(spec, spec_url=spec_url) + validator.validate() + + assert validator.is_valid() == True @pytest.mark.parametrize( "spec_file", @@ -29,13 +35,16 @@ def test_valid(self, factory, validator_v2, spec_file): "empty.yaml", ], ) - def test_validation_failed(self, factory, validator_v2, spec_file): + def test_validation_failed(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) + validator = OpenAPIV2SpecValidator(spec, base_uri=spec_url) with pytest.raises(OpenAPIValidationError): - validator_v2.validate(spec, spec_url=spec_url) + validator.validate() + + assert validator.is_valid() == False @pytest.mark.parametrize( "spec_file", @@ -43,13 +52,13 @@ def test_validation_failed(self, factory, validator_v2, spec_file): "missing-reference.yaml", ], ) - def test_ref_failed(self, factory, validator_v2, spec_file): + def test_ref_failed(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) with pytest.raises(Unresolvable): - validator_v2.validate(spec, spec_url=spec_url) + OpenAPIV2SpecValidator(spec, base_uri=spec_url).validate() class TestLocalOpenAPIv30Validator: @@ -68,12 +77,15 @@ def local_test_suite_file_path(self, test_file): "read-only-write-only.yaml", ], ) - def test_valid(self, factory, validator_v30, spec_file): + def test_valid(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) + validator = OpenAPIV30SpecValidator(spec, base_uri=spec_url) - return validator_v30.validate(spec, spec_url=spec_url) + validator.validate() + + assert validator.is_valid() == True @pytest.mark.parametrize( "spec_file", @@ -81,13 +93,16 @@ def test_valid(self, factory, validator_v30, spec_file): "empty.yaml", ], ) - def test_failed(self, factory, validator_v30, spec_file): + def test_failed(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) + validator = OpenAPIV30SpecValidator(spec, base_uri=spec_url) with pytest.raises(OpenAPIValidationError): - validator_v30.validate(spec, spec_url=spec_url) + validator.validate() + + assert validator.is_valid() == False @pytest.mark.parametrize( "spec_file", @@ -95,13 +110,13 @@ def test_failed(self, factory, validator_v30, spec_file): "property-missing-reference.yaml", ], ) - def test_ref_failed(self, factory, validator_v30, spec_file): + def test_ref_failed(self, factory, spec_file): spec_path = self.local_test_suite_file_path(spec_file) spec = factory.spec_from_file(spec_path) spec_url = factory.spec_file_url(spec_path) with pytest.raises(Unresolvable): - validator_v30.validate(spec, spec_url=spec_url) + OpenAPIV30SpecValidator(spec, base_uri=spec_url).validate() @pytest.mark.network @@ -124,11 +139,11 @@ def remote_test_suite_file_path(self, test_file): "api-with-examples.yaml", ], ) - def test_valid(self, factory, validator_v30, spec_file): + def test_valid(self, factory, spec_file): spec_url = self.remote_test_suite_file_path(spec_file) spec = factory.spec_from_url(spec_url) - return validator_v30.validate(spec, spec_url=spec_url) + OpenAPIV30SpecValidator(spec, base_uri=spec_url).validate() @pytest.mark.network @@ -159,13 +174,13 @@ def remote_test_suite_file_path(self, test_file): "valid_schema_types.yaml", ], ) - def test_valid(self, factory, validator_v31, spec_file): + def test_valid(self, factory, spec_file): spec_url = self.remote_test_suite_file_path( f"tests/v3.1/pass/{spec_file}" ) spec = factory.spec_from_url(spec_url) - return validator_v31.validate(spec, spec_url=spec_url) + OpenAPIV31SpecValidator(spec, base_uri=spec_url).validate() @pytest.mark.parametrize( "spec_file", @@ -177,11 +192,11 @@ def test_valid(self, factory, validator_v31, spec_file): "unknown_container.yaml", ], ) - def test_failed(self, factory, validator_v31, spec_file): + def test_failed(self, factory, spec_file): spec_url = self.remote_test_suite_file_path( f"tests/v3.1/fail/{spec_file}" ) spec = factory.spec_from_url(spec_url) with pytest.raises(OpenAPIValidationError): - validator_v31.validate(spec, spec_url=spec_url) + OpenAPIV31SpecValidator(spec, base_uri=spec_url).validate()