1import copy
2import re
3import warnings
4
5from collections import OrderedDict
6
7from collections.abc import MutableMapping
8from werkzeug.utils import cached_property
9
10from .mask import Mask
11from .errors import abort
12
13from jsonschema import Draft4Validator
14from jsonschema.validators import validator_for
15from jsonschema.exceptions import ValidationError
16
17from .utils import not_none
18from ._http import HTTPStatus
19
20
21RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
22
23
24def instance(cls):
25 if isinstance(cls, type):
26 return cls()
27 return cls
28
29
30class ModelBase(object):
31 """
32 Handles validation and swagger style inheritance for both subclasses.
33 Subclass must define `schema` attribute.
34
35 :param str name: The model public name
36 """
37
38 def __init__(self, name, *args, **kwargs):
39 super(ModelBase, self).__init__(*args, **kwargs)
40 self.__apidoc__ = {"name": name}
41 self.name = name
42 self.__parents__ = []
43
44 def instance_inherit(name, *parents):
45 return self.__class__.inherit(name, self, *parents)
46
47 self.inherit = instance_inherit
48
49 @property
50 def ancestors(self):
51 """
52 Return the ancestors tree
53 """
54 ancestors = [p.ancestors for p in self.__parents__]
55 return set.union(set([self.name]), *ancestors)
56
57 def get_parent(self, name):
58 if self.name == name:
59 return self
60 else:
61 for parent in self.__parents__:
62 found = parent.get_parent(name)
63 if found:
64 return found
65 raise ValueError("Parent " + name + " not found")
66
67 @property
68 def __schema__(self):
69 schema = self._schema
70
71 if self.__parents__:
72 refs = [
73 {"$ref": "#/definitions/{0}".format(parent.name)}
74 for parent in self.__parents__
75 ]
76
77 return {"allOf": refs + [schema]}
78 else:
79 return schema
80
81 @classmethod
82 def inherit(cls, name, *parents):
83 """
84 Inherit this model (use the Swagger composition pattern aka. allOf)
85 :param str name: The new model name
86 :param dict fields: The new model extra fields
87 """
88 model = cls(name, parents[-1])
89 model.__parents__ = parents[:-1]
90 return model
91
92 def validate(self, data, resolver=None, format_checker=None):
93 # For backward compatibility, resolver can be either a RefResolver or a Registry
94 if resolver is not None and hasattr(resolver, "resolve"):
95 # Old RefResolver - convert to registry
96 registry = None
97 validator = Draft4Validator(
98 self.__schema__, resolver=resolver, format_checker=format_checker
99 )
100 else:
101 # New Registry or None
102 # If we have a registry, we need to create a schema that includes definitions
103 schema_to_validate = self.__schema__
104 if resolver is not None:
105 # Check if the schema has $ref that need to be resolved
106 import json
107
108 schema_str = json.dumps(self.__schema__)
109 if '"$ref"' in schema_str:
110 # Create a schema with inline definitions from the registry
111 definitions = {}
112 for uri in resolver:
113 resource = resolver[uri]
114 if isinstance(resource, dict) and "definitions" in resource:
115 definitions.update(resource["definitions"])
116
117 if definitions:
118 # Create a new schema that includes the definitions
119 schema_to_validate = {
120 "$id": "http://localhost/schema.json",
121 "definitions": definitions,
122 **self.__schema__,
123 }
124
125 ValidatorClass = validator_for(schema_to_validate)
126 if resolver is not None:
127 validator = ValidatorClass(
128 schema_to_validate, registry=resolver, format_checker=format_checker
129 )
130 else:
131 validator = ValidatorClass(
132 schema_to_validate, format_checker=format_checker
133 )
134
135 try:
136 validator.validate(data)
137 except ValidationError:
138 abort(
139 HTTPStatus.BAD_REQUEST,
140 message="Input payload validation failed",
141 errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
142 )
143
144 def format_error(self, error):
145 path = list(error.path)
146 if error.validator == "required":
147 name = RE_REQUIRED.match(error.message).group("name")
148 path.append(name)
149 key = ".".join(str(p) for p in path)
150 return key, error.message
151
152 def __unicode__(self):
153 return "Model({name},{{{fields}}})".format(
154 name=self.name, fields=",".join(self.keys())
155 )
156
157 __str__ = __unicode__
158
159
160class RawModel(ModelBase):
161 """
162 A thin wrapper on ordered fields dict to store API doc metadata.
163 Can also be used for response marshalling.
164
165 :param str name: The model public name
166 :param str mask: an optional default model mask
167 :param bool strict: validation should raise error when there is param not provided in schema
168 """
169
170 wrapper = dict
171
172 def __init__(self, name, *args, **kwargs):
173 self.__mask__ = kwargs.pop("mask", None)
174 self.__strict__ = kwargs.pop("strict", False)
175 if self.__mask__ and not isinstance(self.__mask__, Mask):
176 self.__mask__ = Mask(self.__mask__)
177 super(RawModel, self).__init__(name, *args, **kwargs)
178
179 def instance_clone(name, *parents):
180 return self.__class__.clone(name, self, *parents)
181
182 self.clone = instance_clone
183
184 @property
185 def _schema(self):
186 properties = self.wrapper()
187 required = set()
188 discriminator = None
189 for name, field in self.items():
190 field = instance(field)
191 properties[name] = field.__schema__
192 if field.required:
193 required.add(name)
194 if getattr(field, "discriminator", False):
195 discriminator = name
196
197 definition = {
198 "required": sorted(list(required)) or None,
199 "properties": properties,
200 "discriminator": discriminator,
201 "x-mask": str(self.__mask__) if self.__mask__ else None,
202 "type": "object",
203 }
204
205 if self.__strict__:
206 definition["additionalProperties"] = False
207
208 return not_none(definition)
209
210 @cached_property
211 def resolved(self):
212 """
213 Resolve real fields before submitting them to marshal
214 """
215 # Duplicate fields
216 resolved = copy.deepcopy(self)
217
218 # Recursively copy parent fields if necessary
219 for parent in self.__parents__:
220 resolved.update(parent.resolved)
221
222 # Handle discriminator
223 candidates = [f for f in resolved.values() if getattr(f, "discriminator", None)]
224 # Ensure the is only one discriminator
225 if len(candidates) > 1:
226 raise ValueError("There can only be one discriminator by schema")
227 # Ensure discriminator always output the model name
228 elif len(candidates) == 1:
229 candidates[0].default = self.name
230
231 return resolved
232
233 def extend(self, name, fields):
234 """
235 Extend this model (Duplicate all fields)
236
237 :param str name: The new model name
238 :param dict fields: The new model extra fields
239
240 :deprecated: since 0.9. Use :meth:`clone` instead.
241 """
242 warnings.warn(
243 "extend is is deprecated, use clone instead",
244 DeprecationWarning,
245 stacklevel=2,
246 )
247 if isinstance(fields, (list, tuple)):
248 return self.clone(name, *fields)
249 else:
250 return self.clone(name, fields)
251
252 @classmethod
253 def clone(cls, name, *parents):
254 """
255 Clone these models (Duplicate all fields)
256
257 It can be used from the class
258
259 >>> model = Model.clone(fields_1, fields_2)
260
261 or from an Instanciated model
262
263 >>> new_model = model.clone(fields_1, fields_2)
264
265 :param str name: The new model name
266 :param dict parents: The new model extra fields
267 """
268 fields = cls.wrapper()
269 for parent in parents:
270 fields.update(copy.deepcopy(parent))
271 return cls(name, fields)
272
273 def __deepcopy__(self, memo):
274 obj = self.__class__(
275 self.name,
276 [(key, copy.deepcopy(value, memo)) for key, value in self.items()],
277 mask=self.__mask__,
278 strict=self.__strict__,
279 )
280 obj.__parents__ = self.__parents__
281 return obj
282
283
284class Model(RawModel, dict, MutableMapping):
285 """
286 A thin wrapper on fields dict to store API doc metadata.
287 Can also be used for response marshalling.
288
289 :param str name: The model public name
290 :param str mask: an optional default model mask
291 """
292
293 pass
294
295
296class OrderedModel(RawModel, OrderedDict, MutableMapping):
297 """
298 A thin wrapper on ordered fields dict to store API doc metadata.
299 Can also be used for response marshalling.
300
301 :param str name: The model public name
302 :param str mask: an optional default model mask
303 """
304
305 wrapper = OrderedDict
306
307
308class SchemaModel(ModelBase):
309 """
310 Stores API doc metadata based on a json schema.
311
312 :param str name: The model public name
313 :param dict schema: The json schema we are documenting
314 """
315
316 def __init__(self, name, schema=None):
317 super(SchemaModel, self).__init__(name)
318 self._schema = schema or {}
319
320 def __unicode__(self):
321 return "SchemaModel({name},{schema})".format(
322 name=self.name, schema=self._schema
323 )
324
325 __str__ = __unicode__