Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/flask_restx/model.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import re
3import warnings
5from collections import OrderedDict
7from collections.abc import MutableMapping
8from werkzeug.utils import cached_property
10from .mask import Mask
11from .errors import abort
13from jsonschema import Draft4Validator
14from jsonschema.validators import validator_for
15from jsonschema.exceptions import ValidationError
17from .utils import not_none
18from ._http import HTTPStatus
20RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
23def instance(cls):
24 if isinstance(cls, type):
25 return cls()
26 return cls
29class ModelBase(object):
30 """
31 Handles validation and swagger style inheritance for both subclasses.
32 Subclass must define `schema` attribute.
34 :param str name: The model public name
35 """
37 def __init__(self, name, *args, **kwargs):
38 super(ModelBase, self).__init__(*args, **kwargs)
39 self.__apidoc__ = {"name": name}
40 self.name = name
41 self.__parents__ = []
43 def instance_inherit(name, *parents):
44 return self.__class__.inherit(name, self, *parents)
46 self.inherit = instance_inherit
48 @property
49 def ancestors(self):
50 """
51 Return the ancestors tree
52 """
53 ancestors = [p.ancestors for p in self.__parents__]
54 return set.union(set([self.name]), *ancestors)
56 def get_parent(self, name):
57 if self.name == name:
58 return self
59 else:
60 for parent in self.__parents__:
61 found = parent.get_parent(name)
62 if found:
63 return found
64 raise ValueError("Parent " + name + " not found")
66 @property
67 def __schema__(self):
68 schema = self._schema
70 if self.__parents__:
71 refs = [
72 {"$ref": "#/definitions/{0}".format(parent.name)}
73 for parent in self.__parents__
74 ]
76 return {"allOf": refs + [schema]}
77 else:
78 return schema
80 @classmethod
81 def inherit(cls, name, *parents):
82 """
83 Inherit this model (use the Swagger composition pattern aka. allOf)
84 :param str name: The new model name
85 :param dict fields: The new model extra fields
86 """
87 model = cls(name, parents[-1])
88 model.__parents__ = parents[:-1]
89 return model
91 def validate(self, data, resolver=None, format_checker=None):
92 # For backward compatibility, resolver can be either a RefResolver or a Registry
93 if resolver is not None and hasattr(resolver, "resolve"):
94 # Old RefResolver - convert to registry
95 registry = None
96 validator = Draft4Validator(
97 self.__schema__, resolver=resolver, format_checker=format_checker
98 )
99 else:
100 # New Registry or None
101 # If we have a registry, we need to create a schema that includes definitions
102 schema_to_validate = self.__schema__
103 if resolver is not None:
104 # Check if the schema has $ref that need to be resolved
105 import json
107 schema_str = json.dumps(self.__schema__)
108 if '"$ref"' in schema_str:
109 # Create a schema with inline definitions from the registry
110 definitions = {}
111 for uri in resolver:
112 resource = resolver[uri]
113 if isinstance(resource, dict) and "definitions" in resource:
114 definitions.update(resource["definitions"])
116 if definitions:
117 # Create a new schema that includes the definitions
118 schema_to_validate = {
119 "$id": "http://localhost/schema.json",
120 "definitions": definitions,
121 **self.__schema__,
122 }
124 ValidatorClass = validator_for(schema_to_validate)
125 if resolver is not None:
126 validator = ValidatorClass(
127 schema_to_validate, registry=resolver, format_checker=format_checker
128 )
129 else:
130 validator = ValidatorClass(
131 schema_to_validate, format_checker=format_checker
132 )
134 try:
135 validator.validate(data)
136 except ValidationError:
137 abort(
138 HTTPStatus.BAD_REQUEST,
139 message="Input payload validation failed",
140 errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
141 )
143 def format_error(self, error):
144 path = list(error.path)
145 if error.validator == "required":
146 name = RE_REQUIRED.match(error.message).group("name")
147 path.append(name)
148 key = ".".join(str(p) for p in path)
149 return key, error.message
151 def __unicode__(self):
152 return "Model({name},{{{fields}}})".format(
153 name=self.name, fields=",".join(self.keys())
154 )
156 __str__ = __unicode__
159class RawModel(ModelBase):
160 """
161 A thin wrapper on ordered fields dict to store API doc metadata.
162 Can also be used for response marshalling.
164 :param str name: The model public name
165 :param str mask: an optional default model mask
166 :param bool strict: validation should raise error when there is param not provided in schema
167 """
169 wrapper = dict
171 def __init__(self, name, *args, **kwargs):
172 self.__mask__ = kwargs.pop("mask", None)
173 self.__strict__ = kwargs.pop("strict", False)
174 if self.__mask__ and not isinstance(self.__mask__, Mask):
175 self.__mask__ = Mask(self.__mask__)
176 super(RawModel, self).__init__(name, *args, **kwargs)
178 def instance_clone(name, *parents):
179 return self.__class__.clone(name, self, *parents)
181 self.clone = instance_clone
183 @property
184 def _schema(self):
185 properties = self.wrapper()
186 required = set()
187 discriminator = None
188 for name, field in self.items():
189 field = instance(field)
190 properties[name] = field.__schema__
191 if field.required:
192 required.add(name)
193 if getattr(field, "discriminator", False):
194 discriminator = name
196 definition = {
197 "required": sorted(list(required)) or None,
198 "properties": properties,
199 "discriminator": discriminator,
200 "x-mask": str(self.__mask__) if self.__mask__ else None,
201 "type": "object",
202 }
204 if self.__strict__:
205 definition["additionalProperties"] = False
207 return not_none(definition)
209 @cached_property
210 def resolved(self):
211 """
212 Resolve real fields before submitting them to marshal
213 """
214 # Duplicate fields
215 resolved = copy.deepcopy(self)
217 # Recursively copy parent fields if necessary
218 for parent in self.__parents__:
219 resolved.update(parent.resolved)
221 # Handle discriminator
222 candidates = [f for f in resolved.values() if getattr(f, "discriminator", None)]
223 # Ensure the is only one discriminator
224 if len(candidates) > 1:
225 raise ValueError("There can only be one discriminator by schema")
226 # Ensure discriminator always output the model name
227 elif len(candidates) == 1:
228 candidates[0].default = self.name
230 return resolved
232 def extend(self, name, fields):
233 """
234 Extend this model (Duplicate all fields)
236 :param str name: The new model name
237 :param dict fields: The new model extra fields
239 :deprecated: since 0.9. Use :meth:`clone` instead.
240 """
241 warnings.warn(
242 "extend is is deprecated, use clone instead",
243 DeprecationWarning,
244 stacklevel=2,
245 )
246 if isinstance(fields, (list, tuple)):
247 return self.clone(name, *fields)
248 else:
249 return self.clone(name, fields)
251 @classmethod
252 def clone(cls, name, *parents):
253 """
254 Clone these models (Duplicate all fields)
256 It can be used from the class
258 >>> model = Model.clone(fields_1, fields_2)
260 or from an Instanciated model
262 >>> new_model = model.clone(fields_1, fields_2)
264 :param str name: The new model name
265 :param dict parents: The new model extra fields
266 """
267 fields = cls.wrapper()
268 for parent in parents:
269 fields.update(copy.deepcopy(parent))
270 return cls(name, fields)
272 def __deepcopy__(self, memo):
273 obj = self.__class__(
274 self.name,
275 [(key, copy.deepcopy(value, memo)) for key, value in self.items()],
276 mask=self.__mask__,
277 strict=self.__strict__,
278 )
279 obj.__parents__ = self.__parents__
280 return obj
283class Model(RawModel, dict, MutableMapping):
284 """
285 A thin wrapper on fields dict to store API doc metadata.
286 Can also be used for response marshalling.
288 :param str name: The model public name
289 :param str mask: an optional default model mask
290 """
292 pass
295class OrderedModel(RawModel, OrderedDict, MutableMapping):
296 """
297 A thin wrapper on ordered fields dict to store API doc metadata.
298 Can also be used for response marshalling.
300 :param str name: The model public name
301 :param str mask: an optional default model mask
302 """
304 wrapper = OrderedDict
307class SchemaModel(ModelBase):
308 """
309 Stores API doc metadata based on a json schema.
311 :param str name: The model public name
312 :param dict schema: The json schema we are documenting
313 """
315 def __init__(self, name, schema=None):
316 super(SchemaModel, self).__init__(name)
317 self._schema = schema or {}
319 def __unicode__(self):
320 return "SchemaModel({name},{schema})".format(
321 name=self.name, schema=self._schema
322 )
324 __str__ = __unicode__