1from collections import OrderedDict
2from decimal import Decimal
3import re
4
5from .exceptions import JsonSchemaValueException, JsonSchemaValuesException, JsonSchemaDefinitionException
6from .indent import indent
7from .ref_resolver import RefResolver
8
9
10def enforce_list(variable):
11 if isinstance(variable, list):
12 return variable
13 return [variable]
14
15
16# pylint: disable=too-many-instance-attributes,too-many-public-methods
17class CodeGenerator:
18 """
19 This class is not supposed to be used directly. Anything
20 inside of this class can be changed without noticing.
21
22 This class generates code of validation function from JSON
23 schema object as string. Example:
24
25 .. code-block:: python
26
27 CodeGenerator(json_schema_definition).func_code
28 """
29
30 INDENT = 4 # spaces
31
32 def __init__(self, definition, resolver=None, detailed_exceptions=True, fast_fail=True):
33 self._code = []
34 self._compile_regexps = {}
35 self._custom_formats = {}
36 self._detailed_exceptions = detailed_exceptions
37 self._fast_fail = fast_fail
38
39 # Any extra library should be here to be imported only once.
40 # Lines are imports to be printed in the file and objects
41 # key-value pair to pass to compile function directly.
42 self._extra_imports_lines = [
43 "from decimal import Decimal",
44 ]
45 self._extra_imports_objects = {
46 "Decimal": Decimal,
47 }
48
49 self._variables = set()
50 self._indent = 0
51 self._indent_last_line = None
52 self._variable = None
53 self._variable_name = None
54 self._root_definition = definition
55 self._definition = None
56
57 # map schema URIs to validation function names for functions
58 # that are not yet generated, but need to be generated
59 self._needed_validation_functions = {}
60 # validation function names that are already done
61 self._validation_functions_done = set()
62
63 if resolver is None:
64 resolver = RefResolver.from_schema(definition, store={})
65 self._resolver = resolver
66
67 # add main function to `self._needed_validation_functions`
68 self._needed_validation_functions[self._resolver.get_uri()] = self._resolver.get_scope_name()
69
70 self._json_keywords_to_function = OrderedDict()
71
72 @property
73 def func_code(self):
74 """
75 Returns generated code of whole validation function as string.
76 """
77 self._generate_func_code()
78
79 return '\n'.join(self._code)
80
81 @property
82 def global_state(self):
83 """
84 Returns global variables for generating function from ``func_code``. Includes
85 compiled regular expressions and imports, so it does not have to do it every
86 time when validation function is called.
87 """
88 self._generate_func_code()
89
90 return dict(
91 **self._extra_imports_objects,
92 REGEX_PATTERNS=self._compile_regexps,
93 re=re,
94 JsonSchemaValueException=JsonSchemaValueException,
95 JsonSchemaValuesException=JsonSchemaValuesException,
96 )
97
98 @property
99 def global_state_code(self):
100 """
101 Returns global variables for generating function from ``func_code`` as code.
102 Includes compiled regular expressions and imports.
103 """
104 self._generate_func_code()
105
106 if not self._compile_regexps:
107 return '\n'.join(self._extra_imports_lines + [
108 'from fastjsonschema import JsonSchemaValueException, JsonSchemaValuesException',
109 '',
110 '',
111 ])
112 return '\n'.join(self._extra_imports_lines + [
113 'import re',
114 'from fastjsonschema import JsonSchemaValueException, JsonSchemaValuesException',
115 '',
116 '',
117 'REGEX_PATTERNS = ' + serialize_regexes(self._compile_regexps),
118 '',
119 ])
120
121
122 def _generate_func_code(self):
123 if not self._code:
124 self.generate_func_code()
125
126 def generate_func_code(self):
127 """
128 Creates base code of validation function and calls helper
129 for creating code by definition.
130 """
131 self.l('NoneType = type(None)')
132 # Generate parts that are referenced and not yet generated
133 while self._needed_validation_functions:
134 # During generation of validation function, could be needed to generate
135 # new one that is added again to `_needed_validation_functions`.
136 # Therefore usage of while instead of for loop.
137 uri, name = self._needed_validation_functions.popitem()
138 self.generate_validation_function(uri, name)
139
140 def generate_validation_function(self, uri, name):
141 """
142 Generate validation function for given uri with given name
143 """
144 self._validation_functions_done.add(uri)
145 self.l('')
146 with self._resolver.resolving(uri) as definition:
147 with self.l('def {}(data, custom_formats={{}}, name_prefix=None):', name):
148 if not self._fast_fail:
149 self.l('errors = []')
150 self.generate_func_code_block(definition, 'data', 'data', clear_variables=True)
151 if not self._fast_fail:
152 self.l('if errors: raise JsonSchemaValuesException(errors)')
153 self.l('return data')
154
155 def generate_func_code_block(self, definition, variable, variable_name, clear_variables=False):
156 """
157 Creates validation rules for current definition.
158
159 Returns the number of validation rules generated as code.
160 """
161 backup = self._definition, self._variable, self._variable_name
162 self._definition, self._variable, self._variable_name = definition, variable, variable_name
163 if clear_variables:
164 backup_variables = self._variables
165 self._variables = set()
166
167 count = self._generate_func_code_block(definition)
168
169 self._definition, self._variable, self._variable_name = backup
170 if clear_variables:
171 self._variables = backup_variables
172
173 return count
174
175 def _generate_func_code_block(self, definition):
176 if not isinstance(definition, dict):
177 raise JsonSchemaDefinitionException("definition must be an object")
178 if '$ref' in definition:
179 # needed because ref overrides any sibling keywords
180 return self.generate_ref()
181 return self.run_generate_functions(definition)
182
183 def run_generate_functions(self, definition):
184 """Returns the number of generate functions that were executed."""
185 count = 0
186 for key, func in self._json_keywords_to_function.items():
187 if key in definition:
188 func()
189 count += 1
190 return count
191
192 def generate_ref(self):
193 """
194 Ref can be link to remote or local definition.
195
196 .. code-block:: python
197
198 {'$ref': 'http://json-schema.org/draft-04/schema#'}
199 {
200 'properties': {
201 'foo': {'type': 'integer'},
202 'bar': {'$ref': '#/properties/foo'}
203 }
204 }
205 """
206 with self._resolver.in_scope(self._definition['$ref']):
207 name = self._resolver.get_scope_name()
208 uri = self._resolver.get_uri()
209 if uri not in self._validation_functions_done:
210 self._needed_validation_functions[uri] = name
211 # call validation function
212 assert self._variable_name.startswith("data")
213 path = self._variable_name[4:]
214 name_arg = '(name_prefix or "data") + "{}"'.format(path)
215 if '{' in name_arg:
216 name_arg = name_arg + '.format(**locals())'
217 self.l('{}({variable}, custom_formats, {name_arg})', name, name_arg=name_arg)
218
219
220 # pylint: disable=invalid-name
221 @indent
222 def l(self, line, *args, **kwds):
223 """
224 Short-cut of line. Used for inserting line. It's formated with parameters
225 ``variable``, ``variable_name`` (as ``name`` for short-cut), all keys from
226 current JSON schema ``definition`` and also passed arguments in ``args``
227 and named ``kwds``.
228
229 .. code-block:: python
230
231 self.l('if {variable} not in {enum}: raise JsonSchemaValueException("Wrong!")')
232
233 When you want to indent block, use it as context manager. For example:
234
235 .. code-block:: python
236
237 with self.l('if {variable} not in {enum}:'):
238 self.l('raise JsonSchemaValueException("Wrong!")')
239 """
240 spaces = ' ' * self.INDENT * self._indent
241
242 name = self._variable_name
243 if name:
244 # Add name_prefix to the name when it is being outputted.
245 assert name.startswith('data')
246 name = '" + (name_prefix or "data") + "' + name[4:]
247 if '{' in name:
248 name = name + '".format(**locals()) + "'
249
250 context = dict(
251 self._definition if self._definition and self._definition is not True else {},
252 variable=self._variable,
253 name=name,
254 **kwds
255 )
256 line = line.format(*args, **context)
257 line = line.replace('\n', '\\n').replace('\r', '\\r')
258 self._code.append(spaces + line)
259 return line
260
261 def e(self, string):
262 """
263 Short-cut of escape. Used for inserting user values into a string message.
264
265 .. code-block:: python
266
267 self.l('raise JsonSchemaValueException("Variable: {}")', self.e(variable))
268 """
269 return str(string).replace('"', '\\"')
270
271 def exc(self, msg, *args, append_to_msg=None, rule=None):
272 """
273 Short-cut for creating raising exception in the code.
274 """
275 if not self._detailed_exceptions:
276 if self._fast_fail:
277 self.l('raise JsonSchemaValueException("'+msg+'")', *args)
278 else:
279 self.l('errors.append(JsonSchemaValueException("'+msg+'"))', *args)
280 return
281
282 arg = '"'+msg+'"'
283 if append_to_msg:
284 arg += ' + (' + append_to_msg + ')'
285 # pylint: disable=line-too-long
286 msg = (
287 'raise JsonSchemaValueException('+arg+', value={variable}, name="{name}", definition={definition}, rule={rule})'
288 if self._fast_fail else
289 'errors.append(JsonSchemaValueException('+arg+', value={variable}, name="{name}", definition={definition}, rule={rule}))'
290 )
291 definition = self._expand_refs(self._definition)
292 definition_rule = self.e(definition.get(rule) if isinstance(definition, dict) else None)
293 self.l(msg, *args, definition=repr(definition), rule=repr(rule), definition_rule=definition_rule)
294
295 def _expand_refs(self, definition):
296 if isinstance(definition, list):
297 return [self._expand_refs(v) for v in definition]
298 if not isinstance(definition, dict):
299 return definition
300 if "$ref" in definition and isinstance(definition["$ref"], str):
301 with self._resolver.resolving(definition["$ref"]) as schema:
302 return schema
303 return {k: self._expand_refs(v) for k, v in definition.items()}
304
305 def create_variable_with_length(self):
306 """
307 Append code for creating variable with length of that variable
308 (for example length of list or dictionary) with name ``{variable}_len``.
309 It can be called several times and always it's done only when that variable
310 still does not exists.
311 """
312 variable_name = '{}_len'.format(self._variable)
313 if variable_name in self._variables:
314 return
315 self._variables.add(variable_name)
316 self.l('{variable}_len = len({variable})')
317
318 def create_variable_keys(self):
319 """
320 Append code for creating variable with keys of that variable (dictionary)
321 with a name ``{variable}_keys``. Similar to `create_variable_with_length`.
322 """
323 variable_name = '{}_keys'.format(self._variable)
324 if variable_name in self._variables:
325 return
326 self._variables.add(variable_name)
327 self.l('{variable}_keys = set({variable}.keys())')
328
329 def create_variable_is_list(self):
330 """
331 Append code for creating variable with bool if it's instance of list
332 with a name ``{variable}_is_list``. Similar to `create_variable_with_length`.
333 """
334 variable_name = '{}_is_list'.format(self._variable)
335 if variable_name in self._variables:
336 return
337 self._variables.add(variable_name)
338 self.l('{variable}_is_list = isinstance({variable}, (list, tuple))')
339
340 def create_variable_is_dict(self):
341 """
342 Append code for creating variable with bool if it's instance of list
343 with a name ``{variable}_is_dict``. Similar to `create_variable_with_length`.
344 """
345 variable_name = '{}_is_dict'.format(self._variable)
346 if variable_name in self._variables:
347 return
348 self._variables.add(variable_name)
349 self.l('{variable}_is_dict = isinstance({variable}, dict)')
350
351
352def serialize_regexes(patterns_dict):
353 # Unfortunately using `pprint.pformat` is causing errors
354 # specially with big regexes
355 regex_patterns = (
356 repr(k) + ": " + repr_regex(v)
357 for k, v in patterns_dict.items()
358 )
359 return '{\n ' + ",\n ".join(regex_patterns) + "\n}"
360
361
362def repr_regex(regex):
363 all_flags = ("A", "I", "DEBUG", "L", "M", "S", "X")
364 flags = " | ".join(f"re.{f}" for f in all_flags if regex.flags & getattr(re, f))
365 flags = ", " + flags if flags else ""
366 return "re.compile({!r}{})".format(regex.pattern, flags)