Coverage for /pythoncovmergedfiles/medio/medio/src/pydantic/pydantic/_internal/_core_utils.py: 49%

226 statements  

« prev     ^ index     » next       coverage.py v7.2.3, created at 2023-04-27 07:38 +0000

1# TODO: Should we move WalkAndApply into pydantic_core proper? 

2 

3from __future__ import annotations 

4 

5from typing import Any, Callable, Union, cast 

6 

7from pydantic_core import CoreSchema, CoreSchemaType, core_schema 

8from typing_extensions import TypeGuard, get_args 

9 

10from . import _repr 

11 

12AnyFunctionSchema = Union[ 

13 core_schema.AfterValidatorFunctionSchema, 

14 core_schema.BeforeValidatorFunctionSchema, 

15 core_schema.WrapValidatorFunctionSchema, 

16 core_schema.PlainValidatorFunctionSchema, 

17] 

18 

19 

20FunctionSchemaWithInnerSchema = Union[ 

21 core_schema.AfterValidatorFunctionSchema, 

22 core_schema.BeforeValidatorFunctionSchema, 

23 core_schema.WrapValidatorFunctionSchema, 

24] 

25 

26_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field'} 

27 

28 

29def is_typed_dict_field( 

30 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField, 

31) -> TypeGuard[core_schema.TypedDictField]: 

32 return schema['type'] == 'typed-dict-field' 

33 

34 

35def is_dataclass_field( 

36 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField, 

37) -> TypeGuard[core_schema.DataclassField]: 

38 return schema['type'] == 'dataclass-field' 

39 

40 

41def is_core_schema( 

42 schema: CoreSchema | core_schema.TypedDictField | core_schema.DataclassField, 

43) -> TypeGuard[CoreSchema]: 

44 return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES 

45 

46 

47def is_function_with_inner_schema( 

48 schema: CoreSchema | core_schema.TypedDictField, 

49) -> TypeGuard[FunctionSchemaWithInnerSchema]: 

50 return is_core_schema(schema) and schema['type'] in ('function-before', 'function-after', 'function-wrap') 

51 

52 

53def is_list_like_schema_with_items_schema( 

54 schema: CoreSchema, 

55) -> TypeGuard[ 

56 core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema 

57]: 

58 return schema['type'] in ('list', 'tuple-variable', 'set', 'frozenset') 

59 

60 

61def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str: 

62 """ 

63 Produces the ref to be used for this type by pydantic_core's core schemas. 

64 

65 This `args_override` argument was added for the purpose of creating valid recursive references 

66 when creating generic models without needing to create a concrete class. 

67 """ 

68 origin = type_ 

69 args = args_override or () 

70 generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None) 

71 if generic_metadata: 

72 origin = generic_metadata['origin'] or origin 

73 args = generic_metadata['args'] or args 

74 

75 module_name = getattr(origin, '__module__', '<No __module__>') 

76 qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>') 

77 type_ref = f'{module_name}.{qualname}:{id(origin)}' 

78 

79 arg_refs: list[str] = [] 

80 for arg in args: 

81 if isinstance(arg, str): 

82 # Handle string literals as a special case; we may be able to remove this special handling if we 

83 # wrap them in a ForwardRef at some point. 

84 arg_ref = f'{arg}:str-{id(arg)}' 

85 else: 

86 arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}' 

87 arg_refs.append(arg_ref) 

88 if arg_refs: 

89 type_ref = f'{type_ref}[{",".join(arg_refs)}]' 

90 return type_ref 

91 

92 

93def consolidate_refs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

94 """ 

95 This function walks a schema recursively, replacing all but the first occurrence of each ref with 

96 a definition-ref schema referencing that ref. 

97 

98 This makes the fundamental assumption that any time two schemas have the same ref, occurrences 

99 after the first can safely be replaced. 

100 

101 In most cases, schemas with the same ref should not actually be produced. However, when building recursive 

102 models with multiple references to themselves at some level in the field hierarchy, it is difficult to avoid 

103 getting multiple (identical) copies of the same schema with the same ref. This function removes the copied refs, 

104 but is safe because the "duplicate" refs refer to the same schema. 

105 

106 There is one case where we purposely emit multiple (different) schemas with the same ref: when building 

107 recursive generic models. In this case, as an implementation detail, recursive generic models will emit 

108 a _non_-identical schema deeper in the tree with a re-used ref, with the intent that _that_ schema will 

109 be replaced with a recursive reference once the specific generic parametrization to use can be determined. 

110 """ 

111 refs: set[str] = set() 

112 

113 def _replace_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema: 

114 ref: str | None = s.get('ref') # type: ignore[assignment] 

115 if ref: 

116 if ref in refs: 

117 return {'type': 'definition-ref', 'schema_ref': ref} 

118 refs.add(ref) 

119 return s 

120 

121 schema = WalkCoreSchema(_replace_refs, apply_before_recurse=True).walk(schema) 

122 return schema 

123 

124 

125def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]: 

126 # Only collect valid definitions. This is equivalent to collecting all definitions for "valid" schemas, 

127 # but allows us to reuse this logic while removing "invalid" definitions 

128 valid_definitions = dict() 

129 

130 def _record_valid_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema: 

131 ref: str | None = s.get('ref') # type: ignore[assignment] 

132 if ref: 

133 metadata = s.get('metadata') 

134 definition_is_invalid = isinstance(metadata, dict) and 'invalid' in metadata 

135 if not definition_is_invalid: 

136 valid_definitions[ref] = s 

137 return s 

138 

139 WalkCoreSchema(_record_valid_refs).walk(schema) 

140 

141 return valid_definitions 

142 

143 

144def remove_unnecessary_invalid_definitions(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

145 valid_refs = collect_definitions(schema).keys() 

146 

147 def _remove_invalid_defs(s: core_schema.CoreSchema) -> core_schema.CoreSchema: 

148 if s['type'] != 'definitions': 

149 return s 

150 

151 new_schema = s.copy() 

152 

153 new_definitions: list[CoreSchema] = [] 

154 for definition in s['definitions']: 

155 metadata = definition.get('metadata') 

156 # fmt: off 

157 if ( 

158 isinstance(metadata, dict) 

159 and 'invalid' in metadata 

160 and definition['ref'] in valid_refs # type: ignore 

161 ): 

162 continue 

163 # fmt: on 

164 new_definitions.append(definition) 

165 

166 new_schema['definitions'] = new_definitions 

167 return new_schema 

168 

169 return WalkCoreSchema(_remove_invalid_defs).walk(schema) 

170 

171 

172def define_expected_missing_refs( 

173 schema: core_schema.CoreSchema, allowed_missing_refs: set[str] 

174) -> core_schema.CoreSchema: 

175 if not allowed_missing_refs: 

176 # in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema 

177 # this is a common case (will be hit for all non-generic models), so it's worth optimizing for 

178 return schema 

179 refs = set() 

180 

181 def _record_refs(s: core_schema.CoreSchema) -> core_schema.CoreSchema: 

182 ref: str | None = s.get('ref') # type: ignore[assignment] 

183 if ref: 

184 refs.add(ref) 

185 return s 

186 

187 WalkCoreSchema(_record_refs).walk(schema) 

188 

189 expected_missing_refs = allowed_missing_refs.difference(refs) 

190 if expected_missing_refs: 

191 definitions: list[core_schema.CoreSchema] = [ 

192 # TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail 

193 core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True}) 

194 for ref in expected_missing_refs 

195 ] 

196 return core_schema.definitions_schema(schema, definitions) 

197 return schema 

198 

199 

200def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]: 

201 invalid_schemas: list[core_schema.CoreSchema] = [] 

202 

203 def _is_schema_valid(s: core_schema.CoreSchema) -> core_schema.CoreSchema: 

204 if s.get('metadata', {}).get('invalid'): 

205 invalid_schemas.append(s) 

206 return s 

207 

208 WalkCoreSchema(_is_schema_valid).walk(schema) 

209 return invalid_schemas 

210 

211 

212class WalkCoreSchema: 

213 """ 

214 Transforms a CoreSchema by recursively calling the provided function on all (nested) fields of type CoreSchema 

215 

216 The provided function need not actually modify the schema in any way, but will still be called on all nested 

217 fields with type CoreSchema. (This can be useful for collecting information about refs, etc.) 

218 """ 

219 

220 def __init__( 

221 self, f: Callable[[core_schema.CoreSchema], core_schema.CoreSchema], apply_before_recurse: bool = True 

222 ): 

223 self.f = f 

224 

225 self.apply_before_recurse = apply_before_recurse 

226 

227 self._schema_type_to_method = self._build_schema_type_to_method() 

228 

229 def _build_schema_type_to_method(self) -> dict[CoreSchemaType, Callable[[CoreSchema], CoreSchema]]: 

230 mapping: dict[CoreSchemaType, Callable[[CoreSchema], CoreSchema]] = {} 

231 for key in get_args(CoreSchemaType): 

232 method_name = f"handle_{key.replace('-', '_')}_schema" 

233 mapping[key] = getattr(self, method_name, self._handle_other_schemas) 

234 return mapping 

235 

236 def walk(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

237 return self._walk(schema) 

238 

239 def _walk(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

240 schema = schema.copy() 

241 if self.apply_before_recurse: 

242 schema = self.f(schema) 

243 method = self._schema_type_to_method[schema['type']] 

244 schema = method(schema) 

245 if not self.apply_before_recurse: 

246 schema = self.f(schema) 

247 return schema 

248 

249 def _handle_other_schemas(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: 

250 if 'schema' in schema: 

251 schema['schema'] = self._walk(schema['schema']) # type: ignore 

252 return schema 

253 

254 def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema) -> CoreSchema: 

255 new_definitions = [] 

256 for definition in schema['definitions']: 

257 updated_definition = self._walk(definition) 

258 if 'ref' in updated_definition: 

259 # If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions 

260 # This is most likely to happen due to replacing something with a definition reference, in 

261 # which case it should certainly not go in the definitions list 

262 new_definitions.append(updated_definition) 

263 new_inner_schema = self._walk(schema['schema']) 

264 

265 if not new_definitions and len(schema) == 3: 

266 # This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema 

267 return new_inner_schema 

268 

269 new_schema = schema.copy() 

270 new_schema['schema'] = new_inner_schema 

271 new_schema['definitions'] = new_definitions 

272 return new_schema 

273 

274 def handle_list_schema(self, schema: core_schema.ListSchema) -> CoreSchema: 

275 if 'items_schema' in schema: 

276 schema['items_schema'] = self._walk(schema['items_schema']) 

277 return schema 

278 

279 def handle_set_schema(self, schema: core_schema.SetSchema) -> CoreSchema: 

280 if 'items_schema' in schema: 

281 schema['items_schema'] = self._walk(schema['items_schema']) 

282 return schema 

283 

284 def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema) -> CoreSchema: 

285 if 'items_schema' in schema: 

286 schema['items_schema'] = self._walk(schema['items_schema']) 

287 return schema 

288 

289 def handle_generator_schema(self, schema: core_schema.GeneratorSchema) -> CoreSchema: 

290 if 'items_schema' in schema: 

291 schema['items_schema'] = self._walk(schema['items_schema']) 

292 return schema 

293 

294 def handle_tuple_variable_schema( 

295 self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema 

296 ) -> CoreSchema: 

297 schema = cast(core_schema.TupleVariableSchema, schema) 

298 if 'items_schema' in schema: 

299 # Could drop the # type: ignore on the next line if we made 'mode' required in TupleVariableSchema 

300 schema['items_schema'] = self._walk(schema['items_schema']) 

301 return schema 

302 

303 def handle_tuple_positional_schema( 

304 self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema 

305 ) -> CoreSchema: 

306 schema = cast(core_schema.TuplePositionalSchema, schema) 

307 schema['items_schema'] = [self._walk(v) for v in schema['items_schema']] 

308 if 'extra_schema' in schema: 

309 schema['extra_schema'] = self._walk(schema['extra_schema']) 

310 return schema 

311 

312 def handle_dict_schema(self, schema: core_schema.DictSchema) -> CoreSchema: 

313 if 'keys_schema' in schema: 

314 schema['keys_schema'] = self._walk(schema['keys_schema']) 

315 if 'values_schema' in schema: 

316 schema['values_schema'] = self._walk(schema['values_schema']) 

317 return schema 

318 

319 def handle_function_schema( 

320 self, 

321 schema: AnyFunctionSchema, 

322 ) -> CoreSchema: 

323 if not is_function_with_inner_schema(schema): 

324 return schema 

325 schema['schema'] = self._walk(schema['schema']) 

326 return schema 

327 

328 def handle_union_schema(self, schema: core_schema.UnionSchema) -> CoreSchema: 

329 schema['choices'] = [self._walk(v) for v in schema['choices']] 

330 return schema 

331 

332 def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> CoreSchema: 

333 new_choices: dict[str | int, str | int | CoreSchema] = {} 

334 for k, v in schema['choices'].items(): 

335 new_choices[k] = v if isinstance(v, (str, int)) else self._walk(v) 

336 schema['choices'] = new_choices 

337 return schema 

338 

339 def handle_chain_schema(self, schema: core_schema.ChainSchema) -> CoreSchema: 

340 schema['steps'] = [self._walk(v) for v in schema['steps']] 

341 return schema 

342 

343 def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema) -> CoreSchema: 

344 schema['lax_schema'] = self._walk(schema['lax_schema']) 

345 schema['strict_schema'] = self._walk(schema['strict_schema']) 

346 return schema 

347 

348 def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> CoreSchema: 

349 if 'extra_validator' in schema: 

350 schema['extra_validator'] = self._walk(schema['extra_validator']) 

351 replaced_fields: dict[str, core_schema.TypedDictField] = {} 

352 for k, v in schema['fields'].items(): 

353 replaced_field = v.copy() 

354 replaced_field['schema'] = self._walk(v['schema']) 

355 replaced_fields[k] = replaced_field 

356 schema['fields'] = replaced_fields 

357 return schema 

358 

359 def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema) -> CoreSchema: 

360 replaced_fields: list[core_schema.DataclassField] = [] 

361 for field in schema['fields']: 

362 replaced_field = field.copy() 

363 replaced_field['schema'] = self._walk(field['schema']) 

364 replaced_fields.append(replaced_field) 

365 schema['fields'] = replaced_fields 

366 return schema 

367 

368 def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema) -> CoreSchema: 

369 replaced_arguments_schema = [] 

370 for param in schema['arguments_schema']: 

371 replaced_param = param.copy() 

372 replaced_param['schema'] = self._walk(param['schema']) 

373 replaced_arguments_schema.append(replaced_param) 

374 schema['arguments_schema'] = replaced_arguments_schema 

375 if 'var_args_schema' in schema: 

376 schema['var_args_schema'] = self._walk(schema['var_args_schema']) 

377 if 'var_kwargs_schema' in schema: 

378 schema['var_kwargs_schema'] = self._walk(schema['var_kwargs_schema']) 

379 return schema 

380 

381 def handle_call_schema(self, schema: core_schema.CallSchema) -> CoreSchema: 

382 schema['arguments_schema'] = self._walk(schema['arguments_schema']) 

383 if 'return_schema' in schema: 

384 schema['return_schema'] = self._walk(schema['return_schema']) 

385 return schema