1import copy
2import re
3import warnings
4
5from collections import OrderedDict
6
7from collections.abc import MutableMapping
8from werkzeug.utils import cached_property
9
10from .mask import Mask
11from .errors import abort
12
13from jsonschema import Draft4Validator
14from jsonschema.exceptions import ValidationError
15
16from .utils import not_none
17from ._http import HTTPStatus
18
19
20RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
21
22
23def instance(cls):
24 if isinstance(cls, type):
25 return cls()
26 return cls
27
28
29class ModelBase(object):
30 """
31 Handles validation and swagger style inheritance for both subclasses.
32 Subclass must define `schema` attribute.
33
34 :param str name: The model public name
35 """
36
37 def __init__(self, name, *args, **kwargs):
38 super(ModelBase, self).__init__(*args, **kwargs)
39 self.__apidoc__ = {"name": name}
40 self.name = name
41 self.__parents__ = []
42
43 def instance_inherit(name, *parents):
44 return self.__class__.inherit(name, self, *parents)
45
46 self.inherit = instance_inherit
47
48 @property
49 def ancestors(self):
50 """
51 Return the ancestors tree
52 """
53 ancestors = [p.ancestors for p in self.__parents__]
54 return set.union(set([self.name]), *ancestors)
55
56 def get_parent(self, name):
57 if self.name == name:
58 return self
59 else:
60 for parent in self.__parents__:
61 found = parent.get_parent(name)
62 if found:
63 return found
64 raise ValueError("Parent " + name + " not found")
65
66 @property
67 def __schema__(self):
68 schema = self._schema
69
70 if self.__parents__:
71 refs = [
72 {"$ref": "#/definitions/{0}".format(parent.name)}
73 for parent in self.__parents__
74 ]
75
76 return {"allOf": refs + [schema]}
77 else:
78 return schema
79
80 @classmethod
81 def inherit(cls, name, *parents):
82 """
83 Inherit this model (use the Swagger composition pattern aka. allOf)
84 :param str name: The new model name
85 :param dict fields: The new model extra fields
86 """
87 model = cls(name, parents[-1])
88 model.__parents__ = parents[:-1]
89 return model
90
91 def validate(self, data, resolver=None, format_checker=None):
92 validator = Draft4Validator(
93 self.__schema__, resolver=resolver, format_checker=format_checker
94 )
95 try:
96 validator.validate(data)
97 except ValidationError:
98 abort(
99 HTTPStatus.BAD_REQUEST,
100 message="Input payload validation failed",
101 errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
102 )
103
104 def format_error(self, error):
105 path = list(error.path)
106 if error.validator == "required":
107 name = RE_REQUIRED.match(error.message).group("name")
108 path.append(name)
109 key = ".".join(str(p) for p in path)
110 return key, error.message
111
112 def __unicode__(self):
113 return "Model({name},{{{fields}}})".format(
114 name=self.name, fields=",".join(self.keys())
115 )
116
117 __str__ = __unicode__
118
119
120class RawModel(ModelBase):
121 """
122 A thin wrapper on ordered fields dict to store API doc metadata.
123 Can also be used for response marshalling.
124
125 :param str name: The model public name
126 :param str mask: an optional default model mask
127 :param bool strict: validation should raise error when there is param not provided in schema
128 """
129
130 wrapper = dict
131
132 def __init__(self, name, *args, **kwargs):
133 self.__mask__ = kwargs.pop("mask", None)
134 self.__strict__ = kwargs.pop("strict", False)
135 if self.__mask__ and not isinstance(self.__mask__, Mask):
136 self.__mask__ = Mask(self.__mask__)
137 super(RawModel, self).__init__(name, *args, **kwargs)
138
139 def instance_clone(name, *parents):
140 return self.__class__.clone(name, self, *parents)
141
142 self.clone = instance_clone
143
144 @property
145 def _schema(self):
146 properties = self.wrapper()
147 required = set()
148 discriminator = None
149 for name, field in self.items():
150 field = instance(field)
151 properties[name] = field.__schema__
152 if field.required:
153 required.add(name)
154 if getattr(field, "discriminator", False):
155 discriminator = name
156
157 definition = {
158 "required": sorted(list(required)) or None,
159 "properties": properties,
160 "discriminator": discriminator,
161 "x-mask": str(self.__mask__) if self.__mask__ else None,
162 "type": "object",
163 }
164
165 if self.__strict__:
166 definition["additionalProperties"] = False
167
168 return not_none(definition)
169
170 @cached_property
171 def resolved(self):
172 """
173 Resolve real fields before submitting them to marshal
174 """
175 # Duplicate fields
176 resolved = copy.deepcopy(self)
177
178 # Recursively copy parent fields if necessary
179 for parent in self.__parents__:
180 resolved.update(parent.resolved)
181
182 # Handle discriminator
183 candidates = [f for f in resolved.values() if getattr(f, "discriminator", None)]
184 # Ensure the is only one discriminator
185 if len(candidates) > 1:
186 raise ValueError("There can only be one discriminator by schema")
187 # Ensure discriminator always output the model name
188 elif len(candidates) == 1:
189 candidates[0].default = self.name
190
191 return resolved
192
193 def extend(self, name, fields):
194 """
195 Extend this model (Duplicate all fields)
196
197 :param str name: The new model name
198 :param dict fields: The new model extra fields
199
200 :deprecated: since 0.9. Use :meth:`clone` instead.
201 """
202 warnings.warn(
203 "extend is is deprecated, use clone instead",
204 DeprecationWarning,
205 stacklevel=2,
206 )
207 if isinstance(fields, (list, tuple)):
208 return self.clone(name, *fields)
209 else:
210 return self.clone(name, fields)
211
212 @classmethod
213 def clone(cls, name, *parents):
214 """
215 Clone these models (Duplicate all fields)
216
217 It can be used from the class
218
219 >>> model = Model.clone(fields_1, fields_2)
220
221 or from an Instanciated model
222
223 >>> new_model = model.clone(fields_1, fields_2)
224
225 :param str name: The new model name
226 :param dict parents: The new model extra fields
227 """
228 fields = cls.wrapper()
229 for parent in parents:
230 fields.update(copy.deepcopy(parent))
231 return cls(name, fields)
232
233 def __deepcopy__(self, memo):
234 obj = self.__class__(
235 self.name,
236 [(key, copy.deepcopy(value, memo)) for key, value in self.items()],
237 mask=self.__mask__,
238 strict=self.__strict__,
239 )
240 obj.__parents__ = self.__parents__
241 return obj
242
243
244class Model(RawModel, dict, MutableMapping):
245 """
246 A thin wrapper on fields dict to store API doc metadata.
247 Can also be used for response marshalling.
248
249 :param str name: The model public name
250 :param str mask: an optional default model mask
251 """
252
253 pass
254
255
256class OrderedModel(RawModel, OrderedDict, MutableMapping):
257 """
258 A thin wrapper on ordered fields dict to store API doc metadata.
259 Can also be used for response marshalling.
260
261 :param str name: The model public name
262 :param str mask: an optional default model mask
263 """
264
265 wrapper = OrderedDict
266
267
268class SchemaModel(ModelBase):
269 """
270 Stores API doc metadata based on a json schema.
271
272 :param str name: The model public name
273 :param dict schema: The json schema we are documenting
274 """
275
276 def __init__(self, name, schema=None):
277 super(SchemaModel, self).__init__(name)
278 self._schema = schema or {}
279
280 def __unicode__(self):
281 return "SchemaModel({name},{schema})".format(
282 name=self.name, schema=self._schema
283 )
284
285 __str__ = __unicode__