1"""
2Common validator wrapper to provide a uniform usage of other schema validation
3libraries.
4"""
5
6# Copyright (c) Jupyter Development Team.
7# Distributed under the terms of the Modified BSD License.
8from __future__ import annotations
9
10import os
11
12import fastjsonschema
13import jsonschema
14from fastjsonschema import JsonSchemaException as _JsonSchemaException
15from jsonschema import Draft4Validator as _JsonSchemaValidator
16from jsonschema.exceptions import ErrorTree, ValidationError
17
18__all__ = [
19 "ValidationError",
20 "JsonSchemaValidator",
21 "FastJsonSchemaValidator",
22 "get_current_validator",
23 "VALIDATORS",
24]
25
26
27class JsonSchemaValidator:
28 """A json schema validator."""
29
30 name = "jsonschema"
31
32 def __init__(self, schema):
33 """Initialize the validator."""
34 self._schema = schema
35 self._default_validator = _JsonSchemaValidator(schema) # Default
36 self._validator = self._default_validator
37
38 def validate(self, data):
39 """Validate incoming data."""
40 self._default_validator.validate(data)
41
42 def iter_errors(self, data, schema=None):
43 """Iterate over errors in incoming data."""
44 if schema is None:
45 return self._default_validator.iter_errors(data)
46 if hasattr(self._default_validator, "evolve"):
47 return self._default_validator.evolve(schema=schema).iter_errors(data)
48 return self._default_validator.iter_errors(data, schema)
49
50 def error_tree(self, errors):
51 """Create an error tree for the errors."""
52 return ErrorTree(errors=errors)
53
54
55class FastJsonSchemaValidator(JsonSchemaValidator):
56 """A schema validator using fastjsonschema."""
57
58 name = "fastjsonschema"
59
60 def __init__(self, schema):
61 """Initialize the validator."""
62 super().__init__(schema)
63 self._validator = fastjsonschema.compile(schema)
64
65 def validate(self, data):
66 """Validate incoming data."""
67 try:
68 self._validator(data)
69 except _JsonSchemaException as error:
70 raise ValidationError(str(error), schema_path=error.path) from error
71
72 def iter_errors(self, data, schema=None):
73 """Iterate over errors in incoming data."""
74 if schema is not None:
75 return super().iter_errors(data, schema)
76
77 errors = []
78 validate_func = self._validator
79 try:
80 validate_func(data)
81 except _JsonSchemaException as error:
82 errors = [ValidationError(str(error), schema_path=error.path)]
83
84 return errors
85
86 def error_tree(self, errors):
87 """Create an error tree for the errors."""
88 # fastjsonschema's exceptions don't contain the same information that the jsonschema ValidationErrors
89 # do. This method is primarily used for introspecting metadata schema failures so that we can strip
90 # them if asked to do so in `nbformat.validate`.
91 # Another way forward for compatibility: we could distill both validator errors into a custom collection
92 # for this data. Since implementation details of ValidationError is used elsewhere, we would probably
93 # just use this data for schema introspection.
94 msg = "JSON schema error introspection not enabled for fastjsonschema"
95 raise NotImplementedError(msg)
96
97
98_VALIDATOR_MAP = [
99 ("fastjsonschema", fastjsonschema, FastJsonSchemaValidator),
100 ("jsonschema", jsonschema, JsonSchemaValidator),
101]
102VALIDATORS = [item[0] for item in _VALIDATOR_MAP]
103
104
105def _validator_for_name(validator_name):
106 if validator_name not in VALIDATORS:
107 msg = f"Invalid validator '{validator_name}' value!\nValid values are: {VALIDATORS}"
108 raise ValueError(msg)
109
110 for name, module, validator_cls in _VALIDATOR_MAP:
111 if module and validator_name == name:
112 return validator_cls
113 # we always return something.
114 msg = f"Missing validator for {validator_name!r}"
115 raise ValueError(msg)
116
117
118def get_current_validator():
119 """
120 Return the default validator based on the value of an environment variable.
121 """
122 validator_name = os.environ.get("NBFORMAT_VALIDATOR", "fastjsonschema")
123 return _validator_for_name(validator_name)