Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/op_def_library.py: 13%

448 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Class to hold a library of OpDefs and use it to create Brain operations.""" 

17 

18from google.protobuf import text_format 

19from tensorflow.core.config import flags 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.core.framework import tensor_pb2 

22from tensorflow.core.framework import tensor_shape_pb2 

23from tensorflow.core.framework import types_pb2 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import op_callbacks 

26from tensorflow.python.framework import op_def_library_pybind 

27from tensorflow.python.framework import op_def_registry 

28from tensorflow.python.framework import ops 

29from tensorflow.python.framework import tensor_shape 

30from tensorflow.python.platform import tf_logging as logging 

31from tensorflow.python.util import _pywrap_utils 

32from tensorflow.python.util import compat 

33from tensorflow.python.util import tf_contextlib 

34 

35 

36def _Attr(op_def, name): 

37 for attr in op_def.attr: 

38 if attr.name == name: 

39 return attr 

40 raise TypeError(f"Inconsistent OpDef for '{op_def.name}', missing attr " 

41 f"'{name}'") 

42 

43 

44def _AttrValue(attr_protos, name, op_type_name): 

45 if name in attr_protos: 

46 return attr_protos[name] 

47 raise TypeError(f"Inconsistent OpDef for '{op_type_name}', missing attr " 

48 f"'{name}' from '{attr_protos}'.") 

49 

50 

51def _SatisfiesTypeConstraint(dtype, attr_def, param_name): 

52 if attr_def.HasField("allowed_values"): 

53 allowed_list = attr_def.allowed_values.list.type 

54 allowed_values = ", ".join(dtypes.as_dtype(x).name for x in allowed_list) 

55 if dtype not in allowed_list: 

56 raise TypeError( 

57 f"Value passed to parameter '{param_name}' has DataType " 

58 f"{dtypes.as_dtype(dtype).name} not in list of allowed values: " 

59 f"{allowed_values}") 

60 

61 

62def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): 

63 if attr_def.has_minimum and length < attr_def.minimum: 

64 raise ValueError(f"Attr '{param_name}' of '{op_type_name}' Op passed list " 

65 f"of length {length} less than minimum " 

66 f"{attr_def.minimum}.") 

67 

68 

69def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): 

70 if value not in attr_def.allowed_values.list.s: 

71 allowed_values = '", "'.join( 

72 map(compat.as_text, attr_def.allowed_values.list.s)) 

73 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed string " 

74 f"'{compat.as_text(value)}' not in: \"{allowed_values}\".") 

75 

76 

77def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): 

78 if value < attr_def.minimum: 

79 raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed {value} " 

80 f"less than minimum {attr_def.minimum}.") 

81 

82 

83def _IsListParameter(arg): 

84 if arg.number_attr: 

85 return True 

86 elif arg.type_list_attr: 

87 return True 

88 return False 

89 

90 

91def _NumTypeFields(arg): 

92 num = 0 

93 if arg.type != types_pb2.DT_INVALID: num += 1 

94 if arg.type_attr: num += 1 

95 if arg.type_list_attr: num += 1 

96 return num 

97 

98 

99def _IsListValue(v): 

100 return isinstance(v, (list, tuple)) 

101 

102 

103def _Flatten(l): 

104 """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" 

105 # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] 

106 l_of_l = [x if _IsListValue(x) else [x] for x in l] 

107 # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] 

108 return [item for sublist in l_of_l for item in sublist] 

109 

110 

111def _Restructure(l, structure): 

112 """Returns the elements of list l structured according to the given structure. 

113 

114 A structure is represented by a list whose elements are either 

115 `None` or a non-negative integer. `None` corresponds to a single 

116 element in the output list, and an integer N corresponds to a nested 

117 list of length N. 

118 

119 The function returns a data structure whose shape is given by 

120 `structure`, and whose elements are taken from `l`. If `structure` 

121 is a singleton, the function returns the single data structure 

122 implied by the 0th element of `structure`. For example: 

123 

124 _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) 

125 -> ["foo", ["bar", "baz"], "qux"] 

126 

127 _Restructure(["foo"], [None]) -> "foo" 

128 

129 _Restructure(["foo"], [1]) -> ["foo"] 

130 

131 _Restructure([], [0]) -> [] 

132 

133 Args: 

134 l: A list. 

135 structure: A list whose elements are either `None` or a non-negative 

136 integer. 

137 

138 Returns: 

139 The elements of `l`, restructured according to `structure`. If 

140 `structure` is a list of length 1, this function returns the 

141 single data structure implied by `structure[0]`. 

142 

143 """ 

144 result = [] 

145 current_index = 0 

146 for element in structure: 

147 if element is None: 

148 result.append(l[current_index]) 

149 current_index += 1 

150 else: 

151 result.append(l[current_index:current_index+element]) 

152 current_index += element 

153 

154 if len(result) == 1: 

155 return result[0] 

156 else: 

157 return tuple(result) 

158 

159 

160def _MakeFloat(v, arg_name): 

161 if not isinstance(v, compat.real_types): 

162 raise TypeError(f"Expected float for argument '{arg_name}' not {repr(v)}.") 

163 return float(v) 

164 

165 

166def _MakeInt(v, arg_name): 

167 if isinstance(v, str): 

168 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") 

169 try: 

170 return int(v) 

171 except (ValueError, TypeError): 

172 raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") 

173 

174 

175def _MakeStr(v, arg_name): 

176 if not isinstance(v, compat.bytes_or_text_types): 

177 raise TypeError(f"Expected string for argument '{arg_name}' not {repr(v)}.") 

178 return compat.as_bytes(v) # Convert unicode strings to bytes. 

179 

180 

181def _MakeBool(v, arg_name): 

182 if not isinstance(v, bool): 

183 raise TypeError(f"Expected bool for argument '{arg_name}' not {repr(v)}.") 

184 return v 

185 

186 

187def _MakeType(v, arg_name): 

188 try: 

189 v = dtypes.as_dtype(v).base_dtype 

190 except TypeError: 

191 raise TypeError(f"Expected DataType for argument '{arg_name}' not " 

192 f"{repr(v)}.") 

193 return v.as_datatype_enum 

194 

195 

196def _MakeShape(v, arg_name): 

197 """Convert v into a TensorShapeProto.""" 

198 # Args: 

199 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 

200 # arg_name: String, for error messages. 

201 

202 # Returns: 

203 # A TensorShapeProto. 

204 if isinstance(v, tensor_shape_pb2.TensorShapeProto): 

205 for d in v.dim: 

206 if d.name: 

207 logging.warning("Warning: TensorShapeProto with a named dimension: %s", 

208 str(v)) 

209 break 

210 return v 

211 try: 

212 return tensor_shape.as_shape(v).as_proto() 

213 except TypeError as e: 

214 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " 

215 f"TensorShape: {e}") 

216 except ValueError as e: 

217 raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " 

218 f"TensorShape: {e}") 

219 

220 

221def _MakeTensor(v, arg_name): 

222 """Ensure v is a TensorProto.""" 

223 if isinstance(v, tensor_pb2.TensorProto): 

224 return v 

225 raise TypeError( 

226 f"Don't know how to convert {repr(v)} to a TensorProto for argument " 

227 f"'{arg_name}'") 

228 

229 

230def _MakeFunc(v, arg_name): 

231 """Ensure v is a func.""" 

232 if isinstance(v, attr_value_pb2.NameAttrList): 

233 return v 

234 if isinstance(v, compat.bytes_or_text_types): 

235 fn_attr = attr_value_pb2.NameAttrList(name=v) 

236 elif hasattr(v, "add_to_graph"): 

237 v.add_to_graph(ops.get_default_graph()) 

238 if hasattr(v, "_as_name_attr_list"): 

239 fn_attr = v._as_name_attr_list # pylint: disable=protected-access 

240 else: 

241 fn_attr = attr_value_pb2.NameAttrList(name=v.name) 

242 else: 

243 raise TypeError(f"Don't know how to convert {repr(v)} to a func for " 

244 f"argument {arg_name}") 

245 return fn_attr 

246 

247 

248# pylint: disable=g-doc-return-or-yield 

249@tf_contextlib.contextmanager 

250def _MaybeColocateWith(inputs): 

251 """A context manager for (maybe) colocating with a list of input tensors. 

252 

253 Args: 

254 inputs: A list of `Tensor` or `Operation` objects. 

255 

256 Returns: 

257 A context manager. 

258 """ 

259 if not inputs: 

260 yield 

261 else: 

262 # NOTE(mrry): The `ops.colocate_with()` function accepts only a single 

263 # op or tensor, so we create one context manager per element in the list. 

264 with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): 

265 yield 

266# pylint: enable=g-doc-return-or-yield 

267 

268 

269def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 

270 """Add a node invoking a registered Op to a graph. 

271 

272 Example usage: 

273 # input1 and input2 can be Tensors or anything ops.convert_to_tensor() 

274 # will convert to a Tensor. 

275 op_def_library.apply_op("op", input1=input1, input2=input2) 

276 # Can specify a node name. 

277 op_def_library.apply_op("op", input1=input1, name="node_name") 

278 # Must use keyword arguments, with the names specified in the OpDef. 

279 op_def_library.apply_op("op", input_name=input, attr_name=attr) 

280 

281 All attrs must either be inferred from an input or specified. 

282 (If inferred, the attr must not be specified.) If an attr has a default 

283 value specified in the Op's OpDef, then you may pass None as the value 

284 of that attr to get the default. 

285 

286 Args: 

287 op_type_name: string. Must match the name field of a registered Op. 

288 name: string. Optional name of the created op. 

289 **keywords: input Tensor and attr arguments specified by name, and optional 

290 parameters to pass when constructing the Operation. 

291 

292 Returns: 

293 The Tensor(s) representing the output of the operation, or the Operation 

294 itself if there are no outputs. 

295 

296 Raises: 

297 RuntimeError: On some errors. 

298 TypeError: On some errors. 

299 ValueError: On some errors. 

300 """ 

301 output_structure, is_stateful, op, outputs = _apply_op_helper( 

302 op_type_name, name, **keywords) 

303 if output_structure: 

304 res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) 

305 if isinstance(res, list) and not res and is_stateful: 

306 return op 

307 else: 

308 return res 

309 else: 

310 return op 

311 

312 

313# This is temporary Python/C++ code duplication until all of it can be ported 

314# over to C++. 

315# LINT.IfChange 

316def _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos): 

317 """Extracts `attr_protos`. For use in _apply_op_helper.""" 

318 for attr_def in op_def.attr: 

319 key = attr_def.name 

320 value = attrs[key] 

321 

322 if attr_def.HasField("default_value") and value is None: 

323 attr_value = attr_value_pb2.AttrValue() 

324 attr_value.CopyFrom(attr_def.default_value) 

325 attr_protos[key] = attr_value 

326 continue 

327 

328 attr_value = value_to_attr_value(value, attr_def.type, key) 

329 if attr_def.type.startswith("list("): 

330 _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) 

331 if attr_def.HasField("allowed_values"): 

332 if attr_def.type == "string": 

333 _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, 

334 op_type_name) 

335 elif attr_def.type == "list(string)": 

336 for value in attr_value.list.s: 

337 _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) 

338 if attr_def.has_minimum and attr_def.type == "int": 

339 _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) 

340 if attr_def.type == "type": 

341 _SatisfiesTypeConstraint(attr_value.type, attr_def, key) 

342 if attr_def.type == "list(type)": 

343 for value in attr_value.list.type: 

344 _SatisfiesTypeConstraint(value, attr_def, key) 

345 

346 attr_protos[key] = attr_value 

347 

348 

349def _ExtractOutputStructure(op_type_name, op_def, attr_protos, 

350 output_structure): 

351 """Extracts `output_structure`. For use in _apply_op_helper.""" 

352 for arg in op_def.output_arg: 

353 if arg.number_attr: 

354 n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i 

355 output_structure.append(n) 

356 elif arg.type_attr: 

357 t = _AttrValue(attr_protos, arg.type_attr, op_type_name) 

358 output_structure.append(None) 

359 elif arg.type_list_attr: 

360 t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) 

361 output_structure.append(len(t.list.type)) 

362 else: 

363 output_structure.append(None) 

364 

365 

366def _CanExtractAttrsFastPath(op_def, keywords): 

367 """Check if the fast path for _apply_op_helper is applicable.""" 

368 # Check if all inputs are already tf.Tensor 

369 for input_arg in op_def.input_arg: 

370 value = keywords.get(input_arg.name, None) 

371 if not isinstance(value, ops.Tensor): 

372 return False 

373 

374 # Check that attrs are not `func` or `list(func)` type. 

375 for attr_def in op_def.attr: 

376 if attr_def.type == "func" or attr_def.type == "list(func)": 

377 return False 

378 

379 return True 

380 

381 

382def _CheckOpDeprecation(op_type_name, op_def, producer): 

383 """Checks if the op is deprecated.""" 

384 deprecation_version = op_def.deprecation.version 

385 if deprecation_version and producer >= deprecation_version: 

386 raise NotImplementedError( 

387 f"Op {op_type_name} is not available in GraphDef version {producer}. " 

388 f"It has been removed in version {deprecation_version}. " 

389 f"{op_def.deprecation.explanation}.") 

390 

391 

392def _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, 

393 allowed_list_attr_map): 

394 """Extracts the `default_type_attr_map` and `allowed_list_attr_map`.""" 

395 # TODO(b/31302892): Currently the defaults don't work in the right 

396 # way if you have two inputs, one of whose type resolution depends 

397 # on the other. Handling this will require restructuring this code 

398 # significantly. 

399 for attr_def in op_def.attr: 

400 if attr_def.type != "type": 

401 continue 

402 key = attr_def.name 

403 if attr_def.HasField("default_value"): 

404 default_type_attr_map[key] = dtypes.as_dtype( 

405 attr_def.default_value.type) 

406 if attr_def.HasField("allowed_values"): 

407 allowed_list_attr_map[key] = attr_def.allowed_values.list.type 

408 

409 

410def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, 

411 keywords, default_type_attr_map, attrs, inputs, 

412 input_types): 

413 """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper.""" 

414 inferred_from = {} 

415 for input_arg in op_def.input_arg: 

416 input_name = input_arg.name 

417 if input_name in keywords: 

418 values = keywords.pop(input_name) 

419 elif input_name + "_" in keywords: 

420 # Handle the case where the name is a keyword or built-in 

421 # for Python so we use the name + _ instead. 

422 input_name += "_" 

423 values = keywords.pop(input_name) 

424 else: 

425 raise TypeError(f"No argument for input {input_name} found in {op_def}") 

426 

427 # Goals: 

428 # * Convert values to Tensors if it contains constants. 

429 # * Verify that values is a list if that matches the input_arg's 

430 # type. 

431 # * If the input_arg's type is determined by attrs, either set 

432 # those attrs and validate those attr values are legal (if 

433 # they have not yet been set) or validate the input matches 

434 # the type indicated by the attrs (if they have already been 

435 # inferred via an earlier input). 

436 # * If the input_arg has an explicit type, make sure the input 

437 # conforms. 

438 

439 if _IsListParameter(input_arg): 

440 if not _IsListValue(values): 

441 raise TypeError( 

442 f"Expected list for '{input_name}' argument to '{op_type_name}' " 

443 f"Op, not {values}.") 

444 # In cases where we expect all elements of the list to have the 

445 # same dtype, try to cast non-Tensor elements to that type. 

446 dtype = None 

447 default_dtype = None 

448 if input_arg.type != types_pb2.DT_INVALID: 

449 dtype = input_arg.type 

450 elif input_arg.number_attr: 

451 if input_arg.type_attr in attrs: 

452 dtype = attrs[input_arg.type_attr] 

453 else: 

454 for t in values: 

455 if isinstance(t, ops.Tensor): 

456 dtype = t.dtype 

457 break 

458 

459 # dtype still not found, prefer using the default dtype 

460 # from the attr. 

461 if dtype is None and input_arg.type_attr in default_type_attr_map: 

462 default_dtype = default_type_attr_map[input_arg.type_attr] 

463 

464 try: 

465 if not input_arg.is_ref and dtype: 

466 dtype = dtypes.as_dtype(dtype).base_dtype 

467 values = ops.internal_convert_n_to_tensor( 

468 values, 

469 name=input_arg.name, 

470 dtype=dtype if dtype else None, 

471 preferred_dtype=default_dtype, 

472 as_ref=input_arg.is_ref) 

473 all_types = set(v.dtype.base_dtype for v in values) 

474 if input_arg.number_attr and len(all_types) > 1: 

475 # All types should match. 

476 raise TypeError(f"Not all types matched for {input_arg.name} for " 

477 f"{op_type_name}. Got {all_types}") 

478 except (TypeError, ValueError): 

479 # What types does the conversion function think values have? 

480 observed_types = [] 

481 for value in values: 

482 try: 

483 converted_value = ops.convert_to_tensor( 

484 value, as_ref=input_arg.is_ref) 

485 observed_types.append(converted_value.dtype.base_dtype.name) 

486 except (TypeError, ValueError): 

487 observed_types.append("<NOT CONVERTIBLE TO TENSOR>") 

488 observed = ", ".join(observed_types) 

489 

490 prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" % 

491 (input_name, op_type_name, observed)) 

492 if input_arg.number_attr: 

493 if input_arg.type != types_pb2.DT_INVALID: 

494 raise TypeError(f"{prefix} that do not match expected type " 

495 f"{dtype.name}.") 

496 elif input_arg.type_attr in attrs: 

497 raise TypeError(f"{prefix} that do not match type {dtype.name} " 

498 "inferred from earlier arguments.") 

499 else: 

500 raise TypeError(f"{prefix} that don't all match.") 

501 else: 

502 raise TypeError(f"{prefix} that are invalid. Tensors: {values}") 

503 

504 types = [x.dtype for x in values] 

505 inputs.extend(values) 

506 else: 

507 # In cases where we have an expected type, try to convert non-Tensor 

508 # arguments to that type. 

509 dtype = None 

510 default_dtype = None 

511 allowed_list = None 

512 if input_arg.type != types_pb2.DT_INVALID: 

513 dtype = input_arg.type 

514 elif input_arg.type_attr in attrs: 

515 dtype = attrs[input_arg.type_attr] 

516 elif input_arg.type_attr in default_type_attr_map: 

517 # The dtype could not be inferred solely from the inputs, 

518 # so we prefer the attr's default, so code that adds a new attr 

519 # with a default is backwards compatible. 

520 default_dtype = default_type_attr_map[input_arg.type_attr] 

521 allowed_list = allowed_list_attr_map.get(input_arg.type_attr) 

522 

523 try: 

524 # First see if we can get a valid dtype with the default conversion 

525 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may 

526 # not list allowed dtypes, in which case we should skip this. 

527 if dtype is None and allowed_list: 

528 inferred = None 

529 try: 

530 inferred = ops.convert_to_tensor( 

531 values, name=input_arg.name, as_ref=input_arg.is_ref) 

532 except TypeError as err: 

533 # When converting a python object such as a list of Dimensions, we 

534 # need a dtype to be specified, thus tensor conversion may throw 

535 # an exception which we will ignore and try again below. 

536 pass 

537 

538 # If we did not match an allowed dtype, try again with the default 

539 # dtype. This could be because we have an empty tensor and thus we 

540 # picked the wrong type. 

541 if inferred is not None and inferred.dtype in allowed_list: 

542 values = inferred 

543 else: 

544 values = ops.convert_to_tensor( 

545 values, 

546 name=input_arg.name, 

547 as_ref=input_arg.is_ref, 

548 preferred_dtype=default_dtype) 

549 else: 

550 values = ops.convert_to_tensor( 

551 values, 

552 name=input_arg.name, 

553 dtype=dtype, 

554 as_ref=input_arg.is_ref, 

555 preferred_dtype=default_dtype) 

556 except TypeError as err: 

557 if dtype is None: 

558 raise err 

559 else: 

560 raise TypeError( 

561 f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " 

562 f"'{input_arg.name}' of op '{op_type_name}', got " 

563 f"{repr(values)} of type '{type(values).__name__}' instead. " 

564 f"Error: {err}") 

565 except ValueError: 

566 # What type does convert_to_tensor think it has? 

567 try: 

568 observed = ops.convert_to_tensor( 

569 values, as_ref=input_arg.is_ref).dtype.name 

570 except ValueError as err: 

571 raise ValueError( 

572 f"Tried to convert '{input_name}' to a tensor and failed. " 

573 f"Error: {err}") 

574 prefix = ("Input '%s' of '%s' Op has type %s that does not match" % 

575 (input_name, op_type_name, observed)) 

576 if input_arg.type != types_pb2.DT_INVALID: 

577 raise TypeError(f"{prefix} expected type of " 

578 f"{dtypes.as_dtype(input_arg.type).name}.") 

579 else: 

580 # Update the maps with the default, if needed. 

581 k = input_arg.type_attr 

582 if k in default_type_attr_map: 

583 if k not in attrs: 

584 attrs[k] = default_type_attr_map[k] 

585 if k not in inferred_from: 

586 inferred_from[k] = "Default in OpDef" 

587 

588 raise TypeError( 

589 f"{prefix} type " 

590 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " 

591 f"argument '{inferred_from[input_arg.type_attr]}'.") 

592 

593 types = [values.dtype] 

594 inputs.append(values) 

595 base_types = [x.base_dtype for x in types] 

596 

597 if input_arg.number_attr: 

598 # <number-attr> * <type> or <number-attr> * <type-attr> 

599 if input_arg.number_attr in attrs: 

600 if len(values) != attrs[input_arg.number_attr]: 

601 raise ValueError( 

602 f"List argument '{input_name}' to '{op_type_name}' Op with " 

603 f"length {len(values)} must match length " 

604 f"{attrs[input_arg.number_attr]} of argument " 

605 f"'{inferred_from[input_arg.number_attr]}'.") 

606 else: 

607 attrs[input_arg.number_attr] = len(values) 

608 inferred_from[input_arg.number_attr] = input_name 

609 num_attr = _Attr(op_def, input_arg.number_attr) 

610 if num_attr.has_minimum and len(values) < num_attr.minimum: 

611 raise ValueError( 

612 f"List argument '{input_name}' to '{op_type_name}' Op with " 

613 f"length {len(values)} shorter than minimum length " 

614 f"{num_attr.minimum}.") 

615 # All tensors must have the same base type. 

616 if any(bt != base_types[0] for bt in base_types): 

617 raise TypeError( 

618 f"All tensors passed to '{input_name}' of '{op_type_name}' Op " 

619 f"must have the same type. Got {base_types} instead.") 

620 if input_arg.type != types_pb2.DT_INVALID: 

621 # <number-attr> * <type> case 

622 if base_types and base_types[0] != input_arg.type: 

623 assert False, "Unreachable" 

624 elif input_arg.type_attr in attrs: 

625 # <number-attr> * <type-attr> case, where <type-attr> already 

626 # has an inferred value. 

627 if base_types and base_types[0] != attrs[input_arg.type_attr]: 

628 assert False, "Unreachable" 

629 else: 

630 # <number-attr> * <type-attr> case, where we are now setting 

631 # the <type-attr> based on this input 

632 if not base_types: 

633 # If it's in default_type_attr_map, then wait to set it 

634 # (in "process remaining attrs", below). 

635 if input_arg.type_attr not in default_type_attr_map: 

636 raise TypeError( 

637 "Don't know how to infer type variable from empty input " 

638 f"list passed to input '{input_name}' of '{op_type_name}' " 

639 "Op.") 

640 else: 

641 attrs[input_arg.type_attr] = base_types[0] 

642 inferred_from[input_arg.type_attr] = input_name 

643 type_attr = _Attr(op_def, input_arg.type_attr) 

644 _SatisfiesTypeConstraint( 

645 base_types[0], type_attr, param_name=input_name) 

646 elif input_arg.type_attr: 

647 # <type-attr> 

648 attr_value = base_types[0] 

649 if input_arg.type_attr in attrs: 

650 if attrs[input_arg.type_attr] != attr_value: 

651 raise TypeError( 

652 f"Input '{input_name}' of '{op_type_name}' Op has type " 

653 f"{dtypes.as_dtype(attr_value).name} that does not match type " 

654 f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " 

655 f"argument '{inferred_from[input_arg.type_attr]}'.") 

656 else: 

657 for base_type in base_types: 

658 _SatisfiesTypeConstraint( 

659 base_type, 

660 _Attr(op_def, input_arg.type_attr), 

661 param_name=input_name) 

662 attrs[input_arg.type_attr] = attr_value 

663 inferred_from[input_arg.type_attr] = input_name 

664 elif input_arg.type_list_attr: 

665 # <type-list-attr> 

666 attr_value = base_types 

667 if input_arg.type_list_attr in attrs: 

668 if attrs[input_arg.type_list_attr] != attr_value: 

669 actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value) 

670 expected_types = ", ".join( 

671 dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) 

672 raise TypeError( 

673 f"Input '{input_name}' of '{op_type_name}' Op has type list of " 

674 f"{actual_types} that does not match type list {expected_types}" 

675 f" of argument '{inferred_from[input_arg.type_list_attr]}'.") 

676 else: 

677 for base_type in base_types: 

678 _SatisfiesTypeConstraint( 

679 base_type, 

680 _Attr(op_def, input_arg.type_list_attr), 

681 param_name=input_name) 

682 attrs[input_arg.type_list_attr] = attr_value 

683 inferred_from[input_arg.type_list_attr] = input_name 

684 else: 

685 # single Tensor with specified type 

686 if base_types[0] != input_arg.type: 

687 assert False, "Unreachable" 

688 

689 if input_arg.is_ref: 

690 if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access 

691 raise TypeError( 

692 f"'{op_type_name}' Op requires that input '{input_name}' be a " 

693 "mutable tensor (e.g.: a tf.Variable)") 

694 input_types.extend(types) 

695 else: 

696 input_types.extend(base_types) 

697 

698 

699def _ExtractRemainingAttrs(op_type_name, op_def, keywords, 

700 default_type_attr_map, attrs): 

701 """Extracts the remaining attributes into `attrs` in _apply_op_helper.""" 

702 for attr in op_def.attr: 

703 # Skip attrs that have already had their values inferred 

704 if attr.name in attrs: 

705 if attr.name in keywords: 

706 raise TypeError( 

707 f"Should not specify value for inferred attr '{attr.name}' for " 

708 f"{op_type_name}.") 

709 continue 

710 if attr.name in keywords: 

711 attrs[attr.name] = keywords.pop(attr.name) 

712 elif attr.name + "_" in keywords: 

713 # Attrs whose names match Python keywords have an extra '_' 

714 # appended, so we must check for that as well. 

715 attrs[attr.name] = keywords.pop(attr.name + "_") 

716 elif attr.name in default_type_attr_map: 

717 attrs[attr.name] = default_type_attr_map[attr.name] 

718 else: 

719 raise TypeError(f"No argument found for attr {attr.name} for " 

720 f"{op_type_name}") 

721 

722 

723def _GetOpDef(op_type_name, keywords): 

724 """Returns the OpDef, Graph and Producer. For use in _apply_op_helper.""" 

725 op_def = op_def_registry.get(op_type_name) 

726 if op_def is None: 

727 raise RuntimeError(f"Unrecognized Op name {op_type_name}") 

728 

729 # Determine the graph context. 

730 try: 

731 # Need to flatten all the arguments into a list. 

732 # pylint: disable=protected-access 

733 g = ops._get_graph_from_inputs(_Flatten(keywords.values())) 

734 producer = g.graph_def_versions.producer 

735 # pylint: enable=protected-access 

736 except AssertionError as e: 

737 raise RuntimeError( 

738 f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") 

739 

740 return op_def, g, producer 

741 

742 

743def _CheckAllInputsUsed(op_type_name, keywords): 

744 """Ensures all inputs passed into _apply_op_helper were used.""" 

745 if keywords: 

746 all_keywords = ", ".join(sorted(keywords.keys())) 

747 raise TypeError(f"{op_type_name} got unexpected keyword arguments: " 

748 f"{all_keywords}.") 

749 

750 

751def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name 

752 """Implementation of apply_op that returns output_structure, op.""" 

753 

754 op_def, g, producer = _GetOpDef(op_type_name, keywords) 

755 name = name if name else op_type_name 

756 

757 attrs, attr_protos = {}, {} 

758 default_type_attr_map, allowed_list_attr_map = {}, {} 

759 inputs, input_types, output_structure = [], [], [] 

760 fallback = True 

761 

762 if (_CanExtractAttrsFastPath(op_def, keywords) and 

763 flags.config().graph_building_optimization.value()): 

764 fallback = False 

765 attr_protos, inputs, input_types, output_structure = ( 

766 op_def_library_pybind.process_inputs(op_type_name, producer, keywords)) 

767 

768 if fallback: 

769 _CheckOpDeprecation(op_type_name, op_def, producer) 

770 _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, 

771 allowed_list_attr_map) 

772 

773 # Requires that op_def has passed validation (using the C++ 

774 # ValidateOpDef() from ../framework/op_def_util.h). 

775 with g.as_default(), ops.name_scope(name) as scope: 

776 if fallback: 

777 _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, 

778 keywords, default_type_attr_map, attrs, inputs, 

779 input_types) 

780 _ExtractRemainingAttrs(op_type_name, op_def, keywords, 

781 default_type_attr_map, attrs) 

782 _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos) 

783 del attrs # attrs is no longer authoritative, use attr_protos instead 

784 _ExtractOutputStructure(op_type_name, op_def, attr_protos, 

785 output_structure) 

786 _CheckAllInputsUsed(op_type_name, keywords) 

787 

788 # NOTE(mrry): We add an explicit colocation constraint between 

789 # the newly created op and any of its reference-typed inputs. 

790 must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) 

791 if arg.is_ref] 

792 with _MaybeColocateWith(must_colocate_inputs): 

793 # Add Op to graph 

794 # pylint: disable=protected-access 

795 op = g._create_op_internal(op_type_name, inputs, dtypes=None, 

796 name=scope, input_types=input_types, 

797 attrs=attr_protos, op_def=op_def) 

798 

799 # `outputs` is returned as a separate return value so that the output 

800 # tensors can the `op` per se can be decoupled so that the 

801 # `op_callbacks` can function properly. See framework/op_callbacks.py 

802 # for more details. 

803 outputs = op.outputs 

804 # Conditionally invoke tfdbg v2's op callback(s). 

805 if op_callbacks.should_invoke_op_callbacks(): 

806 callback_outputs = op_callbacks.invoke_op_callbacks( 

807 op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), 

808 op_name=op.name, graph=g) 

809 if callback_outputs is not None: 

810 outputs = callback_outputs 

811 

812 return output_structure, op_def.is_stateful, op, outputs 

813 

814 

815def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name 

816 """Encodes a Python value as an `AttrValue` proto message. 

817 

818 Args: 

819 value: The value to convert. 

820 attr_type: The value type (string) -- see the AttrValue proto definition for 

821 valid strings. 

822 arg_name: Argument name (for error messages). 

823 

824 Returns: 

825 An AttrValue proto message that encodes `value`. 

826 """ 

827 attr_value = attr_value_pb2.AttrValue() 

828 

829 if attr_type.startswith("list("): 

830 if not _IsListValue(value): 

831 raise TypeError(f"Expected list for attr {arg_name}, obtained " 

832 f"{type(value).__name__} instead.") 

833 

834 if attr_type == "string": 

835 attr_value.s = _MakeStr(value, arg_name) 

836 elif attr_type == "list(string)": 

837 attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) 

838 elif attr_type == "int": 

839 attr_value.i = _MakeInt(value, arg_name) 

840 elif attr_type == "list(int)": 

841 attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) 

842 elif attr_type == "float": 

843 attr_value.f = _MakeFloat(value, arg_name) 

844 elif attr_type == "list(float)": 

845 attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) 

846 elif attr_type == "bool": 

847 attr_value.b = _MakeBool(value, arg_name) 

848 elif attr_type == "list(bool)": 

849 attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) 

850 elif attr_type == "type": 

851 attr_value.type = _MakeType(value, arg_name) 

852 elif attr_type == "list(type)": 

853 attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) 

854 elif attr_type == "shape": 

855 attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) 

856 elif attr_type == "list(shape)": 

857 attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) 

858 elif attr_type == "tensor": 

859 attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) 

860 elif attr_type == "list(tensor)": 

861 attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) 

862 elif attr_type == "func": 

863 attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) 

864 elif attr_type == "list(func)": 

865 attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) 

866 else: 

867 raise TypeError(f"Unrecognized Attr type {attr_type} for {arg_name}.") 

868 return attr_value 

869# LINT.ThenChange(//tensorflow/python/framework/op_def_library_pybind.cc) 

870 

871 

872# The following symbols are used by op_def_util.cc. 

873_pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) 

874_pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) 

875_pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape) 

876_pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape) 

877_pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto) 

878_pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse) 

879_pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)