Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py: 33%

2184 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes and functions used to construct graphs.""" 

16# pylint: disable=g-bad-name 

17import collections 

18import copy 

19import re 

20import sys 

21import threading 

22import types 

23from typing import Optional 

24from absl import app 

25 

26import numpy as np 

27 

28from tensorflow.core.framework import attr_value_pb2 

29from tensorflow.core.framework import full_type_pb2 

30from tensorflow.core.framework import function_pb2 

31from tensorflow.core.framework import graph_pb2 

32from tensorflow.core.framework import node_def_pb2 

33from tensorflow.core.framework import op_def_pb2 

34from tensorflow.core.framework import versions_pb2 

35from tensorflow.core.protobuf import config_pb2 

36# pywrap_tensorflow must be imported first to avoid protobuf issues. 

37# (b/143110113) 

38# pylint: disable=invalid-import-order,g-bad-import-order,unused-import 

39from tensorflow.python import pywrap_tensorflow 

40from tensorflow.python import pywrap_tfe 

41# pylint: enable=invalid-import-order,g-bad-import-order,unused-import 

42from tensorflow.python import tf2 

43from tensorflow.python.client import pywrap_tf_session 

44from tensorflow.python.eager import context 

45from tensorflow.python.eager import core 

46from tensorflow.python.eager import monitoring 

47from tensorflow.python.eager import record 

48from tensorflow.python.framework import c_api_util 

49from tensorflow.python.framework import composite_tensor 

50from tensorflow.python.framework import device as pydev 

51from tensorflow.python.framework import dtypes 

52from tensorflow.python.framework import errors 

53from tensorflow.python.framework import op_callbacks 

54from tensorflow.python.framework import registry 

55from tensorflow.python.framework import stack 

56from tensorflow.python.framework import tensor_conversion_registry 

57from tensorflow.python.framework import tensor_shape 

58from tensorflow.python.framework import tensor_util 

59from tensorflow.python.framework import traceable_stack 

60from tensorflow.python.framework import versions 

61from tensorflow.python.ops import control_flow_util 

62from tensorflow.python.ops import handle_data_util 

63from tensorflow.python.platform import tf_logging as logging 

64from tensorflow.python.profiler import trace as profiler_trace 

65from tensorflow.python.types import core as core_tf_types 

66from tensorflow.python.types import internal 

67from tensorflow.python.util import compat 

68from tensorflow.python.util import decorator_utils 

69from tensorflow.python.util import deprecation 

70from tensorflow.python.util import function_utils 

71from tensorflow.python.util import lock_util 

72from tensorflow.python.util import object_identity 

73from tensorflow.python.util import tf_contextlib 

74from tensorflow.python.util import tf_stack 

75from tensorflow.python.util import traceback_utils 

76from tensorflow.python.util.compat import collections_abc 

77from tensorflow.python.util.deprecation import deprecated_args 

78from tensorflow.python.util.lazy_loader import LazyLoader 

79from tensorflow.python.util.tf_export import kwarg_only 

80from tensorflow.python.util.tf_export import tf_export 

81 

82# TODO(b/218887885): Loaded lazily due to a circular dependency with this file. 

83tensor_spec = LazyLoader( 

84 "tensor_spec", globals(), 

85 "tensorflow.python.framework.tensor_spec") 

86ag_ctx = LazyLoader( 

87 "ag_ctx", globals(), 

88 "tensorflow.python.autograph.core.ag_ctx") 

89 

90 

91# Temporary global switches determining if we should enable the work-in-progress 

92# calls to the C API. These will be removed once all functionality is supported. 

93_USE_C_API = True 

94_USE_C_SHAPES = True 

95 

96_api_usage_gauge = monitoring.BoolGauge( 

97 "/tensorflow/api/ops_eager_execution", 

98 "Whether ops.enable_eager_execution() is called.") 

99 

100_tensor_equality_api_usage_gauge = monitoring.BoolGauge( 

101 "/tensorflow/api/enable_tensor_equality", 

102 "Whether ops.enable_tensor_equality() is called.") 

103 

104_control_flow_api_gauge = monitoring.BoolGauge( 

105 "/tensorflow/api/enable_control_flow_v2", 

106 "Whether enable_control_flow_v2() is called.") 

107 

108_tf_function_api_gauge = monitoring.BoolGauge( 

109 "/tensorflow/api/tf_function", 

110 "Whether tf.function() is used.") 

111 

112# pylint: disable=protected-access 

113_DTYPES_INTERN_TABLE = dtypes._INTERN_TABLE 

114# pylint: enable=protected-access 

115 

116 

117def tensor_id(tensor): 

118 """Returns a unique identifier for this Tensor.""" 

119 return tensor._id # pylint: disable=protected-access 

120 

121 

122class _UserDeviceSpec(object): 

123 """Store user-specified device and provide computation of merged device.""" 

124 

125 def __init__(self, device_name_or_function): 

126 self._device_name_or_function = device_name_or_function 

127 self.display_name = str(self._device_name_or_function) 

128 self.function = device_name_or_function 

129 self.raw_string = None 

130 

131 if isinstance(device_name_or_function, pydev.MergeDevice): 

132 self.is_null_merge = device_name_or_function.is_null_merge 

133 

134 elif callable(device_name_or_function): 

135 self.is_null_merge = False 

136 dev_func = self._device_name_or_function 

137 func_name = function_utils.get_func_name(dev_func) 

138 func_code = function_utils.get_func_code(dev_func) 

139 if func_code: 

140 fname = func_code.co_filename 

141 lineno = func_code.co_firstlineno 

142 else: 

143 fname = "unknown" 

144 lineno = -1 

145 self.display_name = "%s<%s, %d>" % (func_name, fname, lineno) 

146 

147 elif device_name_or_function is None: 

148 # NOTE(taylorrobie): This MUST be False. None signals a break in the 

149 # device stack, so `is_null_merge` must be False for such a case to 

150 # allow callers to safely skip over null merges without missing a None. 

151 self.is_null_merge = False 

152 

153 else: 

154 self.raw_string = device_name_or_function 

155 self.function = pydev.merge_device(device_name_or_function) 

156 self.is_null_merge = self.function.is_null_merge 

157 

158 # We perform this check in __init__ because it is of non-trivial cost, 

159 # and self.string_merge is typically called many times. 

160 self.fast_string_merge = isinstance(self.function, pydev.MergeDevice) 

161 

162 def string_merge(self, node_def): 

163 if self.fast_string_merge: 

164 return self.function.shortcut_string_merge(node_def) 

165 

166 return compat.as_str(_device_string(self.function(node_def))) 

167 

168 

169class NullContextmanager(object): 

170 

171 def __init__(self, *args, **kwargs): 

172 pass 

173 

174 def __enter__(self): 

175 pass 

176 

177 def __exit__(self, type_arg, value_arg, traceback_arg): 

178 return False # False values do not suppress exceptions 

179 

180 

181def _override_helper(clazz_object, operator, func): 

182 """Overrides (string) operator on Tensors to call func. 

183 

184 Args: 

185 clazz_object: the class to override for; either Tensor or SparseTensor. 

186 operator: the string name of the operator to override. 

187 func: the function that replaces the overridden operator. 

188 

189 Raises: 

190 ValueError: If operator is not allowed to be overwritten. 

191 """ 

192 if operator not in Tensor.OVERLOADABLE_OPERATORS: 

193 raise ValueError(f"Overriding {operator} is disallowed. " 

194 f"Allowed operators are {Tensor.OVERLOADABLE_OPERATORS}.") 

195 setattr(clazz_object, operator, func) 

196 

197 

198def _as_graph_element(obj): 

199 """Convert `obj` to a graph element if possible, otherwise return `None`. 

200 

201 Args: 

202 obj: Object to convert. 

203 

204 Returns: 

205 The result of `obj._as_graph_element()` if that method is available; 

206 otherwise `None`. 

207 """ 

208 conv_fn = getattr(obj, "_as_graph_element", None) 

209 if conv_fn and callable(conv_fn): 

210 return conv_fn() 

211 return None 

212 

213 

214# Deprecated - do not use. 

215# This API to avoid breaking estimator and tensorflow-mesh which depend on this 

216# internal API. The stub should be safe to use after TF 2.3 is released. 

217def is_dense_tensor_like(t): 

218 return isinstance(t, core_tf_types.Tensor) 

219 

220 

221def uid(): 

222 """A unique (within this program execution) integer.""" 

223 return pywrap_tfe.TFE_Py_UID() 

224 

225 

226def numpy_text(tensor, is_repr=False): 

227 """Human readable representation of a tensor's numpy value.""" 

228 if tensor.dtype.is_numpy_compatible: 

229 # pylint: disable=protected-access 

230 text = repr(tensor._numpy()) if is_repr else str(tensor._numpy()) 

231 # pylint: enable=protected-access 

232 else: 

233 text = "<unprintable>" 

234 if "\n" in text: 

235 text = "\n" + text 

236 return text 

237 

238 

239def value_text(tensor, is_repr=False): 

240 """Either the NumPy value or a custom TensorFlow formatting of `tensor`. 

241 

242 Custom formatting is used for custom device tensors, e.g. parallel tensors 

243 with multiple components on different devices. 

244 

245 Args: 

246 tensor: The tensor to format. 

247 is_repr: Controls the style/verbosity of formatting. 

248 

249 Returns: 

250 The formatted tensor. 

251 """ 

252 # pylint: disable=protected-access # friend access 

253 if tensor._prefer_custom_summarizer(): 

254 text = tensor._summarize_value() 

255 # pylint: enable=protected-access 

256 if is_repr: 

257 text = "value=" + text 

258 else: 

259 text = numpy_text(tensor, is_repr=is_repr) 

260 if is_repr: 

261 text = "numpy=" + text 

262 return text 

263 

264 

265@tf_export(v1=["enable_tensor_equality"]) 

266def enable_tensor_equality(): 

267 """Compare Tensors with element-wise comparison and thus be unhashable. 

268 

269 Comparing tensors with element-wise allows comparisons such as 

270 tf.Variable(1.0) == 1.0. Element-wise equality implies that tensors are 

271 unhashable. Thus tensors can no longer be directly used in sets or as a key in 

272 a dictionary. 

273 """ 

274 logging.vlog(1, "Enabling tensor equality") 

275 _tensor_equality_api_usage_gauge.get_cell().set(True) 

276 Tensor._USE_EQUALITY = True # pylint: disable=protected-access 

277 

278 

279@tf_export(v1=["disable_tensor_equality"]) 

280def disable_tensor_equality(): 

281 """Compare Tensors by their id and be hashable. 

282 

283 This is a legacy behaviour of TensorFlow and is highly discouraged. 

284 """ 

285 logging.vlog(1, "Disabling tensor equality") 

286 _tensor_equality_api_usage_gauge.get_cell().set(False) 

287 Tensor._USE_EQUALITY = False # pylint: disable=protected-access 

288 

289 

290# Tensor subclassing has historically been a mess. 

291# 

292# There is no "Tensor" base class for Graph & Eager tensors. Instead, when we 

293# introduced EagerTensor, we had it subclass the graph "Tensor" class, and 

294# override a bunch of behavior. Introducing a proper subclassing relationship 

295# is complicated because many users check for type(t) == Tensor of isinstance. 

296# 

297# This is done internally for "bad" reasons as a way to separate out Graph and 

298# Eager tensors, or subclasses which "look like" Tensor, e.g. distribute.Value. 

299# 

300# For now, we work around this by deferring initialization of graph tensors to 

301# a separate `_init` method. `GraphTensor` is a free function, not a class, that 

302# returns a Tensor object. 

303# 

304# b(XXX) -- fix type(t) == Tensor checks in the code base 

305@tf_export("Tensor", "experimental.numpy.ndarray", v1=["Tensor"]) 

306class Tensor( 

307 pywrap_tf_session.PyTensor, internal.NativeObject, core_tf_types.Symbol 

308): 

309 """A `tf.Tensor` represents a multidimensional array of elements. 

310 

311 All elements are of a single known data type. 

312 

313 When writing a TensorFlow program, the main object that is 

314 manipulated and passed around is the `tf.Tensor`. 

315 

316 A `tf.Tensor` has the following properties: 

317 

318 * a single data type (float32, int32, or string, for example) 

319 * a shape 

320 

321 TensorFlow supports eager execution and graph execution. In eager 

322 execution, operations are evaluated immediately. In graph 

323 execution, a computational graph is constructed for later 

324 evaluation. 

325 

326 TensorFlow defaults to eager execution. In the example below, the 

327 matrix multiplication results are calculated immediately. 

328 

329 >>> # Compute some values using a Tensor 

330 >>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 

331 >>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]]) 

332 >>> e = tf.matmul(c, d) 

333 >>> print(e) 

334 tf.Tensor( 

335 [[1. 3.] 

336 [3. 7.]], shape=(2, 2), dtype=float32) 

337 

338 Note that during eager execution, you may discover your `Tensors` are actually 

339 of type `EagerTensor`. This is an internal detail, but it does give you 

340 access to a useful function, `numpy`: 

341 

342 >>> type(e) 

343 <class '...ops.EagerTensor'> 

344 >>> print(e.numpy()) 

345 [[1. 3.] 

346 [3. 7.]] 

347 

348 In TensorFlow, `tf.function`s are a common way to define graph execution. 

349 

350 A Tensor's shape (that is, the rank of the Tensor and the size of 

351 each dimension) may not always be fully known. In `tf.function` 

352 definitions, the shape may only be partially known. 

353 

354 Most operations produce tensors of fully-known shapes if the shapes of their 

355 inputs are also fully known, but in some cases it's only possible to find the 

356 shape of a tensor at execution time. 

357 

358 A number of specialized tensors are available: see `tf.Variable`, 

359 `tf.constant`, `tf.placeholder`, `tf.sparse.SparseTensor`, and 

360 `tf.RaggedTensor`. 

361 

362 Caution: when constructing a tensor from a numpy array or pandas dataframe 

363 the underlying buffer may be re-used: 

364 

365 ```python 

366 a = np.array([1, 2, 3]) 

367 b = tf.constant(a) 

368 a[0] = 4 

369 print(b) # tf.Tensor([4 2 3], shape=(3,), dtype=int64) 

370 ``` 

371 

372 Note: this is an implementation detail that is subject to change and users 

373 should not rely on this behaviour. 

374 

375 For more on Tensors, see the [guide](https://tensorflow.org/guide/tensor). 

376 """ 

377 # List of Python operators that we allow to override. 

378 OVERLOADABLE_OPERATORS = { 

379 # Binary. 

380 "__add__", 

381 "__radd__", 

382 "__sub__", 

383 "__rsub__", 

384 "__mul__", 

385 "__rmul__", 

386 "__div__", 

387 "__rdiv__", 

388 "__truediv__", 

389 "__rtruediv__", 

390 "__floordiv__", 

391 "__rfloordiv__", 

392 "__mod__", 

393 "__rmod__", 

394 "__lt__", 

395 "__le__", 

396 "__gt__", 

397 "__ge__", 

398 "__ne__", 

399 "__eq__", 

400 "__and__", 

401 "__rand__", 

402 "__or__", 

403 "__ror__", 

404 "__xor__", 

405 "__rxor__", 

406 "__getitem__", 

407 "__pow__", 

408 "__rpow__", 

409 # Unary. 

410 "__invert__", 

411 "__neg__", 

412 "__abs__", 

413 "__matmul__", 

414 "__rmatmul__" 

415 } 

416 

417 # Whether to allow hashing or numpy-style equality 

418 _USE_EQUALITY = tf2.enabled() 

419 

420 def __getattr__(self, name): 

421 if name in {"T", "astype", "ravel", "transpose", "reshape", "clip", "size", 

422 "tolist", "data"}: 

423 # TODO(wangpeng): Export the enable_numpy_behavior knob 

424 raise AttributeError( 

425 f"{type(self).__name__} object has no attribute '{name}'. " + """ 

426 If you are looking for numpy-related methods, please run the following: 

427 from tensorflow.python.ops.numpy_ops import np_config 

428 np_config.enable_numpy_behavior() 

429 """) 

430 self.__getattribute__(name) 

431 

432 @property 

433 def dtype(self): 

434 """The `DType` of elements in this tensor.""" 

435 return self._dtype 

436 

437 @property 

438 def name(self): 

439 """The string name of this tensor.""" 

440 if self._name is None: 

441 assert self._op.name 

442 self._name = "%s:%d" % (self._op.name, self.value_index) 

443 return self._name 

444 

445 @property 

446 def shape(self): 

447 """Returns a `tf.TensorShape` that represents the shape of this tensor. 

448 

449 >>> t = tf.constant([1,2,3,4,5]) 

450 >>> t.shape 

451 TensorShape([5]) 

452 

453 `tf.Tensor.shape` is equivalent to `tf.Tensor.get_shape()`. 

454 

455 In a `tf.function` or when building a model using 

456 `tf.keras.Input`, they return the build-time shape of the 

457 tensor, which may be partially unknown. 

458 

459 A `tf.TensorShape` is not a tensor. Use `tf.shape(t)` to get a tensor 

460 containing the shape, calculated at runtime. 

461 

462 See `tf.Tensor.get_shape()`, and `tf.TensorShape` for details and examples. 

463 """ 

464 if self._shape_val is None: 

465 dims, unknown_shape = self._shape 

466 if unknown_shape: 

467 self._shape_val = tensor_shape.unknown_shape() 

468 else: 

469 self._shape_val = tensor_shape.TensorShape(dims) 

470 return self._shape_val 

471 

472 @property 

473 def ndim(self): 

474 return self.shape.rank 

475 

476 def _disallow_when_autograph_unavailable(self, task): 

477 raise errors.OperatorNotAllowedInGraphError( 

478 f"{task} is not allowed: AutoGraph is unavailable in this runtime. See" 

479 " https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/limitations.md#access-to-source-code" 

480 " for more information.") 

481 

482 def _disallow_when_autograph_disabled(self, task): 

483 raise errors.OperatorNotAllowedInGraphError( 

484 f"{task} is not allowed: AutoGraph is disabled in this function." 

485 " Try decorating it directly with @tf.function.") 

486 

487 def _disallow_when_autograph_enabled(self, task): 

488 raise errors.OperatorNotAllowedInGraphError( 

489 f"{task} is not allowed: AutoGraph did convert this function. This" 

490 " might indicate you are trying to use an unsupported feature.") 

491 

492 def _disallow_in_graph_mode(self, task): 

493 raise errors.OperatorNotAllowedInGraphError( 

494 f"{task} is not allowed in Graph execution. Use Eager execution or" 

495 " decorate this function with @tf.function.") 

496 

497 def _disallow_bool_casting(self): 

498 if not ag_ctx.INSPECT_SOURCE_SUPPORTED: 

499 self._disallow_when_autograph_unavailable( 

500 "Using a symbolic `tf.Tensor` as a Python `bool`") 

501 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 

502 self._disallow_when_autograph_disabled( 

503 "Using a symbolic `tf.Tensor` as a Python `bool`") 

504 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 

505 self._disallow_when_autograph_enabled( 

506 "Using a symbolic `tf.Tensor` as a Python `bool`") 

507 else: 

508 # Default: V1-style Graph execution. 

509 self._disallow_in_graph_mode( 

510 "Using a symbolic `tf.Tensor` as a Python `bool`") 

511 

512 def _disallow_iteration(self): 

513 if not ag_ctx.INSPECT_SOURCE_SUPPORTED: 

514 self._disallow_when_autograph_unavailable( 

515 "Iterating over a symbolic `tf.Tensor`") 

516 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: 

517 self._disallow_when_autograph_disabled( 

518 "Iterating over a symbolic `tf.Tensor`") 

519 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED: 

520 self._disallow_when_autograph_enabled( 

521 "Iterating over a symbolic `tf.Tensor`") 

522 else: 

523 # Default: V1-style Graph execution. 

524 self._disallow_in_graph_mode("Iterating over a symbolic `tf.Tensor`") 

525 

526 def __iter__(self): 

527 if not context.executing_eagerly(): 

528 self._disallow_iteration() 

529 

530 shape = self._shape_tuple() 

531 if shape is None: 

532 raise TypeError("Cannot iterate over a tensor with unknown shape.") 

533 if not shape: 

534 raise TypeError("Cannot iterate over a scalar tensor.") 

535 if shape[0] is None: 

536 raise TypeError( 

537 "Cannot iterate over a tensor with unknown first dimension.") 

538 return _TensorIterator(self, shape[0]) 

539 

540 def _shape_as_list(self): 

541 if self.shape.ndims is not None: 

542 return [dim.value for dim in self.shape.dims] 

543 else: 

544 return None 

545 

546 def _shape_tuple(self): 

547 shape = self._shape_as_list() 

548 if shape is None: 

549 return None 

550 return tuple(shape) 

551 

552 def _record_tape(self, capture): 

553 """Connect this graph tensor with capture for gradients calculation.""" 

554 record.record_operation( 

555 "captured_value", 

556 [self], [capture], 

557 backward_function=lambda x: [x], 

558 forward_function=lambda x: [x]) 

559 

560 def get_shape(self): 

561 """Returns a `tf.TensorShape` that represents the shape of this tensor. 

562 

563 In eager execution the shape is always fully-known. 

564 

565 >>> a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) 

566 >>> print(a.shape) 

567 (2, 3) 

568 

569 `tf.Tensor.get_shape()` is equivalent to `tf.Tensor.shape`. 

570 

571 

572 When executing in a `tf.function` or building a model using 

573 `tf.keras.Input`, `Tensor.shape` may return a partial shape (including 

574 `None` for unknown dimensions). See `tf.TensorShape` for more details. 

575 

576 >>> inputs = tf.keras.Input(shape = [10]) 

577 >>> # Unknown batch size 

578 >>> print(inputs.shape) 

579 (None, 10) 

580 

581 The shape is computed using shape inference functions that are 

582 registered for each `tf.Operation`. 

583 

584 The returned `tf.TensorShape` is determined at *build* time, without 

585 executing the underlying kernel. It is not a `tf.Tensor`. If you need a 

586 shape *tensor*, either convert the `tf.TensorShape` to a `tf.constant`, or 

587 use the `tf.shape(tensor)` function, which returns the tensor's shape at 

588 *execution* time. 

589 

590 This is useful for debugging and providing early errors. For 

591 example, when tracing a `tf.function`, no ops are being executed, shapes 

592 may be unknown (See the [Concrete Functions 

593 Guide](https://www.tensorflow.org/guide/concrete_function) for details). 

594 

595 >>> @tf.function 

596 ... def my_matmul(a, b): 

597 ... result = a@b 

598 ... # the `print` executes during tracing. 

599 ... print("Result shape: ", result.shape) 

600 ... return result 

601 

602 The shape inference functions propagate shapes to the extent possible: 

603 

604 >>> f = my_matmul.get_concrete_function( 

605 ... tf.TensorSpec([None,3]), 

606 ... tf.TensorSpec([3,5])) 

607 Result shape: (None, 5) 

608 

609 Tracing may fail if a shape missmatch can be detected: 

610 

611 >>> cf = my_matmul.get_concrete_function( 

612 ... tf.TensorSpec([None,3]), 

613 ... tf.TensorSpec([4,5])) 

614 Traceback (most recent call last): 

615 ... 

616 ValueError: Dimensions must be equal, but are 3 and 4 for 'matmul' (op: 

617 'MatMul') with input shapes: [?,3], [4,5]. 

618 

619 In some cases, the inferred shape may have unknown dimensions. If 

620 the caller has additional information about the values of these 

621 dimensions, `tf.ensure_shape` or `Tensor.set_shape()` can be used to augment 

622 the inferred shape. 

623 

624 >>> @tf.function 

625 ... def my_fun(a): 

626 ... a = tf.ensure_shape(a, [5, 5]) 

627 ... # the `print` executes during tracing. 

628 ... print("Result shape: ", a.shape) 

629 ... return a 

630 

631 >>> cf = my_fun.get_concrete_function( 

632 ... tf.TensorSpec([None, None])) 

633 Result shape: (5, 5) 

634 

635 Returns: 

636 A `tf.TensorShape` representing the shape of this tensor. 

637 

638 """ 

639 return self.shape 

640 

641 def set_shape(self, shape): 

642 """Updates the shape of this tensor. 

643 

644 Note: It is recommended to use `tf.ensure_shape` instead of 

645 `Tensor.set_shape`, because `tf.ensure_shape` provides better checking for 

646 programming errors and can create guarantees for compiler 

647 optimization. 

648 

649 With eager execution this operates as a shape assertion. 

650 Here the shapes match: 

651 

652 >>> t = tf.constant([[1,2,3]]) 

653 >>> t.set_shape([1, 3]) 

654 

655 Passing a `None` in the new shape allows any value for that axis: 

656 

657 >>> t.set_shape([1,None]) 

658 

659 An error is raised if an incompatible shape is passed. 

660 

661 >>> t.set_shape([1,5]) 

662 Traceback (most recent call last): 

663 ... 

664 ValueError: Tensor's shape (1, 3) is not compatible with supplied 

665 shape [1, 5] 

666 

667 When executing in a `tf.function`, or building a model using 

668 `tf.keras.Input`, `Tensor.set_shape` will *merge* the given `shape` with 

669 the current shape of this tensor, and set the tensor's shape to the 

670 merged value (see `tf.TensorShape.merge_with` for details): 

671 

672 >>> t = tf.keras.Input(shape=[None, None, 3]) 

673 >>> print(t.shape) 

674 (None, None, None, 3) 

675 

676 Dimensions set to `None` are not updated: 

677 

678 >>> t.set_shape([None, 224, 224, None]) 

679 >>> print(t.shape) 

680 (None, 224, 224, 3) 

681 

682 The main use case for this is to provide additional shape information 

683 that cannot be inferred from the graph alone. 

684 

685 For example if you know all the images in a dataset have shape [28,28,3] you 

686 can set it with `tf.set_shape`: 

687 

688 >>> @tf.function 

689 ... def load_image(filename): 

690 ... raw = tf.io.read_file(filename) 

691 ... image = tf.image.decode_png(raw, channels=3) 

692 ... # the `print` executes during tracing. 

693 ... print("Initial shape: ", image.shape) 

694 ... image.set_shape([28, 28, 3]) 

695 ... print("Final shape: ", image.shape) 

696 ... return image 

697 

698 Trace the function, see the [Concrete Functions 

699 Guide](https://www.tensorflow.org/guide/concrete_function) for details. 

700 

701 >>> cf = load_image.get_concrete_function( 

702 ... tf.TensorSpec([], dtype=tf.string)) 

703 Initial shape: (None, None, 3) 

704 Final shape: (28, 28, 3) 

705 

706 Similarly the `tf.io.parse_tensor` function could return a tensor with 

707 any shape, even the `tf.rank` is unknown. If you know that all your 

708 serialized tensors will be 2d, set it with `set_shape`: 

709 

710 >>> @tf.function 

711 ... def my_parse(string_tensor): 

712 ... result = tf.io.parse_tensor(string_tensor, out_type=tf.float32) 

713 ... # the `print` executes during tracing. 

714 ... print("Initial shape: ", result.shape) 

715 ... result.set_shape([None, None]) 

716 ... print("Final shape: ", result.shape) 

717 ... return result 

718 

719 Trace the function 

720 

721 >>> concrete_parse = my_parse.get_concrete_function( 

722 ... tf.TensorSpec([], dtype=tf.string)) 

723 Initial shape: <unknown> 

724 Final shape: (None, None) 

725 

726 Make sure it works: 

727 

728 >>> t = tf.ones([5,3], dtype=tf.float32) 

729 >>> serialized = tf.io.serialize_tensor(t) 

730 >>> print(serialized.dtype) 

731 <dtype: 'string'> 

732 >>> print(serialized.shape) 

733 () 

734 >>> t2 = concrete_parse(serialized) 

735 >>> print(t2.shape) 

736 (5, 3) 

737 

738 Caution: `set_shape` ensures that the applied shape is compatible with 

739 the existing shape, but it does not check at runtime. Setting 

740 incorrect shapes can result in inconsistencies between the 

741 statically-known graph and the runtime value of tensors. For runtime 

742 validation of the shape, use `tf.ensure_shape` instead. It also modifies 

743 the `shape` of the tensor. 

744 

745 >>> # Serialize a rank-3 tensor 

746 >>> t = tf.ones([5,5,5], dtype=tf.float32) 

747 >>> serialized = tf.io.serialize_tensor(t) 

748 >>> # The function still runs, even though it `set_shape([None,None])` 

749 >>> t2 = concrete_parse(serialized) 

750 >>> print(t2.shape) 

751 (5, 5, 5) 

752 

753 Args: 

754 shape: A `TensorShape` representing the shape of this tensor, a 

755 `TensorShapeProto`, a list, a tuple, or None. 

756 

757 Raises: 

758 ValueError: If `shape` is not compatible with the current shape of 

759 this tensor. 

760 """ 

761 # Reset cached shape. 

762 self._shape_val = None 

763 

764 # We want set_shape to be reflected in the C API graph for when we run it. 

765 if not isinstance(shape, tensor_shape.TensorShape): 

766 shape = tensor_shape.TensorShape(shape) 

767 dim_list = [] 

768 if shape.dims is None: 

769 unknown_shape = True 

770 else: 

771 unknown_shape = False 

772 for dim in shape.dims: 

773 if dim.value is None: 

774 dim_list.append(-1) 

775 else: 

776 dim_list.append(dim.value) 

777 self._set_shape(dim_list, unknown_shape) 

778 

779 def _as_node_def_input(self): 

780 """Return a value to use for the NodeDef "input" attribute. 

781 

782 The returned string can be used in a NodeDef "input" attribute 

783 to indicate that the NodeDef uses this Tensor as input. 

784 

785 Raises: 

786 ValueError: if this Tensor's Operation does not have a name. 

787 

788 Returns: 

789 a string. 

790 """ 

791 assert self._op.name 

792 if self.value_index == 0: 

793 return self._op.name 

794 else: 

795 return "%s:%d" % (self._op.name, self.value_index) 

796 

797 def __str__(self): 

798 return "Tensor(\"%s\"%s%s%s)" % ( 

799 self.name, 

800 (", shape=%s" % 

801 self.get_shape()) if self.get_shape().ndims is not None else "", 

802 (", dtype=%s" % self._dtype.name) if self._dtype else "", 

803 (", device=%s" % self.device) if self.device else "") 

804 

805 def __repr__(self): 

806 return "<tf.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.get_shape(), 

807 self._dtype.name) 

808 

809 def __hash__(self): 

810 g = getattr(self, "graph", None) 

811 if (Tensor._USE_EQUALITY and (g is None or g.building_function)): 

812 raise TypeError("Tensor is unhashable. " 

813 "Instead, use tensor.ref() as the key.") 

814 else: 

815 return id(self) 

816 

817 def __copy__(self): 

818 # TODO(b/77597810): get rid of Tensor copies. 

819 cls = self.__class__ 

820 result = cls.__new__(cls) 

821 result._init(self.op, self.value_index, self.dtype) 

822 result.__dict__.update(self.__dict__) 

823 return result 

824 

825 # NOTE(mrry): This enables the Tensor's overloaded "right" binary 

826 # operators to run when the left operand is an ndarray, because it 

827 # accords the Tensor class higher priority than an ndarray, or a 

828 # numpy matrix. 

829 # TODO(mrry): Convert this to using numpy's __numpy_ufunc__ 

830 # mechanism, which allows more control over how Tensors interact 

831 # with ndarrays. 

832 __array_priority__ = 100 

833 

834 def __array__(self, dtype=None): 

835 del dtype 

836 raise NotImplementedError( 

837 f"Cannot convert a symbolic tf.Tensor ({self.name}) to a numpy array." 

838 f" This error may indicate that you're trying to pass a Tensor to" 

839 f" a NumPy call, which is not supported.") 

840 

841 def __len__(self): 

842 raise TypeError(f"len is not well defined for a symbolic Tensor " 

843 f"({self.name}). Please call `x.shape` rather than " 

844 f"`len(x)` for shape information.") 

845 

846 # TODO(mdan): This convoluted machinery is hard to maintain. Clean up. 

847 @staticmethod 

848 def _override_operator(operator, func): 

849 _override_helper(Tensor, operator, func) 

850 

851 def __bool__(self): 

852 """Dummy method to prevent a tensor from being used as a Python `bool`. 

853 

854 This overload raises a `TypeError` when the user inadvertently 

855 treats a `Tensor` as a boolean (most commonly in an `if` or `while` 

856 statement), in code that was not converted by AutoGraph. For example: 

857 

858 ```python 

859 if tf.constant(True): # Will raise. 

860 # ... 

861 

862 if tf.constant(5) < tf.constant(7): # Will raise. 

863 # ... 

864 ``` 

865 

866 Raises: 

867 `TypeError`. 

868 """ 

869 self._disallow_bool_casting() 

870 

871 def __nonzero__(self): 

872 """Dummy method to prevent a tensor from being used as a Python `bool`. 

873 

874 This is the Python 2.x counterpart to `__bool__()` above. 

875 

876 Raises: 

877 `TypeError`. 

878 """ 

879 self._disallow_bool_casting() 

880 

881 def eval(self, feed_dict=None, session=None): 

882 """Evaluates this tensor in a `Session`. 

883 

884 Note: If you are not using `compat.v1` libraries, you should not need this, 

885 (or `feed_dict` or `Session`). In eager execution (or within `tf.function`) 

886 you do not need to call `eval`. 

887 

888 Calling this method will execute all preceding operations that 

889 produce the inputs needed for the operation that produces this 

890 tensor. 

891 

892 *N.B.* Before invoking `Tensor.eval()`, its graph must have been 

893 launched in a session, and either a default session must be 

894 available, or `session` must be specified explicitly. 

895 

896 Args: 

897 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 

898 `tf.Session.run` for a description of the valid feed values. 

899 session: (Optional.) The `Session` to be used to evaluate this tensor. If 

900 none, the default session will be used. 

901 

902 Returns: 

903 A numpy array corresponding to the value of this tensor. 

904 """ 

905 return _eval_using_default_session(self, feed_dict, self.graph, session) 

906 

907 @deprecation.deprecated(None, "Use ref() instead.") 

908 def experimental_ref(self): 

909 return self.ref() 

910 

911 def ref(self): 

912 # tf.Variable also has the same ref() API. If you update the 

913 # documentation here, please update tf.Variable.ref() as well. 

914 """Returns a hashable reference object to this Tensor. 

915 

916 The primary use case for this API is to put tensors in a set/dictionary. 

917 We can't put tensors in a set/dictionary as `tensor.__hash__()` is no longer 

918 available starting Tensorflow 2.0. 

919 

920 The following will raise an exception starting 2.0 

921 

922 >>> x = tf.constant(5) 

923 >>> y = tf.constant(10) 

924 >>> z = tf.constant(10) 

925 >>> tensor_set = {x, y, z} 

926 Traceback (most recent call last): 

927 ... 

928 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 

929 >>> tensor_dict = {x: 'five', y: 'ten'} 

930 Traceback (most recent call last): 

931 ... 

932 TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key. 

933 

934 Instead, we can use `tensor.ref()`. 

935 

936 >>> tensor_set = {x.ref(), y.ref(), z.ref()} 

937 >>> x.ref() in tensor_set 

938 True 

939 >>> tensor_dict = {x.ref(): 'five', y.ref(): 'ten', z.ref(): 'ten'} 

940 >>> tensor_dict[y.ref()] 

941 'ten' 

942 

943 Also, the reference object provides `.deref()` function that returns the 

944 original Tensor. 

945 

946 >>> x = tf.constant(5) 

947 >>> x.ref().deref() 

948 <tf.Tensor: shape=(), dtype=int32, numpy=5> 

949 """ 

950 return object_identity.Reference(self) 

951 

952 def __tf_tracing_type__(self, signature_context): 

953 if self.dtype == dtypes.resource or self.dtype == dtypes.variant: 

954 handle_data = handle_data_util.get_handle_data(self) 

955 dtype = dtypes.DType(self.dtype._type_enum, handle_data) 

956 else: 

957 dtype = self.dtype 

958 spec = tensor_spec.TensorSpec(self.shape, dtype) 

959 return spec 

960 

961 def __tf_tensor__( 

962 self, dtype: Optional[dtypes.DType] = None, name: Optional[str] = None 

963 ) -> "Tensor": 

964 if dtype is not None and not dtype.is_compatible_with(self.dtype): 

965 raise ValueError( 

966 _add_error_prefix( 

967 f"Tensor conversion requested dtype {dtype.name} " 

968 f"for Tensor with dtype {self.dtype.name}: {self!r}", 

969 name=name)) 

970 return self 

971 

972 

973def GraphTensor(op, value_index, dtype): 

974 """Creates a new `Tensor` in a graph. 

975 

976 Args: 

977 op: An `Operation`. `Operation` that computes this tensor. 

978 value_index: An `int`. Index of the operation's endpoint that produces this 

979 tensor. 

980 dtype: A `DType`. Type of elements stored in this tensor. 

981 

982 Returns: 

983 A Tensor object. 

984 

985 Raises: 

986 TypeError: If the op is not an `Operation`. 

987 """ 

988 self = Tensor() 

989 # pylint: disable=protected-access 

990 self._init(op, value_index, dtype) 

991 self._dtype = dtypes.as_dtype(dtype) 

992 

993 # This will be set by self.shape(). 

994 self._shape_val = None 

995 self._name = None 

996 self._id = uid() 

997 # pylint: enable=protected-access 

998 return self 

999 

1000 

1001def _create_graph_constant( 

1002 value, dtype, shape, name, verify_shape, allow_broadcast 

1003): 

1004 """Create a graph constant and invoke constant callbacks.""" 

1005 g = get_default_graph() 

1006 tensor_value = attr_value_pb2.AttrValue() 

1007 tensor_value.tensor.CopyFrom( 

1008 tensor_util.make_tensor_proto( 

1009 value, dtype=dtype, shape=shape, verify_shape=verify_shape, 

1010 allow_broadcast=allow_broadcast)) 

1011 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) 

1012 attrs = {"value": tensor_value, "dtype": dtype_value} 

1013 const_tensor = g._create_op_internal( # pylint: disable=protected-access 

1014 "Const", [], [dtype_value.type], attrs=attrs, name=name).outputs[0] 

1015 

1016 if op_callbacks.should_invoke_op_callbacks(): 

1017 # TODO(b/147670703): Once the special-op creation code paths 

1018 # are unified. Remove this `if` block. 

1019 callback_outputs = op_callbacks.invoke_op_callbacks( 

1020 "Const", tuple(), attrs, (const_tensor,), op_name=name, graph=g) 

1021 if callback_outputs is not None: 

1022 [const_tensor] = callback_outputs 

1023 return const_tensor 

1024 

1025 

1026class _EagerTensorBase(Tensor, internal.NativeObject, core_tf_types.Value): 

1027 """Base class for EagerTensor.""" 

1028 

1029 # __complex__, __int__, __float__ and __index__ may copy the tensor to CPU and 

1030 # only work for scalars; values are cast as per numpy. 

1031 def __complex__(self): 

1032 return complex(self._numpy()) 

1033 

1034 def __int__(self): 

1035 return int(self._numpy()) 

1036 

1037 def __long__(self): 

1038 return long(self._numpy()) 

1039 

1040 def __float__(self): 

1041 return float(self._numpy()) 

1042 

1043 def __index__(self): 

1044 return self._numpy().__index__() 

1045 

1046 def __bool__(self): 

1047 return bool(self._numpy()) 

1048 

1049 __nonzero__ = __bool__ 

1050 

1051 def __format__(self, format_spec): 

1052 if self._prefer_custom_summarizer(): 

1053 return self._summarize_value().__format__(format_spec) 

1054 elif self.dtype.is_numpy_compatible: 

1055 # Not numpy_text here, otherwise the __format__ behaves differently. 

1056 return self._numpy().__format__(format_spec) 

1057 else: 

1058 return "<unprintable>".__format__(format_spec) 

1059 

1060 def __reduce__(self): 

1061 return convert_to_tensor, (self._numpy(),) 

1062 

1063 def __copy__(self): 

1064 # Eager Tensors are immutable so it's safe to return themselves as a copy. 

1065 return self 

1066 

1067 def __deepcopy__(self, memo): 

1068 # Eager Tensors are immutable so it's safe to return themselves as a copy. 

1069 del memo 

1070 return self 

1071 

1072 def __str__(self): 

1073 return "tf.Tensor(%s, shape=%s, dtype=%s)" % ( 

1074 value_text(self, is_repr=False), self.shape, self.dtype.name) 

1075 

1076 def __repr__(self): 

1077 return "<tf.Tensor: shape=%s, dtype=%s, %s>" % ( 

1078 self.shape, self.dtype.name, value_text(self, is_repr=True)) 

1079 

1080 def __len__(self): 

1081 """Returns the length of the first dimension in the Tensor.""" 

1082 if not self.shape.ndims: 

1083 raise TypeError("Scalar tensor has no `len()`") 

1084 # pylint: disable=protected-access 

1085 try: 

1086 return self._shape_tuple()[0] 

1087 except core._NotOkStatusException as e: 

1088 raise core._status_to_exception(e) from None 

1089 

1090 def __array__(self, dtype=None): 

1091 a = self._numpy() 

1092 if not dtype: 

1093 return a 

1094 

1095 return np.array(a, dtype=dtype) 

1096 

1097 def __hash__(self) -> int: 

1098 # EagerTensors are never hashable. 

1099 raise TypeError("Tensor is unhashable. " 

1100 "Instead, use tensor.ref() as the key.") 

1101 

1102 def _numpy_internal(self): 

1103 raise NotImplementedError() 

1104 

1105 def _numpy(self): 

1106 try: 

1107 return self._numpy_internal() 

1108 except core._NotOkStatusException as e: # pylint: disable=protected-access 

1109 raise core._status_to_exception(e) from None # pylint: disable=protected-access 

1110 

1111 @property 

1112 def dtype(self): 

1113 # Note: using the intern table directly here as this is 

1114 # performance-sensitive in some models. 

1115 return dtypes._INTERN_TABLE[self._datatype_enum()] # pylint: disable=protected-access 

1116 

1117 def numpy(self): 

1118 """Copy of the contents of this Tensor into a NumPy array or scalar. 

1119 

1120 Unlike NumPy arrays, Tensors are immutable, so this method has to copy 

1121 the contents to ensure safety. Use `memoryview` to get a readonly 

1122 view of the contents without doing a copy: 

1123 

1124 >>> t = tf.constant([42]) 

1125 >>> np.array(memoryview(t)) 

1126 array([42], dtype=int32) 

1127 

1128 Note that `memoryview` is only zero-copy for Tensors on CPU. If a Tensor 

1129 is on GPU, it will have to be transferred to CPU first in order for 

1130 `memoryview` to work. 

1131 

1132 Returns: 

1133 A NumPy array of the same shape and dtype or a NumPy scalar, if this 

1134 Tensor has rank 0. 

1135 

1136 Raises: 

1137 ValueError: If the dtype of this Tensor does not have a compatible 

1138 NumPy dtype. 

1139 """ 

1140 # TODO(slebedev): Consider avoiding a copy for non-CPU or remote tensors. 

1141 maybe_arr = self._numpy() # pylint: disable=protected-access 

1142 return maybe_arr.copy() if isinstance(maybe_arr, np.ndarray) else maybe_arr 

1143 

1144 @property 

1145 def backing_device(self): 

1146 """Returns the name of the device holding this tensor's memory. 

1147 

1148 `.backing_device` is usually the same as `.device`, which returns 

1149 the device on which the kernel of the operation that produced this tensor 

1150 ran. However, some operations can produce tensors on a different device 

1151 (e.g., an operation that executes on the GPU but produces output tensors 

1152 in host memory). 

1153 """ 

1154 raise NotImplementedError() 

1155 

1156 def _datatype_enum(self): 

1157 raise NotImplementedError() 

1158 

1159 def _shape_tuple(self): 

1160 """The shape of this Tensor, as a tuple. 

1161 

1162 This is more performant than tuple(shape().as_list()) as it avoids 

1163 two list and one object creation. Marked private for now as from an API 

1164 perspective, it would be better to have a single performant way of 

1165 getting a shape rather than exposing shape() and shape_tuple() 

1166 (and heaven forbid, shape_list() etc. as well!). Punting on that for now, 

1167 but ideally one would work things out and remove the need for this method. 

1168 

1169 Returns: 

1170 tuple with the shape. 

1171 """ 

1172 raise NotImplementedError() 

1173 

1174 def _rank(self): 

1175 """Integer rank of this Tensor. 

1176 

1177 Unlike regular Tensors, the rank is always known for EagerTensors. 

1178 

1179 This is more performant than len(self._shape_tuple()) 

1180 

1181 Returns: 

1182 Integer rank 

1183 """ 

1184 raise NotImplementedError() 

1185 

1186 def _num_elements(self): 

1187 """Number of elements of this Tensor. 

1188 

1189 Unlike regular Tensors, the number of elements is always known for 

1190 EagerTensors. 

1191 

1192 This is more performant than tensor.shape.num_elements 

1193 

1194 Returns: 

1195 Long - num elements in the tensor 

1196 """ 

1197 raise NotImplementedError() 

1198 

1199 def _copy_to_device(self, device_name): # pylint: disable=redefined-outer-name 

1200 raise NotImplementedError() 

1201 

1202 @staticmethod 

1203 def _override_operator(name, func): 

1204 setattr(_EagerTensorBase, name, func) 

1205 

1206 def _copy_nograd(self, ctx=None, device_name=None): 

1207 """Copies tensor to dest device, but doesn't record the operation.""" 

1208 # Creates a new tensor on the dest device. 

1209 if ctx is None: 

1210 ctx = context.context() 

1211 if device_name is None: 

1212 device_name = ctx.device_name 

1213 # pylint: disable=protected-access 

1214 try: 

1215 ctx.ensure_initialized() 

1216 new_tensor = self._copy_to_device(device_name) 

1217 except core._NotOkStatusException as e: 

1218 raise core._status_to_exception(e) from None 

1219 return new_tensor 

1220 

1221 def _copy(self, ctx=None, device_name=None): 

1222 """Copies tensor to dest device.""" 

1223 new_tensor = self._copy_nograd(ctx, device_name) 

1224 # Record the copy on tape and define backprop copy as well. 

1225 if context.executing_eagerly(): 

1226 self_device = self.device 

1227 

1228 def grad_fun(dresult): 

1229 return [ 

1230 dresult._copy(device_name=self_device) 

1231 if hasattr(dresult, "_copy") else dresult 

1232 ] 

1233 

1234 record.record_operation("_copy", [new_tensor], [self], grad_fun) 

1235 return new_tensor 

1236 # pylint: enable=protected-access 

1237 

1238 @property 

1239 def shape(self): 

1240 if self._tensor_shape is None: # pylint: disable=access-member-before-definition 

1241 # pylint: disable=protected-access 

1242 try: 

1243 # `_tensor_shape` is declared and defined in the definition of 

1244 # `EagerTensor`, in C. 

1245 self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple()) 

1246 except core._NotOkStatusException as e: 

1247 raise core._status_to_exception(e) from None 

1248 

1249 return self._tensor_shape 

1250 

1251 def get_shape(self): 

1252 """Alias of Tensor.shape.""" 

1253 return self.shape 

1254 

1255 def _shape_as_list(self): 

1256 """The shape of the tensor as a list.""" 

1257 return list(self._shape_tuple()) 

1258 

1259 @deprecation.deprecated( 

1260 None, "Use tf.identity with explicit device placement instead.") 

1261 def cpu(self): 

1262 """A copy of this Tensor with contents backed by host memory.""" 

1263 return self._copy(context.context(), "CPU:0") 

1264 

1265 @deprecation.deprecated(None, "Use tf.identity instead.") 

1266 def gpu(self, gpu_index=0): 

1267 """A copy of this Tensor with contents backed by memory on the GPU. 

1268 

1269 Args: 

1270 gpu_index: Identifies which GPU to place the contents on the returned 

1271 Tensor in. 

1272 

1273 Returns: 

1274 A GPU-memory backed Tensor object initialized with the same contents 

1275 as this Tensor. 

1276 """ 

1277 return self._copy(context.context(), "GPU:" + str(gpu_index)) 

1278 

1279 def set_shape(self, shape): 

1280 if not self.shape.is_compatible_with(shape): 

1281 raise ValueError(f"Tensor's shape {self.shape} is not compatible " 

1282 f"with supplied shape {shape}.") 

1283 

1284 # Methods not supported / implemented for Eager Tensors. 

1285 @property 

1286 def op(self): 

1287 raise AttributeError( 

1288 "Tensor.op is undefined when eager execution is enabled.") 

1289 

1290 @property 

1291 def graph(self): 

1292 raise AttributeError( 

1293 "Tensor.graph is undefined when eager execution is enabled.") 

1294 

1295 @property 

1296 def name(self): 

1297 raise AttributeError( 

1298 "Tensor.name is undefined when eager execution is enabled.") 

1299 

1300 @property 

1301 def value_index(self): 

1302 raise AttributeError( 

1303 "Tensor.value_index is undefined when eager execution is enabled.") 

1304 

1305 def consumers(self): 

1306 raise NotImplementedError( 

1307 "Tensor.consumers is undefined when eager execution is enabled.") 

1308 

1309 def _add_consumer(self, consumer): 

1310 raise NotImplementedError( 

1311 "_add_consumer not supported when eager execution is enabled.") 

1312 

1313 def _as_node_def_input(self): 

1314 raise NotImplementedError( 

1315 "_as_node_def_input not supported when eager execution is enabled.") 

1316 

1317 def _as_tf_output(self): 

1318 raise NotImplementedError( 

1319 "_as_tf_output not supported when eager execution is enabled.") 

1320 

1321 def eval(self, feed_dict=None, session=None): 

1322 raise NotImplementedError( 

1323 "eval is not supported when eager execution is enabled, " 

1324 "is .numpy() what you're looking for?") 

1325 

1326 def __tf_tensor__( 

1327 self, dtype: Optional[dtypes.DType] = None, name: Optional[str] = None 

1328 ) -> Tensor: 

1329 if not context.executing_eagerly(): 

1330 graph = get_default_graph() 

1331 if not graph.building_function: 

1332 raise RuntimeError( 

1333 _add_error_prefix( 

1334 "Attempting to capture an EagerTensor without " 

1335 "building a function.", 

1336 name=name)) 

1337 return graph.capture(self, name=name) 

1338 return super().__tf_tensor__(dtype, name) 

1339 

1340 def _capture_as_const(self, name): 

1341 """Capture the EagerTensor to a graph constant tensor.""" 

1342 with control_dependencies(None): 

1343 constant_value = tensor_util.constant_value(self) 

1344 if constant_value is None: 

1345 # Some eager tensors, e.g. parallel tensors, are not convertible to 

1346 # a single constant. Return None in this case and the caller graph 

1347 # would create a placeholder instead. 

1348 return None 

1349 

1350 const_tensor = _create_graph_constant( 

1351 constant_value, dtype=self.dtype, shape=self.shape, name=name, 

1352 verify_shape=False, allow_broadcast=True) 

1353 return const_tensor 

1354 

1355 

1356# This call creates an EagerTensor class, as a subclass of _EagerTensorBase, and 

1357# registers it with the current module. 

1358# It is exposed as an __internal__ api for now (b/171081052), though we 

1359# expect it to be eventually covered by tf Tensor types and typing. 

1360EagerTensor = tf_export("__internal__.EagerTensor", v1=[])( 

1361 pywrap_tfe.TFE_Py_InitEagerTensor(_EagerTensorBase)) 

1362 

1363 

1364def _add_error_prefix(msg, *, name=None): 

1365 return msg if name is None else f"{name}: {msg}" 

1366 

1367 

1368def pack_eager_tensors(tensors, ctx=None): 

1369 """Pack multiple `EagerTensor`s of the same dtype and shape. 

1370 

1371 Args: 

1372 tensors: a list of EagerTensors to pack. 

1373 ctx: context.context(). 

1374 

1375 Returns: 

1376 A packed EagerTensor. 

1377 """ 

1378 if not isinstance(tensors, list): 

1379 raise TypeError(f"tensors must be a list, but got a {type(tensors)}") 

1380 

1381 if not tensors: 

1382 raise ValueError("Cannot pack an empty list of tensors.") 

1383 

1384 dtype = tensors[0].dtype 

1385 shape = tensors[0].shape 

1386 handle_data = tensors[0]._handle_data # pylint: disable=protected-access 

1387 is_resource = dtype == dtypes.resource 

1388 for i in range(len(tensors)): 

1389 t = tensors[i] 

1390 if not isinstance(t, EagerTensor): 

1391 raise TypeError(f"All tensors being packed must be EagerTensor. " 

1392 f"Found an item of type {type(t)}.") 

1393 

1394 if t.dtype != dtype: 

1395 raise ValueError( 

1396 f"All tensors being packed should have the same dtype {dtype}, " 

1397 f"but the {i}-th tensor is of dtype {t.dtype}") 

1398 if t.shape != shape: 

1399 raise ValueError( 

1400 f"All tensors being packed should have the same shape {shape}, " 

1401 f"but the {i}-th tensor is of shape {t.shape}") 

1402 # pylint: disable=protected-access 

1403 if is_resource and t._handle_data != handle_data: 

1404 raise ValueError( 

1405 f"All tensors being packed should have the same handle data " 

1406 f"{handle_data}, " 

1407 f"but the {i}-th tensor is of handle data {t._handle_data}") 

1408 # pylint: enable=protected-access 

1409 

1410 if ctx is None: 

1411 ctx = context.context() 

1412 

1413 # Propagate handle data for resource variables 

1414 packed_tensor = ctx.pack_eager_tensors(tensors) 

1415 if handle_data is not None: 

1416 packed_tensor._handle_data = handle_data # pylint: disable=protected-access 

1417 

1418 def grad_fun(_): 

1419 raise ValueError( 

1420 "Computing gradients through pack_eager_tensors is not supported.") 

1421 

1422 record.record_operation("pack_eager_tensors", [packed_tensor], tensors, 

1423 grad_fun) 

1424 

1425 return packed_tensor 

1426 

1427 

1428@profiler_trace.trace_wrapper("convert_to_tensor") 

1429def convert_to_tensor( 

1430 value, 

1431 dtype=None, 

1432 name=None, 

1433 as_ref=False, 

1434 preferred_dtype=None, 

1435 dtype_hint=None, 

1436 # TODO(b/268347915): Remove argument. 

1437 ctx=None, # pylint: disable=unused-argument 

1438 accepted_result_types=(Tensor,), 

1439): 

1440 """Implementation of the public convert_to_tensor.""" 

1441 # TODO(b/142518781): Fix all call-sites and remove redundant arg 

1442 preferred_dtype = preferred_dtype or dtype_hint 

1443 return tensor_conversion_registry.convert( 

1444 value, dtype, name, as_ref, preferred_dtype, accepted_result_types 

1445 ) 

1446 

1447 

1448internal_convert_to_tensor = convert_to_tensor 

1449 

1450 

1451def internal_convert_n_to_tensor(values, 

1452 dtype=None, 

1453 name=None, 

1454 as_ref=False, 

1455 preferred_dtype=None, 

1456 # TODO(b/268347915): Remove argument. 

1457 ctx=None): # pylint: disable=unused-argument 

1458 """Converts `values` to a list of `Tensor` objects. 

1459 

1460 Args: 

1461 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 

1462 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 

1463 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 

1464 which case element `i` will be given the name `name + '_' + i`. 

1465 as_ref: True if the caller wants the results as ref tensors. 

1466 preferred_dtype: Optional element type for the returned tensors, used when 

1467 dtype is None. In some cases, a caller may not have a dtype in mind when 

1468 converting to a tensor, so preferred_dtype can be used as a soft 

1469 preference. If the conversion to `preferred_dtype` is not possible, this 

1470 argument has no effect. 

1471 ctx: Unused. Present for API backwards compatibility. 

1472 

1473 Returns: 

1474 A list of `Tensor` and/or `IndexedSlices` objects. 

1475 

1476 Raises: 

1477 TypeError: If no conversion function is registered for an element in 

1478 `values`. 

1479 RuntimeError: If a registered conversion function returns an invalid 

1480 value. 

1481 """ 

1482 if not isinstance(values, collections_abc.Sequence): 

1483 raise TypeError("values must be a sequence.") 

1484 ret = [] 

1485 for i, value in enumerate(values): 

1486 n = None if name is None else "%s_%d" % (name, i) 

1487 ret.append( 

1488 convert_to_tensor( 

1489 value, 

1490 dtype=dtype, 

1491 name=n, 

1492 as_ref=as_ref, 

1493 preferred_dtype=preferred_dtype)) 

1494 return ret 

1495 

1496 

1497def convert_n_to_tensor(values, dtype=None, name=None, preferred_dtype=None): 

1498 """Converts `values` to a list of `Tensor` objects. 

1499 

1500 Args: 

1501 values: A list of objects that can be consumed by `tf.convert_to_tensor()`. 

1502 dtype: (Optional.) The required `DType` of the returned `Tensor` objects. 

1503 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 

1504 which case element `i` will be given the name `name + '_' + i`. 

1505 preferred_dtype: Optional element type for the returned tensors, used when 

1506 dtype is None. In some cases, a caller may not have a dtype in mind when 

1507 converting to a tensor, so preferred_dtype can be used as a soft 

1508 preference. If the conversion to `preferred_dtype` is not possible, this 

1509 argument has no effect. 

1510 

1511 Returns: 

1512 A list of `Tensor` and/or `IndexedSlices` objects. 

1513 

1514 Raises: 

1515 TypeError: If no conversion function is registered for an element in 

1516 `values`. 

1517 RuntimeError: If a registered conversion function returns an invalid 

1518 value. 

1519 """ 

1520 return internal_convert_n_to_tensor( 

1521 values=values, 

1522 dtype=dtype, 

1523 name=name, 

1524 preferred_dtype=preferred_dtype, 

1525 as_ref=False) 

1526 

1527 

1528def convert_to_tensor_or_composite(value, dtype=None, name=None): 

1529 """Converts the given object to a `Tensor` or `CompositeTensor`. 

1530 

1531 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 

1532 is converted to a `Tensor` using `convert_to_tensor()`. 

1533 

1534 Args: 

1535 value: A `CompositeTensor` or an object that can be consumed by 

1536 `convert_to_tensor()`. 

1537 dtype: (Optional.) The required `DType` of the returned `Tensor` or 

1538 `CompositeTensor`. 

1539 name: (Optional.) A name to use if a new `Tensor` is created. 

1540 

1541 Returns: 

1542 A `Tensor` or `CompositeTensor`, based on `value`. 

1543 

1544 Raises: 

1545 ValueError: If `dtype` does not match the element type of `value`. 

1546 """ 

1547 return internal_convert_to_tensor_or_composite( 

1548 value=value, dtype=dtype, name=name, as_ref=False) 

1549 

1550 

1551def internal_convert_to_tensor_or_composite(value, 

1552 dtype=None, 

1553 name=None, 

1554 as_ref=False): 

1555 """Converts the given object to a `Tensor` or `CompositeTensor`. 

1556 

1557 If `value` is a `CompositeTensor` it is returned unmodified. Otherwise, it 

1558 is converted to a `Tensor` using `convert_to_tensor()`. 

1559 

1560 Args: 

1561 value: A `CompositeTensor`, or an object that can be consumed by 

1562 `convert_to_tensor()`. 

1563 dtype: (Optional.) The required `DType` of the returned `Tensor` or 

1564 `CompositeTensor`. 

1565 name: (Optional.) A name to use if a new `Tensor` is created. 

1566 as_ref: True if the caller wants the results as ref tensors. 

1567 

1568 Returns: 

1569 A `Tensor` or `CompositeTensor`, based on `value`. 

1570 

1571 Raises: 

1572 ValueError: If `dtype` does not match the element type of `value`. 

1573 """ 

1574 if isinstance(value, composite_tensor.CompositeTensor): 

1575 value_dtype = getattr(value, "dtype", None) 

1576 if dtype and not dtypes.as_dtype(dtype).is_compatible_with(value_dtype): 

1577 raise ValueError(f"Tensor conversion dtype mismatch. " 

1578 f"Requested dtype is {dtypes.as_dtype(dtype).name}, " 

1579 f"Tensor has dtype {value.dtype.name}: {value!r}") 

1580 return value 

1581 else: 

1582 return convert_to_tensor( 

1583 value, 

1584 dtype=dtype, 

1585 name=name, 

1586 as_ref=as_ref, 

1587 accepted_result_types=(Tensor, composite_tensor.CompositeTensor)) 

1588 

1589 

1590def internal_convert_n_to_tensor_or_composite(values, 

1591 dtype=None, 

1592 name=None, 

1593 as_ref=False): 

1594 """Converts `values` to a list of `Tensor` or `CompositeTensor` objects. 

1595 

1596 Any `CompositeTensor` objects in `values` are returned unmodified. 

1597 

1598 Args: 

1599 values: A list of `None`, `CompositeTensor`, or objects that can be consumed 

1600 by `convert_to_tensor()`. 

1601 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 

1602 `CompositeTensor`s. 

1603 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 

1604 which case element `i` will be given the name `name + '_' + i`. 

1605 as_ref: True if the caller wants the results as ref tensors. 

1606 

1607 Returns: 

1608 A list of `Tensor`, `CompositeTensor`, and/or `None` objects. 

1609 

1610 Raises: 

1611 TypeError: If no conversion function is registered for an element in 

1612 `values`. 

1613 RuntimeError: If a registered conversion function returns an invalid 

1614 value. 

1615 """ 

1616 if not isinstance(values, collections_abc.Sequence): 

1617 raise TypeError("values must be a sequence.") 

1618 ret = [] 

1619 for i, value in enumerate(values): 

1620 if value is None: 

1621 ret.append(value) 

1622 else: 

1623 n = None if name is None else "%s_%d" % (name, i) 

1624 ret.append( 

1625 internal_convert_to_tensor_or_composite( 

1626 value, dtype=dtype, name=n, as_ref=as_ref)) 

1627 return ret 

1628 

1629 

1630def convert_n_to_tensor_or_composite(values, dtype=None, name=None): 

1631 """Converts `values` to a list of `Output` or `CompositeTensor` objects. 

1632 

1633 Any `CompositeTensor` objects in `values` are returned unmodified. 

1634 

1635 Args: 

1636 values: A list of `None`, `CompositeTensor``, or objects that can be 

1637 consumed by `convert_to_tensor()`. 

1638 dtype: (Optional.) The required `DType` of the returned `Tensor`s or 

1639 `CompositeTensor`s. 

1640 name: (Optional.) A name prefix to used when a new `Tensor` is created, in 

1641 which case element `i` will be given the name `name + '_' + i`. 

1642 

1643 Returns: 

1644 A list of `Tensor` and/or `CompositeTensor` objects. 

1645 

1646 Raises: 

1647 TypeError: If no conversion function is registered for an element in 

1648 `values`. 

1649 RuntimeError: If a registered conversion function returns an invalid 

1650 value. 

1651 """ 

1652 return internal_convert_n_to_tensor_or_composite( 

1653 values=values, dtype=dtype, name=name, as_ref=False) 

1654 

1655 

1656def _device_string(dev_spec): 

1657 if pydev.is_device_spec(dev_spec): 

1658 return dev_spec.to_string() 

1659 else: 

1660 return dev_spec 

1661 

1662 

1663def _NodeDef(op_type, name, attrs=None): 

1664 """Create a NodeDef proto. 

1665 

1666 Args: 

1667 op_type: Value for the "op" attribute of the NodeDef proto. 

1668 name: Value for the "name" attribute of the NodeDef proto. 

1669 attrs: Dictionary where the key is the attribute name (a string) 

1670 and the value is the respective "attr" attribute of the NodeDef proto (an 

1671 AttrValue). 

1672 

1673 Returns: 

1674 A node_def_pb2.NodeDef protocol buffer. 

1675 """ 

1676 node_def = node_def_pb2.NodeDef(op=compat.as_bytes(op_type), 

1677 name=compat.as_bytes(name)) 

1678 if attrs: 

1679 for k, v in attrs.items(): 

1680 node_def.attr[k].CopyFrom(v) 

1681 return node_def 

1682 

1683 

1684# Copied from core/framework/node_def_util.cc 

1685# TODO(mrry,josh11b): Consolidate this validation in C++ code. 

1686_VALID_OP_NAME_REGEX = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.\\/>-]*$") 

1687_VALID_SCOPE_NAME_REGEX = re.compile(r"^[A-Za-z0-9_.\\/>-]*$") 

1688 

1689 

1690@tf_export("__internal__.create_c_op", v1=[]) 

1691@traceback_utils.filter_traceback 

1692def _create_c_op(graph, 

1693 node_def, 

1694 inputs, 

1695 control_inputs, 

1696 op_def=None, 

1697 extract_traceback=True): 

1698 """Creates a TF_Operation. 

1699 

1700 Args: 

1701 graph: a `Graph`. 

1702 node_def: `node_def_pb2.NodeDef` for the operation to create. 

1703 inputs: A flattened list of `Tensor`s. This function handles grouping 

1704 tensors into lists as per attributes in the `node_def`. 

1705 control_inputs: A list of `Operation`s to set as control dependencies. 

1706 op_def: Optional. `op_def_pb2.OpDef` for the operation to create. If not 

1707 specified, is looked up from the `graph` using `node_def.op`. 

1708 extract_traceback: if True, extract the current Python traceback to the 

1709 TF_Operation. 

1710 

1711 Returns: 

1712 A wrapped TF_Operation*. 

1713 """ 

1714 if op_def is None: 

1715 op_def = graph.op_def_for_type(node_def.op) # pylint: disable=protected-access 

1716 # TODO(skyewm): op_def_library.apply_op() flattens the incoming inputs. 

1717 # Refactor so we don't have to do this here. 

1718 inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.attr) 

1719 # pylint: disable=protected-access 

1720 with graph._c_graph.get() as c_graph: 

1721 op_desc = pywrap_tf_session.TF_NewOperation(c_graph, 

1722 compat.as_str(node_def.op), 

1723 compat.as_str(node_def.name)) 

1724 if node_def.device: 

1725 pywrap_tf_session.TF_SetDevice(op_desc, compat.as_str(node_def.device)) 

1726 # Add inputs 

1727 for op_input in inputs: 

1728 if isinstance(op_input, (list, tuple)): 

1729 pywrap_tf_session.TF_AddInputList(op_desc, 

1730 [t._as_tf_output() for t in op_input]) 

1731 else: 

1732 pywrap_tf_session.TF_AddInput(op_desc, op_input._as_tf_output()) 

1733 

1734 # Add control inputs 

1735 for control_input in control_inputs: 

1736 pywrap_tf_session.TF_AddControlInput(op_desc, control_input._c_op) 

1737 # pylint: enable=protected-access 

1738 

1739 # Add attrs 

1740 for name, attr_value in node_def.attr.items(): 

1741 serialized = attr_value.SerializeToString() 

1742 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 

1743 # It might be worth creating a convenient way to re-use the same status. 

1744 pywrap_tf_session.TF_SetAttrValueProto(op_desc, compat.as_str(name), 

1745 serialized) 

1746 

1747 try: 

1748 c_op = pywrap_tf_session.TF_FinishOperation(op_desc) 

1749 except errors.InvalidArgumentError as e: 

1750 # Convert to ValueError for backwards compatibility. 

1751 raise ValueError(e.message) 

1752 

1753 # Record the current Python stack trace as the creating stacktrace of this 

1754 # TF_Operation. 

1755 if extract_traceback: 

1756 tf_stack.extract_stack_for_op(c_op, stacklevel=3) 

1757 

1758 return c_op 

1759 

1760 

1761@tf_export("Operation") 

1762class Operation(pywrap_tf_session.PyOperation): 

1763 """Represents a graph node that performs computation on tensors. 

1764 

1765 An `Operation` is a node in a `tf.Graph` that takes zero or more `Tensor` 

1766 objects as input, and produces zero or more `Tensor` objects as output. 

1767 Objects of type `Operation` are created by calling a Python op constructor 

1768 (such as `tf.matmul`) within a `tf.function` or under a `tf.Graph.as_default` 

1769 context manager. 

1770 

1771 For example, within a `tf.function`, `c = tf.matmul(a, b)` creates an 

1772 `Operation` of type "MatMul" that takes tensors `a` and `b` as input, and 

1773 produces `c` as output. 

1774 

1775 If a `tf.compat.v1.Session` is used, an `Operation` of a `tf.Graph` can be 

1776 executed by passing it to `tf.Session.run`. `op.run()` is a shortcut for 

1777 calling `tf.compat.v1.get_default_session().run(op)`. 

1778 """ 

1779 

1780 @classmethod 

1781 def from_node_def( 

1782 cls, 

1783 node_def, 

1784 g, 

1785 inputs=None, 

1786 output_types=None, 

1787 control_inputs=None, 

1788 input_types=None, 

1789 original_op=None, 

1790 op_def=None, 

1791 ): 

1792 r"""Creates an `Operation`. 

1793 

1794 NOTE: This constructor validates the name of the `Operation` (passed 

1795 as `node_def.name`). Valid `Operation` names match the following 

1796 regular expression: 

1797 

1798 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* 

1799 

1800 Args: 

1801 node_def: `node_def_pb2.NodeDef`. `NodeDef` for the `Operation`. Used for 

1802 attributes of `node_def_pb2.NodeDef`, typically `name`, `op`, and 

1803 `device`. The `input` attribute is irrelevant here as it will be 

1804 computed when generating the model. 

1805 g: `Graph`. The parent graph. 

1806 inputs: list of `Tensor` objects. The inputs to this `Operation`. 

1807 output_types: list of `DType` objects. List of the types of the `Tensors` 

1808 computed by this operation. The length of this list indicates the 

1809 number of output endpoints of the `Operation`. 

1810 control_inputs: list of operations or tensors from which to have a control 

1811 dependency. 

1812 input_types: List of `DType` objects representing the types of the tensors 

1813 accepted by the `Operation`. By default uses `[x.dtype.base_dtype for x 

1814 in inputs]`. Operations that expect reference-typed inputs must specify 

1815 these explicitly. 

1816 original_op: Optional. Used to associate the new `Operation` with an 

1817 existing `Operation` (for example, a replica with the op that was 

1818 replicated). 

1819 op_def: Optional. The `op_def_pb2.OpDef` proto that describes the op type 

1820 that this `Operation` represents. 

1821 

1822 Raises: 

1823 TypeError: if control inputs are not Operations or Tensors, 

1824 or if `node_def` is not a `NodeDef`, 

1825 or if `g` is not a `Graph`, 

1826 or if `inputs` are not tensors, 

1827 or if `inputs` and `input_types` are incompatible. 

1828 ValueError: if the `node_def` name is not valid. 

1829 

1830 Returns: 

1831 Operation object. 

1832 """ 

1833 if not isinstance(g, Graph): 

1834 raise TypeError(f"Argument g must be a Graph. " 

1835 f"Received an instance of type {type(g)}") 

1836 

1837 if not isinstance(node_def, node_def_pb2.NodeDef): 

1838 raise TypeError(f"Argument node_def must be a NodeDef. " 

1839 f"Received an instance of type: {type(node_def)}.") 

1840 if node_def.ByteSize() >= (1 << 31) or node_def.ByteSize() < 0: 

1841 raise ValueError( 

1842 f"Cannot create a tensor proto whose content is larger than 2GB. " 

1843 f"Size of tensor is {node_def.ByteSize()} bytes.") 

1844 

1845 # TODO(mdan): This does not belong here. Graph::AddNode should handle it. 

1846 if not _VALID_OP_NAME_REGEX.match(node_def.name): 

1847 raise ValueError( 

1848 f"`{node_def.name}` is not a valid node name. " 

1849 f"Accepted names conform to Regex /{_VALID_OP_NAME_REGEX}/") 

1850 

1851 # FIXME(b/225400189): output_types is unused. Consider remove it from 

1852 # the argument list. 

1853 del output_types 

1854 

1855 if inputs is None: 

1856 inputs = [] 

1857 elif not isinstance(inputs, list): 

1858 raise TypeError(f"Argument inputs shall be a list of Tensors. " 

1859 f"Received an instance of type {type(inputs)}") 

1860 for a in inputs: 

1861 if not isinstance(a, Tensor): 

1862 raise TypeError(f"Items of argument inputs shall be Tensor. " 

1863 f"Received an instance of type {type(a)}.") 

1864 if input_types is None: 

1865 input_types = [i.dtype.base_dtype for i in inputs] 

1866 else: 

1867 if not all( 

1868 x.is_compatible_with(i.dtype) for i, x in zip(inputs, input_types)): 

1869 raise TypeError("In op '%s', input types (%s) are not compatible " 

1870 "with expected types (%s)" % 

1871 (node_def.name, [i.dtype for i in inputs], input_types)) 

1872 

1873 # Build the list of control inputs. 

1874 control_input_ops = [] 

1875 if control_inputs: 

1876 for c in control_inputs: 

1877 control_op = None 

1878 if isinstance(c, Operation): 

1879 control_op = c 

1880 elif isinstance(c, (Tensor, internal.IndexedSlices)): 

1881 control_op = c.op 

1882 else: 

1883 raise TypeError(f"Control input must be an Operation, " 

1884 f"a Tensor, or IndexedSlices. " 

1885 f"Received an instance of type {type(c)}.") 

1886 control_input_ops.append(control_op) 

1887 

1888 # Initialize c_op from node_def and other inputs 

1889 c_op = _create_c_op(g, node_def, inputs, control_input_ops, op_def=op_def) 

1890 self = Operation(c_op, GraphTensor) 

1891 self._init(g) 

1892 

1893 self._original_op = original_op 

1894 

1895 # Post process for control flows. 

1896 self._control_flow_post_processing(input_tensors=inputs) 

1897 

1898 # Removes this frame from the Python traceback. 

1899 # We adjust stacklevel directly to avoid triggering serialization. 

1900 if self.traceback is not None: 

1901 self.traceback._stacklevel += 1 # pylint: disable=protected-access 

1902 

1903 return self 

1904 

1905 @classmethod 

1906 def _from_c_op(cls, c_op, g): 

1907 """Create an Operation from a TF_Operation. 

1908 

1909 For internal use only: This is useful for creating Operation for ops 

1910 indirectly created by C API methods, e.g. the ops created by 

1911 TF_ImportGraphDef. 

1912 

1913 Args: 

1914 c_op: a TF_Operation. 

1915 g: A Graph. 

1916 

1917 Returns: 

1918 an Operation object. 

1919 """ 

1920 self = Operation(c_op, GraphTensor) 

1921 self._init(g) 

1922 return self 

1923 

1924 def _init(self, graph): 

1925 """Initializes Operation from a TF_Operation.""" 

1926 self.graph = graph 

1927 self._original_op = None 

1928 

1929 # This will be set by self.inputs. 

1930 self._inputs_val = None 

1931 

1932 # List of _UserDevSpecs holding code location of device context manager 

1933 # invocations and the users original argument to them. 

1934 self._device_code_locations = None 

1935 # Dict mapping op name to file and line information for op colocation 

1936 # context managers. 

1937 self._colocation_code_locations = None 

1938 self._control_flow_context = self.graph._get_control_flow_context() # pylint: disable=protected-access 

1939 

1940 # Gradient function for this op. There are three ways to specify gradient 

1941 # function, and first available gradient gets used, in the following order. 

1942 # 1. self._gradient_function 

1943 # 2. Gradient name registered by "_gradient_op_type" attribute. 

1944 # 3. Gradient name registered by op.type. 

1945 self._gradient_function = None 

1946 

1947 self._init_outputs() 

1948 self._id_value = self.graph._add_op(self) # pylint: disable=protected-access 

1949 

1950 def _control_flow_post_processing(self, input_tensors=None): 

1951 """Add this op to its control flow context. 

1952 

1953 This may add new ops and change this op's inputs. self.inputs must be 

1954 available before calling this method. 

1955 

1956 Args: 

1957 input_tensors: (Optional.) A list of `Tensors` corresponding to the inputs 

1958 of this op, which should be equivalent to `self.inputs`. Pass this 

1959 argument to avoid evaluating `self.inputs` unnecessarily. 

1960 """ 

1961 if input_tensors is None: 

1962 input_tensors = self.inputs 

1963 for input_tensor in input_tensors: 

1964 control_flow_util.CheckInputFromValidContext(self, input_tensor.op) 

1965 if self._control_flow_context is not None: 

1966 self._control_flow_context.AddOp(self) 

1967 

1968 def colocation_groups(self): 

1969 """Returns the list of colocation groups of the op.""" 

1970 default_colocation_group = [compat.as_bytes("loc:@%s" % self.name)] 

1971 try: 

1972 class_attr = self.get_attr("_class") 

1973 except ValueError: 

1974 # This op has no explicit colocation group, so it is itself its 

1975 # own root of a colocation group. 

1976 return default_colocation_group 

1977 

1978 attr_groups = [ 

1979 class_name for class_name in class_attr 

1980 if class_name.startswith(b"loc:@") 

1981 ] 

1982 

1983 # If there are no colocation groups in the explicit _class field, 

1984 # return the default colocation group. 

1985 return attr_groups if attr_groups else default_colocation_group 

1986 

1987 def values(self): 

1988 """DEPRECATED: Use outputs.""" 

1989 return tuple(self.outputs) 

1990 

1991 def _get_control_flow_context(self): 

1992 """Returns the control flow context of this op. 

1993 

1994 Returns: 

1995 A context object. 

1996 """ 

1997 return self._control_flow_context 

1998 

1999 def _set_control_flow_context(self, ctx): 

2000 """Sets the current control flow context of this op. 

2001 

2002 Args: 

2003 ctx: a context object. 

2004 """ 

2005 self._control_flow_context = ctx 

2006 

2007 @property 

2008 def _id(self): 

2009 """The unique integer id of this operation.""" 

2010 return self._id_value 

2011 

2012 @property 

2013 def device(self): 

2014 """The name of the device to which this op has been assigned, if any. 

2015 

2016 Returns: 

2017 The string name of the device to which this op has been 

2018 assigned, or an empty string if it has not been assigned to a 

2019 device. 

2020 """ 

2021 return pywrap_tf_session.TF_OperationDevice(self._c_op) 

2022 

2023 @property 

2024 def _device_assignments(self): 

2025 """Code locations for device context managers active at op creation. 

2026 

2027 This property will return a list of traceable_stack.TraceableObject 

2028 instances where .obj is a string representing the assigned device 

2029 (or information about the function that would be applied to this op 

2030 to compute the desired device) and the filename and lineno members 

2031 record the location of the relevant device context manager. 

2032 

2033 For example, suppose file_a contained these lines: 

2034 

2035 file_a.py: 

2036 15: with tf.device('/gpu:0'): 

2037 16: node_b = tf.constant(4, name='NODE_B') 

2038 

2039 Then a TraceableObject t_obj representing the device context manager 

2040 would have these member values: 

2041 

2042 t_obj.obj -> '/gpu:0' 

2043 t_obj.filename = 'file_a.py' 

2044 t_obj.lineno = 15 

2045 

2046 and node_b.op._device_assignments would return the list [t_obj]. 

2047 

2048 Returns: 

2049 [str: traceable_stack.TraceableObject, ...] as per this method's 

2050 description, above. 

2051 """ 

2052 return self._device_code_locations or [] 

2053 

2054 @property 

2055 def _colocation_dict(self): 

2056 """Code locations for colocation context managers active at op creation. 

2057 

2058 This property will return a dictionary for which the keys are nodes with 

2059 which this Operation is colocated, and for which the values are 

2060 traceable_stack.TraceableObject instances. The TraceableObject instances 

2061 record the location of the relevant colocation context manager but have the 

2062 "obj" field set to None to prevent leaking private data. 

2063 

2064 For example, suppose file_a contained these lines: 

2065 

2066 file_a.py: 

2067 14: node_a = tf.constant(3, name='NODE_A') 

2068 15: with tf.compat.v1.colocate_with(node_a): 

2069 16: node_b = tf.constant(4, name='NODE_B') 

2070 

2071 Then a TraceableObject t_obj representing the colocation context manager 

2072 would have these member values: 

2073 

2074 t_obj.obj -> None 

2075 t_obj.filename = 'file_a.py' 

2076 t_obj.lineno = 15 

2077 

2078 and node_b.op._colocation_dict would return the dictionary 

2079 

2080 { 'NODE_A': t_obj } 

2081 

2082 Returns: 

2083 {str: traceable_stack.TraceableObject} as per this method's description, 

2084 above. 

2085 """ 

2086 locations_dict = self._colocation_code_locations or {} 

2087 return locations_dict.copy() 

2088 

2089 @property 

2090 def _output_types(self): 

2091 """List this operation's output types. 

2092 

2093 Returns: 

2094 List of the types of the Tensors computed by this operation. 

2095 Each element in the list is an integer whose value is one of 

2096 the TF_DataType enums defined in pywrap_tf_session.h 

2097 The length of this list indicates the number of output endpoints 

2098 of the operation. 

2099 """ 

2100 num_outputs = pywrap_tf_session.TF_OperationNumOutputs(self._c_op) 

2101 output_types = [ 

2102 int(pywrap_tf_session.TF_OperationOutputType(self._tf_output(i))) 

2103 for i in range(num_outputs) 

2104 ] 

2105 

2106 return output_types 

2107 

2108 def _set_device(self, device): # pylint: disable=redefined-outer-name 

2109 """Set the device of this operation. 

2110 

2111 Args: 

2112 device: string or device.. The device to set. 

2113 """ 

2114 self._set_device_from_string(compat.as_str(_device_string(device))) 

2115 

2116 def _update_input(self, index, tensor): 

2117 """Update the input to this operation at the given index. 

2118 

2119 NOTE: This is for TF internal use only. Please don't use it. 

2120 

2121 Args: 

2122 index: the index of the input to update. 

2123 tensor: the Tensor to be used as the input at the given index. 

2124 

2125 Raises: 

2126 TypeError: if tensor is not a Tensor, 

2127 or if input tensor type is not convertible to dtype. 

2128 ValueError: if the Tensor is from a different graph. 

2129 """ 

2130 if not isinstance(tensor, Tensor): 

2131 raise TypeError("tensor must be a Tensor: %s" % tensor) 

2132 

2133 _assert_same_graph(self, tensor) 

2134 

2135 # Reset cached inputs. 

2136 self._inputs_val = None 

2137 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

2138 pywrap_tf_session.UpdateEdge( 

2139 c_graph, 

2140 tensor._as_tf_output(), # pylint: disable=protected-access 

2141 self._tf_input(index)) 

2142 

2143 def _add_while_inputs(self, tensors): 

2144 """See AddWhileInputHack in python_api.h. 

2145 

2146 NOTE: This is for TF internal use only. Please don't use it. 

2147 

2148 Args: 

2149 tensors: list of Tensors 

2150 

2151 Raises: 

2152 TypeError: if tensor is not a Tensor, 

2153 or if input tensor type is not convertible to dtype. 

2154 ValueError: if the Tensor is from a different graph. 

2155 """ 

2156 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

2157 for tensor in tensors: 

2158 if not isinstance(tensor, Tensor): 

2159 raise TypeError("tensor must be a Tensor: %s" % tensor) 

2160 _assert_same_graph(self, tensor) 

2161 

2162 # Reset cached inputs. 

2163 self._inputs_val = None 

2164 pywrap_tf_session.AddWhileInputHack( 

2165 c_graph, # pylint: disable=protected-access 

2166 tensor._as_tf_output(), # pylint: disable=protected-access 

2167 self._c_op) 

2168 

2169 def __str__(self): 

2170 return str(self.node_def) 

2171 

2172 def __repr__(self): 

2173 return "<tf.Operation '%s' type=%s>" % (self.name, self.type) 

2174 

2175 def __tf_tensor__(self, dtype=None, name=None): 

2176 """Raises a helpful error.""" 

2177 raise TypeError("can't convert Operation '{}' to Tensor".format(self.name)) 

2178 

2179 @property 

2180 def inputs(self): 

2181 """The sequence of `Tensor` objects representing the data inputs of this op.""" 

2182 if self._inputs_val is None: 

2183 # pylint: disable=protected-access 

2184 self._inputs_val = tuple( 

2185 self.graph._get_tensor_by_tf_output(i) 

2186 for i in pywrap_tf_session.GetOperationInputs(self._c_op)) 

2187 # pylint: enable=protected-access 

2188 return self._inputs_val 

2189 

2190 @property 

2191 def _input_types(self): 

2192 num_inputs = pywrap_tf_session.TF_OperationNumInputs(self._c_op) 

2193 input_types = [ 

2194 dtypes.as_dtype( 

2195 pywrap_tf_session.TF_OperationInputType(self._tf_input(i))) 

2196 for i in range(num_inputs) 

2197 ] 

2198 return input_types 

2199 

2200 @property 

2201 def traceback(self): 

2202 """Returns the call stack from when this operation was constructed.""" 

2203 # FIXME(b/225423591): This object contains a dangling reference if _c_op 

2204 # goes out of scope. 

2205 return pywrap_tf_session.TF_OperationGetStackTrace(self._c_op) 

2206 

2207 @property 

2208 def node_def(self): 

2209 return node_def_pb2.NodeDef.FromString(self._node_def) 

2210 

2211 @property 

2212 def op_def(self): 

2213 return op_def_pb2.OpDef.FromString(self._op_def) 

2214 

2215 def _set_attr(self, attr_name, attr_value): 

2216 """Private method used to set an attribute in the node_def.""" 

2217 buf = pywrap_tf_session.TF_NewBufferFromString( 

2218 compat.as_bytes(attr_value.SerializeToString())) 

2219 try: 

2220 self._set_attr_with_buf(attr_name, buf) 

2221 finally: 

2222 pywrap_tf_session.TF_DeleteBuffer(buf) 

2223 

2224 def _set_attr_with_buf(self, attr_name, attr_buf): 

2225 """Set an attr in the node_def with a pre-allocated buffer.""" 

2226 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

2227 # pylint: disable=protected-access 

2228 pywrap_tf_session.SetAttr(c_graph, self._c_op, attr_name, attr_buf) 

2229 # pylint: enable=protected-access 

2230 

2231 def _set_func_attr(self, attr_name, func_name): 

2232 """Private method used to set a function attribute in the node_def.""" 

2233 func = attr_value_pb2.NameAttrList(name=func_name) 

2234 self._set_attr(attr_name, attr_value_pb2.AttrValue(func=func)) 

2235 

2236 def _set_func_list_attr(self, attr_name, func_names): 

2237 """Private method used to set a list(function) attribute in the node_def.""" 

2238 funcs = [attr_value_pb2.NameAttrList(name=func_name) 

2239 for func_name in func_names] 

2240 funcs_list = attr_value_pb2.AttrValue.ListValue(func=funcs) 

2241 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=funcs_list)) 

2242 

2243 def _set_type_list_attr(self, attr_name, types): 

2244 """Private method used to set a list(type) attribute in the node_def.""" 

2245 if not types: 

2246 return 

2247 if isinstance(types[0], dtypes.DType): 

2248 types = [dt.as_datatype_enum for dt in types] 

2249 types_list = attr_value_pb2.AttrValue.ListValue(type=types) 

2250 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=types_list)) 

2251 

2252 def _set_shape_list_attr(self, attr_name, shapes): 

2253 """Private method used to set a list(shape) attribute in the node_def.""" 

2254 shapes = [s.as_proto() for s in shapes] 

2255 shapes_list = attr_value_pb2.AttrValue.ListValue(shape=shapes) 

2256 self._set_attr(attr_name, attr_value_pb2.AttrValue(list=shapes_list)) 

2257 

2258 def _clear_attr(self, attr_name): 

2259 """Private method used to clear an attribute in the node_def.""" 

2260 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

2261 # pylint: disable=protected-access 

2262 pywrap_tf_session.ClearAttr(c_graph, self._c_op, attr_name) 

2263 # pylint: enable=protected-access 

2264 

2265 def get_attr(self, name): 

2266 """Returns the value of the attr of this op with the given `name`. 

2267 

2268 Args: 

2269 name: The name of the attr to fetch. 

2270 

2271 Returns: 

2272 The value of the attr, as a Python object. 

2273 

2274 Raises: 

2275 ValueError: If this op does not have an attr with the given `name`. 

2276 """ 

2277 fields = ("s", "i", "f", "b", "type", "shape", "tensor", "func") 

2278 try: 

2279 with c_api_util.tf_buffer() as buf: 

2280 pywrap_tf_session.TF_OperationGetAttrValueProto(self._c_op, name, buf) 

2281 data = pywrap_tf_session.TF_GetBuffer(buf) 

2282 except errors.InvalidArgumentError as e: 

2283 # Convert to ValueError for backwards compatibility. 

2284 raise ValueError(e.message) 

2285 x = attr_value_pb2.AttrValue() 

2286 x.ParseFromString(data) 

2287 

2288 oneof_value = x.WhichOneof("value") 

2289 if oneof_value is None: 

2290 return [] 

2291 if oneof_value == "list": 

2292 for f in fields: 

2293 if getattr(x.list, f): 

2294 if f == "type": 

2295 return [dtypes.as_dtype(t) for t in x.list.type] 

2296 else: 

2297 return list(getattr(x.list, f)) 

2298 return [] 

2299 if oneof_value == "type": 

2300 return dtypes.as_dtype(x.type) 

2301 assert oneof_value in fields, "Unsupported field type in " + str(x) 

2302 return getattr(x, oneof_value) 

2303 

2304 def _get_attr_type(self, name): 

2305 """Returns the `DType` value of the attr of this op with the given `name`.""" 

2306 try: 

2307 dtype_enum = pywrap_tf_session.TF_OperationGetAttrType(self._c_op, name) 

2308 return _DTYPES_INTERN_TABLE[dtype_enum] 

2309 except errors.InvalidArgumentError as e: 

2310 # Convert to ValueError for backwards compatibility. 

2311 raise ValueError(e.message) 

2312 

2313 def _get_attr_bool(self, name): 

2314 """Returns the `bool` value of the attr of this op with the given `name`.""" 

2315 try: 

2316 return pywrap_tf_session.TF_OperationGetAttrBool(self._c_op, name) 

2317 except errors.InvalidArgumentError as e: 

2318 # Convert to ValueError for backwards compatibility. 

2319 raise ValueError(e.message) 

2320 

2321 def _get_attr_int(self, name): 

2322 """Returns the `int` value of the attr of this op with the given `name`.""" 

2323 try: 

2324 return pywrap_tf_session.TF_OperationGetAttrInt(self._c_op, name) 

2325 except errors.InvalidArgumentError as e: 

2326 # Convert to ValueError for backwards compatibility. 

2327 raise ValueError(e.message) 

2328 

2329 def experimental_set_type(self, type_proto): 

2330 """Sets the corresponding node's `experimental_type` field. 

2331 

2332 See the description of `NodeDef.experimental_type` for more info. 

2333 

2334 Args: 

2335 type_proto: A FullTypeDef proto message. The root type_if of this object 

2336 must be `TFT_PRODUCT`, even for ops which only have a singlre return 

2337 value. 

2338 """ 

2339 with self.graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

2340 if (type_proto.type_id 

2341 not in (full_type_pb2.TFT_UNSET, full_type_pb2.TFT_PRODUCT)): 

2342 raise ValueError("error setting the type of ", self.name, 

2343 ": expected TFT_UNSET or TFT_PRODUCT, got ", 

2344 type_proto.type_id) 

2345 pywrap_tf_session.SetFullType(c_graph, self._c_op, 

2346 type_proto.SerializeToString()) # pylint:disable=protected-access 

2347 

2348 def run(self, feed_dict=None, session=None): 

2349 """Runs this operation in a `Session`. 

2350 

2351 Calling this method will execute all preceding operations that 

2352 produce the inputs needed for this operation. 

2353 

2354 *N.B.* Before invoking `Operation.run()`, its graph must have been 

2355 launched in a session, and either a default session must be 

2356 available, or `session` must be specified explicitly. 

2357 

2358 Args: 

2359 feed_dict: A dictionary that maps `Tensor` objects to feed values. See 

2360 `tf.Session.run` for a description of the valid feed values. 

2361 session: (Optional.) The `Session` to be used to run to this operation. If 

2362 none, the default session will be used. 

2363 """ 

2364 _run_using_default_session(self, feed_dict, self.graph, session) 

2365 

2366 

2367# TODO(b/185395742): Clean up usages of _gradient_registry 

2368gradient_registry = _gradient_registry = registry.Registry("gradient") 

2369 

2370 

2371@tf_export("RegisterGradient") 

2372class RegisterGradient(object): 

2373 """A decorator for registering the gradient function for an op type. 

2374 

2375 This decorator is only used when defining a new op type. For an op 

2376 with `m` inputs and `n` outputs, the gradient function is a function 

2377 that takes the original `Operation` and `n` `Tensor` objects 

2378 (representing the gradients with respect to each output of the op), 

2379 and returns `m` `Tensor` objects (representing the partial gradients 

2380 with respect to each input of the op). 

2381 

2382 For example, assuming that operations of type `"Sub"` take two 

2383 inputs `x` and `y`, and return a single output `x - y`, the 

2384 following gradient function would be registered: 

2385 

2386 ```python 

2387 @tf.RegisterGradient("Sub") 

2388 def _sub_grad(unused_op, grad): 

2389 return grad, tf.negative(grad) 

2390 ``` 

2391 

2392 The decorator argument `op_type` is the string type of an 

2393 operation. This corresponds to the `OpDef.name` field for the proto 

2394 that defines the operation. 

2395 """ 

2396 

2397 __slots__ = ["_op_type"] 

2398 

2399 def __init__(self, op_type): 

2400 """Creates a new decorator with `op_type` as the Operation type. 

2401 

2402 Args: 

2403 op_type: The string type of an operation. This corresponds to the 

2404 `OpDef.name` field for the proto that defines the operation. 

2405 

2406 Raises: 

2407 TypeError: If `op_type` is not string. 

2408 """ 

2409 if not isinstance(op_type, str): 

2410 raise TypeError("op_type must be a string") 

2411 self._op_type = op_type 

2412 

2413 def __call__(self, f): 

2414 """Registers the function `f` as gradient function for `op_type`.""" 

2415 gradient_registry.register(f, self._op_type) 

2416 return f 

2417 

2418 

2419@deprecation.deprecated_endpoints("NotDifferentiable", "NoGradient") 

2420@tf_export("no_gradient", v1=["no_gradient", "NotDifferentiable", "NoGradient"]) 

2421def no_gradient(op_type): 

2422 """Specifies that ops of type `op_type` is not differentiable. 

2423 

2424 This function should *not* be used for operations that have a 

2425 well-defined gradient that is not yet implemented. 

2426 

2427 This function is only used when defining a new op type. It may be 

2428 used for ops such as `tf.size()` that are not differentiable. For 

2429 example: 

2430 

2431 ```python 

2432 tf.no_gradient("Size") 

2433 ``` 

2434 

2435 The gradient computed for 'op_type' will then propagate zeros. 

2436 

2437 For ops that have a well-defined gradient but are not yet implemented, 

2438 no declaration should be made, and an error *must* be thrown if 

2439 an attempt to request its gradient is made. 

2440 

2441 Args: 

2442 op_type: The string type of an operation. This corresponds to the 

2443 `OpDef.name` field for the proto that defines the operation. 

2444 

2445 Raises: 

2446 TypeError: If `op_type` is not a string. 

2447 

2448 """ 

2449 if not isinstance(op_type, str): 

2450 raise TypeError("op_type must be a string") 

2451 gradient_registry.register(None, op_type) 

2452 

2453 

2454# Aliases for the old names, will be eventually removed. 

2455NoGradient = no_gradient 

2456NotDifferentiable = no_gradient 

2457 

2458 

2459def get_gradient_function(op): 

2460 """Returns the function that computes gradients for "op".""" 

2461 if not op.inputs: 

2462 return None 

2463 

2464 gradient_function = op._gradient_function # pylint: disable=protected-access 

2465 if gradient_function: 

2466 return gradient_function 

2467 

2468 try: 

2469 op_type = op.get_attr("_gradient_op_type") 

2470 except ValueError: 

2471 op_type = op.type 

2472 return gradient_registry.lookup(op_type) 

2473 

2474 

2475def set_shape_and_handle_data_for_outputs(_): 

2476 """No op. TODO(b/74620627): Remove this.""" 

2477 pass 

2478 

2479 

2480class OpStats(object): 

2481 """A holder for statistics about an operator. 

2482 

2483 This class holds information about the resource requirements for an op, 

2484 including the size of its weight parameters on-disk and how many FLOPS it 

2485 requires to execute forward inference. 

2486 

2487 If you define a new operation, you can create a function that will return a 

2488 set of information about its usage of the CPU and disk space when serialized. 

2489 The function itself takes a Graph object that's been set up so you can call 

2490 methods like get_tensor_by_name to help calculate the results, and a NodeDef 

2491 argument. 

2492 

2493 """ 

2494 

2495 __slots__ = ["_statistic_type", "_value"] 

2496 

2497 def __init__(self, statistic_type, value=None): 

2498 """Sets up the initial placeholders for the statistics.""" 

2499 self.statistic_type = statistic_type 

2500 self.value = value 

2501 

2502 @property 

2503 def statistic_type(self): 

2504 return self._statistic_type 

2505 

2506 @statistic_type.setter 

2507 def statistic_type(self, statistic_type): 

2508 self._statistic_type = statistic_type 

2509 

2510 @property 

2511 def value(self): 

2512 return self._value 

2513 

2514 @value.setter 

2515 def value(self, value): 

2516 self._value = value 

2517 

2518 def __iadd__(self, other): 

2519 if other.statistic_type != self.statistic_type: 

2520 raise ValueError("Can't add an OpStat of type %s to one of %s." % 

2521 (self.statistic_type, other.statistic_type)) 

2522 if self.value is None: 

2523 self.value = other.value 

2524 elif other.value is not None: 

2525 self._value += other.value 

2526 return self 

2527 

2528 

2529_stats_registry = registry.Registry("statistical functions") 

2530 

2531 

2532class RegisterStatistics(object): 

2533 """A decorator for registering the statistics function for an op type. 

2534 

2535 This decorator can be defined for an op type so that it gives a 

2536 report on the resources used by an instance of an operator, in the 

2537 form of an OpStats object. 

2538 

2539 Well-known types of statistics include these so far: 

2540 

2541 - flops: When running a graph, the bulk of the computation happens doing 

2542 numerical calculations like matrix multiplications. This type allows a node 

2543 to return how many floating-point operations it takes to complete. The 

2544 total number of FLOPs for a graph is a good guide to its expected latency. 

2545 

2546 You can add your own statistics just by picking a new type string, registering 

2547 functions for the ops you care about, and then calling get_stats_for_node_def. 

2548 

2549 If a statistic for an op is registered multiple times, a KeyError will be 

2550 raised. 

2551 

2552 Since the statistics is counted on a per-op basis. It is not suitable for 

2553 model parameters (capacity), which is expected to be counted only once, even 

2554 if it is shared by multiple ops. (e.g. RNN) 

2555 

2556 For example, you can define a new metric called doohickey for a Foo operation 

2557 by placing this in your code: 

2558 

2559 ```python 

2560 @ops.RegisterStatistics("Foo", "doohickey") 

2561 def _calc_foo_bojangles(unused_graph, unused_node_def): 

2562 return ops.OpStats("doohickey", 20) 

2563 ``` 

2564 

2565 Then in client code you can retrieve the value by making this call: 

2566 

2567 ```python 

2568 doohickey = ops.get_stats_for_node_def(graph, node_def, "doohickey") 

2569 ``` 

2570 

2571 If the NodeDef is for an op with a registered doohickey function, you'll get 

2572 back the calculated amount in doohickey.value, or None if it's not defined. 

2573 

2574 """ 

2575 

2576 __slots__ = ["_op_type", "_statistic_type"] 

2577 

2578 def __init__(self, op_type, statistic_type): 

2579 """Saves the `op_type` as the `Operation` type.""" 

2580 if not isinstance(op_type, str): 

2581 raise TypeError("op_type must be a string.") 

2582 if "," in op_type: 

2583 raise TypeError("op_type must not contain a comma.") 

2584 self._op_type = op_type 

2585 if not isinstance(statistic_type, str): 

2586 raise TypeError("statistic_type must be a string.") 

2587 if "," in statistic_type: 

2588 raise TypeError("statistic_type must not contain a comma.") 

2589 self._statistic_type = statistic_type 

2590 

2591 def __call__(self, f): 

2592 """Registers "f" as the statistics function for "op_type".""" 

2593 _stats_registry.register(f, self._op_type + "," + self._statistic_type) 

2594 return f 

2595 

2596 

2597def get_stats_for_node_def(graph, node, statistic_type): 

2598 """Looks up the node's statistics function in the registry and calls it. 

2599 

2600 This function takes a Graph object and a NodeDef from a GraphDef, and if 

2601 there's an associated statistics method, calls it and returns a result. If no 

2602 function has been registered for the particular node type, it returns an empty 

2603 statistics object. 

2604 

2605 Args: 

2606 graph: A Graph object that's been set up with the node's graph. 

2607 node: A NodeDef describing the operator. 

2608 statistic_type: A string identifying the statistic we're interested in. 

2609 

2610 Returns: 

2611 An OpStats object containing information about resource usage. 

2612 """ 

2613 

2614 try: 

2615 stats_func = _stats_registry.lookup(node.op + "," + statistic_type) 

2616 result = stats_func(graph, node) 

2617 except LookupError: 

2618 result = OpStats(statistic_type) 

2619 return result 

2620 

2621 

2622def name_from_scope_name(name): 

2623 """Returns the name of an op given the name of its scope. 

2624 

2625 Args: 

2626 name: the name of the scope. 

2627 

2628 Returns: 

2629 the name of the op (equal to scope name minus any trailing slash). 

2630 """ 

2631 return name[:-1] if (name and name[-1] == "/") else name 

2632 

2633 

2634_MUTATION_LOCK_GROUP = 0 

2635_SESSION_RUN_LOCK_GROUP = 1 

2636 

2637 

2638@tf_contextlib.contextmanager 

2639def resource_creator_scope(resource_type, resource_creator): 

2640 with get_default_graph()._resource_creator_scope(resource_type, # pylint: disable=protected-access 

2641 resource_creator): 

2642 yield 

2643 

2644 

2645@tf_export("Graph") 

2646class Graph(pywrap_tf_session.PyGraph): 

2647 """A TensorFlow computation, represented as a dataflow graph. 

2648 

2649 Graphs are used by `tf.function`s to represent the function's computations. 

2650 Each graph contains a set of `tf.Operation` objects, which represent units of 

2651 computation; and `tf.Tensor` objects, which represent the units of data that 

2652 flow between operations. 

2653 

2654 ### Using graphs directly (deprecated) 

2655 

2656 A `tf.Graph` can be constructed and used directly without a `tf.function`, as 

2657 was required in TensorFlow 1, but this is deprecated and it is recommended to 

2658 use a `tf.function` instead. If a graph is directly used, other deprecated 

2659 TensorFlow 1 classes are also required to execute the graph, such as a 

2660 `tf.compat.v1.Session`. 

2661 

2662 A default graph can be registered with the `tf.Graph.as_default` context 

2663 manager. Then, operations will be added to the graph instead of being executed 

2664 eagerly. For example: 

2665 

2666 ```python 

2667 g = tf.Graph() 

2668 with g.as_default(): 

2669 # Define operations and tensors in `g`. 

2670 c = tf.constant(30.0) 

2671 assert c.graph is g 

2672 ``` 

2673 

2674 `tf.compat.v1.get_default_graph()` can be used to obtain the default graph. 

2675 

2676 Important note: This class *is not* thread-safe for graph construction. All 

2677 operations should be created from a single thread, or external 

2678 synchronization must be provided. Unless otherwise specified, all methods 

2679 are not thread-safe. 

2680 

2681 A `Graph` instance supports an arbitrary number of "collections" 

2682 that are identified by name. For convenience when building a large 

2683 graph, collections can store groups of related objects: for 

2684 example, the `tf.Variable` uses a collection (named 

2685 `tf.GraphKeys.GLOBAL_VARIABLES`) for 

2686 all variables that are created during the construction of a graph. The caller 

2687 may define additional collections by specifying a new name. 

2688 """ 

2689 

2690 def __init__(self): 

2691 """Creates a new, empty Graph.""" 

2692 super().__init__() 

2693 # Protects core state that can be returned via public accessors. 

2694 # Thread-safety is provided on a best-effort basis to support buggy 

2695 # programs, and is not guaranteed by the public `tf.Graph` API. 

2696 # 

2697 # NOTE(mrry): This does not protect the various stacks. A warning will 

2698 # be reported if these are used from multiple threads 

2699 self._lock = threading.RLock() 

2700 # The group lock synchronizes Session.run calls with methods that create 

2701 # and mutate ops (e.g. Graph.create_op()). This synchronization is 

2702 # necessary because it's illegal to modify an operation after it's been run. 

2703 # The group lock allows any number of threads to mutate ops at the same time 

2704 # but if any modification is going on, all Session.run calls have to wait. 

2705 # Similarly, if one or more Session.run calls are going on, all mutate ops 

2706 # have to wait until all Session.run calls have finished. 

2707 self._group_lock = lock_util.GroupLock(num_groups=2) 

2708 # Maps a name used in the graph to the next id to use for that name. 

2709 self._names_in_use = {} 

2710 self._stack_state_is_thread_local = False 

2711 self._thread_local = threading.local() 

2712 # Functions that will be applied to choose a device if none is specified. 

2713 # In TF2.x or after switch_to_thread_local(), 

2714 # self._thread_local._device_function_stack is used instead. 

2715 self._graph_device_function_stack = traceable_stack.TraceableStack() 

2716 # Default original_op applied to new ops. 

2717 self._default_original_op = None 

2718 # Current control flow context. It could be either CondContext or 

2719 # WhileContext defined in ops/control_flow_ops.py 

2720 self._control_flow_context = None 

2721 # A new node will depend of the union of all of the nodes in the stack. 

2722 # In TF2.x or after switch_to_thread_local(), 

2723 # self._thread_local._control_dependencies_stack is used instead. 

2724 self._graph_control_dependencies_stack = [] 

2725 # Arbitrary collections of objects. 

2726 self._collections = {} 

2727 # The graph-level random seed 

2728 self._seed = None 

2729 # A dictionary of attributes that should be applied to all ops. 

2730 self._attr_scope_map = {} 

2731 # A map from op type to the kernel label that should be used. 

2732 self._op_to_kernel_label_map = {} 

2733 # A map from op type to an alternative op type that should be used when 

2734 # computing gradients. 

2735 self._gradient_override_map = {} 

2736 # A map from op type to a gradient function that should be used instead. 

2737 self._gradient_function_map = {} 

2738 # True if the graph is considered "finalized". In that case no 

2739 # new operations can be added. 

2740 self._finalized = False 

2741 # Functions defined in the graph 

2742 self._functions = collections.OrderedDict() 

2743 # Default GraphDef versions 

2744 self._graph_def_versions = versions_pb2.VersionDef( 

2745 producer=versions.GRAPH_DEF_VERSION, 

2746 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) 

2747 self._building_function = False 

2748 # Stack of colocate_with ops. In TF2.x or after switch_to_thread_local(), 

2749 # self._thread_local._colocation_stack is used instead. 

2750 self._graph_colocation_stack = traceable_stack.TraceableStack() 

2751 # Set of tensors that are dangerous to feed! 

2752 self._unfeedable_tensors = object_identity.ObjectIdentitySet() 

2753 # Set of operations that are dangerous to fetch! 

2754 self._unfetchable_ops = set() 

2755 # A map of tensor handle placeholder to tensor dtype. 

2756 self._handle_feeders = {} 

2757 # A map from tensor handle to its read op. 

2758 self._handle_readers = {} 

2759 # A map from tensor handle to its move op. 

2760 self._handle_movers = {} 

2761 # A map from tensor handle to its delete op. 

2762 self._handle_deleters = {} 

2763 # Allow optimizers and other objects to pseudo-uniquely key graphs (this key 

2764 # will be shared when defining function graphs, for example, so optimizers 

2765 # being called inside function definitions behave as if they were seeing the 

2766 # actual outside graph). 

2767 self._graph_key = "graph-key-%d/" % (uid(),) 

2768 # A string with the last reduction method passed to 

2769 # losses.compute_weighted_loss(), or None. This is required only for 

2770 # backward compatibility with Estimator and optimizer V1 use cases. 

2771 self._last_loss_reduction = None 

2772 # Flag that is used to indicate whether loss has been scaled by optimizer. 

2773 # If this flag has been set, then estimator uses it to scale losss back 

2774 # before reporting. This is required only for backward compatibility with 

2775 # Estimator and optimizer V1 use cases. 

2776 self._is_loss_scaled_by_optimizer = False 

2777 self._container = "" 

2778 

2779 # The current AutomaticControlDependencies context manager. 

2780 self.experimental_acd_manager = None 

2781 # Set to True if this graph is being built in an 

2782 # AutomaticControlDependencies context. 

2783 # Deprecated: use acd_manager instead. 

2784 self._add_control_dependencies = False 

2785 

2786 # Cache for OpDef protobufs retrieved via the C API. 

2787 self._op_def_cache = {} 

2788 # Cache for constant results of `broadcast_gradient_args()`. The keys are 

2789 # tuples of fully-defined shapes: (x_shape_tuple, y_shape_tuple), and the 

2790 # values are tuples of reduction indices: (rx, ry). 

2791 self._bcast_grad_args_cache = {} 

2792 # Cache for constant results of `reduced_shape()`. The keys are pairs of 

2793 # tuples: (input_shape_tuple, reduction_indices_tuple), and the values 

2794 # are pairs of tuples: (output_shape_kept_dims, tile_scaling). 

2795 self._reduced_shape_cache = {} 

2796 

2797 if tf2.enabled(): 

2798 self.switch_to_thread_local() 

2799 

2800 # `Graph` now _is_ the C graph, but we have many places that manually attempt 

2801 # to manipulate the _c_graph object. Leave these accessors here until these 

2802 # are cleaned up. 

2803 @property 

2804 def _c_graph(self): 

2805 return self 

2806 

2807 def __enter__(self): 

2808 return self 

2809 

2810 def __exit__(self, *args): 

2811 return 

2812 

2813 def get(self): 

2814 return self 

2815 

2816 # Note: this method is private because the API of tf.Graph() is public and 

2817 # frozen, and this functionality is still not ready for public visibility. 

2818 @tf_contextlib.contextmanager 

2819 def _variable_creator_scope(self, creator, priority=100): 

2820 """Scope which defines a variable creation function. 

2821 

2822 Args: 

2823 creator: A callable taking `next_creator` and `kwargs`. See the 

2824 `tf.variable_creator_scope` docstring. 

2825 priority: Creators with a higher `priority` are called first. Within the 

2826 same priority, creators are called inner-to-outer. 

2827 

2828 Yields: 

2829 `_variable_creator_scope` is a context manager with a side effect, but 

2830 doesn't return a value. 

2831 

2832 Raises: 

2833 RuntimeError: If variable creator scopes are not properly nested. 

2834 """ 

2835 # This step keeps a reference to the existing stack, and it also initializes 

2836 # self._thread_local._variable_creator_stack if it doesn't exist yet. 

2837 old = self._variable_creator_stack 

2838 new = list(old) 

2839 new.append((priority, creator)) 

2840 # Sorting is stable, so we'll put higher-priority creators later in the list 

2841 # but otherwise maintain registration order. 

2842 new.sort(key=lambda item: item[0]) 

2843 self._thread_local._variable_creator_stack = new # pylint: disable=protected-access 

2844 try: 

2845 yield 

2846 finally: 

2847 if self._thread_local._variable_creator_stack is not new: # pylint: disable=protected-access 

2848 raise RuntimeError( 

2849 "Exiting variable_creator_scope without proper nesting.") 

2850 self._thread_local._variable_creator_stack = old # pylint: disable=protected-access 

2851 

2852 # TODO(b/192405401): unify resource_creator_scope with variable_creator_scope. 

2853 # pylint: disable=protected-access 

2854 @tf_contextlib.contextmanager 

2855 def _resource_creator_scope(self, resource_type, creator): 

2856 """Scope which defines a resource creation function used by some resource. 

2857 

2858 The resource should be a subclass of CapturableResource with a class method 

2859 `cls._resource_type`, the output of which is what the `resource_type` 

2860 argument should be. By default, `cls._resource_type` returns the class name, 

2861 `cls.__name__`. Given a scope, creators being added with the same 

2862 `resource_type` argument will be composed together to apply to all classes 

2863 with this `_resource_type`. 

2864 

2865 

2866 `creator` is expected to be a function with the following signature: 

2867 

2868 ``` 

2869 def resource_creator(next_creator, *a, **kwargs) 

2870 ``` 

2871 

2872 The creator is supposed to eventually call the next_creator to create an 

2873 instance if it does want to create an instance and not call 

2874 the class initialization method directly. This helps make creators 

2875 composable. A creator may choose to create multiple instances, return 

2876 already existing instances, or simply register that an instance was created 

2877 and defer to the next creator in line. Creators can also modify keyword 

2878 arguments seen by the next creators. 

2879 

2880 Valid keyword arguments in `kwargs` depends on the specific resource 

2881 class. For StaticHashTable, this may be: 

2882 * initializer: The table initializer to use. 

2883 * default_value: The value to use if a key is missing in the table. 

2884 * name: Optional name for the table, default to None. 

2885 

2886 

2887 Args: 

2888 resource_type: the output of the resource class's `_resource_type` method. 

2889 creator: the passed creator for the resource. 

2890 

2891 Yields: 

2892 A scope in which the creator is active 

2893 

2894 Raises: 

2895 RuntimeError: If resource_creator_scope is existed without proper nesting. 

2896 """ 

2897 # This step keeps a reference to the existing stack, and it also initializes 

2898 # self._thread_local._variable_creator_stack if it doesn't exist yet. 

2899 old = self._resource_creator_stack 

2900 new = copy.deepcopy(old) 

2901 if isinstance(resource_type, (list, tuple)): 

2902 for r in resource_type: 

2903 new[r].append(creator) 

2904 else: 

2905 new[resource_type].append(creator) 

2906 self._thread_local._resource_creator_stack = new 

2907 try: 

2908 yield 

2909 finally: 

2910 if self._thread_local._resource_creator_stack is not new: 

2911 raise RuntimeError( 

2912 "Exiting resource_creator_scope without proper nesting.") 

2913 self._thread_local._resource_creator_stack = old 

2914 

2915 @property 

2916 def _resource_creator_stack(self): 

2917 if not hasattr(self._thread_local, "_resource_creator_stack"): 

2918 self._thread_local._resource_creator_stack = collections.defaultdict(list) 

2919 return self._thread_local._resource_creator_stack 

2920 

2921 @_resource_creator_stack.setter 

2922 def _resource_creator_stack(self, resource_creator_stack): 

2923 self._thread_local._resource_creator_stack = resource_creator_stack 

2924 # pylint: enable=protected-access 

2925 

2926 # Note: this method is private because the API of tf.Graph() is public and 

2927 # frozen, and this functionality is still not ready for public visibility. 

2928 @property 

2929 def _variable_creator_stack(self): 

2930 if not hasattr(self._thread_local, "_variable_creator_stack"): 

2931 self._thread_local._variable_creator_stack = [] # pylint: disable=protected-access 

2932 

2933 # This previously returned a copy of the stack instead of the stack itself, 

2934 # to guard against accidental mutation. Consider, however, code that wants 

2935 # to save and restore the variable creator stack: 

2936 # def f(): 

2937 # original_stack = graph._variable_creator_stack 

2938 # graph._variable_creator_stack = new_stack 

2939 # ... # Some code 

2940 # graph._variable_creator_stack = original_stack 

2941 # 

2942 # And lets say you have some code that calls this function with some 

2943 # variable_creator: 

2944 # def g(): 

2945 # with variable_scope.variable_creator_scope(creator): 

2946 # f() 

2947 # When exiting the variable creator scope, it would see a different stack 

2948 # object than it expected leading to a "Exiting variable_creator_scope 

2949 # without proper nesting" error. 

2950 return self._thread_local._variable_creator_stack # pylint: disable=protected-access 

2951 

2952 @_variable_creator_stack.setter 

2953 def _variable_creator_stack(self, variable_creator_stack): 

2954 self._thread_local._variable_creator_stack = variable_creator_stack # pylint: disable=protected-access 

2955 

2956 def _check_not_finalized(self): 

2957 """Check if the graph is finalized. 

2958 

2959 Raises: 

2960 RuntimeError: If the graph finalized. 

2961 """ 

2962 if self._finalized: 

2963 raise RuntimeError("Graph is finalized and cannot be modified.") 

2964 

2965 @property 

2966 def graph_def_versions(self): 

2967 # pylint: disable=line-too-long 

2968 """The GraphDef version information of this graph. 

2969 

2970 For details on the meaning of each version, see 

2971 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto). 

2972 

2973 Returns: 

2974 A `VersionDef`. 

2975 """ 

2976 return versions_pb2.VersionDef.FromString(self._version_def) 

2977 

2978 @property 

2979 def seed(self): 

2980 """The graph-level random seed of this graph.""" 

2981 return self._seed 

2982 

2983 @seed.setter 

2984 def seed(self, seed): 

2985 self._seed = seed 

2986 

2987 @property 

2988 def finalized(self): 

2989 """True if this graph has been finalized.""" 

2990 return self._finalized 

2991 

2992 def finalize(self): 

2993 """Finalizes this graph, making it read-only. 

2994 

2995 After calling `g.finalize()`, no new operations can be added to 

2996 `g`. This method is used to ensure that no operations are added 

2997 to a graph when it is shared between multiple threads, for example 

2998 when using a `tf.compat.v1.train.QueueRunner`. 

2999 """ 

3000 self._finalized = True 

3001 

3002 def _unsafe_unfinalize(self): 

3003 """Opposite of `finalize`. 

3004 

3005 Internal interface. 

3006 

3007 NOTE: Unfinalizing a graph could have negative impact on performance, 

3008 especially in a multi-threaded environment. Unfinalizing a graph 

3009 when it is in use by a Session may lead to undefined behavior. Ensure 

3010 that all sessions using a graph are closed before calling this method. 

3011 """ 

3012 self._finalized = False 

3013 

3014 def _get_control_flow_context(self): 

3015 """Returns the current control flow context. 

3016 

3017 Returns: 

3018 A context object. 

3019 """ 

3020 return self._control_flow_context 

3021 

3022 def _set_control_flow_context(self, ctx): 

3023 """Sets the current control flow context. 

3024 

3025 Args: 

3026 ctx: a context object. 

3027 """ 

3028 self._control_flow_context = ctx 

3029 

3030 def _copy_functions_to_graph_def(self, graph_def, starting_bytesize): 

3031 """If this graph contains functions, copy them to `graph_def`.""" 

3032 bytesize = starting_bytesize 

3033 for f in self._functions.values(): 

3034 bytesize += f.cached_definition.ByteSize() 

3035 if bytesize >= (1 << 31) or bytesize < 0: 

3036 raise ValueError("GraphDef cannot be larger than 2GB.") 

3037 graph_def.library.function.extend([f.cached_definition]) 

3038 if getattr(f, "grad_func_name", None): 

3039 grad_def = function_pb2.GradientDef() 

3040 grad_def.function_name = f.name 

3041 grad_def.gradient_func = f.grad_func_name 

3042 graph_def.library.gradient.extend([grad_def]) 

3043 

3044 def _as_graph_def(self, from_version=None, add_shapes=False): 

3045 # pylint: disable=line-too-long 

3046 """Returns a serialized `GraphDef` representation of this graph. 

3047 

3048 The serialized `GraphDef` can be imported into another `Graph` 

3049 (using `tf.import_graph_def`) or used with the 

3050 [C++ Session API](https://chromium.googlesource.com/external/github.com/tensorflow/tensorflow/+/r0.10/tensorflow/g3doc/api_docs/cc/index.md). 

3051 

3052 This method is thread-safe. 

3053 

3054 Args: 

3055 from_version: Optional. If this is set, returns a `GraphDef` containing 

3056 only the nodes that were added to this graph since its `version` 

3057 property had the given value. 

3058 add_shapes: If true, adds an "_output_shapes" list attr to each node with 

3059 the inferred shapes of each of its outputs. 

3060 

3061 Returns: 

3062 A tuple containing a 

3063 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 

3064 protocol buffer, and the version of the graph to which that 

3065 `GraphDef` corresponds. 

3066 

3067 Raises: 

3068 ValueError: If the `graph_def` would be too large. 

3069 

3070 """ 

3071 # pylint: enable=line-too-long 

3072 with self._lock: 

3073 with c_api_util.tf_buffer() as buf: 

3074 with self._c_graph.get() as c_graph: 

3075 pywrap_tf_session.TF_GraphToGraphDef(c_graph, buf) 

3076 data = pywrap_tf_session.TF_GetBuffer(buf) 

3077 graph = graph_pb2.GraphDef() 

3078 graph.ParseFromString(compat.as_bytes(data)) 

3079 # Strip the experimental library field iff it's empty. 

3080 if not graph.library.function: 

3081 graph.ClearField("library") 

3082 

3083 if add_shapes: 

3084 for node in graph.node: 

3085 op = self._get_operation_by_name(node.name) 

3086 if op.outputs: 

3087 node.attr["_output_shapes"].list.shape.extend( 

3088 [output.get_shape().as_proto() for output in op.outputs]) 

3089 for function_def in graph.library.function: 

3090 defined_function = self._functions[function_def.signature.name] 

3091 try: 

3092 func_graph = defined_function.graph 

3093 except AttributeError: 

3094 # _DefinedFunction doesn't have a graph, _EagerDefinedFunction 

3095 # does. Both rely on ops.py, so we can't really isinstance check 

3096 # them. 

3097 continue 

3098 input_shapes = function_def.attr["_input_shapes"] 

3099 try: 

3100 func_graph_inputs = func_graph.inputs 

3101 except AttributeError: 

3102 continue 

3103 # TODO(b/141471245): Fix the inconsistency when inputs of func graph 

3104 # are appended during gradient computation of while/cond. 

3105 assert len(input_shapes.list.shape) in [0, len(func_graph_inputs)] 

3106 # If the function_def has inputs already filled out, skip this step. 

3107 if not input_shapes.list.shape: 

3108 for input_tensor, arg_def in zip(func_graph_inputs, 

3109 function_def.signature.input_arg): 

3110 input_shapes.list.shape.add().CopyFrom( 

3111 input_tensor.get_shape().as_proto()) 

3112 if input_tensor.dtype == dtypes.resource: 

3113 _copy_handle_data_to_arg_def(input_tensor, arg_def) 

3114 

3115 for output_tensor, arg_def in zip(func_graph.outputs, 

3116 function_def.signature.output_arg): 

3117 if output_tensor.dtype == dtypes.resource: 

3118 _copy_handle_data_to_arg_def(output_tensor, arg_def) 

3119 

3120 for node in function_def.node_def: 

3121 try: 

3122 op = func_graph.get_operation_by_name(node.name) 

3123 except KeyError: 

3124 continue 

3125 outputs = op.outputs 

3126 

3127 if op.type == "StatefulPartitionedCall": 

3128 # Filter out any extra outputs (possibly added by function 

3129 # backpropagation rewriting). 

3130 num_outputs = len(node.attr["Tout"].list.type) 

3131 outputs = outputs[:num_outputs] 

3132 

3133 node.attr["_output_shapes"].list.shape.extend( 

3134 [output.get_shape().as_proto() for output in outputs]) 

3135 

3136 return graph, self.version 

3137 

3138 def as_graph_def(self, from_version=None, add_shapes=False): 

3139 # pylint: disable=line-too-long 

3140 """Returns a serialized `GraphDef` representation of this graph. 

3141 

3142 The serialized `GraphDef` can be imported into another `Graph` 

3143 (using `tf.import_graph_def`) or used with the 

3144 [C++ Session API](../../api_docs/cc/index.md). 

3145 

3146 This method is thread-safe. 

3147 

3148 Args: 

3149 from_version: Optional. If this is set, returns a `GraphDef` containing 

3150 only the nodes that were added to this graph since its `version` 

3151 property had the given value. 

3152 add_shapes: If true, adds an "_output_shapes" list attr to each node with 

3153 the inferred shapes of each of its outputs. 

3154 

3155 Returns: 

3156 A 

3157 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 

3158 protocol buffer. 

3159 

3160 Raises: 

3161 ValueError: If the `graph_def` would be too large. 

3162 """ 

3163 # pylint: enable=line-too-long 

3164 result, _ = self._as_graph_def(from_version, add_shapes) 

3165 return result 

3166 

3167 def _is_function(self, name): 

3168 """Tests whether 'name' is registered in this graph's function library. 

3169 

3170 Args: 

3171 name: string op name. 

3172 

3173 Returns: 

3174 bool indicating whether or not 'name' is registered in function library. 

3175 """ 

3176 return compat.as_str(name) in self._functions 

3177 

3178 def _get_function(self, name): 

3179 """Returns the function definition for 'name'. 

3180 

3181 Args: 

3182 name: string function name. 

3183 

3184 Returns: 

3185 The function def proto. 

3186 """ 

3187 return self._functions.get(compat.as_str(name), None) 

3188 

3189 def _add_function_recursive(self, function, overwrite=False): 

3190 """Adds function to the graph including other functions in its graph.""" 

3191 

3192 if self._is_function(function.name): 

3193 if overwrite: 

3194 self._remove_function(function.name) 

3195 self._add_function(function) 

3196 else: 

3197 self._add_function(function) 

3198 

3199 if hasattr(function, "graph"): 

3200 for f in function.graph._functions.values(): # pylint: disable=protected-access 

3201 if self._is_function(f.name): 

3202 if overwrite: 

3203 self._remove_function(f.name) 

3204 self._add_function(f) 

3205 else: 

3206 self._add_function(f) 

3207 

3208 def _add_function(self, function): 

3209 """Adds a function to the graph. 

3210 

3211 After the function has been added, you can call to the function by 

3212 passing the function name in place of an op name to 

3213 `Graph.create_op()`. 

3214 

3215 Args: 

3216 function: A `_DefinedFunction` object. 

3217 

3218 Raises: 

3219 ValueError: if another function is defined with the same name. 

3220 """ 

3221 self._check_not_finalized() 

3222 

3223 name = function.name 

3224 # Sanity checks on gradient definition for deprecated _DefinedFunction. 

3225 if getattr(function, "grad_func_name", None) and getattr( 

3226 function, "python_grad_func", None 

3227 ): 

3228 raise ValueError("Gradient defined twice for function %s" % name) 

3229 

3230 # Add function to graph 

3231 # pylint: disable=protected-access 

3232 with self._c_graph.get() as c_graph: 

3233 with function._c_func.get() as func: 

3234 if getattr(function, "_grad_func", None): 

3235 # For deprecated _DefinedFunction. 

3236 with function._grad_func._c_func.get() as gradient: 

3237 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, gradient) 

3238 else: 

3239 pywrap_tf_session.TF_GraphCopyFunction(c_graph, func, None) 

3240 # pylint: enable=protected-access 

3241 

3242 self._functions[compat.as_str(name)] = function 

3243 

3244 # Need a new-enough consumer to support the functions we add to the graph. 

3245 if self._graph_def_versions.min_consumer < 12: 

3246 self._graph_def_versions.min_consumer = 12 

3247 

3248 def _remove_function(self, name): 

3249 self._check_not_finalized() 

3250 if not self._is_function(name): 

3251 raise ValueError(f"Function {name!r} is not found in {self!r}.") 

3252 

3253 with self._c_graph.get() as c_graph: 

3254 pywrap_tf_session.TF_GraphRemoveFunction(c_graph, compat.as_bytes(name)) 

3255 del self._functions[compat.as_str(name)] 

3256 

3257 @property 

3258 def building_function(self): 

3259 """Returns True iff this graph represents a function.""" 

3260 return self._building_function 

3261 

3262 # Helper functions to create operations. 

3263 @deprecated_args(None, 

3264 "Shapes are always computed; don't use the compute_shapes " 

3265 "as it has no effect.", "compute_shapes") 

3266 @traceback_utils.filter_traceback 

3267 def create_op( 

3268 self, 

3269 op_type, 

3270 inputs, 

3271 dtypes=None, # pylint: disable=redefined-outer-name 

3272 input_types=None, 

3273 name=None, 

3274 attrs=None, 

3275 op_def=None, 

3276 compute_shapes=True, 

3277 compute_device=True): 

3278 """Creates an `Operation` in this graph. 

3279 

3280 This is a low-level interface for creating an `Operation`. Most 

3281 programs will not call this method directly, and instead use the 

3282 Python op constructors, such as `tf.constant()`, which add ops to 

3283 the default graph. 

3284 

3285 Args: 

3286 op_type: The `Operation` type to create. This corresponds to the 

3287 `OpDef.name` field for the proto that defines the operation. 

3288 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 

3289 dtypes: (Optional) A list of `DType` objects that will be the types of the 

3290 tensors that the operation produces. 

3291 input_types: (Optional.) A list of `DType`s that will be the types of the 

3292 tensors that the operation consumes. By default, uses the base `DType` 

3293 of each input in `inputs`. Operations that expect reference-typed inputs 

3294 must specify `input_types` explicitly. 

3295 name: (Optional.) A string name for the operation. If not specified, a 

3296 name is generated based on `op_type`. 

3297 attrs: (Optional.) A dictionary where the key is the attribute name (a 

3298 string) and the value is the respective `attr` attribute of the 

3299 `NodeDef` proto that will represent the operation (an `AttrValue` 

3300 proto). 

3301 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 

3302 the operation will have. 

3303 compute_shapes: (Optional.) Deprecated. Has no effect (shapes are always 

3304 computed). 

3305 compute_device: (Optional.) If True, device functions will be executed to 

3306 compute the device property of the Operation. 

3307 

3308 Raises: 

3309 TypeError: if any of the inputs is not a `Tensor`. 

3310 ValueError: if colocation conflicts with existing device assignment. 

3311 

3312 Returns: 

3313 An `Operation` object. 

3314 """ 

3315 del compute_shapes 

3316 for idx, a in enumerate(inputs): 

3317 if not isinstance(a, Tensor): 

3318 raise TypeError("Input #%d is not a tensor: %s" % (idx, a)) 

3319 return self._create_op_internal(op_type, inputs, dtypes, input_types, name, 

3320 attrs, op_def, compute_device) 

3321 

3322 def _create_op_internal( 

3323 self, 

3324 op_type, 

3325 inputs, 

3326 dtypes=None, # pylint: disable=redefined-outer-name 

3327 input_types=None, 

3328 name=None, 

3329 attrs=None, 

3330 op_def=None, 

3331 compute_device=True): 

3332 """Creates an `Operation` in this graph. 

3333 

3334 Implements `Graph.create_op()` without the overhead of the deprecation 

3335 wrapper. 

3336 

3337 Args: 

3338 op_type: The `Operation` type to create. This corresponds to the 

3339 `OpDef.name` field for the proto that defines the operation. 

3340 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 

3341 dtypes: (Optional) A list of `DType` objects that will be the types of the 

3342 tensors that the operation produces. 

3343 input_types: (Optional.) A list of `DType`s that will be the types of the 

3344 tensors that the operation consumes. By default, uses the base `DType` 

3345 of each input in `inputs`. Operations that expect reference-typed inputs 

3346 must specify `input_types` explicitly. 

3347 name: (Optional.) A string name for the operation. If not specified, a 

3348 name is generated based on `op_type`. 

3349 attrs: (Optional.) A dictionary where the key is the attribute name (a 

3350 string) and the value is the respective `attr` attribute of the 

3351 `NodeDef` proto that will represent the operation (an `AttrValue` 

3352 proto). 

3353 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 

3354 the operation will have. 

3355 compute_device: (Optional.) If True, device functions will be executed to 

3356 compute the device property of the Operation. 

3357 

3358 Raises: 

3359 ValueError: if colocation conflicts with existing device assignment. 

3360 

3361 Returns: 

3362 An `Operation` object. 

3363 """ 

3364 self._check_not_finalized() 

3365 if name is None: 

3366 name = op_type 

3367 # If a names ends with a '/' it is a "name scope" and we use it as-is, 

3368 # after removing the trailing '/'. 

3369 if name and name[-1] == "/": 

3370 name = name_from_scope_name(name) 

3371 else: 

3372 name = self.unique_name(name) 

3373 

3374 node_def = _NodeDef(op_type, name, attrs) 

3375 

3376 input_ops = set(t.op for t in inputs) 

3377 control_inputs = self._control_dependencies_for_inputs(input_ops) 

3378 # _create_op_helper mutates the new Operation. `_mutation_lock` ensures a 

3379 # Session.run call cannot occur between creating and mutating the op. 

3380 with self._mutation_lock(): 

3381 ret = Operation.from_node_def( 

3382 node_def, 

3383 self, 

3384 inputs=inputs, 

3385 output_types=dtypes, 

3386 control_inputs=control_inputs, 

3387 input_types=input_types, 

3388 original_op=self._default_original_op, 

3389 op_def=op_def, 

3390 ) 

3391 self._create_op_helper(ret, compute_device=compute_device) 

3392 return ret 

3393 

3394 def _create_op_from_tf_operation(self, c_op, compute_device=True): 

3395 """Creates an `Operation` in this graph from the supplied TF_Operation. 

3396 

3397 This method is like create_op() except the new Operation is constructed 

3398 using `c_op`. The returned Operation will have `c_op` as its _c_op 

3399 field. This is used to create Operation objects around TF_Operations created 

3400 indirectly by the C API (e.g. by TF_ImportGraphDef, TF_FinishWhile). 

3401 

3402 This function does not call Operation._control_flow_post_processing or 

3403 Graph._control_dependencies_for_inputs (since the inputs may not be 

3404 available yet). The caller is responsible for calling these methods. 

3405 

3406 Args: 

3407 c_op: a wrapped TF_Operation 

3408 compute_device: (Optional.) If True, device functions will be executed to 

3409 compute the device property of the Operation. 

3410 

3411 Returns: 

3412 An `Operation` object. 

3413 """ 

3414 self._check_not_finalized() 

3415 ret = Operation._from_c_op(c_op=c_op, g=self) # pylint: disable=protected-access 

3416 # If a name_scope was created with ret.name but no nodes were created in it, 

3417 # the name will still appear in _names_in_use even though the name hasn't 

3418 # been used. This is ok, just leave _names_in_use as-is in this case. 

3419 # TODO(skyewm): make the C API guarantee no name conflicts. 

3420 name_key = ret.name.lower() 

3421 if name_key not in self._names_in_use: 

3422 self._names_in_use[name_key] = 1 

3423 self._create_op_helper(ret, compute_device=compute_device) 

3424 return ret 

3425 

3426 def _create_op_helper(self, op, compute_device=True): 

3427 """Common logic for creating an op in this graph.""" 

3428 # Apply any additional attributes requested. Do not overwrite any existing 

3429 # attributes. 

3430 for key, value in self._attr_scope_map.items(): 

3431 try: 

3432 op.get_attr(key) 

3433 except ValueError: 

3434 if callable(value): 

3435 value = value(op.node_def) 

3436 if not isinstance(value, (type(None), attr_value_pb2.AttrValue)): 

3437 raise TypeError( 

3438 "Callable for scope map key '%s' must return either None or " 

3439 "an AttrValue protocol buffer; but it returned: %s" % 

3440 (key, value)) 

3441 if value: 

3442 op._set_attr(key, value) # pylint: disable=protected-access 

3443 

3444 # Apply a kernel label if one has been specified for this op type. 

3445 try: 

3446 kernel_label = self._op_to_kernel_label_map[op.type] 

3447 op._set_attr("_kernel", # pylint: disable=protected-access 

3448 attr_value_pb2.AttrValue(s=compat.as_bytes(kernel_label))) 

3449 except KeyError: 

3450 pass 

3451 

3452 op._gradient_function = self._gradient_function_map.get(op.type) # pylint: disable=protected-access 

3453 

3454 # Apply the overriding op type for gradients if one has been specified for 

3455 # this op type. 

3456 try: 

3457 mapped_op_type = self._gradient_override_map[op.type] 

3458 op._set_attr("_gradient_op_type", # pylint: disable=protected-access 

3459 attr_value_pb2.AttrValue(s=compat.as_bytes(mapped_op_type))) 

3460 except KeyError: 

3461 pass 

3462 

3463 self._record_op_seen_by_control_dependencies(op) 

3464 

3465 if compute_device: 

3466 self._apply_device_functions(op) 

3467 

3468 # Snapshot the colocation stack metadata before we might generate error 

3469 # messages using it. Note that this snapshot depends on the actual stack 

3470 # and is independent of the op's _class attribute. 

3471 # pylint: disable=protected-access 

3472 op._colocation_code_locations = self._snapshot_colocation_stack_metadata() 

3473 # pylint: enable=protected-access 

3474 

3475 if self._colocation_stack: 

3476 all_colocation_groups = [] 

3477 is_device_set = False 

3478 for colocation_op in self._colocation_stack.peek_objs(): 

3479 try: 

3480 all_colocation_groups.extend(colocation_op.colocation_groups()) 

3481 except AttributeError: 

3482 pass 

3483 if colocation_op.device and not is_device_set: 

3484 # pylint: disable=protected-access 

3485 op._set_device(colocation_op.device) 

3486 # pylint: enable=protected-access 

3487 is_device_set = True 

3488 

3489 all_colocation_groups = sorted(set(all_colocation_groups)) 

3490 # pylint: disable=protected-access 

3491 op._set_attr( 

3492 "_class", 

3493 attr_value_pb2.AttrValue( 

3494 list=attr_value_pb2.AttrValue.ListValue(s=all_colocation_groups))) 

3495 # pylint: enable=protected-access 

3496 

3497 # Sets "container" attribute if 

3498 # (1) self._container is not None 

3499 # (2) "is_stateful" is set in OpDef 

3500 # (3) "container" attribute is in OpDef 

3501 # (4) "container" attribute is None 

3502 if self._container and op._is_stateful: # pylint: disable=protected-access 

3503 try: 

3504 container_attr = op.get_attr("container") 

3505 except ValueError: 

3506 # "container" attribute is not in OpDef 

3507 pass 

3508 else: 

3509 if not container_attr: 

3510 op._set_attr("container", attr_value_pb2.AttrValue( # pylint: disable=protected-access 

3511 s=compat.as_bytes(self._container))) 

3512 

3513 def _add_new_tf_operations(self, compute_devices=True): 

3514 """Creates `Operations` in this graph for any new TF_Operations. 

3515 

3516 This is useful for when TF_Operations are indirectly created by the C API 

3517 outside of the Operation constructor (e.g. by TF_ImportGraphDef, 

3518 TF_FinishWhile). This ensures there are corresponding Operations for all 

3519 TF_Operations in the underlying TF_Graph. 

3520 

3521 Args: 

3522 compute_devices: (Optional.) If True, device functions will be executed to 

3523 compute the device properties of each new Operation. 

3524 

3525 Returns: 

3526 A list of the new `Operation` objects. 

3527 """ 

3528 self._check_not_finalized() 

3529 

3530 # Create all Operation objects before accessing their inputs since an op may 

3531 # be created before its inputs. 

3532 new_ops = [ 

3533 self._create_op_from_tf_operation(c_op, compute_device=compute_devices) 

3534 for c_op in self.new_operations() 

3535 ] 

3536 

3537 # pylint: disable=protected-access 

3538 for op in new_ops: 

3539 new_control_inputs = self._control_dependencies_for_inputs(op.inputs) 

3540 op._add_control_inputs(new_control_inputs) 

3541 op._control_flow_post_processing() 

3542 # pylint: enable=protected-access 

3543 

3544 return new_ops 

3545 

3546 def as_graph_element(self, obj, allow_tensor=True, allow_operation=True): 

3547 """Returns the object referred to by `obj`, as an `Operation` or `Tensor`. 

3548 

3549 This function validates that `obj` represents an element of this 

3550 graph, and gives an informative error message if it is not. 

3551 

3552 This function is the canonical way to get/validate an object of 

3553 one of the allowed types from an external argument reference in the 

3554 Session API. 

3555 

3556 This method may be called concurrently from multiple threads. 

3557 

3558 Args: 

3559 obj: A `Tensor`, an `Operation`, or the name of a tensor or operation. Can 

3560 also be any object with an `_as_graph_element()` method that returns a 

3561 value of one of these types. Note: `_as_graph_element` will be called 

3562 inside the graph's lock and so may not modify the graph. 

3563 allow_tensor: If true, `obj` may refer to a `Tensor`. 

3564 allow_operation: If true, `obj` may refer to an `Operation`. 

3565 

3566 Returns: 

3567 The `Tensor` or `Operation` in the Graph corresponding to `obj`. 

3568 

3569 Raises: 

3570 TypeError: If `obj` is not a type we support attempting to convert 

3571 to types. 

3572 ValueError: If `obj` is of an appropriate type but invalid. For 

3573 example, an invalid string. 

3574 KeyError: If `obj` is not an object in the graph. 

3575 """ 

3576 if self._finalized: 

3577 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 

3578 

3579 with self._lock: 

3580 return self._as_graph_element_locked(obj, allow_tensor, allow_operation) 

3581 

3582 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation): 

3583 """See `Graph.as_graph_element()` for details.""" 

3584 # The vast majority of this function is figuring 

3585 # out what an API user might be doing wrong, so 

3586 # that we can give helpful error messages. 

3587 # 

3588 # Ideally, it would be nice to split it up, but we 

3589 # need context to generate nice error messages. 

3590 

3591 if allow_tensor and allow_operation: 

3592 types_str = "Tensor or Operation" 

3593 elif allow_tensor: 

3594 types_str = "Tensor" 

3595 elif allow_operation: 

3596 types_str = "Operation" 

3597 else: 

3598 raise ValueError("allow_tensor and allow_operation can't both be False.") 

3599 

3600 temp_obj = _as_graph_element(obj) 

3601 if temp_obj is not None: 

3602 obj = temp_obj 

3603 

3604 # If obj appears to be a name... 

3605 if isinstance(obj, compat.bytes_or_text_types): 

3606 name = compat.as_str(obj) 

3607 

3608 if ":" in name and allow_tensor: 

3609 # Looks like a Tensor name and can be a Tensor. 

3610 try: 

3611 op_name, out_n = name.split(":") 

3612 out_n = int(out_n) 

3613 except: 

3614 raise ValueError("The name %s looks a like a Tensor name, but is " 

3615 "not a valid one. Tensor names must be of the " 

3616 "form \"<op_name>:<output_index>\"." % repr(name)) 

3617 try: 

3618 op = self._get_operation_by_name(op_name) 

3619 except KeyError as exc: 

3620 raise KeyError( 

3621 "The name %s refers to a Tensor which does not " 

3622 "exist. The operation, %s, does not exist in the " 

3623 "graph." % (repr(name), repr(op_name)) 

3624 ) from exc 

3625 

3626 try: 

3627 return op.outputs[out_n] 

3628 except: 

3629 raise KeyError("The name %s refers to a Tensor which does not " 

3630 "exist. The operation, %s, exists but only has " 

3631 "%s outputs." % 

3632 (repr(name), repr(op_name), len(op.outputs))) 

3633 

3634 elif ":" in name and not allow_tensor: 

3635 # Looks like a Tensor name but can't be a Tensor. 

3636 raise ValueError("Name %s appears to refer to a Tensor, not a %s." % 

3637 (repr(name), types_str)) 

3638 

3639 elif ":" not in name and allow_operation: 

3640 # Looks like an Operation name and can be an Operation. 

3641 try: 

3642 op = self._get_operation_by_name(name) 

3643 except KeyError as exc: 

3644 raise KeyError( 

3645 "The name %s refers to an Operation not in the graph." 

3646 % repr(name) 

3647 ) from exc 

3648 return op 

3649 

3650 elif ":" not in name and not allow_operation: 

3651 # Looks like an Operation name but can't be an Operation. 

3652 try: 

3653 op = self._get_operation_by_name(name) 

3654 # Yep, it's an Operation name 

3655 err_msg = ("The name %s refers to an Operation, not a %s." % 

3656 (repr(name), types_str)) 

3657 except KeyError: 

3658 err_msg = ("The name %s looks like an (invalid) Operation name, " 

3659 "not a %s." % (repr(name), types_str)) 

3660 err_msg += (" Tensor names must be of the form " 

3661 "\"<op_name>:<output_index>\".") 

3662 raise ValueError(err_msg) 

3663 

3664 elif isinstance(obj, Tensor) and allow_tensor: 

3665 # Actually obj is just the object it's referring to. 

3666 if obj.graph is not self: 

3667 raise ValueError("Tensor %s is not an element of this graph." % obj) 

3668 return obj 

3669 elif isinstance(obj, Operation) and allow_operation: 

3670 # Actually obj is just the object it's referring to. 

3671 if obj.graph is not self: 

3672 raise ValueError("Operation %s is not an element of this graph." % obj) 

3673 return obj 

3674 else: 

3675 # We give up! 

3676 raise TypeError("Can not convert a %s into a %s." % 

3677 (type(obj).__name__, types_str)) 

3678 

3679 def get_operation_by_name(self, name): 

3680 """Returns the `Operation` with the given `name`. 

3681 

3682 This method may be called concurrently from multiple threads. 

3683 

3684 Args: 

3685 name: The name of the `Operation` to return. 

3686 

3687 Returns: 

3688 The `Operation` with the given `name`. 

3689 

3690 Raises: 

3691 TypeError: If `name` is not a string. 

3692 KeyError: If `name` does not correspond to an operation in this graph. 

3693 """ 

3694 

3695 if not isinstance(name, str): 

3696 raise TypeError("Operation names are strings (or similar), not %s." % 

3697 type(name).__name__) 

3698 return self.as_graph_element(name, allow_tensor=False, allow_operation=True) 

3699 

3700 def _get_operation_by_tf_operation(self, tf_oper): 

3701 op_name = pywrap_tf_session.TF_OperationName(tf_oper) 

3702 return self._get_operation_by_name(op_name) 

3703 

3704 def get_tensor_by_name(self, name): 

3705 """Returns the `Tensor` with the given `name`. 

3706 

3707 This method may be called concurrently from multiple threads. 

3708 

3709 Args: 

3710 name: The name of the `Tensor` to return. 

3711 

3712 Returns: 

3713 The `Tensor` with the given `name`. 

3714 

3715 Raises: 

3716 TypeError: If `name` is not a string. 

3717 KeyError: If `name` does not correspond to a tensor in this graph. 

3718 """ 

3719 # Names should be strings. 

3720 if not isinstance(name, str): 

3721 raise TypeError("Tensor names are strings (or similar), not %s." % 

3722 type(name).__name__) 

3723 return self.as_graph_element(name, allow_tensor=True, allow_operation=False) 

3724 

3725 def _get_tensor_by_tf_output(self, tf_output): 

3726 """Returns the `Tensor` representing `tf_output`. 

3727 

3728 Note that there is only one such `Tensor`, i.e. multiple calls to this 

3729 function with the same TF_Output value will always return the same `Tensor` 

3730 object. 

3731 

3732 Args: 

3733 tf_output: A wrapped `TF_Output` (the C API equivalent of `Tensor`). 

3734 

3735 Returns: 

3736 The `Tensor` that represents `tf_output`. 

3737 """ 

3738 op = self._get_operation_by_tf_operation(tf_output.oper) 

3739 return op.outputs[tf_output.index] 

3740 

3741 def op_def_for_type(self, type): # pylint: disable=redefined-builtin 

3742 """Returns the `OpDef` proto for `type`. `type` is a string.""" 

3743 # NOTE: No locking is required because the lookup and insertion operations 

3744 # on Python dictionaries are atomic. 

3745 try: 

3746 return self._op_def_cache[type] 

3747 except KeyError: 

3748 self._op_def_cache[type] = op_def_pb2.OpDef.FromString( 

3749 self._op_def_for_type(type) 

3750 ) 

3751 return self._op_def_cache[type] 

3752 

3753 def as_default(self): 

3754 """Returns a context manager that makes this `Graph` the default graph. 

3755 

3756 This method should be used if you want to create multiple graphs 

3757 in the same process. For convenience, a global default graph is 

3758 provided, and all ops will be added to this graph if you do not 

3759 create a new graph explicitly. 

3760 

3761 Use this method with the `with` keyword to specify that ops created within 

3762 the scope of a block should be added to this graph. In this case, once 

3763 the scope of the `with` is exited, the previous default graph is set again 

3764 as default. There is a stack, so it's ok to have multiple nested levels 

3765 of `as_default` calls. 

3766 

3767 The default graph is a property of the current thread. If you 

3768 create a new thread, and wish to use the default graph in that 

3769 thread, you must explicitly add a `with g.as_default():` in that 

3770 thread's function. 

3771 

3772 The following code examples are equivalent: 

3773 

3774 ```python 

3775 # 1. Using Graph.as_default(): 

3776 g = tf.Graph() 

3777 with g.as_default(): 

3778 c = tf.constant(5.0) 

3779 assert c.graph is g 

3780 

3781 # 2. Constructing and making default: 

3782 with tf.Graph().as_default() as g: 

3783 c = tf.constant(5.0) 

3784 assert c.graph is g 

3785 ``` 

3786 

3787 If eager execution is enabled ops created under this context manager will be 

3788 added to the graph instead of executed eagerly. 

3789 

3790 Returns: 

3791 A context manager for using this graph as the default graph. 

3792 """ 

3793 return _default_graph_stack.get_controller(self) 

3794 

3795 @property 

3796 def collections(self): 

3797 """Returns the names of the collections known to this graph.""" 

3798 return list(self._collections) 

3799 

3800 def add_to_collection(self, name, value): 

3801 """Stores `value` in the collection with the given `name`. 

3802 

3803 Note that collections are not sets, so it is possible to add a value to 

3804 a collection several times. 

3805 

3806 Args: 

3807 name: The key for the collection. The `GraphKeys` class contains many 

3808 standard names for collections. 

3809 value: The value to add to the collection. 

3810 """ # pylint: disable=g-doc-exception 

3811 self._check_not_finalized() 

3812 with self._lock: 

3813 if name not in self._collections: 

3814 self._collections[name] = [value] 

3815 else: 

3816 self._collections[name].append(value) 

3817 

3818 def add_to_collections(self, names, value): 

3819 """Stores `value` in the collections given by `names`. 

3820 

3821 Note that collections are not sets, so it is possible to add a value to 

3822 a collection several times. This function makes sure that duplicates in 

3823 `names` are ignored, but it will not check for pre-existing membership of 

3824 `value` in any of the collections in `names`. 

3825 

3826 `names` can be any iterable, but if `names` is a string, it is treated as a 

3827 single collection name. 

3828 

3829 Args: 

3830 names: The keys for the collections to add to. The `GraphKeys` class 

3831 contains many standard names for collections. 

3832 value: The value to add to the collections. 

3833 """ 

3834 # Make sure names are unique, but treat strings as a single collection name 

3835 names = (names,) if isinstance(names, str) else set(names) 

3836 for name in names: 

3837 self.add_to_collection(name, value) 

3838 

3839 def get_collection_ref(self, name): 

3840 """Returns a list of values in the collection with the given `name`. 

3841 

3842 If the collection exists, this returns the list itself, which can 

3843 be modified in place to change the collection. If the collection does 

3844 not exist, it is created as an empty list and the list is returned. 

3845 

3846 This is different from `get_collection()` which always returns a copy of 

3847 the collection list if it exists and never creates an empty collection. 

3848 

3849 Args: 

3850 name: The key for the collection. For example, the `GraphKeys` class 

3851 contains many standard names for collections. 

3852 

3853 Returns: 

3854 The list of values in the collection with the given `name`, or an empty 

3855 list if no value has been added to that collection. 

3856 """ # pylint: disable=g-doc-exception 

3857 with self._lock: 

3858 coll_list = self._collections.get(name, None) 

3859 if coll_list is None: 

3860 coll_list = [] 

3861 self._collections[name] = coll_list 

3862 return coll_list 

3863 

3864 def get_collection(self, name, scope=None): 

3865 """Returns a list of values in the collection with the given `name`. 

3866 

3867 This is different from `get_collection_ref()` which always returns the 

3868 actual collection list if it exists in that it returns a new list each time 

3869 it is called. 

3870 

3871 Args: 

3872 name: The key for the collection. For example, the `GraphKeys` class 

3873 contains many standard names for collections. 

3874 scope: (Optional.) A string. If supplied, the resulting list is filtered 

3875 to include only items whose `name` attribute matches `scope` using 

3876 `re.match`. Items without a `name` attribute are never returned if a 

3877 scope is supplied. The choice of `re.match` means that a `scope` without 

3878 special tokens filters by prefix. 

3879 

3880 Returns: 

3881 The list of values in the collection with the given `name`, or 

3882 an empty list if no value has been added to that collection. The 

3883 list contains the values in the order under which they were 

3884 collected. 

3885 """ # pylint: disable=g-doc-exception 

3886 with self._lock: 

3887 collection = self._collections.get(name, None) 

3888 if collection is None: 

3889 return [] 

3890 if scope is None: 

3891 return list(collection) 

3892 else: 

3893 c = [] 

3894 regex = re.compile(scope) 

3895 for item in collection: 

3896 try: 

3897 if regex.match(item.name): 

3898 c.append(item) 

3899 except AttributeError: 

3900 # Collection items with no name are ignored. 

3901 pass 

3902 return c 

3903 

3904 def get_all_collection_keys(self): 

3905 """Returns a list of collections used in this graph.""" 

3906 with self._lock: 

3907 return [x for x in self._collections if isinstance(x, str)] 

3908 

3909 def clear_collection(self, name): 

3910 """Clears all values in a collection. 

3911 

3912 Args: 

3913 name: The key for the collection. The `GraphKeys` class contains many 

3914 standard names for collections. 

3915 """ 

3916 self._check_not_finalized() 

3917 with self._lock: 

3918 if name in self._collections: 

3919 del self._collections[name] 

3920 

3921 @tf_contextlib.contextmanager 

3922 def _original_op(self, op): 

3923 """Python 'with' handler to help annotate ops with their originator. 

3924 

3925 An op may have an 'original_op' property that indicates the op on which 

3926 it was based. For example a replica op is based on the op that was 

3927 replicated and a gradient op is based on the op that was differentiated. 

3928 

3929 All ops created in the scope of this 'with' handler will have 

3930 the given 'op' as their original op. 

3931 

3932 Args: 

3933 op: The Operation that all ops created in this scope will have as their 

3934 original op. 

3935 

3936 Yields: 

3937 Nothing. 

3938 """ 

3939 old_original_op = self._default_original_op 

3940 self._default_original_op = op 

3941 try: 

3942 yield 

3943 finally: 

3944 self._default_original_op = old_original_op 

3945 

3946 @property 

3947 def _name_stack(self): 

3948 # This may be called from a thread where name_stack doesn't yet exist. 

3949 if not hasattr(self._thread_local, "_name_stack"): 

3950 self._thread_local._name_stack = "" 

3951 return self._thread_local._name_stack 

3952 

3953 @_name_stack.setter 

3954 def _name_stack(self, name_stack): 

3955 self._thread_local._name_stack = name_stack 

3956 

3957 # pylint: disable=g-doc-return-or-yield,line-too-long 

3958 @tf_contextlib.contextmanager 

3959 def name_scope(self, name): 

3960 """Returns a context manager that creates hierarchical names for operations. 

3961 

3962 A graph maintains a stack of name scopes. A `with name_scope(...):` 

3963 statement pushes a new name onto the stack for the lifetime of the context. 

3964 

3965 The `name` argument will be interpreted as follows: 

3966 

3967 * A string (not ending with '/') will create a new name scope, in which 

3968 `name` is appended to the prefix of all operations created in the 

3969 context. If `name` has been used before, it will be made unique by 

3970 calling `self.unique_name(name)`. 

3971 * A scope previously captured from a `with g.name_scope(...) as 

3972 scope:` statement will be treated as an "absolute" name scope, which 

3973 makes it possible to re-enter existing scopes. 

3974 * A value of `None` or the empty string will reset the current name scope 

3975 to the top-level (empty) name scope. 

3976 

3977 For example: 

3978 

3979 ```python 

3980 with tf.Graph().as_default() as g: 

3981 c = tf.constant(5.0, name="c") 

3982 assert c.op.name == "c" 

3983 c_1 = tf.constant(6.0, name="c") 

3984 assert c_1.op.name == "c_1" 

3985 

3986 # Creates a scope called "nested" 

3987 with g.name_scope("nested") as scope: 

3988 nested_c = tf.constant(10.0, name="c") 

3989 assert nested_c.op.name == "nested/c" 

3990 

3991 # Creates a nested scope called "inner". 

3992 with g.name_scope("inner"): 

3993 nested_inner_c = tf.constant(20.0, name="c") 

3994 assert nested_inner_c.op.name == "nested/inner/c" 

3995 

3996 # Create a nested scope called "inner_1". 

3997 with g.name_scope("inner"): 

3998 nested_inner_1_c = tf.constant(30.0, name="c") 

3999 assert nested_inner_1_c.op.name == "nested/inner_1/c" 

4000 

4001 # Treats `scope` as an absolute name scope, and 

4002 # switches to the "nested/" scope. 

4003 with g.name_scope(scope): 

4004 nested_d = tf.constant(40.0, name="d") 

4005 assert nested_d.op.name == "nested/d" 

4006 

4007 with g.name_scope(""): 

4008 e = tf.constant(50.0, name="e") 

4009 assert e.op.name == "e" 

4010 ``` 

4011 

4012 The name of the scope itself can be captured by `with 

4013 g.name_scope(...) as scope:`, which stores the name of the scope 

4014 in the variable `scope`. This value can be used to name an 

4015 operation that represents the overall result of executing the ops 

4016 in a scope. For example: 

4017 

4018 ```python 

4019 inputs = tf.constant(...) 

4020 with g.name_scope('my_layer') as scope: 

4021 weights = tf.Variable(..., name="weights") 

4022 biases = tf.Variable(..., name="biases") 

4023 affine = tf.matmul(inputs, weights) + biases 

4024 output = tf.nn.relu(affine, name=scope) 

4025 ``` 

4026 

4027 NOTE: This constructor validates the given `name`. Valid scope 

4028 names match one of the following regular expressions: 

4029 

4030 [A-Za-z0-9.][A-Za-z0-9_.\\-/]* (for scopes at the root) 

4031 [A-Za-z0-9_.\\-/]* (for other scopes) 

4032 

4033 Args: 

4034 name: A name for the scope. 

4035 

4036 Returns: 

4037 A context manager that installs `name` as a new name scope. 

4038 

4039 Raises: 

4040 ValueError: If `name` is not a valid scope name, according to the rules 

4041 above. 

4042 """ 

4043 if name: 

4044 if isinstance(name, compat.bytes_or_text_types): 

4045 name = compat.as_str(name) 

4046 

4047 if self._name_stack: 

4048 # Scopes created in a nested scope may have initial characters 

4049 # that are illegal as the initial character of an op name 

4050 # (viz. '-', '\', '/', and '_'). 

4051 if not _VALID_SCOPE_NAME_REGEX.match(name): 

4052 raise ValueError( 

4053 f"'{name}' is not a valid scope name. A scope name has to match " 

4054 f"the following pattern: {_VALID_SCOPE_NAME_REGEX.pattern}") 

4055 else: 

4056 # Scopes created in the root must match the more restrictive 

4057 # op name regex, which constrains the initial character. 

4058 if not _VALID_OP_NAME_REGEX.match(name): 

4059 raise ValueError( 

4060 f"'{name}' is not a valid root scope name. A root scope name has " 

4061 f"to match the following pattern: {_VALID_OP_NAME_REGEX.pattern}") 

4062 old_stack = self._name_stack 

4063 if not name: # Both for name=None and name="" we re-set to empty scope. 

4064 new_stack = "" 

4065 returned_scope = "" 

4066 elif name[-1] == "/": 

4067 new_stack = name_from_scope_name(name) 

4068 returned_scope = name 

4069 else: 

4070 new_stack = self.unique_name(name) 

4071 returned_scope = new_stack + "/" 

4072 self._name_stack = new_stack 

4073 try: 

4074 yield returned_scope 

4075 finally: 

4076 self._name_stack = old_stack 

4077 

4078 # pylint: enable=g-doc-return-or-yield,line-too-long 

4079 

4080 def unique_name(self, name, mark_as_used=True): 

4081 """Return a unique operation name for `name`. 

4082 

4083 Note: You rarely need to call `unique_name()` directly. Most of 

4084 the time you just need to create `with g.name_scope()` blocks to 

4085 generate structured names. 

4086 

4087 `unique_name` is used to generate structured names, separated by 

4088 `"/"`, to help identify operations when debugging a graph. 

4089 Operation names are displayed in error messages reported by the 

4090 TensorFlow runtime, and in various visualization tools such as 

4091 TensorBoard. 

4092 

4093 If `mark_as_used` is set to `True`, which is the default, a new 

4094 unique name is created and marked as in use. If it's set to `False`, 

4095 the unique name is returned without actually being marked as used. 

4096 This is useful when the caller simply wants to know what the name 

4097 to be created will be. 

4098 

4099 Args: 

4100 name: The name for an operation. 

4101 mark_as_used: Whether to mark this name as being used. 

4102 

4103 Returns: 

4104 A string to be passed to `create_op()` that will be used 

4105 to name the operation being created. 

4106 """ 

4107 if self._name_stack: 

4108 name = self._name_stack + "/" + name 

4109 

4110 # For the sake of checking for names in use, we treat names as case 

4111 # insensitive (e.g. foo = Foo). 

4112 name_key = name.lower() 

4113 i = self._names_in_use.get(name_key, 0) 

4114 # Increment the number for "name_key". 

4115 if mark_as_used: 

4116 self._names_in_use[name_key] = i + 1 

4117 if i > 0: 

4118 base_name_key = name_key 

4119 # Make sure the composed name key is not already used. 

4120 while name_key in self._names_in_use: 

4121 name_key = "%s_%d" % (base_name_key, i) 

4122 i += 1 

4123 # Mark the composed name_key as used in case someone wants 

4124 # to call unique_name("name_1"). 

4125 if mark_as_used: 

4126 self._names_in_use[name_key] = 1 

4127 

4128 # Return the new name with the original capitalization of the given name. 

4129 name = "%s_%d" % (name, i - 1) 

4130 return name 

4131 

4132 def get_name_scope(self): 

4133 """Returns the current name scope. 

4134 

4135 For example: 

4136 

4137 ```python 

4138 with tf.name_scope('scope1'): 

4139 with tf.name_scope('scope2'): 

4140 print(tf.compat.v1.get_default_graph().get_name_scope()) 

4141 ``` 

4142 would print the string `scope1/scope2`. 

4143 

4144 Returns: 

4145 A string representing the current name scope. 

4146 """ 

4147 return self._name_stack 

4148 

4149 @tf_contextlib.contextmanager 

4150 def _colocate_with_for_gradient(self, op, gradient_uid, 

4151 ignore_existing=False): 

4152 with self.colocate_with(op, ignore_existing): 

4153 if gradient_uid is not None: 

4154 ctx = _get_enclosing_context(self) 

4155 if ctx is not None: 

4156 ctx.EnterGradientColocation(op, gradient_uid) 

4157 try: 

4158 yield 

4159 finally: 

4160 ctx.ExitGradientColocation(op, gradient_uid) 

4161 else: 

4162 yield 

4163 else: 

4164 yield 

4165 

4166 @tf_contextlib.contextmanager 

4167 def colocate_with(self, op, ignore_existing=False): 

4168 """Returns a context manager that specifies an op to colocate with. 

4169 

4170 Note: this function is not for public use, only for internal libraries. 

4171 

4172 For example: 

4173 

4174 ```python 

4175 a = tf.Variable([1.0]) 

4176 with g.colocate_with(a): 

4177 b = tf.constant(1.0) 

4178 c = tf.add(a, b) 

4179 ``` 

4180 

4181 `b` and `c` will always be colocated with `a`, no matter where `a` 

4182 is eventually placed. 

4183 

4184 **NOTE** Using a colocation scope resets any existing device constraints. 

4185 

4186 If `op` is `None` then `ignore_existing` must be `True` and the new 

4187 scope resets all colocation and device constraints. 

4188 

4189 Args: 

4190 op: The op to colocate all created ops with, or `None`. 

4191 ignore_existing: If true, only applies colocation of this op within the 

4192 context, rather than applying all colocation properties on the stack. 

4193 If `op` is `None`, this value must be `True`. 

4194 

4195 Raises: 

4196 ValueError: if op is None but ignore_existing is False. 

4197 

4198 Yields: 

4199 A context manager that specifies the op with which to colocate 

4200 newly created ops. 

4201 """ 

4202 if op is None and not ignore_existing: 

4203 raise ValueError("Trying to reset colocation (op is None) but " 

4204 "ignore_existing is not True") 

4205 op, device_only_candidate = _op_to_colocate_with(op, self) 

4206 

4207 # By default, colocate_with resets the device function stack, 

4208 # since colocate_with is typically used in specific internal 

4209 # library functions where colocation is intended to be "stronger" 

4210 # than device functions. 

4211 # 

4212 # In the future, a caller may specify that device_functions win 

4213 # over colocation, in which case we can add support. 

4214 device_fn_tmp = self._device_function_stack 

4215 self._device_function_stack = traceable_stack.TraceableStack() 

4216 

4217 if ignore_existing: 

4218 current_stack = self._colocation_stack 

4219 self._colocation_stack = traceable_stack.TraceableStack() 

4220 

4221 if op is not None: 

4222 # offset refers to the stack frame used for storing code location. 

4223 # We use 4, the sum of 1 to use our caller's stack frame and 3 

4224 # to jump over layers of context managers above us. 

4225 self._colocation_stack.push_obj(op, offset=4) 

4226 if device_only_candidate is not None: 

4227 self._colocation_stack.push_obj(device_only_candidate, offset=4) 

4228 elif not ignore_existing: 

4229 raise ValueError("Trying to reset colocation (op is None) but " 

4230 "ignore_existing is not True") 

4231 try: 

4232 yield 

4233 finally: 

4234 # Restore device function stack 

4235 self._device_function_stack = device_fn_tmp 

4236 if op is not None: 

4237 self._colocation_stack.pop_obj() 

4238 if device_only_candidate is not None: 

4239 self._colocation_stack.pop_obj() 

4240 

4241 # Reset the colocation stack if requested. 

4242 if ignore_existing: 

4243 self._colocation_stack = current_stack 

4244 

4245 def _add_device_to_stack(self, device_name_or_function, offset=0): 

4246 """Add device to stack manually, separate from a context manager.""" 

4247 total_offset = 1 + offset 

4248 spec = _UserDeviceSpec(device_name_or_function) 

4249 self._device_function_stack.push_obj(spec, offset=total_offset) 

4250 return spec 

4251 

4252 @tf_contextlib.contextmanager 

4253 def device(self, device_name_or_function): 

4254 # pylint: disable=line-too-long 

4255 """Returns a context manager that specifies the default device to use. 

4256 

4257 The `device_name_or_function` argument may either be a device name 

4258 string, a device function, or None: 

4259 

4260 * If it is a device name string, all operations constructed in 

4261 this context will be assigned to the device with that name, unless 

4262 overridden by a nested `device()` context. 

4263 * If it is a function, it will be treated as a function from 

4264 Operation objects to device name strings, and invoked each time 

4265 a new Operation is created. The Operation will be assigned to 

4266 the device with the returned name. 

4267 * If it is None, all `device()` invocations from the enclosing context 

4268 will be ignored. 

4269 

4270 For information about the valid syntax of device name strings, see 

4271 the documentation in 

4272 [`DeviceNameUtils`](https://www.tensorflow.org/code/tensorflow/core/util/device_name_utils.h). 

4273 

4274 For example: 

4275 

4276 ```python 

4277 with g.device('/device:GPU:0'): 

4278 # All operations constructed in this context will be placed 

4279 # on GPU 0. 

4280 with g.device(None): 

4281 # All operations constructed in this context will have no 

4282 # assigned device. 

4283 

4284 # Defines a function from `Operation` to device string. 

4285 def matmul_on_gpu(n): 

4286 if n.type == "MatMul": 

4287 return "/device:GPU:0" 

4288 else: 

4289 return "/cpu:0" 

4290 

4291 with g.device(matmul_on_gpu): 

4292 # All operations of type "MatMul" constructed in this context 

4293 # will be placed on GPU 0; all other operations will be placed 

4294 # on CPU 0. 

4295 ``` 

4296 

4297 **N.B.** The device scope may be overridden by op wrappers or 

4298 other library code. For example, a variable assignment op 

4299 `v.assign()` must be colocated with the `tf.Variable` `v`, and 

4300 incompatible device scopes will be ignored. 

4301 

4302 Args: 

4303 device_name_or_function: The device name or function to use in the 

4304 context. 

4305 

4306 Yields: 

4307 A context manager that specifies the default device to use for newly 

4308 created ops. 

4309 

4310 Raises: 

4311 RuntimeError: If device scopes are not properly nested. 

4312 """ 

4313 self._add_device_to_stack(device_name_or_function, offset=2) 

4314 old_top_of_stack = self._device_function_stack.peek_top_obj() 

4315 try: 

4316 yield 

4317 finally: 

4318 new_top_of_stack = self._device_function_stack.peek_top_obj() 

4319 if old_top_of_stack is not new_top_of_stack: 

4320 raise RuntimeError("Exiting device scope without proper scope nesting.") 

4321 self._device_function_stack.pop_obj() 

4322 

4323 def _apply_device_functions(self, op): 

4324 """Applies the current device function stack to the given operation.""" 

4325 # Apply any device functions in LIFO order, so that the most recently 

4326 # pushed function has the first chance to apply a device to the op. 

4327 # We apply here because the result can depend on the Operation's 

4328 # signature, which is computed in the Operation constructor. 

4329 # pylint: disable=protected-access 

4330 prior_device_string = None 

4331 for device_spec in self._device_function_stack.peek_objs(): 

4332 if device_spec.is_null_merge: 

4333 continue 

4334 

4335 if device_spec.function is None: 

4336 break 

4337 

4338 device_string = device_spec.string_merge(op) 

4339 

4340 # Take advantage of the fact that None is a singleton and Python interns 

4341 # strings, since identity checks are faster than equality checks. 

4342 if device_string is not prior_device_string: 

4343 op._set_device_from_string(device_string) 

4344 prior_device_string = device_string 

4345 op._device_code_locations = self._snapshot_device_function_stack_metadata() 

4346 # pylint: enable=protected-access 

4347 

4348 # pylint: disable=g-doc-return-or-yield 

4349 @tf_contextlib.contextmanager 

4350 def container(self, container_name): 

4351 """Returns a context manager that specifies the resource container to use. 

4352 

4353 Stateful operations, such as variables and queues, can maintain their 

4354 states on devices so that they can be shared by multiple processes. 

4355 A resource container is a string name under which these stateful 

4356 operations are tracked. These resources can be released or cleared 

4357 with `tf.Session.reset()`. 

4358 

4359 For example: 

4360 

4361 ```python 

4362 with g.container('experiment0'): 

4363 # All stateful Operations constructed in this context will be placed 

4364 # in resource container "experiment0". 

4365 v1 = tf.Variable([1.0]) 

4366 v2 = tf.Variable([2.0]) 

4367 with g.container("experiment1"): 

4368 # All stateful Operations constructed in this context will be 

4369 # placed in resource container "experiment1". 

4370 v3 = tf.Variable([3.0]) 

4371 q1 = tf.queue.FIFOQueue(10, tf.float32) 

4372 # All stateful Operations constructed in this context will be 

4373 # be created in the "experiment0". 

4374 v4 = tf.Variable([4.0]) 

4375 q1 = tf.queue.FIFOQueue(20, tf.float32) 

4376 with g.container(""): 

4377 # All stateful Operations constructed in this context will be 

4378 # be placed in the default resource container. 

4379 v5 = tf.Variable([5.0]) 

4380 q3 = tf.queue.FIFOQueue(30, tf.float32) 

4381 

4382 # Resets container "experiment0", after which the state of v1, v2, v4, q1 

4383 # will become undefined (such as uninitialized). 

4384 tf.Session.reset(target, ["experiment0"]) 

4385 ``` 

4386 

4387 Args: 

4388 container_name: container name string. 

4389 

4390 Returns: 

4391 A context manager for defining resource containers for stateful ops, 

4392 yields the container name. 

4393 """ 

4394 original_container = self._container 

4395 self._container = container_name 

4396 try: 

4397 yield self._container 

4398 finally: 

4399 self._container = original_container 

4400 

4401 # pylint: enable=g-doc-return-or-yield 

4402 

4403 class _ControlDependenciesController(object): 

4404 """Context manager for `control_dependencies()`.""" 

4405 

4406 def __init__(self, graph, control_inputs): 

4407 """Create a new `_ControlDependenciesController`. 

4408 

4409 A `_ControlDependenciesController` is the context manager for 

4410 `with tf.control_dependencies()` blocks. These normally nest, 

4411 as described in the documentation for `control_dependencies()`. 

4412 

4413 The `control_inputs` argument list control dependencies that must be 

4414 added to the current set of control dependencies. Because of 

4415 uniquification the set can be empty even if the caller passed a list of 

4416 ops. The special value `None` indicates that we want to start a new 

4417 empty set of control dependencies instead of extending the current set. 

4418 

4419 In that case we also clear the current control flow context, which is an 

4420 additional mechanism to add control dependencies. 

4421 

4422 Args: 

4423 graph: The graph that this controller is managing. 

4424 control_inputs: List of ops to use as control inputs in addition to the 

4425 current control dependencies. None to indicate that the dependencies 

4426 should be cleared. 

4427 """ 

4428 self._graph = graph 

4429 if control_inputs is None: 

4430 self._control_inputs_val = [] 

4431 self._new_stack = True 

4432 else: 

4433 self._control_inputs_val = control_inputs 

4434 self._new_stack = False 

4435 self._seen_nodes = set() 

4436 self._old_stack = None 

4437 self._old_control_flow_context = None 

4438 

4439 # pylint: disable=protected-access 

4440 

4441 def __enter__(self): 

4442 if self._new_stack: 

4443 # Clear the control_dependencies graph. 

4444 self._old_stack = self._graph._control_dependencies_stack 

4445 self._graph._control_dependencies_stack = [] 

4446 # Clear the control_flow_context too. 

4447 self._old_control_flow_context = self._graph._get_control_flow_context() 

4448 self._graph._set_control_flow_context(None) 

4449 self._graph._push_control_dependencies_controller(self) 

4450 

4451 def __exit__(self, unused_type, unused_value, unused_traceback): 

4452 self._graph._pop_control_dependencies_controller(self) 

4453 if self._new_stack: 

4454 self._graph._control_dependencies_stack = self._old_stack 

4455 self._graph._set_control_flow_context(self._old_control_flow_context) 

4456 

4457 # pylint: enable=protected-access 

4458 

4459 @property 

4460 def control_inputs(self): 

4461 return self._control_inputs_val 

4462 

4463 def add_op(self, op): 

4464 if isinstance(op, Tensor): 

4465 op = op.ref() 

4466 self._seen_nodes.add(op) 

4467 

4468 def op_in_group(self, op): 

4469 if isinstance(op, Tensor): 

4470 op = op.ref() 

4471 return op in self._seen_nodes 

4472 

4473 def _push_control_dependencies_controller(self, controller): 

4474 self._control_dependencies_stack.append(controller) 

4475 

4476 def _pop_control_dependencies_controller(self, controller): 

4477 assert self._control_dependencies_stack[-1] is controller 

4478 self._control_dependencies_stack.pop() 

4479 

4480 def _current_control_dependencies(self): 

4481 ret = set() 

4482 for controller in self._control_dependencies_stack: 

4483 for op in controller.control_inputs: 

4484 ret.add(op) 

4485 return ret 

4486 

4487 def _control_dependencies_for_inputs(self, input_ops): 

4488 """For an op that takes `input_ops` as inputs, compute control inputs. 

4489 

4490 The returned control dependencies should yield an execution that 

4491 is equivalent to adding all control inputs in 

4492 self._control_dependencies_stack to a newly created op. However, 

4493 this function attempts to prune the returned control dependencies 

4494 by observing that nodes created within the same `with 

4495 control_dependencies(...):` block may have data dependencies that make 

4496 the explicit approach redundant. 

4497 

4498 Args: 

4499 input_ops: The data input ops for an op to be created. 

4500 

4501 Returns: 

4502 A list of control inputs for the op to be created. 

4503 """ 

4504 ret = [] 

4505 for controller in self._control_dependencies_stack: 

4506 # If any of the input_ops already depends on the inputs from controller, 

4507 # we say that the new op is dominated (by that input), and we therefore 

4508 # do not need to add control dependencies for this controller's inputs. 

4509 dominated = False 

4510 for op in input_ops: 

4511 if controller.op_in_group(op): 

4512 dominated = True 

4513 break 

4514 if not dominated: 

4515 # Don't add a control input if we already have a data dependency on i. 

4516 # NOTE(mrry): We do not currently track transitive data dependencies, 

4517 # so we may add redundant control inputs. 

4518 ret.extend(c for c in controller.control_inputs if c not in input_ops) 

4519 return ret 

4520 

4521 def _record_op_seen_by_control_dependencies(self, op): 

4522 """Record that the given op depends on all registered control dependencies. 

4523 

4524 Args: 

4525 op: An Operation. 

4526 """ 

4527 for controller in self._control_dependencies_stack: 

4528 controller.add_op(op) 

4529 

4530 def control_dependencies(self, control_inputs): 

4531 """Returns a context manager that specifies control dependencies. 

4532 

4533 Use with the `with` keyword to specify that all operations constructed 

4534 within the context should have control dependencies on 

4535 `control_inputs`. For example: 

4536 

4537 ```python 

4538 with g.control_dependencies([a, b, c]): 

4539 # `d` and `e` will only run after `a`, `b`, and `c` have executed. 

4540 d = ... 

4541 e = ... 

4542 ``` 

4543 

4544 Multiple calls to `control_dependencies()` can be nested, and in 

4545 that case a new `Operation` will have control dependencies on the union 

4546 of `control_inputs` from all active contexts. 

4547 

4548 ```python 

4549 with g.control_dependencies([a, b]): 

4550 # Ops constructed here run after `a` and `b`. 

4551 with g.control_dependencies([c, d]): 

4552 # Ops constructed here run after `a`, `b`, `c`, and `d`. 

4553 ``` 

4554 

4555 You can pass None to clear the control dependencies: 

4556 

4557 ```python 

4558 with g.control_dependencies([a, b]): 

4559 # Ops constructed here run after `a` and `b`. 

4560 with g.control_dependencies(None): 

4561 # Ops constructed here run normally, not waiting for either `a` or `b`. 

4562 with g.control_dependencies([c, d]): 

4563 # Ops constructed here run after `c` and `d`, also not waiting 

4564 # for either `a` or `b`. 

4565 ``` 

4566 

4567 *N.B.* The control dependencies context applies *only* to ops that 

4568 are constructed within the context. Merely using an op or tensor 

4569 in the context does not add a control dependency. The following 

4570 example illustrates this point: 

4571 

4572 ```python 

4573 # WRONG 

4574 def my_func(pred, tensor): 

4575 t = tf.matmul(tensor, tensor) 

4576 with tf.control_dependencies([pred]): 

4577 # The matmul op is created outside the context, so no control 

4578 # dependency will be added. 

4579 return t 

4580 

4581 # RIGHT 

4582 def my_func(pred, tensor): 

4583 with tf.control_dependencies([pred]): 

4584 # The matmul op is created in the context, so a control dependency 

4585 # will be added. 

4586 return tf.matmul(tensor, tensor) 

4587 ``` 

4588 

4589 Also note that though execution of ops created under this scope will trigger 

4590 execution of the dependencies, the ops created under this scope might still 

4591 be pruned from a normal tensorflow graph. For example, in the following 

4592 snippet of code the dependencies are never executed: 

4593 

4594 ```python 

4595 loss = model.loss() 

4596 with tf.control_dependencies(dependencies): 

4597 loss = loss + tf.constant(1) # note: dependencies ignored in the 

4598 # backward pass 

4599 return tf.gradients(loss, model.variables) 

4600 ``` 

4601 

4602 This is because evaluating the gradient graph does not require evaluating 

4603 the constant(1) op created in the forward pass. 

4604 

4605 Args: 

4606 control_inputs: A list of `Operation` or `Tensor` objects which must be 

4607 executed or computed before running the operations defined in the 

4608 context. Can also be `None` to clear the control dependencies. 

4609 

4610 Returns: 

4611 A context manager that specifies control dependencies for all 

4612 operations constructed within the context. 

4613 

4614 Raises: 

4615 TypeError: If `control_inputs` is not a list of `Operation` or 

4616 `Tensor` objects. 

4617 """ 

4618 if control_inputs is None: 

4619 return self._ControlDependenciesController(self, None) 

4620 # First convert the inputs to ops, and deduplicate them. 

4621 # NOTE(mrry): Other than deduplication, we do not currently track direct 

4622 # or indirect dependencies between control_inputs, which may result in 

4623 # redundant control inputs. 

4624 control_ops = [] 

4625 current = self._current_control_dependencies() 

4626 for c in control_inputs: 

4627 # The hasattr(handle) is designed to match ResourceVariables. This is so 

4628 # control dependencies on a variable or on an unread variable don't 

4629 # trigger reads. 

4630 if (isinstance(c, internal.IndexedSlices) or 

4631 (hasattr(c, "_handle") and hasattr(c, "op"))): 

4632 c = c.op 

4633 c = self.as_graph_element(c) 

4634 if isinstance(c, Tensor): 

4635 c = c.op 

4636 elif not isinstance(c, Operation): 

4637 raise TypeError("Control input must be Operation or Tensor: %s" % c) 

4638 if c not in current: 

4639 control_ops.append(c) 

4640 current.add(c) 

4641 # Mark this op with an attribute indicating that it is used as a manual 

4642 # control dep in order to allow tracking how common utilization of 

4643 # manual control deps in graphs run through the MLIR Bridge are. See 

4644 # go/manual-control-dependencies-bridge for details. 

4645 # pylint: disable=protected-access 

4646 c._set_attr("_has_manual_control_dependencies", 

4647 attr_value_pb2.AttrValue(b=True)) 

4648 # pylint: enable=protected-access 

4649 return self._ControlDependenciesController(self, control_ops) 

4650 

4651 # pylint: disable=g-doc-return-or-yield 

4652 @tf_contextlib.contextmanager 

4653 def _attr_scope(self, attr_map): 

4654 """EXPERIMENTAL: A context manager for setting attributes on operators. 

4655 

4656 This context manager can be used to add additional 

4657 attributes to operators within the scope of the context. 

4658 

4659 For example: 

4660 

4661 with ops.Graph().as_default() as g: 

4662 f_1 = Foo() # No extra attributes 

4663 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=False)}): 

4664 f_2 = Foo() # Additional attribute _a=False 

4665 with g._attr_scope({"_a": tf.attr_value_pb2.AttrValue(b=True)}): 

4666 f_3 = Foo() # Additional attribute _a=False 

4667 with g._attr_scope({"_a": None}): 

4668 f_4 = Foo() # No additional attributes. 

4669 

4670 Args: 

4671 attr_map: A dictionary mapping attr name strings to AttrValue protocol 

4672 buffers or None. 

4673 

4674 Returns: 

4675 A context manager that sets the kernel label to be used for one or more 

4676 ops created in that context. 

4677 

4678 Raises: 

4679 TypeError: If attr_map is not a dictionary mapping 

4680 strings to AttrValue protobufs. 

4681 """ 

4682 if not isinstance(attr_map, dict): 

4683 raise TypeError("attr_map must be a dictionary mapping " 

4684 "strings to AttrValue protocol buffers") 

4685 # The saved_attrs dictionary stores any currently-set labels that 

4686 # will be overridden by this context manager. 

4687 saved_attrs = {} 

4688 # Install the given attribute 

4689 for name, attr in attr_map.items(): 

4690 if not (isinstance(name, str) and 

4691 (isinstance(attr, (type(None), attr_value_pb2.AttrValue)) or 

4692 callable(attr))): 

4693 raise TypeError("attr_map must be a dictionary mapping " 

4694 "strings to AttrValue protocol buffers or " 

4695 "callables that emit AttrValue protocol buffers") 

4696 try: 

4697 saved_attrs[name] = self._attr_scope_map[name] 

4698 except KeyError: 

4699 pass 

4700 if attr is None: 

4701 del self._attr_scope_map[name] 

4702 else: 

4703 self._attr_scope_map[name] = attr 

4704 try: 

4705 yield # The code within the context runs here. 

4706 finally: 

4707 # Remove the attributes set for this context, and restore any saved 

4708 # attributes. 

4709 for name, attr in attr_map.items(): 

4710 try: 

4711 self._attr_scope_map[name] = saved_attrs[name] 

4712 except KeyError: 

4713 del self._attr_scope_map[name] 

4714 

4715 # pylint: enable=g-doc-return-or-yield 

4716 

4717 # pylint: disable=g-doc-return-or-yield 

4718 @tf_contextlib.contextmanager 

4719 def _kernel_label_map(self, op_to_kernel_label_map): 

4720 """EXPERIMENTAL: A context manager for setting kernel labels. 

4721 

4722 This context manager can be used to select particular 

4723 implementations of kernels within the scope of the context. 

4724 

4725 For example: 

4726 

4727 with ops.Graph().as_default() as g: 

4728 f_1 = Foo() # Uses the default registered kernel for the Foo op. 

4729 with g.kernel_label_map({"Foo": "v_2"}): 

4730 f_2 = Foo() # Uses the registered kernel with label "v_2" 

4731 # for the Foo op. 

4732 with g.kernel_label_map({"Foo": "v_3"}): 

4733 f_3 = Foo() # Uses the registered kernel with label "v_3" 

4734 # for the Foo op. 

4735 with g.kernel_label_map({"Foo": ""}): 

4736 f_4 = Foo() # Uses the default registered kernel 

4737 # for the Foo op. 

4738 

4739 Args: 

4740 op_to_kernel_label_map: A dictionary mapping op type strings to kernel 

4741 label strings. 

4742 

4743 Returns: 

4744 A context manager that sets the kernel label to be used for one or more 

4745 ops created in that context. 

4746 

4747 Raises: 

4748 TypeError: If op_to_kernel_label_map is not a dictionary mapping 

4749 strings to strings. 

4750 """ 

4751 if not isinstance(op_to_kernel_label_map, dict): 

4752 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 

4753 "strings to strings") 

4754 # The saved_labels dictionary stores any currently-set labels that 

4755 # will be overridden by this context manager. 

4756 saved_labels = {} 

4757 # Install the given label 

4758 for op_type, label in op_to_kernel_label_map.items(): 

4759 if not (isinstance(op_type, str) and 

4760 isinstance(label, str)): 

4761 raise TypeError("op_to_kernel_label_map must be a dictionary mapping " 

4762 "strings to strings") 

4763 try: 

4764 saved_labels[op_type] = self._op_to_kernel_label_map[op_type] 

4765 except KeyError: 

4766 pass 

4767 self._op_to_kernel_label_map[op_type] = label 

4768 try: 

4769 yield # The code within the context runs here. 

4770 finally: 

4771 # Remove the labels set for this context, and restore any saved labels. 

4772 for op_type, label in op_to_kernel_label_map.items(): 

4773 try: 

4774 self._op_to_kernel_label_map[op_type] = saved_labels[op_type] 

4775 except KeyError: 

4776 del self._op_to_kernel_label_map[op_type] 

4777 

4778 # pylint: enable=g-doc-return-or-yield 

4779 

4780 @tf_contextlib.contextmanager 

4781 def _override_gradient_function(self, gradient_function_map): 

4782 """Specify gradient function for the given op type.""" 

4783 

4784 # This is an internal API and we don't need nested context for this. 

4785 # TODO(mdan): make it a proper context manager. 

4786 assert not self._gradient_function_map 

4787 self._gradient_function_map = gradient_function_map 

4788 try: 

4789 yield 

4790 finally: 

4791 self._gradient_function_map = {} 

4792 

4793 # pylint: disable=g-doc-return-or-yield 

4794 @tf_contextlib.contextmanager 

4795 def gradient_override_map(self, op_type_map): 

4796 """EXPERIMENTAL: A context manager for overriding gradient functions. 

4797 

4798 This context manager can be used to override the gradient function 

4799 that will be used for ops within the scope of the context. 

4800 

4801 For example: 

4802 

4803 ```python 

4804 @tf.RegisterGradient("CustomSquare") 

4805 def _custom_square_grad(op, grad): 

4806 # ... 

4807 

4808 with tf.Graph().as_default() as g: 

4809 c = tf.constant(5.0) 

4810 s_1 = tf.square(c) # Uses the default gradient for tf.square. 

4811 with g.gradient_override_map({"Square": "CustomSquare"}): 

4812 s_2 = tf.square(s_2) # Uses _custom_square_grad to compute the 

4813 # gradient of s_2. 

4814 ``` 

4815 

4816 Args: 

4817 op_type_map: A dictionary mapping op type strings to alternative op type 

4818 strings. 

4819 

4820 Returns: 

4821 A context manager that sets the alternative op type to be used for one 

4822 or more ops created in that context. 

4823 

4824 Raises: 

4825 TypeError: If `op_type_map` is not a dictionary mapping strings to 

4826 strings. 

4827 """ 

4828 if not isinstance(op_type_map, dict): 

4829 raise TypeError("op_type_map must be a dictionary mapping " 

4830 "strings to strings") 

4831 # The saved_mappings dictionary stores any currently-set mappings that 

4832 # will be overridden by this context manager. 

4833 saved_mappings = {} 

4834 # Install the given label 

4835 for op_type, mapped_op_type in op_type_map.items(): 

4836 if not (isinstance(op_type, str) and 

4837 isinstance(mapped_op_type, str)): 

4838 raise TypeError("op_type_map must be a dictionary mapping " 

4839 "strings to strings") 

4840 try: 

4841 saved_mappings[op_type] = self._gradient_override_map[op_type] 

4842 except KeyError: 

4843 pass 

4844 self._gradient_override_map[op_type] = mapped_op_type 

4845 try: 

4846 yield # The code within the context runs here. 

4847 finally: 

4848 # Remove the labels set for this context, and restore any saved labels. 

4849 for op_type, mapped_op_type in op_type_map.items(): 

4850 try: 

4851 self._gradient_override_map[op_type] = saved_mappings[op_type] 

4852 except KeyError: 

4853 del self._gradient_override_map[op_type] 

4854 

4855 # pylint: enable=g-doc-return-or-yield 

4856 

4857 def prevent_feeding(self, tensor): 

4858 """Marks the given `tensor` as unfeedable in this graph.""" 

4859 self._unfeedable_tensors.add(tensor) 

4860 

4861 def is_feedable(self, tensor): 

4862 """Returns `True` if and only if `tensor` is feedable.""" 

4863 return tensor not in self._unfeedable_tensors 

4864 

4865 def prevent_fetching(self, op): 

4866 """Marks the given `op` as unfetchable in this graph.""" 

4867 self._unfetchable_ops.add(op) 

4868 

4869 def is_fetchable(self, tensor_or_op): 

4870 """Returns `True` if and only if `tensor_or_op` is fetchable.""" 

4871 if isinstance(tensor_or_op, Tensor): 

4872 return tensor_or_op.op not in self._unfetchable_ops 

4873 else: 

4874 return tensor_or_op not in self._unfetchable_ops 

4875 

4876 def switch_to_thread_local(self): 

4877 """Make device, colocation and dependencies stacks thread-local. 

4878 

4879 Device, colocation and dependencies stacks are not thread-local be default. 

4880 If multiple threads access them, then the state is shared. This means that 

4881 one thread may affect the behavior of another thread. 

4882 

4883 After this method is called, the stacks become thread-local. If multiple 

4884 threads access them, then the state is not shared. Each thread uses its own 

4885 value; a thread doesn't affect other threads by mutating such a stack. 

4886 

4887 The initial value for every thread's stack is set to the current value 

4888 of the stack when `switch_to_thread_local()` was first called. 

4889 """ 

4890 if not self._stack_state_is_thread_local: 

4891 self._stack_state_is_thread_local = True 

4892 

4893 @property 

4894 def _device_function_stack(self): 

4895 if self._stack_state_is_thread_local: 

4896 # This may be called from a thread where device_function_stack doesn't yet 

4897 # exist. 

4898 # pylint: disable=protected-access 

4899 if not hasattr(self._thread_local, "_device_function_stack"): 

4900 stack_copy_for_this_thread = self._graph_device_function_stack.copy() 

4901 self._thread_local._device_function_stack = stack_copy_for_this_thread 

4902 return self._thread_local._device_function_stack 

4903 # pylint: enable=protected-access 

4904 else: 

4905 return self._graph_device_function_stack 

4906 

4907 @property 

4908 def _device_functions_outer_to_inner(self): 

4909 user_device_specs = self._device_function_stack.peek_objs() 

4910 device_functions = [spec.function for spec in user_device_specs] 

4911 device_functions_outer_to_inner = list(reversed(device_functions)) 

4912 return device_functions_outer_to_inner 

4913 

4914 def _snapshot_device_function_stack_metadata(self): 

4915 """Return device function stack as a list of TraceableObjects. 

4916 

4917 Returns: 

4918 [traceable_stack.TraceableObject, ...] where each TraceableObject's .obj 

4919 member is a displayable name for the user's argument to Graph.device, and 

4920 the filename and lineno members point to the code location where 

4921 Graph.device was called directly or indirectly by the user. 

4922 """ 

4923 snapshot = [] 

4924 for obj in self._device_function_stack.peek_traceable_objs(): 

4925 obj_copy = obj.copy_metadata() 

4926 obj_copy.obj = obj.obj.display_name 

4927 snapshot.append(obj_copy) 

4928 return snapshot 

4929 

4930 @_device_function_stack.setter 

4931 def _device_function_stack(self, device_function_stack): 

4932 if self._stack_state_is_thread_local: 

4933 # pylint: disable=protected-access 

4934 self._thread_local._device_function_stack = device_function_stack 

4935 # pylint: enable=protected-access 

4936 else: 

4937 self._graph_device_function_stack = device_function_stack 

4938 

4939 @property 

4940 def _colocation_stack(self): 

4941 """Return thread-local copy of colocation stack.""" 

4942 if self._stack_state_is_thread_local: 

4943 # This may be called from a thread where colocation_stack doesn't yet 

4944 # exist. 

4945 # pylint: disable=protected-access 

4946 if not hasattr(self._thread_local, "_colocation_stack"): 

4947 stack_copy_for_this_thread = self._graph_colocation_stack.copy() 

4948 self._thread_local._colocation_stack = stack_copy_for_this_thread 

4949 return self._thread_local._colocation_stack 

4950 # pylint: enable=protected-access 

4951 else: 

4952 return self._graph_colocation_stack 

4953 

4954 def _snapshot_colocation_stack_metadata(self): 

4955 """Return colocation stack metadata as a dictionary.""" 

4956 return { 

4957 traceable_obj.obj.name: traceable_obj.copy_metadata() 

4958 for traceable_obj in self._colocation_stack.peek_traceable_objs() 

4959 } 

4960 

4961 @_colocation_stack.setter 

4962 def _colocation_stack(self, colocation_stack): 

4963 if self._stack_state_is_thread_local: 

4964 # pylint: disable=protected-access 

4965 self._thread_local._colocation_stack = colocation_stack 

4966 # pylint: enable=protected-access 

4967 else: 

4968 self._graph_colocation_stack = colocation_stack 

4969 

4970 @property 

4971 def _control_dependencies_stack(self): 

4972 if self._stack_state_is_thread_local: 

4973 # This may be called from a thread where control_dependencies_stack 

4974 # doesn't yet exist. 

4975 if not hasattr(self._thread_local, "_control_dependencies_stack"): 

4976 self._thread_local._control_dependencies_stack = ( 

4977 self._graph_control_dependencies_stack[:]) 

4978 return self._thread_local._control_dependencies_stack 

4979 else: 

4980 return self._graph_control_dependencies_stack 

4981 

4982 @_control_dependencies_stack.setter 

4983 def _control_dependencies_stack(self, control_dependencies): 

4984 if self._stack_state_is_thread_local: 

4985 self._thread_local._control_dependencies_stack = control_dependencies 

4986 else: 

4987 self._graph_control_dependencies_stack = control_dependencies 

4988 

4989 @property 

4990 def _distribution_strategy_stack(self): 

4991 """A stack to maintain distribution strategy context for each thread.""" 

4992 if not hasattr(self._thread_local, "_distribution_strategy_stack"): 

4993 self._thread_local._distribution_strategy_stack = [] # pylint: disable=protected-access 

4994 return self._thread_local._distribution_strategy_stack # pylint: disable=protected-access 

4995 

4996 @_distribution_strategy_stack.setter 

4997 def _distribution_strategy_stack(self, _distribution_strategy_stack): 

4998 self._thread_local._distribution_strategy_stack = ( # pylint: disable=protected-access 

4999 _distribution_strategy_stack) 

5000 

5001 @property 

5002 def _global_distribute_strategy_scope(self): 

5003 """For implementing `tf.distribute.set_strategy()`.""" 

5004 if not hasattr(self._thread_local, "distribute_strategy_scope"): 

5005 self._thread_local.distribute_strategy_scope = None 

5006 return self._thread_local.distribute_strategy_scope 

5007 

5008 @_global_distribute_strategy_scope.setter 

5009 def _global_distribute_strategy_scope(self, distribute_strategy_scope): 

5010 self._thread_local.distribute_strategy_scope = (distribute_strategy_scope) 

5011 

5012 def _mutation_lock(self): 

5013 """Returns a lock to guard code that creates & mutates ops. 

5014 

5015 See the comment for self._group_lock for more info. 

5016 """ 

5017 return self._group_lock.group(_MUTATION_LOCK_GROUP) 

5018 

5019 def _session_run_lock(self): 

5020 """Returns a lock to guard code for Session.run. 

5021 

5022 See the comment for self._group_lock for more info. 

5023 """ 

5024 return self._group_lock.group(_SESSION_RUN_LOCK_GROUP) 

5025 

5026 

5027# TODO(agarwal): currently device directives in an outer eager scope will not 

5028# apply to inner graph mode code. Fix that. 

5029 

5030 

5031@tf_export(v1=["device"]) 

5032def device(device_name_or_function): 

5033 """Wrapper for `Graph.device()` using the default graph. 

5034 

5035 See `tf.Graph.device` for more details. 

5036 

5037 Args: 

5038 device_name_or_function: The device name or function to use in the context. 

5039 

5040 Returns: 

5041 A context manager that specifies the default device to use for newly 

5042 created ops. 

5043 

5044 Raises: 

5045 RuntimeError: If eager execution is enabled and a function is passed in. 

5046 """ 

5047 if context.executing_eagerly(): 

5048 if callable(device_name_or_function): 

5049 raise RuntimeError( 

5050 "tf.device does not support functions when eager execution " 

5051 "is enabled.") 

5052 return context.device(device_name_or_function) 

5053 elif executing_eagerly_outside_functions(): 

5054 @tf_contextlib.contextmanager 

5055 def combined(device_name_or_function): 

5056 with get_default_graph().device(device_name_or_function): 

5057 if not callable(device_name_or_function): 

5058 with context.device(device_name_or_function): 

5059 yield 

5060 else: 

5061 yield 

5062 return combined(device_name_or_function) 

5063 else: 

5064 return get_default_graph().device(device_name_or_function) 

5065 

5066 

5067@tf_export("device", v1=[]) 

5068def device_v2(device_name): 

5069 """Specifies the device for ops created/executed in this context. 

5070 

5071 This function specifies the device to be used for ops created/executed in a 

5072 particular context. Nested contexts will inherit and also create/execute 

5073 their ops on the specified device. If a specific device is not required, 

5074 consider not using this function so that a device can be automatically 

5075 assigned. In general the use of this function is optional. `device_name` can 

5076 be fully specified, as in "/job:worker/task:1/device:cpu:0", or partially 

5077 specified, containing only a subset of the "/"-separated fields. Any fields 

5078 which are specified will override device annotations from outer scopes. 

5079 

5080 For example: 

5081 

5082 ```python 

5083 with tf.device('/job:foo'): 

5084 # ops created here have devices with /job:foo 

5085 with tf.device('/job:bar/task:0/device:gpu:2'): 

5086 # ops created here have the fully specified device above 

5087 with tf.device('/device:gpu:1'): 

5088 # ops created here have the device '/job:foo/device:gpu:1' 

5089 ``` 

5090 

5091 Args: 

5092 device_name: The device name to use in the context. 

5093 

5094 Returns: 

5095 A context manager that specifies the default device to use for newly 

5096 created ops. 

5097 

5098 Raises: 

5099 RuntimeError: If a function is passed in. 

5100 """ 

5101 if callable(device_name): 

5102 raise RuntimeError("tf.device does not support functions.") 

5103 return device(device_name) 

5104 

5105 

5106@tf_export(v1=["container"]) 

5107def container(container_name): 

5108 """Wrapper for `Graph.container()` using the default graph. 

5109 

5110 Args: 

5111 container_name: The container string to use in the context. 

5112 

5113 Returns: 

5114 A context manager that specifies the default container to use for newly 

5115 created stateful ops. 

5116 """ 

5117 return get_default_graph().container(container_name) 

5118 

5119 

5120def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False): 

5121 if context.executing_eagerly(): 

5122 if op is not None: 

5123 if not hasattr(op, "device"): 

5124 op = convert_to_tensor(op) 

5125 return device(op.device) 

5126 else: 

5127 return NullContextmanager() 

5128 else: 

5129 default_graph = get_default_graph() 

5130 if isinstance(op, EagerTensor): 

5131 if default_graph.building_function: 

5132 return default_graph.device(op.device) 

5133 else: 

5134 raise ValueError("Encountered an Eager-defined Tensor during graph " 

5135 "construction, but a function was not being built.") 

5136 return default_graph._colocate_with_for_gradient( 

5137 op, gradient_uid=gradient_uid, ignore_existing=ignore_existing) 

5138 

5139 

5140# Internal interface to colocate_with. colocate_with has been deprecated from 

5141# public API. There are still a few internal uses of colocate_with. Add internal 

5142# only API for those uses to avoid deprecation warning. 

5143def colocate_with(op, ignore_existing=False): 

5144 return _colocate_with_for_gradient(op, None, ignore_existing=ignore_existing) 

5145 

5146 

5147@deprecation.deprecated( 

5148 date=None, instructions="Colocations handled automatically by placer.") 

5149@tf_export(v1=["colocate_with"]) 

5150def _colocate_with(op, ignore_existing=False): 

5151 return colocate_with(op, ignore_existing) 

5152 

5153 

5154@tf_export("control_dependencies") 

5155def control_dependencies(control_inputs): 

5156 """Wrapper for `Graph.control_dependencies()` using the default graph. 

5157 

5158 See `tf.Graph.control_dependencies` for more details. 

5159 

5160 In TensorFlow 2 with eager and/or Autograph, you should not need this method 

5161 most of the times, as ops execute in the expected order thanks to automatic 

5162 control dependencies. Only use it to manually control ordering, for example as 

5163 a workaround to known issues such as `tf.function` with `tf.debugging.assert*` 

5164 and `tf.py_function`. 

5165 For example: 

5166 

5167 >>> @tf.function( 

5168 ... input_signature=[tf.TensorSpec([None, None], tf.float32), 

5169 ... tf.TensorSpec([None, None], tf.float32)]) 

5170 ... def my_assert_func_1(x, bias): 

5171 ... # `tf.function` attempts to execute `tf.math.add` in parallel to 

5172 ... # `assert_equal`. As a result an error can get raised from `tf.math.add` 

5173 ... # without triggering the assertion error. 

5174 ... tf.assert_equal(tf.shape(x)[1], 

5175 ... tf.shape(bias)[1], 

5176 ... message='bad shape') 

5177 ... return x + bias 

5178 

5179 >>> # Error raised in either `add` or `assert` 

5180 >>> my_assert_func_1(tf.ones((2, 5)), tf.ones((2, 7))) 

5181 Traceback (most recent call last): 

5182 ... 

5183 InvalidArgumentError: ... 

5184 

5185 

5186 >>> @tf.function( 

5187 ... input_signature=[tf.TensorSpec([None, None], tf.float32), 

5188 ... tf.TensorSpec([None, None], tf.float32)]) 

5189 ... def my_assert_func_2(x, bias): 

5190 ... with tf.control_dependencies( 

5191 ... [tf.assert_equal(tf.shape(x)[1], 

5192 ... tf.shape(bias)[1], 

5193 ... message='bad shape')]): 

5194 ... return x + bias 

5195 

5196 >>> # Error raised in `assert` 

5197 >>> my_assert_func_2(tf.ones((2, 5)), tf.ones((2, 7))) 

5198 Traceback (most recent call last): 

5199 ... 

5200 InvalidArgumentError: ... 

5201 

5202 When eager execution is enabled, any callable object in the `control_inputs` 

5203 list will be called. 

5204 

5205 Args: 

5206 control_inputs: A list of `Operation` or `Tensor` objects which must be 

5207 executed or computed before running the operations defined in the context. 

5208 Can also be `None` to clear the control dependencies. If eager execution 

5209 is enabled, any callable object in the `control_inputs` list will be 

5210 called. 

5211 

5212 Returns: 

5213 A context manager that specifies control dependencies for all 

5214 operations constructed within the context. 

5215 """ 

5216 if context.executing_eagerly(): 

5217 if control_inputs: 

5218 # Execute any pending callables. 

5219 for control in control_inputs: 

5220 if callable(control): 

5221 control() 

5222 return NullContextmanager() 

5223 else: 

5224 return get_default_graph().control_dependencies(control_inputs) 

5225 

5226# TODO(b/271463878): Remove in favor of direct references to `stack`. 

5227get_default_session = stack.get_default_session 

5228 

5229 

5230def _eval_using_default_session(tensors, feed_dict, graph, session=None): 

5231 """Uses the default session to evaluate one or more tensors. 

5232 

5233 Args: 

5234 tensors: A single Tensor, or a list of Tensor objects. 

5235 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 

5236 numpy ndarrays, TensorProtos, or strings. 

5237 graph: The graph in which the tensors are defined. 

5238 session: (Optional) A different session to use to evaluate "tensors". 

5239 

5240 Returns: 

5241 Either a single numpy ndarray if "tensors" is a single tensor; or a list 

5242 of numpy ndarrays that each correspond to the respective element in 

5243 "tensors". 

5244 

5245 Raises: 

5246 ValueError: If no default session is available; the default session 

5247 does not have "graph" as its graph; or if "session" is specified, 

5248 and it does not have "graph" as its graph. 

5249 """ 

5250 if session is None: 

5251 session = stack.get_default_session() 

5252 if session is None: 

5253 raise ValueError("Cannot evaluate tensor using `eval()`: No default " 

5254 "session is registered. Use `with " 

5255 "sess.as_default()` or pass an explicit session to " 

5256 "`eval(session=sess)`") 

5257 if session.graph is not graph: 

5258 raise ValueError("Cannot use the default session to evaluate tensor: " 

5259 "the tensor's graph is different from the session's " 

5260 "graph. Pass an explicit session to " 

5261 "`eval(session=sess)`.") 

5262 else: 

5263 if session.graph is not graph: 

5264 raise ValueError("Cannot use the given session to evaluate tensor: " 

5265 "the tensor's graph is different from the session's " 

5266 "graph.") 

5267 return session.run(tensors, feed_dict) 

5268 

5269 

5270def _run_using_default_session(operation, feed_dict, graph, session=None): 

5271 """Uses the default session to run "operation". 

5272 

5273 Args: 

5274 operation: The Operation to be run. 

5275 feed_dict: A dictionary that maps Tensor objects (or tensor names) to lists, 

5276 numpy ndarrays, TensorProtos, or strings. 

5277 graph: The graph in which "operation" is defined. 

5278 session: (Optional) A different session to use to run "operation". 

5279 

5280 Raises: 

5281 ValueError: If no default session is available; the default session 

5282 does not have "graph" as its graph; or if "session" is specified, 

5283 and it does not have "graph" as its graph. 

5284 """ 

5285 if session is None: 

5286 session = stack.get_default_session() 

5287 if session is None: 

5288 raise ValueError("Cannot execute operation using `run()`: No default " 

5289 "session is registered. Use `with " 

5290 "sess.as_default():` or pass an explicit session to " 

5291 "`run(session=sess)`") 

5292 if session.graph is not graph: 

5293 raise ValueError("Cannot use the default session to execute operation: " 

5294 "the operation's graph is different from the " 

5295 "session's graph. Pass an explicit session to " 

5296 "run(session=sess).") 

5297 else: 

5298 if session.graph is not graph: 

5299 raise ValueError("Cannot use the given session to execute operation: " 

5300 "the operation's graph is different from the session's " 

5301 "graph.") 

5302 session.run(operation, feed_dict) 

5303 

5304 

5305class _DefaultGraphStack(stack.DefaultStack): # pylint: disable=protected-access 

5306 """A thread-local stack of objects for providing an implicit default graph.""" 

5307 

5308 def __init__(self): 

5309 super(_DefaultGraphStack, self).__init__() 

5310 self._global_default_graph = None 

5311 

5312 def get_default(self): 

5313 """Override that returns a global default if the stack is empty.""" 

5314 if self.stack: 

5315 return self.stack[-1] 

5316 elif self._global_default_graph: 

5317 return self._global_default_graph 

5318 else: 

5319 self._global_default_graph = Graph() 

5320 return self._global_default_graph 

5321 

5322 def _GetGlobalDefaultGraph(self): 

5323 if self._global_default_graph is None: 

5324 # TODO(mrry): Perhaps log that the default graph is being used, or set 

5325 # provide some other feedback to prevent confusion when a mixture of 

5326 # the global default graph and an explicit graph are combined in the 

5327 # same process. 

5328 self._global_default_graph = Graph() 

5329 return self._global_default_graph 

5330 

5331 def reset(self): 

5332 super(_DefaultGraphStack, self).reset() 

5333 self._global_default_graph = None 

5334 

5335 @tf_contextlib.contextmanager 

5336 def get_controller(self, default): 

5337 context.context().context_switches.push(default.building_function, 

5338 default.as_default, 

5339 default._device_function_stack) 

5340 try: 

5341 with super(_DefaultGraphStack, 

5342 self).get_controller(default) as g, context.graph_mode(): 

5343 yield g 

5344 finally: 

5345 # If an exception is raised here it may be hiding a related exception in 

5346 # the try-block (just above). 

5347 context.context().context_switches.pop() 

5348 

5349 

5350_default_graph_stack = _DefaultGraphStack() 

5351 

5352 

5353# Shared helper used in init_scope and executing_eagerly_outside_functions 

5354# to obtain the outermost context that is not building a function, and the 

5355# innermost non empty device stack. 

5356def _get_outer_context_and_inner_device_stack(): 

5357 """Get the outermost context not building a function.""" 

5358 default_graph = get_default_graph() 

5359 outer_context = None 

5360 innermost_nonempty_device_stack = default_graph._device_function_stack # pylint: disable=protected-access 

5361 

5362 if not _default_graph_stack.stack: 

5363 # If the default graph stack is empty, then we cannot be building a 

5364 # function. Install the global graph (which, in this case, is also the 

5365 # default graph) as the outer context. 

5366 if default_graph.building_function: 

5367 raise RuntimeError("The global graph is building a function.") 

5368 outer_context = default_graph.as_default 

5369 else: 

5370 # Find a context that is not building a function. 

5371 for stack_entry in reversed(context.context().context_switches.stack): 

5372 if not innermost_nonempty_device_stack: 

5373 innermost_nonempty_device_stack = stack_entry.device_stack 

5374 if not stack_entry.is_building_function: 

5375 outer_context = stack_entry.enter_context_fn 

5376 break 

5377 

5378 if outer_context is None: 

5379 # As a last resort, obtain the global default graph; this graph doesn't 

5380 # necessarily live on the graph stack (and hence it doesn't necessarily 

5381 # live on the context stack), but it is stored in the graph stack's 

5382 # encapsulating object. 

5383 outer_context = _default_graph_stack._GetGlobalDefaultGraph().as_default # pylint: disable=protected-access 

5384 

5385 if outer_context is None: 

5386 # Sanity check; this shouldn't be triggered. 

5387 raise RuntimeError("All graphs are building functions, and no " 

5388 "eager context was previously active.") 

5389 

5390 return outer_context, innermost_nonempty_device_stack 

5391 

5392 

5393# pylint: disable=g-doc-return-or-yield,line-too-long 

5394@tf_export("init_scope") 

5395@tf_contextlib.contextmanager 

5396def init_scope(): 

5397 """A context manager that lifts ops out of control-flow scopes and function-building graphs. 

5398 

5399 There is often a need to lift variable initialization ops out of control-flow 

5400 scopes, function-building graphs, and gradient tapes. Entering an 

5401 `init_scope` is a mechanism for satisfying these desiderata. In particular, 

5402 entering an `init_scope` has three effects: 

5403 

5404 (1) All control dependencies are cleared the moment the scope is entered; 

5405 this is equivalent to entering the context manager returned from 

5406 `control_dependencies(None)`, which has the side-effect of exiting 

5407 control-flow scopes like `tf.cond` and `tf.while_loop`. 

5408 

5409 (2) All operations that are created while the scope is active are lifted 

5410 into the lowest context on the `context_stack` that is not building a 

5411 graph function. Here, a context is defined as either a graph or an eager 

5412 context. Every context switch, i.e., every installation of a graph as 

5413 the default graph and every switch into eager mode, is logged in a 

5414 thread-local stack called `context_switches`; the log entry for a 

5415 context switch is popped from the stack when the context is exited. 

5416 Entering an `init_scope` is equivalent to crawling up 

5417 `context_switches`, finding the first context that is not building a 

5418 graph function, and entering it. A caveat is that if graph mode is 

5419 enabled but the default graph stack is empty, then entering an 

5420 `init_scope` will simply install a fresh graph as the default one. 

5421 

5422 (3) The gradient tape is paused while the scope is active. 

5423 

5424 When eager execution is enabled, code inside an init_scope block runs with 

5425 eager execution enabled even when tracing a `tf.function`. For example: 

5426 

5427 ```python 

5428 tf.compat.v1.enable_eager_execution() 

5429 

5430 @tf.function 

5431 def func(): 

5432 # A function constructs TensorFlow graphs, 

5433 # it does not execute eagerly. 

5434 assert not tf.executing_eagerly() 

5435 with tf.init_scope(): 

5436 # Initialization runs with eager execution enabled 

5437 assert tf.executing_eagerly() 

5438 ``` 

5439 

5440 Raises: 

5441 RuntimeError: if graph state is incompatible with this initialization. 

5442 """ 

5443 # pylint: enable=g-doc-return-or-yield,line-too-long 

5444 

5445 if context.executing_eagerly(): 

5446 # Fastpath. 

5447 with record.stop_recording(): 

5448 yield 

5449 else: 

5450 # Retrieve the active name scope: entering an `init_scope` preserves 

5451 # the name scope of the current context. 

5452 scope = get_default_graph().get_name_scope() 

5453 if scope and scope[-1] != "/": 

5454 # Names that end with trailing slashes are treated by `name_scope` as 

5455 # absolute. 

5456 scope = scope + "/" 

5457 

5458 outer_context, innermost_nonempty_device_stack = ( 

5459 _get_outer_context_and_inner_device_stack()) 

5460 

5461 outer_graph = None 

5462 outer_device_stack = None 

5463 try: 

5464 with outer_context(), name_scope( 

5465 scope, skip_on_eager=False), control_dependencies( 

5466 None), record.stop_recording(): 

5467 context_manager = NullContextmanager 

5468 context_manager_input = None 

5469 if not context.executing_eagerly(): 

5470 # The device stack is preserved when lifting into a graph. Eager 

5471 # execution doesn't implement device stacks and in particular it 

5472 # doesn't support device functions, so in general it's not possible 

5473 # to do the same when lifting into the eager context. 

5474 outer_graph = get_default_graph() 

5475 outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access 

5476 outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access 

5477 elif innermost_nonempty_device_stack is not None: 

5478 for device_spec in innermost_nonempty_device_stack.peek_objs(): 

5479 if device_spec.function is None: 

5480 break 

5481 if device_spec.raw_string: 

5482 context_manager = context.device 

5483 context_manager_input = device_spec.raw_string 

5484 break 

5485 # It is currently not possible to have a device function in V2, 

5486 # but in V1 we are unable to apply device functions in eager mode. 

5487 # This means that we will silently skip some of the entries on the 

5488 # device stack in V1 + eager mode. 

5489 

5490 with context_manager(context_manager_input): 

5491 yield 

5492 finally: 

5493 # If an exception is raised here it may be hiding a related exception in 

5494 # try-block (just above). 

5495 if outer_graph is not None: 

5496 outer_graph._device_function_stack = outer_device_stack # pylint: disable=protected-access 

5497 

5498 

5499@tf_export(v1=["executing_eagerly_outside_functions"]) 

5500def executing_eagerly_outside_functions(): 

5501 """Returns True if executing eagerly, even if inside a graph function. 

5502 

5503 This function will check the outermost context for the program and see if 

5504 it is in eager mode. It is useful comparing to `tf.executing_eagerly()`, 

5505 which checks the current context and will return `False` within a 

5506 `tf.function` body. It can be used to build library that behave differently 

5507 in eager runtime and v1 session runtime (deprecated). 

5508 

5509 Example: 

5510 

5511 >>> tf.compat.v1.enable_eager_execution() 

5512 >>> @tf.function 

5513 ... def func(): 

5514 ... # A function constructs TensorFlow graphs, it does not execute eagerly, 

5515 ... # but the outer most context is still eager. 

5516 ... assert not tf.executing_eagerly() 

5517 ... return tf.compat.v1.executing_eagerly_outside_functions() 

5518 >>> func() 

5519 <tf.Tensor: shape=(), dtype=bool, numpy=True> 

5520 

5521 Returns: 

5522 boolean, whether the outermost context is in eager mode. 

5523 """ 

5524 if context.executing_eagerly(): 

5525 return True 

5526 else: 

5527 outer_context, _ = _get_outer_context_and_inner_device_stack() 

5528 with outer_context(): 

5529 return context.executing_eagerly() 

5530 

5531 

5532@tf_export("inside_function", v1=[]) 

5533def inside_function(): 

5534 """Indicates whether the caller code is executing inside a `tf.function`. 

5535 

5536 Returns: 

5537 Boolean, True if the caller code is executing inside a `tf.function` 

5538 rather than eagerly. 

5539 

5540 Example: 

5541 

5542 >>> tf.inside_function() 

5543 False 

5544 >>> @tf.function 

5545 ... def f(): 

5546 ... print(tf.inside_function()) 

5547 >>> f() 

5548 True 

5549 """ 

5550 return get_default_graph().building_function 

5551 

5552 

5553@tf_export(v1=["enable_eager_execution"]) 

5554def enable_eager_execution(config=None, device_policy=None, 

5555 execution_mode=None): 

5556 """Enables eager execution for the lifetime of this program. 

5557 

5558 Eager execution provides an imperative interface to TensorFlow. With eager 

5559 execution enabled, TensorFlow functions execute operations immediately (as 

5560 opposed to adding to a graph to be executed later in a `tf.compat.v1.Session`) 

5561 and 

5562 return concrete values (as opposed to symbolic references to a node in a 

5563 computational graph). 

5564 

5565 For example: 

5566 

5567 ```python 

5568 tf.compat.v1.enable_eager_execution() 

5569 

5570 # After eager execution is enabled, operations are executed as they are 

5571 # defined and Tensor objects hold concrete values, which can be accessed as 

5572 # numpy.ndarray`s through the numpy() method. 

5573 assert tf.multiply(6, 7).numpy() == 42 

5574 ``` 

5575 

5576 Eager execution cannot be enabled after TensorFlow APIs have been used to 

5577 create or execute graphs. It is typically recommended to invoke this function 

5578 at program startup and not in a library (as most libraries should be usable 

5579 both with and without eager execution). 

5580 

5581 @compatibility(TF2) 

5582 This function is not necessary if you are using TF2. Eager execution is 

5583 enabled by default. 

5584 @end_compatibility 

5585 

5586 Args: 

5587 config: (Optional.) A `tf.compat.v1.ConfigProto` to use to configure the 

5588 environment in which operations are executed. Note that 

5589 `tf.compat.v1.ConfigProto` is also used to configure graph execution (via 

5590 `tf.compat.v1.Session`) and many options within `tf.compat.v1.ConfigProto` 

5591 are not implemented (or are irrelevant) when eager execution is enabled. 

5592 device_policy: (Optional.) Policy controlling how operations requiring 

5593 inputs on a specific device (e.g., a GPU 0) handle inputs on a different 

5594 device (e.g. GPU 1 or CPU). When set to None, an appropriate value will 

5595 be picked automatically. The value picked may change between TensorFlow 

5596 releases. 

5597 Valid values: 

5598 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the 

5599 placement is not correct. 

5600 - DEVICE_PLACEMENT_WARN: copies the tensors which are not 

5601 on the right device but logs a warning. 

5602 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. 

5603 Note that this may hide performance problems as there is no notification 

5604 provided when operations are blocked on the tensor being copied between 

5605 devices. 

5606 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies 

5607 int32 tensors, raising errors on the other ones. 

5608 execution_mode: (Optional.) Policy controlling how operations dispatched are 

5609 actually executed. When set to None, an appropriate value will be picked 

5610 automatically. The value picked may change between TensorFlow releases. 

5611 Valid values: 

5612 - SYNC: executes each operation synchronously. 

5613 - ASYNC: executes each operation asynchronously. These 

5614 operations may return "non-ready" handles. 

5615 

5616 Raises: 

5617 ValueError: If eager execution is enabled after creating/executing a 

5618 TensorFlow graph, or if options provided conflict with a previous call 

5619 to this function. 

5620 """ 

5621 _api_usage_gauge.get_cell().set(True) 

5622 logging.vlog(1, "Enabling eager execution") 

5623 if context.default_execution_mode != context.EAGER_MODE: 

5624 return enable_eager_execution_internal( 

5625 config=config, 

5626 device_policy=device_policy, 

5627 execution_mode=execution_mode, 

5628 server_def=None) 

5629 

5630 

5631@tf_export(v1=["disable_eager_execution"]) 

5632def disable_eager_execution(): 

5633 """Disables eager execution. 

5634 

5635 This function can only be called before any Graphs, Ops, or Tensors have been 

5636 created. 

5637 

5638 @compatibility(TF2) 

5639 This function is not necessary if you are using TF2. Eager execution is 

5640 enabled by default. If you want to use Graph mode please consider 

5641 [tf.function](https://www.tensorflow.org/api_docs/python/tf/function). 

5642 @end_compatibility 

5643 """ 

5644 _api_usage_gauge.get_cell().set(False) 

5645 logging.vlog(1, "Disabling eager execution") 

5646 context.default_execution_mode = context.GRAPH_MODE 

5647 c = context.context_safe() 

5648 if c is not None: 

5649 c._thread_local_data.is_eager = False # pylint: disable=protected-access 

5650 

5651 

5652def enable_eager_execution_internal(config=None, 

5653 device_policy=None, 

5654 execution_mode=None, 

5655 server_def=None): 

5656 """Enables eager execution for the lifetime of this program. 

5657 

5658 Most of the doc string for enable_eager_execution is relevant here as well. 

5659 

5660 Args: 

5661 config: See enable_eager_execution doc string 

5662 device_policy: See enable_eager_execution doc string 

5663 execution_mode: See enable_eager_execution doc string 

5664 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution on 

5665 remote devices. GrpcServers need to be started by creating an identical 

5666 server_def to this, and setting the appropriate task_indexes, so that the 

5667 servers can communicate. It will then be possible to execute operations on 

5668 remote devices. 

5669 

5670 Raises: 

5671 ValueError 

5672 

5673 """ 

5674 if config is not None and not isinstance(config, config_pb2.ConfigProto): 

5675 raise TypeError("config must be a tf.ConfigProto, but got %s" % 

5676 type(config)) 

5677 if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT, 

5678 context.DEVICE_PLACEMENT_WARN, 

5679 context.DEVICE_PLACEMENT_SILENT, 

5680 context.DEVICE_PLACEMENT_SILENT_FOR_INT32): 

5681 raise ValueError("device_policy must be one of None, DEVICE_PLACEMENT_*") 

5682 if execution_mode not in (None, context.SYNC, context.ASYNC): 

5683 raise ValueError("execution_mode must be one of None, SYNC, " "ASYNC") 

5684 if context.default_execution_mode == context.GRAPH_MODE: 

5685 graph_mode_has_been_used = ( 

5686 _default_graph_stack._global_default_graph is not None) # pylint: disable=protected-access 

5687 if graph_mode_has_been_used: 

5688 raise ValueError( 

5689 "tf.enable_eager_execution must be called at program startup.") 

5690 context.default_execution_mode = context.EAGER_MODE 

5691 # pylint: disable=protected-access 

5692 with context._context_lock: 

5693 if context._context is None: 

5694 context._set_context_locked(context.Context( 

5695 config=config, 

5696 device_policy=device_policy, 

5697 execution_mode=execution_mode, 

5698 server_def=server_def)) 

5699 elif ((config is not None and config is not context._context._config) or 

5700 (device_policy is not None and 

5701 device_policy is not context._context._device_policy) or 

5702 (execution_mode is not None and 

5703 execution_mode is not context._context._execution_mode)): 

5704 raise ValueError( 

5705 "Trying to change the options of an active eager" 

5706 " execution. Context config: %s, specified config:" 

5707 " %s. Context device policy: %s, specified device" 

5708 " policy: %s. Context execution mode: %s, " 

5709 " specified execution mode %s." % 

5710 (context._context._config, config, context._context._device_policy, 

5711 device_policy, context._context._execution_mode, execution_mode)) 

5712 else: 

5713 # We already created everything, so update the thread local data. 

5714 context._context._thread_local_data.is_eager = True 

5715 

5716 # Monkey patch to get rid of an unnecessary conditional since the context is 

5717 # now initialized. 

5718 context.context = context.context_safe 

5719 

5720 

5721def eager_run(main=None, argv=None): 

5722 """Runs the program with an optional main function and argv list. 

5723 

5724 The program will run with eager execution enabled. 

5725 

5726 Example: 

5727 ```python 

5728 import tensorflow as tf 

5729 # Import subject to future changes: 

5730 

5731 def main(_): 

5732 u = tf.constant(6.0) 

5733 v = tf.constant(7.0) 

5734 print(u * v) 

5735 

5736 if __name__ == "__main__": 

5737 tfe.run() 

5738 ``` 

5739 

5740 Args: 

5741 main: the main function to run. 

5742 argv: the arguments to pass to it. 

5743 """ 

5744 enable_eager_execution() 

5745 app.run(main, argv) 

5746 

5747 

5748@tf_export(v1=["reset_default_graph"]) 

5749def reset_default_graph(): 

5750 """Clears the default graph stack and resets the global default graph. 

5751 

5752 NOTE: The default graph is a property of the current thread. This 

5753 function applies only to the current thread. Calling this function while 

5754 a `tf.compat.v1.Session` or `tf.compat.v1.InteractiveSession` is active will 

5755 result in undefined 

5756 behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects 

5757 after calling this function will result in undefined behavior. 

5758 

5759 @compatibility(TF2) 

5760 `reset_default_graph` does not work with either eager execution or 

5761 `tf.function`, and you should not invoke it directly. To migrate code that 

5762 uses Graph-related functions to TF2, rewrite the code without them. See the 

5763 [migration guide](https://www.tensorflow.org/guide/migrate) for more 

5764 description about the behavior and semantic changes between Tensorflow 1 and 

5765 Tensorflow 2. 

5766 @end_compatibility 

5767 

5768 Raises: 

5769 AssertionError: If this function is called within a nested graph. 

5770 """ 

5771 if not _default_graph_stack.is_cleared(): 

5772 raise AssertionError("Do not use tf.reset_default_graph() to clear " 

5773 "nested graphs. If you need a cleared graph, " 

5774 "exit the nesting and create a new graph.") 

5775 _default_graph_stack.reset() 

5776 

5777 

5778@tf_export(v1=["get_default_graph"]) 

5779def get_default_graph(): 

5780 """Returns the default graph for the current thread. 

5781 

5782 The returned graph will be the innermost graph on which a 

5783 `Graph.as_default()` context has been entered, or a global default 

5784 graph if none has been explicitly created. 

5785 

5786 NOTE: The default graph is a property of the current thread. If you 

5787 create a new thread, and wish to use the default graph in that 

5788 thread, you must explicitly add a `with g.as_default():` in that 

5789 thread's function. 

5790 

5791 @compatibility(TF2) 

5792 `get_default_graph` does not work with either eager execution or 

5793 `tf.function`, and you should not invoke it directly. To migrate code that 

5794 uses Graph-related functions to TF2, rewrite the code without them. See the 

5795 [migration guide](https://www.tensorflow.org/guide/migrate) for more 

5796 description about the behavior and semantic changes between Tensorflow 1 and 

5797 Tensorflow 2. 

5798 @end_compatibility 

5799 

5800 Returns: 

5801 The default `Graph` being used in the current thread. 

5802 """ 

5803 return _default_graph_stack.get_default() 

5804 

5805 

5806def has_default_graph(): 

5807 """Returns True if there is a default graph.""" 

5808 return len(_default_graph_stack.stack) >= 1 

5809 

5810 

5811# Exported due to b/171079555 

5812@tf_export("__internal__.get_name_scope", v1=[]) 

5813def get_name_scope(): 

5814 """Returns the current name scope in the default_graph. 

5815 

5816 For example: 

5817 

5818 ```python 

5819 with tf.name_scope('scope1'): 

5820 with tf.name_scope('scope2'): 

5821 print(tf.get_name_scope()) 

5822 ``` 

5823 would print the string `scope1/scope2`. 

5824 

5825 Returns: 

5826 A string representing the current name scope. 

5827 """ 

5828 if context.executing_eagerly(): 

5829 return context.context().scope_name.rstrip("/") 

5830 return get_default_graph().get_name_scope() 

5831 

5832 

5833def _assert_same_graph(original_item, item): 

5834 """Fail if the 2 items are from different graphs. 

5835 

5836 Args: 

5837 original_item: Original item to check against. 

5838 item: Item to check. 

5839 

5840 Raises: 

5841 ValueError: if graphs do not match. 

5842 """ 

5843 original_graph = getattr(original_item, "graph", None) 

5844 graph = getattr(item, "graph", None) 

5845 if original_graph and graph and original_graph is not graph: 

5846 raise ValueError( 

5847 "%s must be from the same graph as %s (graphs are %s and %s)." % 

5848 (item, original_item, graph, original_graph)) 

5849 

5850 

5851def _get_graph_from_inputs(op_input_list, graph=None): 

5852 """Returns the appropriate graph to use for the given inputs. 

5853 

5854 This library method provides a consistent algorithm for choosing the graph 

5855 in which an Operation should be constructed: 

5856 

5857 1. If the default graph is being used to construct a function, we 

5858 use the default graph. 

5859 2. If the "graph" is specified explicitly, we validate that all of the inputs 

5860 in "op_input_list" are compatible with that graph. 

5861 3. Otherwise, we attempt to select a graph from the first Operation- 

5862 or Tensor-valued input in "op_input_list", and validate that all other 

5863 such inputs are in the same graph. 

5864 4. If the graph was not specified and it could not be inferred from 

5865 "op_input_list", we attempt to use the default graph. 

5866 

5867 Args: 

5868 op_input_list: A list of inputs to an operation, which may include `Tensor`, 

5869 `Operation`, and other objects that may be converted to a graph element. 

5870 graph: (Optional) The explicit graph to use. 

5871 

5872 Raises: 

5873 TypeError: If op_input_list is not a list or tuple, or if graph is not a 

5874 Graph. 

5875 ValueError: If a graph is explicitly passed and not all inputs are from it, 

5876 or if the inputs are from multiple graphs, or we could not find a graph 

5877 and there was no default graph. 

5878 

5879 Returns: 

5880 The appropriate graph to use for the given inputs. 

5881 

5882 """ 

5883 current_default_graph = get_default_graph() 

5884 if current_default_graph.building_function: 

5885 return current_default_graph 

5886 

5887 op_input_list = tuple(op_input_list) # Handle generators correctly 

5888 if graph and not isinstance(graph, Graph): 

5889 raise TypeError("Input graph needs to be a Graph: %s" % (graph,)) 

5890 

5891 # 1. We validate that all of the inputs are from the same graph. This is 

5892 # either the supplied graph parameter, or the first one selected from one 

5893 # the graph-element-valued inputs. In the latter case, we hold onto 

5894 # that input in original_graph_element so we can provide a more 

5895 # informative error if a mismatch is found. 

5896 original_graph_element = None 

5897 for op_input in op_input_list: 

5898 # Determine if this is a valid graph_element. 

5899 # TODO(josh11b): Note that we exclude subclasses of Tensor. Need to clean this 

5900 # up. 

5901 graph_element = None 

5902 if isinstance(op_input, (Operation, internal.NativeObject)) and ( 

5903 (not isinstance(op_input, Tensor)) or type(op_input) == Tensor # pylint: disable=unidiomatic-typecheck 

5904 ): 

5905 graph_element = op_input 

5906 else: 

5907 graph_element = _as_graph_element(op_input) 

5908 

5909 if graph_element is not None: 

5910 if not graph: 

5911 original_graph_element = graph_element 

5912 graph = getattr(graph_element, "graph", None) 

5913 elif original_graph_element is not None: 

5914 _assert_same_graph(original_graph_element, graph_element) 

5915 elif graph_element.graph is not graph: 

5916 raise ValueError("%s is not from the passed-in graph." % graph_element) 

5917 

5918 # 2. If all else fails, we use the default graph, which is always there. 

5919 return graph or current_default_graph 

5920 

5921 

5922@tf_export(v1=["GraphKeys"]) 

5923class GraphKeys(object): 

5924 """Standard names to use for graph collections. 

5925 

5926 The standard library uses various well-known names to collect and 

5927 retrieve values associated with a graph. For example, the 

5928 `tf.Optimizer` subclasses default to optimizing the variables 

5929 collected under `tf.GraphKeys.TRAINABLE_VARIABLES` if none is 

5930 specified, but it is also possible to pass an explicit list of 

5931 variables. 

5932 

5933 The following standard keys are defined: 

5934 

5935 * `GLOBAL_VARIABLES`: the default collection of `Variable` objects, shared 

5936 across distributed environment (model variables are subset of these). See 

5937 `tf.compat.v1.global_variables` 

5938 for more details. 

5939 Commonly, all `TRAINABLE_VARIABLES` variables will be in `MODEL_VARIABLES`, 

5940 and all `MODEL_VARIABLES` variables will be in `GLOBAL_VARIABLES`. 

5941 * `LOCAL_VARIABLES`: the subset of `Variable` objects that are local to each 

5942 machine. Usually used for temporarily variables, like counters. 

5943 * `MODEL_VARIABLES`: the subset of `Variable` objects that are used in the 

5944 model for inference (feed forward). 

5945 * `TRAINABLE_VARIABLES`: the subset of `Variable` objects that will 

5946 be trained by an optimizer. See 

5947 `tf.compat.v1.trainable_variables` 

5948 for more details. 

5949 * `SUMMARIES`: the summary `Tensor` objects that have been created in the 

5950 graph. See 

5951 `tf.compat.v1.summary.merge_all` 

5952 for more details. 

5953 * `QUEUE_RUNNERS`: the `QueueRunner` objects that are used to 

5954 produce input for a computation. See 

5955 `tf.compat.v1.train.start_queue_runners` 

5956 for more details. 

5957 * `MOVING_AVERAGE_VARIABLES`: the subset of `Variable` objects that will also 

5958 keep moving averages. See 

5959 `tf.compat.v1.moving_average_variables` 

5960 for more details. 

5961 * `REGULARIZATION_LOSSES`: regularization losses collected during graph 

5962 construction. 

5963 

5964 The following standard keys are _defined_, but their collections are **not** 

5965 automatically populated as many of the others are: 

5966 

5967 * `WEIGHTS` 

5968 * `BIASES` 

5969 * `ACTIVATIONS` 

5970 """ 

5971 

5972 # Key to collect Variable objects that are global (shared across machines). 

5973 # Default collection for all variables, except local ones. 

5974 GLOBAL_VARIABLES = "variables" 

5975 # Key to collect local variables that are local to the machine and are not 

5976 # saved/restored. 

5977 LOCAL_VARIABLES = "local_variables" 

5978 # Key to collect local variables which are used to accumulate internal state 

5979 # to be used in tf.metrics.*. 

5980 METRIC_VARIABLES = "metric_variables" 

5981 # Key to collect model variables defined by layers. 

5982 MODEL_VARIABLES = "model_variables" 

5983 # Key to collect Variable objects that will be trained by the 

5984 # optimizers. 

5985 TRAINABLE_VARIABLES = "trainable_variables" 

5986 # Key to collect summaries. 

5987 SUMMARIES = "summaries" 

5988 # Key to collect QueueRunners. 

5989 QUEUE_RUNNERS = "queue_runners" 

5990 # Key to collect table initializers. 

5991 TABLE_INITIALIZERS = "table_initializer" 

5992 # Key to collect asset filepaths. An asset represents an external resource 

5993 # like a vocabulary file. 

5994 ASSET_FILEPATHS = "asset_filepaths" 

5995 # Key to collect Variable objects that keep moving averages. 

5996 MOVING_AVERAGE_VARIABLES = "moving_average_variables" 

5997 # Key to collect regularization losses at graph construction. 

5998 REGULARIZATION_LOSSES = "regularization_losses" 

5999 # Key to collect concatenated sharded variables. 

6000 CONCATENATED_VARIABLES = "concatenated_variables" 

6001 # Key to collect savers. 

6002 SAVERS = "savers" 

6003 # Key to collect weights 

6004 WEIGHTS = "weights" 

6005 # Key to collect biases 

6006 BIASES = "biases" 

6007 # Key to collect activations 

6008 ACTIVATIONS = "activations" 

6009 # Key to collect update_ops 

6010 UPDATE_OPS = "update_ops" 

6011 # Key to collect losses 

6012 LOSSES = "losses" 

6013 # Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. 

6014 SAVEABLE_OBJECTS = "saveable_objects" 

6015 # Key to collect all shared resources used by the graph which need to be 

6016 # initialized once per cluster. 

6017 RESOURCES = "resources" 

6018 # Key to collect all shared resources used in this graph which need to be 

6019 # initialized once per session. 

6020 LOCAL_RESOURCES = "local_resources" 

6021 # Trainable resource-style variables. 

6022 TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables" 

6023 

6024 # Key to indicate various ops. 

6025 INIT_OP = "init_op" 

6026 LOCAL_INIT_OP = "local_init_op" 

6027 READY_OP = "ready_op" 

6028 READY_FOR_LOCAL_INIT_OP = "ready_for_local_init_op" 

6029 SUMMARY_OP = "summary_op" 

6030 GLOBAL_STEP = "global_step" 

6031 

6032 # Used to count the number of evaluations performed during a single evaluation 

6033 # run. 

6034 EVAL_STEP = "eval_step" 

6035 TRAIN_OP = "train_op" 

6036 

6037 # Key for control flow context. 

6038 COND_CONTEXT = "cond_context" 

6039 WHILE_CONTEXT = "while_context" 

6040 

6041 # Used to store v2 summary names. 

6042 _SUMMARY_COLLECTION = "_SUMMARY_V2" 

6043 

6044 # List of all collections that keep track of variables. 

6045 _VARIABLE_COLLECTIONS = [ 

6046 GLOBAL_VARIABLES, 

6047 LOCAL_VARIABLES, 

6048 METRIC_VARIABLES, 

6049 MODEL_VARIABLES, 

6050 TRAINABLE_VARIABLES, 

6051 MOVING_AVERAGE_VARIABLES, 

6052 CONCATENATED_VARIABLES, 

6053 TRAINABLE_RESOURCE_VARIABLES, 

6054 ] 

6055 

6056 # Key for streaming model ports. 

6057 # NOTE(yuanbyu): internal and experimental. 

6058 _STREAMING_MODEL_PORTS = "streaming_model_ports" 

6059 

6060 @decorator_utils.classproperty 

6061 @deprecation.deprecated(None, "Use `tf.GraphKeys.GLOBAL_VARIABLES` instead.") 

6062 def VARIABLES(cls): # pylint: disable=no-self-argument 

6063 return cls.GLOBAL_VARIABLES 

6064 

6065 

6066def dismantle_graph(graph): 

6067 """Cleans up reference cycles from a `Graph`. 

6068 

6069 Helpful for making sure the garbage collector doesn't need to run after a 

6070 temporary `Graph` is no longer needed. 

6071 

6072 Args: 

6073 graph: A `Graph` object to destroy. Neither it nor any of its ops are usable 

6074 after this function runs. 

6075 """ 

6076 graph._functions.clear() # pylint: disable=protected-access 

6077 graph.Dismantle() 

6078 

6079 

6080@tf_export(v1=["add_to_collection"]) 

6081def add_to_collection(name, value): 

6082 """Wrapper for `Graph.add_to_collection()` using the default graph. 

6083 

6084 See `tf.Graph.add_to_collection` 

6085 for more details. 

6086 

6087 Args: 

6088 name: The key for the collection. For example, the `GraphKeys` class 

6089 contains many standard names for collections. 

6090 value: The value to add to the collection. 

6091 

6092 @compatibility(eager) 

6093 Collections are only supported in eager when variables are created inside 

6094 an EagerVariableStore (e.g. as part of a layer or template). 

6095 @end_compatibility 

6096 """ 

6097 get_default_graph().add_to_collection(name, value) 

6098 

6099 

6100@tf_export(v1=["add_to_collections"]) 

6101def add_to_collections(names, value): 

6102 """Wrapper for `Graph.add_to_collections()` using the default graph. 

6103 

6104 See `tf.Graph.add_to_collections` 

6105 for more details. 

6106 

6107 Args: 

6108 names: The key for the collections. The `GraphKeys` class contains many 

6109 standard names for collections. 

6110 value: The value to add to the collections. 

6111 

6112 @compatibility(eager) 

6113 Collections are only supported in eager when variables are created inside 

6114 an EagerVariableStore (e.g. as part of a layer or template). 

6115 @end_compatibility 

6116 """ 

6117 get_default_graph().add_to_collections(names, value) 

6118 

6119 

6120@tf_export(v1=["get_collection_ref"]) 

6121def get_collection_ref(key): 

6122 """Wrapper for `Graph.get_collection_ref()` using the default graph. 

6123 

6124 See `tf.Graph.get_collection_ref` 

6125 for more details. 

6126 

6127 Args: 

6128 key: The key for the collection. For example, the `GraphKeys` class contains 

6129 many standard names for collections. 

6130 

6131 Returns: 

6132 The list of values in the collection with the given `name`, or an empty 

6133 list if no value has been added to that collection. Note that this returns 

6134 the collection list itself, which can be modified in place to change the 

6135 collection. 

6136 

6137 @compatibility(eager) 

6138 Collections are not supported when eager execution is enabled. 

6139 @end_compatibility 

6140 """ 

6141 return get_default_graph().get_collection_ref(key) 

6142 

6143 

6144@tf_export(v1=["get_collection"]) 

6145def get_collection(key, scope=None): 

6146 """Wrapper for `Graph.get_collection()` using the default graph. 

6147 

6148 See `tf.Graph.get_collection` 

6149 for more details. 

6150 

6151 Args: 

6152 key: The key for the collection. For example, the `GraphKeys` class contains 

6153 many standard names for collections. 

6154 scope: (Optional.) If supplied, the resulting list is filtered to include 

6155 only items whose `name` attribute matches using `re.match`. Items without 

6156 a `name` attribute are never returned if a scope is supplied and the 

6157 choice or `re.match` means that a `scope` without special tokens filters 

6158 by prefix. 

6159 

6160 Returns: 

6161 The list of values in the collection with the given `name`, or 

6162 an empty list if no value has been added to that collection. The 

6163 list contains the values in the order under which they were 

6164 collected. 

6165 

6166 @compatibility(eager) 

6167 Collections are not supported when eager execution is enabled. 

6168 @end_compatibility 

6169 """ 

6170 return get_default_graph().get_collection(key, scope) 

6171 

6172 

6173def get_all_collection_keys(): 

6174 """Returns a list of collections used in the default graph.""" 

6175 return get_default_graph().get_all_collection_keys() 

6176 

6177 

6178def name_scope(name, default_name=None, values=None, skip_on_eager=True): 

6179 """Internal-only entry point for `name_scope*`. 

6180 

6181 Internal ops do not use the public API and instead rely on 

6182 `ops.name_scope` regardless of the execution mode. This function 

6183 dispatches to the correct `name_scope*` implementation based on 

6184 the arguments provided and the current mode. Specifically, 

6185 

6186 * if `values` contains a graph tensor `Graph.name_scope` is used; 

6187 * `name_scope_v1` is used in graph mode; 

6188 * `name_scope_v2` -- in eager mode. 

6189 

6190 Args: 

6191 name: The name argument that is passed to the op function. 

6192 default_name: The default name to use if the `name` argument is `None`. 

6193 values: The list of `Tensor` arguments that are passed to the op function. 

6194 skip_on_eager: Indicates to return NullContextmanager if executing eagerly. 

6195 By default this is True since naming tensors and operations in eager mode 

6196 have little use and cause unnecessary performance overhead. However, it is 

6197 important to preserve variable names since they are often useful for 

6198 debugging and saved models. 

6199 

6200 Returns: 

6201 `name_scope*` context manager. 

6202 """ 

6203 if not context.executing_eagerly(): 

6204 return internal_name_scope_v1(name, default_name, values) 

6205 

6206 if skip_on_eager: 

6207 return NullContextmanager() 

6208 

6209 name = default_name if name is None else name 

6210 if values: 

6211 # The presence of a graph tensor in `values` overrides the context. 

6212 # TODO(slebedev): this is Keras-specific and should be removed. 

6213 # pylint: disable=unidiomatic-typecheck 

6214 graph_value = next((value for value in values if type(value) == Tensor), 

6215 None) 

6216 # pylint: enable=unidiomatic-typecheck 

6217 if graph_value is not None: 

6218 return graph_value.graph.name_scope(name) 

6219 

6220 return name_scope_v2(name or "") 

6221 

6222 

6223class internal_name_scope_v1(object): # pylint: disable=invalid-name 

6224 """Graph-only version of `name_scope_v1`.""" 

6225 

6226 @property 

6227 def name(self): 

6228 return self._name 

6229 

6230 def __init__(self, name, default_name=None, values=None): 

6231 """Initialize the context manager. 

6232 

6233 Args: 

6234 name: The name argument that is passed to the op function. 

6235 default_name: The default name to use if the `name` argument is `None`. 

6236 values: The list of `Tensor` arguments that are passed to the op function. 

6237 

6238 Raises: 

6239 TypeError: if `default_name` is passed in but not a string. 

6240 """ 

6241 if not (default_name is None or isinstance(default_name, str)): 

6242 raise TypeError( 

6243 "`default_name` type (%s) is not a string type. You likely meant to " 

6244 "pass this into the `values` kwarg." % type(default_name)) 

6245 self._name = default_name if name is None else name 

6246 self._default_name = default_name 

6247 self._values = values 

6248 

6249 def __enter__(self): 

6250 """Start the scope block. 

6251 

6252 Returns: 

6253 The scope name. 

6254 

6255 Raises: 

6256 ValueError: if neither `name` nor `default_name` is provided 

6257 but `values` are. 

6258 """ 

6259 if self._name is None and self._values is not None: 

6260 # We only raise an error if values is not None (provided) because 

6261 # currently tf.name_scope(None) (values=None then) is sometimes used as 

6262 # an idiom to reset to top scope. 

6263 raise ValueError( 

6264 "At least one of name (%s) and default_name (%s) must be provided." 

6265 % (self._name, self._default_name)) 

6266 

6267 g = get_default_graph() 

6268 if self._values and not g.building_function: 

6269 # Specialize based on the knowledge that `_get_graph_from_inputs()` 

6270 # ignores `inputs` when building a function. 

6271 g_from_inputs = _get_graph_from_inputs(self._values) 

6272 if g_from_inputs is not g: 

6273 g = g_from_inputs 

6274 self._g_manager = g.as_default() 

6275 self._g_manager.__enter__() 

6276 else: 

6277 self._g_manager = None 

6278 else: 

6279 self._g_manager = None 

6280 

6281 try: 

6282 self._name_scope = g.name_scope(self._name) 

6283 return self._name_scope.__enter__() 

6284 except: 

6285 if self._g_manager is not None: 

6286 self._g_manager.__exit__(*sys.exc_info()) 

6287 raise 

6288 

6289 def __exit__(self, *exc_info): 

6290 self._name_scope.__exit__(*exc_info) 

6291 if self._g_manager is not None: 

6292 self._g_manager.__exit__(*exc_info) 

6293 

6294 

6295# Named like a function for backwards compatibility with the 

6296# @tf_contextlib.contextmanager version, which was switched to a class to avoid 

6297# some object creation overhead. 

6298@tf_export(v1=["name_scope"]) 

6299class name_scope_v1(object): # pylint: disable=invalid-name 

6300 """A context manager for use when defining a Python op. 

6301 

6302 This context manager validates that the given `values` are from the 

6303 same graph, makes that graph the default graph, and pushes a 

6304 name scope in that graph (see 

6305 `tf.Graph.name_scope` 

6306 for more details on that). 

6307 

6308 For example, to define a new Python op called `my_op`: 

6309 

6310 ```python 

6311 def my_op(a, b, c, name=None): 

6312 with tf.name_scope(name, "MyOp", [a, b, c]) as scope: 

6313 a = tf.convert_to_tensor(a, name="a") 

6314 b = tf.convert_to_tensor(b, name="b") 

6315 c = tf.convert_to_tensor(c, name="c") 

6316 # Define some computation that uses `a`, `b`, and `c`. 

6317 return foo_op(..., name=scope) 

6318 ``` 

6319 """ 

6320 

6321 __slots__ = ["_name", "_name_scope"] 

6322 

6323 @property 

6324 def name(self): 

6325 return self._name 

6326 

6327 def __init__(self, name, default_name=None, values=None): 

6328 """Initialize the context manager. 

6329 

6330 Args: 

6331 name: The name argument that is passed to the op function. 

6332 default_name: The default name to use if the `name` argument is `None`. 

6333 values: The list of `Tensor` arguments that are passed to the op function. 

6334 

6335 Raises: 

6336 TypeError: if `default_name` is passed in but not a string. 

6337 """ 

6338 self._name_scope = name_scope( 

6339 name, default_name, values, skip_on_eager=False) 

6340 self._name = default_name if name is None else name 

6341 

6342 def __enter__(self): 

6343 return self._name_scope.__enter__() 

6344 

6345 def __exit__(self, *exc_info): 

6346 return self._name_scope.__exit__(*exc_info) 

6347 

6348 

6349@tf_export("get_current_name_scope", v1=[]) 

6350def get_current_name_scope(): 

6351 """Returns current full name scope specified by `tf.name_scope(...)`s. 

6352 

6353 For example, 

6354 ```python 

6355 with tf.name_scope("outer"): 

6356 tf.get_current_name_scope() # "outer" 

6357 

6358 with tf.name_scope("inner"): 

6359 tf.get_current_name_scope() # "outer/inner" 

6360 ``` 

6361 

6362 In other words, `tf.get_current_name_scope()` returns the op name prefix that 

6363 will be prepended to, if an op is created at that place. 

6364 

6365 Note that `@tf.function` resets the name scope stack as shown below. 

6366 

6367 ``` 

6368 with tf.name_scope("outer"): 

6369 

6370 @tf.function 

6371 def foo(x): 

6372 with tf.name_scope("inner"): 

6373 return tf.add(x * x) # Op name is "inner/Add", not "outer/inner/Add" 

6374 ``` 

6375 """ 

6376 

6377 ctx = context.context() 

6378 if ctx.executing_eagerly(): 

6379 return ctx.scope_name.rstrip("/") 

6380 else: 

6381 return get_default_graph().get_name_scope() 

6382 

6383 

6384@tf_export("name_scope", v1=[]) 

6385class name_scope_v2(object): 

6386 """A context manager for use when defining a Python op. 

6387 

6388 This context manager pushes a name scope, which will make the name of all 

6389 operations added within it have a prefix. 

6390 

6391 For example, to define a new Python op called `my_op`: 

6392 

6393 ```python 

6394 def my_op(a, b, c, name=None): 

6395 with tf.name_scope("MyOp") as scope: 

6396 a = tf.convert_to_tensor(a, name="a") 

6397 b = tf.convert_to_tensor(b, name="b") 

6398 c = tf.convert_to_tensor(c, name="c") 

6399 # Define some computation that uses `a`, `b`, and `c`. 

6400 return foo_op(..., name=scope) 

6401 ``` 

6402 

6403 When executed, the Tensors `a`, `b`, `c`, will have names `MyOp/a`, `MyOp/b`, 

6404 and `MyOp/c`. 

6405 

6406 Inside a `tf.function`, if the scope name already exists, the name will be 

6407 made unique by appending `_n`. For example, calling `my_op` the second time 

6408 will generate `MyOp_1/a`, etc. 

6409 """ 

6410 

6411 __slots__ = ["_name", "_exit_fns"] 

6412 

6413 def __init__(self, name): 

6414 """Initialize the context manager. 

6415 

6416 Args: 

6417 name: The prefix to use on all names created within the name scope. 

6418 

6419 Raises: 

6420 ValueError: If name is not a string. 

6421 """ 

6422 if not isinstance(name, str): 

6423 raise ValueError("name for name_scope must be a string.") 

6424 self._name = name 

6425 self._exit_fns = [] 

6426 

6427 @property 

6428 def name(self): 

6429 return self._name 

6430 

6431 def __enter__(self): 

6432 """Start the scope block. 

6433 

6434 Returns: 

6435 The scope name. 

6436 """ 

6437 ctx = context.context() 

6438 if ctx.executing_eagerly(): 

6439 # Names are not auto-incremented in eager mode. 

6440 # A trailing slash breaks out of nested name scopes, indicating a 

6441 # fully specified scope name, for compatibility with Graph.name_scope. 

6442 # This also prevents auto-incrementing. 

6443 old_name = ctx.scope_name 

6444 name = self._name 

6445 if not name: 

6446 scope_name = "" 

6447 elif name[-1] == "/": 

6448 scope_name = name 

6449 elif old_name: 

6450 scope_name = old_name + name + "/" 

6451 else: 

6452 scope_name = name + "/" 

6453 ctx.scope_name = scope_name 

6454 

6455 def _restore_name_scope(*_): 

6456 ctx.scope_name = old_name 

6457 

6458 self._exit_fns.append(_restore_name_scope) 

6459 else: 

6460 scope = get_default_graph().name_scope(self._name) 

6461 scope_name = scope.__enter__() 

6462 self._exit_fns.append(scope.__exit__) 

6463 return scope_name 

6464 

6465 def __exit__(self, type_arg, value_arg, traceback_arg): 

6466 self._exit_fns.pop()(type_arg, value_arg, traceback_arg) 

6467 return False # False values do not suppress exceptions 

6468 

6469 def __getstate__(self): 

6470 return self._name, self._exit_fns 

6471 

6472 def __setstate__(self, state): 

6473 self._name = state[0] 

6474 self._exit_fns = state[1] 

6475 

6476 

6477def strip_name_scope(name, export_scope): 

6478 """Removes name scope from a name. 

6479 

6480 Args: 

6481 name: A `string` name. 

6482 export_scope: Optional `string`. Name scope to remove. 

6483 

6484 Returns: 

6485 Name with name scope removed, or the original name if export_scope 

6486 is None. 

6487 """ 

6488 if export_scope: 

6489 if export_scope[-1] == "/": 

6490 export_scope = export_scope[:-1] 

6491 

6492 try: 

6493 # Strips export_scope/, export_scope///, 

6494 # ^export_scope/, loc:@export_scope/. 

6495 str_to_replace = r"([\^]|loc:@|^)" + export_scope + r"[\/]+(.*)" 

6496 return re.sub(str_to_replace, r"\1\2", compat.as_str(name), count=1) 

6497 except TypeError as e: 

6498 # If the name is not of a type we can process, simply return it. 

6499 logging.warning(e) 

6500 return name 

6501 else: 

6502 return name 

6503 

6504 

6505def prepend_name_scope(name, import_scope): 

6506 """Prepends name scope to a name. 

6507 

6508 Args: 

6509 name: A `string` name. 

6510 import_scope: Optional `string`. Name scope to add. 

6511 

6512 Returns: 

6513 Name with name scope added, or the original name if import_scope 

6514 is None. 

6515 """ 

6516 if import_scope: 

6517 if import_scope[-1] == "/": 

6518 import_scope = import_scope[:-1] 

6519 

6520 try: 

6521 str_to_replace = r"([\^]|loc:@|^)(.*)" 

6522 return re.sub(str_to_replace, r"\1" + import_scope + r"/\2", 

6523 compat.as_str(name)) 

6524 except TypeError as e: 

6525 # If the name is not of a type we can process, simply return it. 

6526 logging.warning(e) 

6527 return name 

6528 else: 

6529 return name 

6530 

6531 

6532# pylint: disable=g-doc-return-or-yield 

6533# pylint: disable=not-context-manager 

6534@tf_export(v1=["op_scope"]) 

6535@tf_contextlib.contextmanager 

6536def op_scope(values, name, default_name=None): 

6537 """DEPRECATED. Same as name_scope above, just different argument order.""" 

6538 logging.warn("tf.op_scope(values, name, default_name) is deprecated," 

6539 " use tf.name_scope(name, default_name, values)") 

6540 with name_scope(name, default_name=default_name, values=values) as scope: 

6541 yield scope 

6542 

6543 

6544_proto_function_registry = registry.Registry("proto functions") 

6545 

6546 

6547def register_proto_function(collection_name, 

6548 proto_type=None, 

6549 to_proto=None, 

6550 from_proto=None): 

6551 """Registers `to_proto` and `from_proto` functions for collection_name. 

6552 

6553 `to_proto` function converts a Python object to the corresponding protocol 

6554 buffer, and returns the protocol buffer. 

6555 

6556 `from_proto` function converts protocol buffer into a Python object, and 

6557 returns the object.. 

6558 

6559 Args: 

6560 collection_name: Name of the collection. 

6561 proto_type: Protobuf type, such as `saver_pb2.SaverDef`, 

6562 `variable_pb2.VariableDef`, `queue_runner_pb2.QueueRunnerDef`.. 

6563 to_proto: Function that implements Python object to protobuf conversion. 

6564 from_proto: Function that implements protobuf to Python object conversion. 

6565 """ 

6566 if to_proto and not callable(to_proto): 

6567 raise TypeError("to_proto must be callable.") 

6568 if from_proto and not callable(from_proto): 

6569 raise TypeError("from_proto must be callable.") 

6570 

6571 _proto_function_registry.register((proto_type, to_proto, from_proto), 

6572 collection_name) 

6573 

6574 

6575def get_collection_proto_type(collection_name): 

6576 """Returns the proto_type for collection_name.""" 

6577 try: 

6578 return _proto_function_registry.lookup(collection_name)[0] 

6579 except LookupError: 

6580 return None 

6581 

6582 

6583def get_to_proto_function(collection_name): 

6584 """Returns the to_proto function for collection_name.""" 

6585 try: 

6586 return _proto_function_registry.lookup(collection_name)[1] 

6587 except LookupError: 

6588 return None 

6589 

6590 

6591def get_from_proto_function(collection_name): 

6592 """Returns the from_proto function for collection_name.""" 

6593 try: 

6594 return _proto_function_registry.lookup(collection_name)[2] 

6595 except LookupError: 

6596 return None 

6597 

6598 

6599def _op_to_colocate_with(v, graph): 

6600 """Operation object corresponding to v to use for colocation constraints.""" 

6601 if v is None: 

6602 return None, None 

6603 if isinstance(v, Operation): 

6604 return v, None 

6605 

6606 # We always want to colocate with the reference op. 

6607 # When 'v' is a ResourceVariable, the reference op is the handle creating op. 

6608 # 

6609 # What this should be is: 

6610 # if isinstance(v, ResourceVariable): 

6611 # return v.handle.op, v 

6612 # However, that would require a circular import dependency. 

6613 # As of October 2018, there were attempts underway to remove 

6614 # colocation constraints altogether. Assuming that will 

6615 # happen soon, perhaps this hack to work around the circular 

6616 # import dependency is acceptable. 

6617 if hasattr(v, "handle") and isinstance(v.handle, Tensor): 

6618 device_only_candidate = lambda: None 

6619 device_only_candidate.device = v.device 

6620 device_only_candidate.name = v.name 

6621 if graph.building_function: 

6622 return graph.capture(v.handle).op, device_only_candidate 

6623 else: 

6624 return v.handle.op, device_only_candidate 

6625 if isinstance(v, EagerTensor) and not context.executing_eagerly(): 

6626 return convert_to_tensor(v, as_ref=True).op, None 

6627 elif isinstance(v, internal.NativeObject): 

6628 return v.op, None 

6629 else: 

6630 return convert_to_tensor(v, as_ref=True).op, None 

6631 

6632 

6633# Helper functions for op wrapper modules generated by `python_op_gen`. 

6634 

6635 

6636def to_raw_op(f): 

6637 """Make a given op wrapper function `f` raw. 

6638 

6639 Raw op wrappers can only be called with keyword arguments. 

6640 

6641 Args: 

6642 f: An op wrapper function to make raw. 

6643 

6644 Returns: 

6645 Raw `f`. 

6646 """ 

6647 # Copy `f` to get a new `__dict__`, otherwise `tf_export` will fail 

6648 # due to double-registration. 

6649 f = types.FunctionType(f.__code__, f.__globals__, f.__name__, f.__defaults__, 

6650 f.__closure__) 

6651 return kwarg_only(f) 

6652 

6653 

6654def raise_from_not_ok_status(e, name): 

6655 e.message += (" name: " + str(name if name is not None else "")) 

6656 raise core._status_to_exception(e) from None # pylint: disable=protected-access 

6657 

6658 

6659def add_exit_callback_to_default_func_graph(fn): 

6660 """Add a callback to run when the default function graph goes out of scope. 

6661 

6662 Usage: 

6663 

6664 ```python 

6665 @tf.function 

6666 def fn(x, v): 

6667 expensive = expensive_object(v) 

6668 add_exit_callback_to_default_func_graph(lambda: expensive.release()) 

6669 return g(x, expensive) 

6670 

6671 fn(x=tf.constant(...), v=...) 

6672 # `expensive` has been released. 

6673 ``` 

6674 

6675 Args: 

6676 fn: A callable that takes no arguments and whose output is ignored. 

6677 To be executed when exiting func graph scope. 

6678 

6679 Raises: 

6680 RuntimeError: If executed when the current default graph is not a FuncGraph, 

6681 or not currently executing in function creation mode (e.g., if inside 

6682 an init_scope). 

6683 """ 

6684 default_graph = get_default_graph() 

6685 if not default_graph._building_function: # pylint: disable=protected-access 

6686 raise RuntimeError( 

6687 "Cannot add scope exit callbacks when not building a function. " 

6688 "Default graph: {}".format(default_graph)) 

6689 default_graph._add_scope_exit_callback(fn) # pylint: disable=protected-access 

6690 

6691 

6692def _reconstruct_sequence_inputs(op_def, inputs, attrs): 

6693 """Regroups a flat list of input tensors into scalar and sequence inputs. 

6694 

6695 Args: 

6696 op_def: The `op_def_pb2.OpDef` (for knowing the input types) 

6697 inputs: a list of input `Tensor`s to the op. 

6698 attrs: mapping from attr name to `attr_value_pb2.AttrValue` (these define 

6699 how long each sequence is) 

6700 

6701 Returns: 

6702 A list of `Tensor`s (corresponding to scalar inputs) and lists of 

6703 `Tensor`s (corresponding to sequence inputs). 

6704 """ 

6705 grouped_inputs = [] 

6706 i = 0 

6707 for input_arg in op_def.input_arg: 

6708 if input_arg.number_attr: 

6709 input_len = attrs[input_arg.number_attr].i 

6710 is_sequence = True 

6711 elif input_arg.type_list_attr: 

6712 input_len = len(attrs[input_arg.type_list_attr].list.type) 

6713 is_sequence = True 

6714 else: 

6715 input_len = 1 

6716 is_sequence = False 

6717 

6718 if is_sequence: 

6719 grouped_inputs.append(inputs[i:i + input_len]) 

6720 else: 

6721 grouped_inputs.append(inputs[i]) 

6722 i += input_len 

6723 

6724 assert i == len(inputs) 

6725 return grouped_inputs 

6726 

6727 

6728_numpy_style_type_promotion = False 

6729 

6730 

6731def enable_numpy_style_type_promotion(): 

6732 """If called, follows NumPy's rules for type promotion. 

6733 

6734 Used for enabling NumPy behavior on methods for TF NumPy. 

6735 """ 

6736 global _numpy_style_type_promotion 

6737 _numpy_style_type_promotion = True 

6738 

6739 

6740_numpy_style_slicing = False 

6741 

6742 

6743def enable_numpy_style_slicing(): 

6744 """If called, follows NumPy's rules for slicing Tensors. 

6745 

6746 Used for enabling NumPy behavior on slicing for TF NumPy. 

6747 """ 

6748 global _numpy_style_slicing 

6749 _numpy_style_slicing = True 

6750 

6751 

6752class _TensorIterator(object): 

6753 """Iterates over the leading dim of a Tensor. Performs no error checks.""" 

6754 

6755 __slots__ = ["_tensor", "_index", "_limit"] 

6756 

6757 def __init__(self, tensor, dim0): 

6758 self._tensor = tensor 

6759 self._index = 0 

6760 self._limit = dim0 

6761 

6762 def __iter__(self): 

6763 return self 

6764 

6765 def __next__(self): 

6766 if self._index == self._limit: 

6767 raise StopIteration 

6768 result = self._tensor[self._index] 

6769 self._index += 1 

6770 return result 

6771 

6772 next = __next__ # python2.x compatibility. 

6773 

6774 

6775def set_int_list_attr(op, attr_name, ints): 

6776 """TF internal method used to set a list(int) attribute in the node_def.""" 

6777 ints_list = attr_value_pb2.AttrValue.ListValue(i=ints) 

6778 op._set_attr(attr_name, attr_value_pb2.AttrValue(list=ints_list)) # pylint:disable=protected-access 

6779 

6780 

6781def _get_enclosing_context(graph): 

6782 # pylint: disable=protected-access 

6783 if graph is None: 

6784 return None 

6785 

6786 if graph._control_flow_context is not None: 

6787 return graph._control_flow_context 

6788 

6789 if graph.building_function and hasattr(graph, "outer_graph"): 

6790 return _get_enclosing_context(graph.outer_graph) 

6791 

6792 

6793# TODO(b/271463878): Remove in favor of direct references to `handle_data_util`. 

6794get_resource_handle_data = handle_data_util.get_resource_handle_data 

6795 

6796 

6797def _copy_handle_data_to_arg_def(tensor, arg_def): 

6798 handle_data = handle_data_util.get_resource_handle_data(tensor) 

6799 if handle_data.shape_and_type: 

6800 shape_and_type = handle_data.shape_and_type[0] 

6801 proto = arg_def.handle_data.add() 

6802 proto.dtype = shape_and_type.dtype 

6803 proto.shape.CopyFrom(handle_data.shape_and_type[0].shape) 

6804 

6805 

6806# This will be replaced by a concrete implementation in a future CL. 

6807@tf_export("__internal__.SymbolicTensor") 

6808class SymbolicTensor(object): 

6809 """Stub class for symbolic tensors.""" 

6810 

6811 

6812@tf_export("is_symbolic_tensor", v1=["is_symbolic_tensor"]) 

6813def is_symbolic_tensor(tensor): 

6814 """Test if `tensor` is a symbolic Tensor. 

6815 

6816 Args: 

6817 tensor: a tensor-like object 

6818 

6819 Returns: 

6820 True if `tensor` is a symbolic tensor (not an eager tensor). 

6821 """ 

6822 return type(tensor) == Tensor # pylint: disable=unidiomatic-typecheck