1"""Pluggable schema validator for pydantic."""
2
3from __future__ import annotations
4
5import functools
6from collections.abc import Iterable
7from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
8
9from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
10from typing_extensions import ParamSpec
11
12if TYPE_CHECKING:
13 from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
14
15
16P = ParamSpec('P')
17R = TypeVar('R')
18Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
19events: list[Event] = list(Event.__args__) # type: ignore
20
21
22def create_schema_validator(
23 schema: CoreSchema,
24 schema_type: Any,
25 schema_type_module: str,
26 schema_type_name: str,
27 schema_kind: SchemaKind,
28 config: CoreConfig | None = None,
29 plugin_settings: dict[str, Any] | None = None,
30 _use_prebuilt: bool = True,
31) -> SchemaValidator | PluggableSchemaValidator:
32 """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
33
34 Returns:
35 If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
36 """
37 from . import SchemaTypePath
38 from ._loader import get_plugins
39
40 plugins = get_plugins()
41 if plugins:
42 return PluggableSchemaValidator(
43 schema,
44 schema_type,
45 SchemaTypePath(schema_type_module, schema_type_name),
46 schema_kind,
47 config,
48 plugins,
49 plugin_settings or {},
50 _use_prebuilt=_use_prebuilt,
51 )
52 else:
53 return SchemaValidator(schema, config, _use_prebuilt=_use_prebuilt)
54
55
56class PluggableSchemaValidator:
57 """Pluggable schema validator."""
58
59 __slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
60
61 def __init__(
62 self,
63 schema: CoreSchema,
64 schema_type: Any,
65 schema_type_path: SchemaTypePath,
66 schema_kind: SchemaKind,
67 config: CoreConfig | None,
68 plugins: Iterable[PydanticPluginProtocol],
69 plugin_settings: dict[str, Any],
70 _use_prebuilt: bool = True,
71 ) -> None:
72 self._schema_validator = SchemaValidator(schema, config, _use_prebuilt=_use_prebuilt)
73
74 python_event_handlers: list[BaseValidateHandlerProtocol] = []
75 json_event_handlers: list[BaseValidateHandlerProtocol] = []
76 strings_event_handlers: list[BaseValidateHandlerProtocol] = []
77 for plugin in plugins:
78 try:
79 p, j, s = plugin.new_schema_validator(
80 schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
81 )
82 except TypeError as e: # pragma: no cover
83 raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
84 if p is not None:
85 python_event_handlers.append(p)
86 if j is not None:
87 json_event_handlers.append(j)
88 if s is not None:
89 strings_event_handlers.append(s)
90
91 self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
92 self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
93 self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
94
95 def __getattr__(self, name: str) -> Any:
96 return getattr(self._schema_validator, name)
97
98
99def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
100 if not event_handlers:
101 return func
102 else:
103 on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
104 on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
105 on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
106 on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
107
108 @functools.wraps(func)
109 def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
110 for on_enter_handler in on_enters:
111 on_enter_handler(*args, **kwargs)
112
113 try:
114 result = func(*args, **kwargs)
115 except ValidationError as error:
116 for on_error_handler in on_errors:
117 on_error_handler(error)
118 raise
119 except Exception as exception:
120 for on_exception_handler in on_exceptions:
121 on_exception_handler(exception)
122 raise
123 else:
124 for on_success_handler in on_successes:
125 on_success_handler(result)
126 return result
127
128 return wrapper
129
130
131def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
132 """Filter out handler methods which are not implemented by the plugin directly - e.g. those that are missing
133 or are inherited from the protocol.
134 """
135 handler = getattr(handler_cls, method_name, None)
136 if handler is None:
137 return False
138 elif handler.__module__ == 'pydantic.plugin':
139 # this is the original handler, from the protocol due to runtime inheritance
140 # we don't want to call it
141 return False
142 else:
143 return True