Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/function_spec.py: 51%

214 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Defines an input type specification for tf.function.""" 

16 

17import functools 

18import inspect 

19from typing import Any, Dict, Tuple 

20 

21import numpy as np 

22import six 

23 

24from tensorflow.core.function import trace_type 

25from tensorflow.core.function.polymorphism import function_type as function_type_lib 

26from tensorflow.python.eager.polymorphic_function import composite_tensor_utils 

27from tensorflow.python.framework import composite_tensor 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import ops 

30from tensorflow.python.framework import tensor_spec 

31from tensorflow.python.framework import type_spec 

32from tensorflow.python.ops import resource_variable_ops 

33from tensorflow.python.util import nest 

34 

35# Sentinel value used by with ConcreteFunction's structured signature to 

36# indicate that a non-tensor parameter should use the value that was 

37# specified when the concrete function was created. 

38BOUND_VALUE = object() 

39 

40 

41def to_fullargspec(function_type: function_type_lib.FunctionType, 

42 default_values: Dict[str, Any]) -> inspect.FullArgSpec: 

43 """Generates backwards compatible FullArgSpec from FunctionType.""" 

44 args = [] 

45 varargs = None 

46 varkw = None 

47 defaults = [] 

48 kwonlyargs = [] 

49 kwonlydefaults = {} 

50 

51 for parameter in function_type.parameters.values(): 

52 if parameter.kind in [ 

53 inspect.Parameter.POSITIONAL_ONLY, 

54 inspect.Parameter.POSITIONAL_OR_KEYWORD 

55 ]: 

56 args.append(parameter.name) 

57 if parameter.default is not inspect.Parameter.empty: 

58 defaults.append(default_values[parameter.name]) 

59 elif parameter.kind is inspect.Parameter.KEYWORD_ONLY: 

60 kwonlyargs.append(parameter.name) 

61 if parameter.default is not inspect.Parameter.empty: 

62 kwonlydefaults[parameter.name] = default_values[parameter.name] 

63 elif parameter.kind is inspect.Parameter.VAR_POSITIONAL: 

64 varargs = parameter.name 

65 elif parameter.kind is inspect.Parameter.VAR_KEYWORD: 

66 varkw = parameter.name 

67 

68 return inspect.FullArgSpec( 

69 args, 

70 varargs, 

71 varkw, 

72 tuple(defaults) if defaults else None, 

73 kwonlyargs, 

74 kwonlydefaults if kwonlydefaults else None, 

75 annotations={}) 

76 

77 

78def _to_default_values(fullargspec): 

79 """Returns default values from the function's inspected fullargspec.""" 

80 if fullargspec.defaults is not None: 

81 defaults = { 

82 name: value for name, value in zip( 

83 fullargspec.args[-len(fullargspec.defaults):], fullargspec.defaults) 

84 } 

85 else: 

86 defaults = {} 

87 

88 if fullargspec.kwonlydefaults is not None: 

89 defaults.update(fullargspec.kwonlydefaults) 

90 

91 defaults = { 

92 function_type_lib.sanitize_arg_name(name): value 

93 for name, value in defaults.items() 

94 } 

95 

96 return defaults 

97 

98 

99def to_function_type(fullargspec): 

100 """Generates FunctionType and default values from fullargspec.""" 

101 default_values = _to_default_values(fullargspec) 

102 parameters = [] 

103 

104 for arg in fullargspec.args: 

105 arg_name = function_type_lib.sanitize_arg_name(arg) 

106 parameters.append( 

107 function_type_lib.Parameter( 

108 arg_name, function_type_lib.Parameter.POSITIONAL_OR_KEYWORD, 

109 arg_name in default_values, None)) 

110 

111 if fullargspec.varargs is not None: 

112 parameters.append( 

113 function_type_lib.Parameter(fullargspec.varargs, 

114 function_type_lib.Parameter.VAR_POSITIONAL, 

115 False, None)) 

116 

117 for kwarg in fullargspec.kwonlyargs: 

118 parameters.append( 

119 function_type_lib.Parameter( 

120 function_type_lib.sanitize_arg_name(kwarg), 

121 function_type_lib.Parameter.KEYWORD_ONLY, kwarg in default_values, 

122 None)) 

123 

124 if fullargspec.varkw is not None: 

125 parameters.append( 

126 function_type_lib.Parameter(fullargspec.varkw, 

127 function_type_lib.Parameter.VAR_KEYWORD, 

128 False, None)) 

129 

130 return function_type_lib.FunctionType(parameters), default_values 

131 

132 

133def to_input_signature(function_type): 

134 """Extracts an input_signature from function_type instance.""" 

135 constrained_parameters = list(function_type.parameters.keys()) 

136 

137 # self does not have a constraint in input_signature 

138 if "self" in constrained_parameters: 

139 constrained_parameters.pop(0) 

140 

141 # There are no parameters to constrain. 

142 if not constrained_parameters: 

143 return tuple() 

144 

145 constraints = [] 

146 is_auto_constrained = False 

147 

148 for parameter_name in constrained_parameters: 

149 parameter = function_type.parameters[parameter_name] 

150 constraint = None 

151 if parameter.type_constraint: 

152 # Generate legacy constraint representation. 

153 constraint = parameter.type_constraint.placeholder_value( 

154 trace_type.InternalPlaceholderContext(unnest_only=True) 

155 ) 

156 if any( 

157 not isinstance(arg, tensor_spec.TensorSpec) 

158 for arg in nest.flatten([constraint], expand_composites=True)): 

159 # input_signature only supports contiguous TensorSpec composites 

160 is_auto_constrained = True 

161 break 

162 else: 

163 constraints.append(constraint) 

164 

165 # All constraints were generated by FunctionType 

166 if is_auto_constrained and not constraints: 

167 return tuple() 

168 

169 # If the list is empty then there was no input_signature specified. 

170 return tuple(constraints) if constraints else None 

171 

172 

173# TODO(b/214462107): Clean up and migrate to core/function when unblocked. 

174class FunctionSpec(object): 

175 """Specification of how to bind arguments to a function.""" 

176 

177 @classmethod 

178 def from_function_and_signature(cls, 

179 python_function, 

180 input_signature, 

181 is_pure=False, 

182 jit_compile=None): 

183 """Creates a FunctionSpec instance given a python function and signature. 

184 

185 Args: 

186 python_function: a function to inspect 

187 input_signature: a signature of the function (None, if variable) 

188 is_pure: if True all input arguments (including variables and constants) 

189 will be converted to tensors and no variable changes allowed. 

190 jit_compile: see `tf.function` 

191 

192 Returns: 

193 instance of FunctionSpec 

194 """ 

195 _validate_signature(input_signature) 

196 

197 function_type = function_type_lib.FunctionType.from_callable( 

198 python_function) 

199 default_values = function_type_lib.FunctionType.get_default_values( 

200 python_function) 

201 

202 if input_signature is not None: 

203 input_signature = tuple(input_signature) 

204 function_type = function_type_lib.add_type_constraints( 

205 function_type, input_signature, default_values) 

206 

207 # Get the function's name. Remove functools.partial wrappers if necessary. 

208 while isinstance(python_function, functools.partial): 

209 python_function = python_function.func 

210 name = getattr(python_function, "__name__", "f") 

211 

212 return FunctionSpec( 

213 function_type, 

214 default_values, 

215 is_pure=is_pure, 

216 jit_compile=jit_compile, 

217 name=name) 

218 

219 @classmethod 

220 def from_fullargspec_and_signature(cls, 

221 fullargspec, 

222 input_signature, 

223 is_pure=False, 

224 name=None, 

225 jit_compile=None): 

226 """Construct FunctionSpec from legacy FullArgSpec format.""" 

227 function_type, default_values = to_function_type(fullargspec) 

228 if input_signature: 

229 input_signature = tuple(input_signature) 

230 _validate_signature(input_signature) 

231 function_type = function_type_lib.add_type_constraints( 

232 function_type, input_signature, default_values) 

233 

234 return FunctionSpec(function_type, default_values, is_pure, 

235 name, jit_compile) 

236 

237 def __init__(self, 

238 function_type, 

239 default_values, 

240 is_pure=False, 

241 name=None, 

242 jit_compile=None): 

243 """Constructs a FunctionSpec describing a python function. 

244 

245 Args: 

246 function_type: A FunctionType describing the python function signature. 

247 default_values: Dictionary mapping parameter names to default values. 

248 is_pure: if True all input arguments (including variables and constants) 

249 will be converted to tensors and no variable changes allowed. 

250 name: Name of the function 

251 jit_compile: see `tf.function`. 

252 """ 

253 self._function_type = function_type 

254 self._default_values = default_values 

255 self._fullargspec = to_fullargspec(function_type, default_values) 

256 self._is_pure = is_pure 

257 self._jit_compile = jit_compile 

258 

259 # TODO(edloper): Include name when serializing for SavedModel? 

260 self._name = name or "f" 

261 self._input_signature = to_input_signature(function_type) 

262 

263 @property 

264 def default_values(self): 

265 """Returns dict mapping parameter names to default values.""" 

266 return self._default_values 

267 

268 @property 

269 def function_type(self): 

270 """Returns a FunctionType representing the Python function signature.""" 

271 return self._function_type 

272 

273 @property 

274 def fullargspec(self): 

275 return self._fullargspec 

276 

277 # TODO(fmuham): Replace usages with FunctionType and remove. 

278 @property 

279 def input_signature(self): 

280 return self._input_signature 

281 

282 # TODO(fmuham): Replace usages with FunctionType and remove. 

283 @property 

284 def flat_input_signature(self): 

285 return tuple(nest.flatten(self.input_signature, expand_composites=True)) 

286 

287 @property 

288 def is_pure(self): 

289 return self._is_pure 

290 

291 @property 

292 def jit_compile(self): 

293 return self._jit_compile 

294 

295 # TODO(fmuham): Replace usages and remove. 

296 @property 

297 def arg_names(self): 

298 return list( 

299 p.name 

300 for p in self.function_type.parameters.values() 

301 if ( 

302 p.kind is function_type_lib.Parameter.POSITIONAL_ONLY 

303 or p.kind is function_type_lib.Parameter.POSITIONAL_OR_KEYWORD 

304 ) 

305 ) 

306 

307 def make_canonicalized_monomorphic_type( 

308 self, 

309 args: Any, 

310 kwargs: Any, 

311 captures: Any = None, 

312 ) -> Tuple[function_type_lib.FunctionType, 

313 trace_type.InternalTracingContext]: 

314 """Generates function type given the function arguments.""" 

315 if captures is None: 

316 captures = dict() 

317 

318 kwargs = { 

319 function_type_lib.sanitize_arg_name(name): value 

320 for name, value in kwargs.items() 

321 } 

322 

323 _, function_type, type_context = ( 

324 function_type_lib.canonicalize_to_monomorphic( 

325 args, kwargs, self.default_values, captures, self.function_type 

326 ) 

327 ) 

328 

329 return function_type, type_context 

330 

331 def signature_summary(self, default_values=False): 

332 """Returns a string summarizing this function's signature. 

333 

334 Args: 

335 default_values: If true, then include default values in the signature. 

336 

337 Returns: 

338 A `string`. 

339 """ 

340 args = list(self._arg_names) 

341 if default_values: 

342 for (i, default) in self._arg_indices_to_default_values.items(): 

343 args[i] += "={}".format(default) 

344 if self._fullargspec.kwonlyargs: 

345 args.append("*") 

346 for arg_name in self._fullargspec.kwonlyargs: 

347 args.append(arg_name) 

348 if default_values and arg_name in self._fullargspec.kwonlydefaults: 

349 args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name]) 

350 return f"{self._name}({', '.join(args)})" 

351 

352 def canonicalize_function_inputs(self, args, kwargs): 

353 """Canonicalizes `args` and `kwargs`. 

354 

355 Canonicalize the inputs to the Python function using a `FunctionSpec` 

356 instance. In particular, we parse the varargs and kwargs that the 

357 original function was called with into a tuple corresponding to the 

358 Python function's positional (named) arguments and a dictionary 

359 corresponding to its kwargs. Missing default arguments are added. 

360 

361 If this `FunctionSpec` has an input signature, then it is used to convert 

362 arguments to tensors; otherwise, any inputs containing numpy arrays are 

363 converted to tensors. 

364 

365 Additionally, any inputs containing numpy arrays are converted to Tensors. 

366 

367 Args: 

368 args: The varargs this object was called with. 

369 kwargs: The keyword args this function was called with. 

370 

371 Returns: 

372 A canonicalized ordering of the inputs, as well as full and filtered 

373 (Tensors and Variables only) versions of their concatenated flattened 

374 representations, represented by a tuple in the form (args, kwargs, 

375 flat_args, filtered_flat_args). Here: `args` is a full list of bound 

376 arguments, and `kwargs` contains only true keyword arguments, as opposed 

377 to named arguments called in a keyword-like fashion. 

378 

379 Raises: 

380 ValueError: If a keyword in `kwargs` cannot be matched with a positional 

381 argument when an input signature is specified, or when the inputs 

382 do not conform to the input signature. 

383 """ 

384 if self.is_pure: 

385 args, kwargs = _convert_variables_to_tensors(args, kwargs) 

386 args, kwargs = self.bind_function_inputs(args, kwargs) 

387 filtered_flat_args = filter_function_inputs(args, kwargs) 

388 

389 return args, kwargs, filtered_flat_args 

390 

391 def bind_function_inputs(self, args, kwargs): 

392 """Bind `args` and `kwargs` into a canonicalized signature args, kwargs.""" 

393 sanitized_kwargs = { 

394 function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items() 

395 } 

396 if len(kwargs) != len(sanitized_kwargs): 

397 raise ValueError(f"Name collision after sanitization. Please rename " 

398 f"tf.function input parameters. Original: " 

399 f"{sorted(kwargs.keys())}, Sanitized: " 

400 f"{sorted(sanitized_kwargs.keys())}") 

401 

402 try: 

403 bound_arguments = self.function_type.bind_with_defaults( 

404 args, sanitized_kwargs, self.default_values) 

405 except Exception as e: 

406 raise TypeError( 

407 f"Binding inputs to tf.function `{self._name}` failed due to `{e}`. " 

408 f"Received args: {args} and kwargs: {sanitized_kwargs} for signature:" 

409 f" {self.function_type}." 

410 ) from e 

411 return bound_arguments.args, bound_arguments.kwargs 

412 

413 

414def _validate_signature(signature): 

415 """Checks the input_signature to be valid.""" 

416 if signature is None: 

417 return 

418 

419 if not isinstance(signature, (tuple, list)): 

420 raise TypeError("input_signature must be either a tuple or a list, got " 

421 f"{type(signature)}.") 

422 

423 # TODO(xjun): Allow VariableSpec once we figure out API for de-aliasing. 

424 variable_specs = _get_variable_specs(signature) 

425 if variable_specs: 

426 raise TypeError( 

427 f"input_signature doesn't support VariableSpec, got {variable_specs}") 

428 

429 if any(not isinstance(arg, tensor_spec.TensorSpec) 

430 for arg in nest.flatten(signature, expand_composites=True)): 

431 bad_args = [ 

432 arg for arg in nest.flatten(signature, expand_composites=True) 

433 if not isinstance(arg, tensor_spec.TensorSpec) 

434 ] 

435 raise TypeError("input_signature must be a possibly nested sequence of " 

436 f"TensorSpec objects, got invalid args {bad_args} with " 

437 f"types {list(six.moves.map(type, bad_args))}.") 

438 

439 

440def _to_tensor_or_tensor_spec(x): 

441 return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else 

442 ops.convert_to_tensor(x)) 

443 

444 

445def _convert_variables_to_tensors(args, kwargs): 

446 args = [_to_tensor_or_tensor_spec(x) for x in args] 

447 kwargs = {kw: _to_tensor_or_tensor_spec(x) for kw, x in kwargs.items()} 

448 return tuple(args), kwargs 

449 

450 

451# TODO(fmuham): Migrate to use TraceType/FunctionType _to_tensors. 

452def filter_function_inputs(args, kwargs): 

453 """Filters and flattens args and kwargs.""" 

454 flat_inputs = composite_tensor_utils.flatten_with_variables( 

455 args) + composite_tensor_utils.flatten_with_variables(kwargs) 

456 

457 for index, flat_input in enumerate(flat_inputs): 

458 if hasattr(flat_input, "__array__") and not ( 

459 hasattr(flat_input, "_should_act_as_resource_variable") 

460 or isinstance( 

461 flat_input, 

462 ( 

463 ops.Tensor, 

464 resource_variable_ops.BaseResourceVariable, 

465 np.str_, 

466 type, 

467 composite_tensor.CompositeTensor, 

468 ), 

469 ) 

470 ): 

471 ndarray = flat_input.__array__() 

472 if not isinstance(ndarray, np.ndarray): 

473 raise TypeError(f"The output of __array__ must be an np.ndarray, " 

474 f"got {type(ndarray)} from {flat_input}.") 

475 flat_inputs[index] = constant_op.constant(ndarray) 

476 

477 filtered_flat_inputs = [ 

478 t for t in flat_inputs 

479 if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) 

480 ] 

481 

482 return filtered_flat_inputs 

483 

484 

485def _get_variable_specs(args): 

486 """Returns `VariableSpecs` from `args`.""" 

487 variable_specs = [] 

488 for arg in nest.flatten(args): 

489 if not isinstance(arg, type_spec.TypeSpec): 

490 continue 

491 if isinstance(arg, resource_variable_ops.VariableSpec): 

492 variable_specs.append(arg) 

493 elif not isinstance(arg, tensor_spec.TensorSpec): 

494 # arg is a CompositeTensor spec. 

495 variable_specs.extend(_get_variable_specs(arg._component_specs)) # pylint: disable=protected-access 

496 return variable_specs 

497 

498 

499# TODO(fmuham): Replace usages with TraceType and remove. 

500def is_same_structure(structure1, structure2, check_values=False): 

501 """Check two structures for equality, optionally of types and of values.""" 

502 try: 

503 nest.assert_same_structure(structure1, structure2, expand_composites=True) 

504 except (ValueError, TypeError): 

505 return False 

506 if check_values: 

507 flattened1 = nest.flatten(structure1, expand_composites=True) 

508 flattened2 = nest.flatten(structure2, expand_composites=True) 

509 # First check the types to avoid AttributeErrors. 

510 if any(type(f1) is not type(f2) for f1, f2 in zip(flattened1, flattened2)): 

511 return False 

512 return flattened1 == flattened2 

513 return True