1"""
2Module containing all code related to json schema validation.
3"""
4
5import contextlib
6import io
7import json
8import os
9import typing as t
10import urllib.parse
11import urllib.request
12from collections.abc import Mapping
13from copy import deepcopy
14
15import requests
16import yaml
17from jsonschema import Draft4Validator
18from jsonschema.exceptions import ValidationError
19from jsonschema.validators import extend
20from referencing import Registry, Resource
21from referencing.jsonschema import DRAFT4
22
23from .utils import deep_get
24
25
26class ExtendedSafeLoader(yaml.SafeLoader):
27 """Extends the yaml SafeLoader to coerce all keys to string so the result is valid json."""
28
29 def __init__(self, stream):
30 self.original_construct_mapping = self.construct_mapping
31 self.construct_mapping = self.extended_construct_mapping
32 super().__init__(stream)
33
34 def extended_construct_mapping(self, node, deep=False):
35 data = self.original_construct_mapping(node, deep)
36 return {str(key): data[key] for key in data}
37
38
39class FileHandler:
40 """Handler to resolve file refs."""
41
42 def __call__(self, uri):
43 filepath = self._uri_to_path(uri)
44 with open(filepath) as fh:
45 return yaml.load(fh, ExtendedSafeLoader)
46
47 @staticmethod
48 def _uri_to_path(uri):
49 parsed = urllib.parse.urlparse(uri)
50 host = "{0}{0}{mnt}{0}".format(os.path.sep, mnt=parsed.netloc)
51 return os.path.abspath(
52 os.path.join(host, urllib.request.url2pathname(parsed.path))
53 )
54
55
56class URLHandler:
57 """Handler to resolve url refs."""
58
59 def __call__(self, uri):
60 response = requests.get(uri)
61 response.raise_for_status()
62
63 data = io.StringIO(response.text)
64 with contextlib.closing(data) as fh:
65 return yaml.load(fh, ExtendedSafeLoader)
66
67
68def resource_from_spec(spec: t.Dict[str, t.Any]) -> Resource:
69 """Create a `referencing.Resource` from a schema specification."""
70 return Resource.from_contents(spec, default_specification=DRAFT4)
71
72
73def retrieve(uri: str) -> Resource:
74 """Retrieve a resource from a URI.
75
76 This function is passed to the `referencing.Registry`,
77 which calls it whenever a URI not present in the registry is accessed."""
78 parsed = urllib.parse.urlsplit(uri)
79 if parsed.scheme in ("http", "https"):
80 content = URLHandler()(uri)
81 elif parsed.scheme in ("file", ""):
82 content = FileHandler()(uri)
83 else: # pragma: no cover
84 # Default branch from jsonschema.RefResolver.resolve_remote()
85 # for backwards compatibility.
86 with urllib.request.urlopen(uri) as url:
87 content = json.loads(url.read().decode("utf-8"))
88 return resource_from_spec(content)
89
90
91def resolve_refs(spec, store=None, base_uri=""):
92 """
93 Resolve JSON references like {"$ref": <some URI>} in a spec.
94 Optionally takes a store, which is a mapping from reference URLs to a
95 dereferenced objects. Prepopulating the store can avoid network calls.
96 """
97 spec = deepcopy(spec)
98 store = store or {}
99 registry = Registry(retrieve=retrieve).with_resources(
100 (
101 (base_uri, resource_from_spec(spec)),
102 *((key, resource_from_spec(value)) for key, value in store.items()),
103 )
104 )
105
106 def _do_resolve(node, resolver):
107 if isinstance(node, Mapping) and "$ref" in node:
108 path = node["$ref"][2:].split("/")
109 try:
110 # resolve known references
111 retrieved = deep_get(spec, path)
112 node.update(retrieved)
113 if isinstance(retrieved, Mapping) and "$ref" in retrieved:
114 node = _do_resolve(node, resolver)
115 node.pop("$ref", None)
116 return node
117 except KeyError:
118 # resolve external references
119 resolved = resolver.lookup(node["$ref"])
120 return _do_resolve(resolved.contents, resolved.resolver)
121 elif isinstance(node, Mapping):
122 for k, v in node.items():
123 node[k] = _do_resolve(v, resolver)
124 elif isinstance(node, (list, tuple)):
125 for i, _ in enumerate(node):
126 node[i] = _do_resolve(node[i], resolver)
127 return node
128
129 res = _do_resolve(spec, registry.resolver(base_uri))
130 return res
131
132
133def format_error_with_path(exception: ValidationError) -> str:
134 """Format a `ValidationError` with path to error."""
135 error_path = ".".join(str(item) for item in exception.path)
136 error_path_msg = f" - '{error_path}'" if error_path else ""
137 return error_path_msg
138
139
140def allow_nullable(validation_fn: t.Callable) -> t.Callable:
141 """Extend an existing validation function, so it allows nullable values to be null."""
142
143 def nullable_validation_fn(validator, to_validate, instance, schema):
144 if instance is None and (
145 schema.get("x-nullable") is True or schema.get("nullable")
146 ):
147 return
148
149 yield from validation_fn(validator, to_validate, instance, schema)
150
151 return nullable_validation_fn
152
153
154def validate_writeOnly(validator, wo, instance, schema):
155 yield ValidationError("Property is write-only")
156
157
158NullableTypeValidator = allow_nullable(Draft4Validator.VALIDATORS["type"])
159NullableEnumValidator = allow_nullable(Draft4Validator.VALIDATORS["enum"])
160
161Draft4RequestValidator = extend(
162 Draft4Validator,
163 {
164 "type": NullableTypeValidator,
165 "enum": NullableEnumValidator,
166 },
167)
168
169Draft4ResponseValidator = extend(
170 Draft4Validator,
171 {
172 "type": NullableTypeValidator,
173 "enum": NullableEnumValidator,
174 "writeOnly": validate_writeOnly,
175 "x-writeOnly": validate_writeOnly,
176 },
177)