Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/flask_restx/model.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1import copy 

2import re 

3import warnings 

4 

5from collections import OrderedDict 

6 

7from collections.abc import MutableMapping 

8from werkzeug.utils import cached_property 

9 

10from .mask import Mask 

11from .errors import abort 

12 

13from jsonschema import Draft4Validator 

14from jsonschema.validators import validator_for 

15from jsonschema.exceptions import ValidationError 

16 

17from .utils import not_none 

18from ._http import HTTPStatus 

19 

20 

21RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U) 

22 

23 

24def instance(cls): 

25 if isinstance(cls, type): 

26 return cls() 

27 return cls 

28 

29 

30class ModelBase(object): 

31 """ 

32 Handles validation and swagger style inheritance for both subclasses. 

33 Subclass must define `schema` attribute. 

34 

35 :param str name: The model public name 

36 """ 

37 

38 def __init__(self, name, *args, **kwargs): 

39 super(ModelBase, self).__init__(*args, **kwargs) 

40 self.__apidoc__ = {"name": name} 

41 self.name = name 

42 self.__parents__ = [] 

43 

44 def instance_inherit(name, *parents): 

45 return self.__class__.inherit(name, self, *parents) 

46 

47 self.inherit = instance_inherit 

48 

49 @property 

50 def ancestors(self): 

51 """ 

52 Return the ancestors tree 

53 """ 

54 ancestors = [p.ancestors for p in self.__parents__] 

55 return set.union(set([self.name]), *ancestors) 

56 

57 def get_parent(self, name): 

58 if self.name == name: 

59 return self 

60 else: 

61 for parent in self.__parents__: 

62 found = parent.get_parent(name) 

63 if found: 

64 return found 

65 raise ValueError("Parent " + name + " not found") 

66 

67 @property 

68 def __schema__(self): 

69 schema = self._schema 

70 

71 if self.__parents__: 

72 refs = [ 

73 {"$ref": "#/definitions/{0}".format(parent.name)} 

74 for parent in self.__parents__ 

75 ] 

76 

77 return {"allOf": refs + [schema]} 

78 else: 

79 return schema 

80 

81 @classmethod 

82 def inherit(cls, name, *parents): 

83 """ 

84 Inherit this model (use the Swagger composition pattern aka. allOf) 

85 :param str name: The new model name 

86 :param dict fields: The new model extra fields 

87 """ 

88 model = cls(name, parents[-1]) 

89 model.__parents__ = parents[:-1] 

90 return model 

91 

92 def validate(self, data, resolver=None, format_checker=None): 

93 # For backward compatibility, resolver can be either a RefResolver or a Registry 

94 if resolver is not None and hasattr(resolver, "resolve"): 

95 # Old RefResolver - convert to registry 

96 registry = None 

97 validator = Draft4Validator( 

98 self.__schema__, resolver=resolver, format_checker=format_checker 

99 ) 

100 else: 

101 # New Registry or None 

102 # If we have a registry, we need to create a schema that includes definitions 

103 schema_to_validate = self.__schema__ 

104 if resolver is not None: 

105 # Check if the schema has $ref that need to be resolved 

106 import json 

107 

108 schema_str = json.dumps(self.__schema__) 

109 if '"$ref"' in schema_str: 

110 # Create a schema with inline definitions from the registry 

111 definitions = {} 

112 for uri in resolver: 

113 resource = resolver[uri] 

114 if isinstance(resource, dict) and "definitions" in resource: 

115 definitions.update(resource["definitions"]) 

116 

117 if definitions: 

118 # Create a new schema that includes the definitions 

119 schema_to_validate = { 

120 "$id": "http://localhost/schema.json", 

121 "definitions": definitions, 

122 **self.__schema__, 

123 } 

124 

125 ValidatorClass = validator_for(schema_to_validate) 

126 if resolver is not None: 

127 validator = ValidatorClass( 

128 schema_to_validate, registry=resolver, format_checker=format_checker 

129 ) 

130 else: 

131 validator = ValidatorClass( 

132 schema_to_validate, format_checker=format_checker 

133 ) 

134 

135 try: 

136 validator.validate(data) 

137 except ValidationError: 

138 abort( 

139 HTTPStatus.BAD_REQUEST, 

140 message="Input payload validation failed", 

141 errors=dict(self.format_error(e) for e in validator.iter_errors(data)), 

142 ) 

143 

144 def format_error(self, error): 

145 path = list(error.path) 

146 if error.validator == "required": 

147 name = RE_REQUIRED.match(error.message).group("name") 

148 path.append(name) 

149 key = ".".join(str(p) for p in path) 

150 return key, error.message 

151 

152 def __unicode__(self): 

153 return "Model({name},{{{fields}}})".format( 

154 name=self.name, fields=",".join(self.keys()) 

155 ) 

156 

157 __str__ = __unicode__ 

158 

159 

160class RawModel(ModelBase): 

161 """ 

162 A thin wrapper on ordered fields dict to store API doc metadata. 

163 Can also be used for response marshalling. 

164 

165 :param str name: The model public name 

166 :param str mask: an optional default model mask 

167 :param bool strict: validation should raise error when there is param not provided in schema 

168 """ 

169 

170 wrapper = dict 

171 

172 def __init__(self, name, *args, **kwargs): 

173 self.__mask__ = kwargs.pop("mask", None) 

174 self.__strict__ = kwargs.pop("strict", False) 

175 if self.__mask__ and not isinstance(self.__mask__, Mask): 

176 self.__mask__ = Mask(self.__mask__) 

177 super(RawModel, self).__init__(name, *args, **kwargs) 

178 

179 def instance_clone(name, *parents): 

180 return self.__class__.clone(name, self, *parents) 

181 

182 self.clone = instance_clone 

183 

184 @property 

185 def _schema(self): 

186 properties = self.wrapper() 

187 required = set() 

188 discriminator = None 

189 for name, field in self.items(): 

190 field = instance(field) 

191 properties[name] = field.__schema__ 

192 if field.required: 

193 required.add(name) 

194 if getattr(field, "discriminator", False): 

195 discriminator = name 

196 

197 definition = { 

198 "required": sorted(list(required)) or None, 

199 "properties": properties, 

200 "discriminator": discriminator, 

201 "x-mask": str(self.__mask__) if self.__mask__ else None, 

202 "type": "object", 

203 } 

204 

205 if self.__strict__: 

206 definition["additionalProperties"] = False 

207 

208 return not_none(definition) 

209 

210 @cached_property 

211 def resolved(self): 

212 """ 

213 Resolve real fields before submitting them to marshal 

214 """ 

215 # Duplicate fields 

216 resolved = copy.deepcopy(self) 

217 

218 # Recursively copy parent fields if necessary 

219 for parent in self.__parents__: 

220 resolved.update(parent.resolved) 

221 

222 # Handle discriminator 

223 candidates = [f for f in resolved.values() if getattr(f, "discriminator", None)] 

224 # Ensure the is only one discriminator 

225 if len(candidates) > 1: 

226 raise ValueError("There can only be one discriminator by schema") 

227 # Ensure discriminator always output the model name 

228 elif len(candidates) == 1: 

229 candidates[0].default = self.name 

230 

231 return resolved 

232 

233 def extend(self, name, fields): 

234 """ 

235 Extend this model (Duplicate all fields) 

236 

237 :param str name: The new model name 

238 :param dict fields: The new model extra fields 

239 

240 :deprecated: since 0.9. Use :meth:`clone` instead. 

241 """ 

242 warnings.warn( 

243 "extend is is deprecated, use clone instead", 

244 DeprecationWarning, 

245 stacklevel=2, 

246 ) 

247 if isinstance(fields, (list, tuple)): 

248 return self.clone(name, *fields) 

249 else: 

250 return self.clone(name, fields) 

251 

252 @classmethod 

253 def clone(cls, name, *parents): 

254 """ 

255 Clone these models (Duplicate all fields) 

256 

257 It can be used from the class 

258 

259 >>> model = Model.clone(fields_1, fields_2) 

260 

261 or from an Instanciated model 

262 

263 >>> new_model = model.clone(fields_1, fields_2) 

264 

265 :param str name: The new model name 

266 :param dict parents: The new model extra fields 

267 """ 

268 fields = cls.wrapper() 

269 for parent in parents: 

270 fields.update(copy.deepcopy(parent)) 

271 return cls(name, fields) 

272 

273 def __deepcopy__(self, memo): 

274 obj = self.__class__( 

275 self.name, 

276 [(key, copy.deepcopy(value, memo)) for key, value in self.items()], 

277 mask=self.__mask__, 

278 strict=self.__strict__, 

279 ) 

280 obj.__parents__ = self.__parents__ 

281 return obj 

282 

283 

284class Model(RawModel, dict, MutableMapping): 

285 """ 

286 A thin wrapper on fields dict to store API doc metadata. 

287 Can also be used for response marshalling. 

288 

289 :param str name: The model public name 

290 :param str mask: an optional default model mask 

291 """ 

292 

293 pass 

294 

295 

296class OrderedModel(RawModel, OrderedDict, MutableMapping): 

297 """ 

298 A thin wrapper on ordered fields dict to store API doc metadata. 

299 Can also be used for response marshalling. 

300 

301 :param str name: The model public name 

302 :param str mask: an optional default model mask 

303 """ 

304 

305 wrapper = OrderedDict 

306 

307 

308class SchemaModel(ModelBase): 

309 """ 

310 Stores API doc metadata based on a json schema. 

311 

312 :param str name: The model public name 

313 :param dict schema: The json schema we are documenting 

314 """ 

315 

316 def __init__(self, name, schema=None): 

317 super(SchemaModel, self).__init__(name) 

318 self._schema = schema or {} 

319 

320 def __unicode__(self): 

321 return "SchemaModel({name},{schema})".format( 

322 name=self.name, schema=self._schema 

323 ) 

324 

325 __str__ = __unicode__