Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/context.py: 45%

1252 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""State management for eager execution.""" 

16 

17import collections 

18import contextlib 

19import copy 

20import gc 

21import itertools 

22import os 

23import random 

24import threading 

25 

26from absl import logging 

27import numpy as np 

28 

29from tensorflow.core.framework import function_pb2 

30from tensorflow.core.protobuf import config_pb2 

31from tensorflow.core.protobuf import rewriter_config_pb2 

32from tensorflow.python import pywrap_tfe 

33from tensorflow.python import tf2 

34from tensorflow.python.client import pywrap_tf_session 

35from tensorflow.python.eager import cancellation 

36from tensorflow.python.eager import execute 

37from tensorflow.python.eager import executor 

38from tensorflow.python.eager import monitoring 

39from tensorflow.python.framework import c_api_util 

40from tensorflow.python.framework import device as pydev 

41from tensorflow.python.framework import tfrt_utils 

42from tensorflow.python.util import compat 

43from tensorflow.python.util import function_utils 

44from tensorflow.python.util import is_in_graph_mode 

45from tensorflow.python.util import tf_contextlib 

46from tensorflow.python.util.deprecation import deprecated 

47from tensorflow.python.util.tf_export import tf_export 

48from tensorflow.tsl.protobuf import coordination_config_pb2 

49 

50GRAPH_MODE = 0 

51EAGER_MODE = 1 

52 

53default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 

54 

55# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 

56# new_device_spec). 

57# Note that we do not protect this with a lock and instead rely on python's GIL 

58# and the idempotent nature of writes to provide thread safety. 

59_device_parsing_cache = {} 

60_starting_device_spec = pydev.DeviceSpec.from_string("") 

61 

62_MAXINT32 = 2**31 - 1 

63 

64DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT 

65DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN 

66DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT 

67DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 

68 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 

69 

70SYNC = 0 

71ASYNC = 1 

72 

73_KEEP_ALIVE_SECS = 600 

74 

75_python_eager_context_create_counter = monitoring.Counter( 

76 "/tensorflow/api/python/eager_context_create_counter", 

77 "Counter for number of eager contexts created in Python.") 

78 

79# Re-exporting through context. 

80is_tfrt_enabled = tfrt_utils.enabled 

81 

82# This flag and the associated environment var are transient and will eventually 

83# be removed, once this experiment is enabled by default. 

84_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1" 

85 

86 

87def run_eager_op_as_function_enabled(): 

88 return True 

89 

90 

91# This method should only be called after the context has beein initialized. 

92def enable_jit_compile_rewrite(): 

93 """Run jit_compile functions through rewrite pass. 

94 

95 This runs jit_compile functions through all of the multidevice function 

96 rewrite passes. 

97 """ 

98 global _JIT_COMPILE_REWRITE_ENABLED 

99 _JIT_COMPILE_REWRITE_ENABLED = True 

100 if context_safe() is not None: 

101 context_safe().jit_compile_rewrite = True 

102 

103 

104# This method should only be called after the context has been initialized. 

105def disable_jit_compile_rewrite(): 

106 global _JIT_COMPILE_REWRITE_ENABLED 

107 _JIT_COMPILE_REWRITE_ENABLED = False 

108 if context_safe() is not None: 

109 context_safe().jit_compile_rewrite = False 

110 

111 

112def jit_compile_rewrite_enabled(): 

113 if context_safe() is not None: 

114 return context_safe().jit_compile_rewrite 

115 return _JIT_COMPILE_REWRITE_ENABLED 

116 

117 

118# Expose it as internally public APIs for Keras use cases in b/171080602. 

119tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) 

120 

121 

122class _EagerTensorCache(object): 

123 """Simple cache which evicts items based on length in a FIFO manner.""" 

124 

125 __slots__ = ["_data", "_max_items", "_max_tensor_size"] 

126 

127 def __init__(self, max_items=256, max_tensor_size=10000): 

128 self._data = collections.OrderedDict() 

129 self._max_items = max_items 

130 self._max_tensor_size = max_tensor_size 

131 

132 def put(self, key, value): 

133 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 

134 return 

135 

136 self._data[key] = value 

137 

138 if len(self._data) > self._max_items: 

139 self._data.popitem(last=False) 

140 

141 def get(self, key): 

142 return self._data.get(key, None) 

143 

144 def flush(self): 

145 self._data.clear() 

146 

147 

148class FunctionCallOptions: 

149 """Options applied at call sites of eager functions. 

150 

151 Eager functions are functions decorated with tf.contrib.eager.defun. 

152 """ 

153 

154 __slots__ = ["_config_proto_serialized", "_executor_type"] 

155 

156 def __init__(self, executor_type=None, config_proto=None): 

157 """Constructor. 

158 

159 Args: 

160 executor_type: (optional) name of the executor to be used to execute the 

161 eager function. If None or an empty string, the default Tensorflow 

162 executor will be used. 

163 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized 

164 string of that proto. The config used by Grappler when optimizing the 

165 function graph. Each concrete function is optimized the first time is 

166 called. Changing config_proto after the first call has no effect. If 

167 config_proto is None, an empty RewriterConfig will be used. 

168 """ 

169 self.config_proto_serialized = config_proto 

170 self.executor_type = executor_type 

171 

172 @property 

173 def executor_type(self): 

174 return self._executor_type 

175 

176 @executor_type.setter 

177 def executor_type(self, executor_type): 

178 self._executor_type = executor_type 

179 

180 @property 

181 def config_proto_serialized(self): 

182 return self._config_proto_serialized 

183 

184 @config_proto_serialized.setter 

185 def config_proto_serialized(self, config): 

186 if isinstance(config, config_pb2.ConfigProto): 

187 self._config_proto_serialized = config.SerializeToString( 

188 deterministic=True) 

189 elif isinstance(config, str): 

190 self._config_proto_serialized = config 

191 elif config is None: 

192 self._config_proto_serialized = ( 

193 config_pb2.ConfigProto().SerializeToString()) 

194 else: 

195 raise ValueError("the rewriter config must be either a " 

196 "config_pb2.ConfigProto, or a serialized string of that " 

197 "proto or None. got: {}".format(type(config))) 

198 

199 def as_attrs(self): 

200 if self.config_proto_serialized is None: 

201 config = function_utils.get_disabled_rewriter_config() 

202 else: 

203 config = self.config_proto_serialized 

204 executor_type = self.executor_type or "" 

205 

206 return {"executor_type": executor_type, "config_proto": config} 

207 

208 

209# Map from context_id (an int) to _TensorCaches. 

210# Dicts are thread safe in CPython. 

211# TODO(iga): Remove this once TensorCaches are moved to C++. 

212_tensor_caches_map = {} 

213 

214 

215class _TensorCaches(threading.local): 

216 """Thread local tensor caches.""" 

217 

218 __slots__ = ["_ones_rank_cache", "_zeros_cache"] 

219 

220 def __init__(self): 

221 super().__init__() 

222 self._ones_rank_cache = None 

223 self._zeros_cache = None 

224 

225 @property 

226 def ones_rank_cache(self): 

227 if not self._ones_rank_cache: 

228 self._ones_rank_cache = _EagerTensorCache() 

229 return self._ones_rank_cache 

230 

231 @property 

232 def zeros_cache(self): 

233 if not self._zeros_cache: 

234 self._zeros_cache = _EagerTensorCache() 

235 return self._zeros_cache 

236 

237 

238ContextSwitch = collections.namedtuple( 

239 "ContextSwitch", 

240 ["is_building_function", "enter_context_fn", "device_stack"]) 

241 

242 

243# `_ContextSwitchStack` is a `threading.local` to match the semantics of 

244# ``DefaultGraphStack`, which is also a `threading.local`. 

245class _ContextSwitchStack(threading.local): 

246 """A thread-local stack of context switches.""" 

247 

248 def __init__(self, eager): 

249 super().__init__() 

250 self.stack = [] 

251 if eager: 

252 # Initialize the stack with a pointer to enter the eager context; this 

253 # ensures that the fact that eager execution was enabled is propagated 

254 # across threads, since (1) `enable_eager_execution` modifies a 

255 # process-level flag (`default_execution_mode`) and (2) `__init__` is 

256 # called each time a threading.local object is used in a separate thread. 

257 self.push( 

258 is_building_function=False, 

259 enter_context_fn=eager_mode, 

260 device_stack=None) 

261 

262 def push(self, is_building_function, enter_context_fn, device_stack): 

263 """Push metadata about a context switch onto the stack. 

264 

265 A context switch can take any one of the two forms: installing a graph as 

266 the default graph, or entering the eager context. For each context switch, 

267 we record whether or not the entered context is building a function. 

268 

269 Args: 

270 is_building_function: (bool.) Whether the context is building a function. 

271 enter_context_fn: (function.) A callable that executes the context switch. 

272 For example, `graph.as_default` or `eager_mode`. 

273 device_stack: If applicable, the device function stack for this graph. 

274 When breaking out of graphs in init_scope, the innermost nonempty device 

275 stack is used. Eager contexts put `None` here and the value is never 

276 used. 

277 """ 

278 

279 self.stack.append( 

280 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 

281 

282 def pop(self): 

283 """Pop the stack.""" 

284 

285 self.stack.pop() 

286 

287 

288@tf_export("config.LogicalDevice") 

289class LogicalDevice( 

290 collections.namedtuple("LogicalDevice", ["name", "device_type"])): 

291 """Abstraction for a logical device initialized by the runtime. 

292 

293 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a 

294 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors 

295 and operations can be placed on a specific logical device by calling 

296 `tf.device` with a specified `tf.config.LogicalDevice`. 

297 

298 Fields: 

299 name: The fully qualified name of the device. Can be used for Op or function 

300 placement. 

301 device_type: String declaring the type of device such as "CPU" or "GPU". 

302 """ 

303 pass 

304 

305 

306@tf_export("config.LogicalDeviceConfiguration", 

307 "config.experimental.VirtualDeviceConfiguration") 

308class LogicalDeviceConfiguration( 

309 collections.namedtuple("LogicalDeviceConfiguration", [ 

310 "memory_limit", "experimental_priority", "experimental_device_ordinal" 

311 ])): 

312 """Configuration class for a logical devices. 

313 

314 The class specifies the parameters to configure a `tf.config.PhysicalDevice` 

315 as it is initialized to a `tf.config.LogicalDevice` during runtime 

316 initialization. Not all fields are valid for all device types. 

317 

318 See `tf.config.get_logical_device_configuration` and 

319 `tf.config.set_logical_device_configuration` for usage examples. 

320 

321 Fields: 

322 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual 

323 device. Currently only supported for GPUs. 

324 experimental_priority: (optional) Priority to assign to a virtual device. 

325 Lower values have higher priorities and 0 is the default. 

326 Within a physical GPU, the GPU scheduler will prioritize ops on virtual 

327 devices with higher priority. Currently only supported for Nvidia GPUs. 

328 experimental_device_ordinal: (optional) Ordinal number to order the virtual 

329 device. 

330 LogicalDevice with lower ordinal number will receive a lower device id. 

331 Physical device id and location in the list is used to break ties. 

332 Currently only supported for Nvidia GPUs. 

333 """ 

334 

335 def __new__(cls, 

336 memory_limit=None, 

337 experimental_priority=None, 

338 experimental_device_ordinal=None): 

339 return super().__new__(cls, memory_limit, experimental_priority, 

340 experimental_device_ordinal) 

341 

342 

343@tf_export("config.PhysicalDevice") 

344class PhysicalDevice( 

345 collections.namedtuple("PhysicalDevice", ["name", "device_type"])): 

346 """Abstraction for a locally visible physical device. 

347 

348 TensorFlow can utilize various devices such as the CPU or multiple GPUs 

349 for computation. Before initializing a local device for use, the user can 

350 customize certain properties of the device such as it's visibility or memory 

351 configuration. 

352 

353 Once a visible `tf.config.PhysicalDevice` is initialized one or more 

354 `tf.config.LogicalDevice` objects are created. Use 

355 `tf.config.set_visible_devices` to configure the visibility of a physical 

356 device and `tf.config.set_logical_device_configuration` to configure multiple 

357 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is 

358 useful when separation between models is needed or to simulate a multi-device 

359 environment. 

360 

361 Fields: 

362 name: Unique identifier for device. 

363 device_type: String declaring the type of device such as "CPU" or "GPU". 

364 """ 

365 pass 

366 

367 

368class _AtomicCounter(object): 

369 """A simple atomic counter.""" 

370 

371 __slots__ = ["_value", "_lock"] 

372 

373 def __init__(self): 

374 self._value = 0 

375 self._lock = threading.Lock() 

376 

377 def increment_and_get(self): 

378 with self._lock: 

379 self._value += 1 

380 return self._value 

381 

382 

383_context_id_counter = _AtomicCounter() 

384 

385 

386class _TensorCacheDeleter(object): 

387 """Deletes tensor caches for a given context.""" 

388 

389 __slots__ = ["_context_id"] 

390 

391 def __init__(self, context_id): 

392 self._context_id = context_id 

393 

394 def __del__(self): 

395 if _tensor_caches_map is None: 

396 return 

397 if self._context_id in _tensor_caches_map: 

398 del _tensor_caches_map[self._context_id] 

399 

400 

401# TODO(agarwal): rename to EagerContext / EagerRuntime ? 

402# TODO(agarwal): consider keeping the corresponding Graph here. 

403class Context: 

404 """Environment in which eager operations execute.""" 

405 

406 # TODO(agarwal): create and link in some documentation for `execution_mode`. 

407 # pylint: disable=redefined-outer-name 

408 def __init__(self, 

409 config=None, 

410 device_policy=None, 

411 execution_mode=None, 

412 server_def=None): 

413 """Creates a new Context. 

414 

415 Args: 

416 config: (Optional.) A `ConfigProto` protocol buffer with configuration 

417 options for the Context. Note that a lot of these options may be 

418 currently unimplemented or irrelevant when eager execution is enabled. 

419 device_policy: (Optional.) What policy to use when trying to run an 

420 operation on a device with inputs which are not on that device. When set 

421 to None, an appropriate value will be picked automatically. The value 

422 picked may change between TensorFlow releases. Defaults to 

423 DEVICE_PLACEMENT_SILENT. 

424 Valid values: 

425 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 

426 correct. 

427 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right 

428 device but raises a warning. 

429 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide 

430 performance problems. 

431 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 

432 raising errors on the other ones. 

433 execution_mode: (Optional.) Policy controlling how operations dispatched 

434 are actually executed. When set to None, an appropriate value will be 

435 picked automatically. The value picked may change between TensorFlow 

436 releases. 

437 Valid values: 

438 - SYNC: executes each operation synchronously. 

439 - ASYNC: executes each operation asynchronously. These operations may 

440 return "non-ready" handles. 

441 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution 

442 on remote devices. GrpcServers need to be started by creating an 

443 identical server_def to this, and setting the appropriate task_indexes, 

444 so that the servers can communicate. It will then be possible to execute 

445 operations on remote devices. 

446 

447 Raises: 

448 ValueError: If execution_mode is not valid. 

449 """ 

450 # This _id is used only to index the tensor caches. 

451 # TODO(iga): Remove this when tensor caches are moved to C++. 

452 self._id = _context_id_counter.increment_and_get() 

453 self._tensor_cache_deleter = _TensorCacheDeleter(self._id) 

454 _tensor_caches_map[self._id] = _TensorCaches() 

455 

456 self._config = config 

457 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( 

458 self, 

459 is_eager=lambda: default_execution_mode == EAGER_MODE, 

460 device_spec=_starting_device_spec) 

461 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 

462 self._context_handle = None 

463 self._context_devices = None 

464 self._seed = None 

465 self._initialize_lock = threading.Lock() 

466 self._initialized = False 

467 if device_policy is None: 

468 device_policy = DEVICE_PLACEMENT_SILENT 

469 self._device_policy = device_policy 

470 self._mirroring_policy = None 

471 if execution_mode not in (None, SYNC, ASYNC): 

472 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" % 

473 execution_mode) 

474 if execution_mode is None: 

475 execution_mode = SYNC 

476 self._default_is_async = execution_mode == ASYNC 

477 self._use_tfrt = is_tfrt_enabled() 

478 self._jit_compile_rewrite = jit_compile_rewrite_enabled() 

479 self._server_def = server_def 

480 self._collective_ops_server_def = None 

481 self._collective_leader = None 

482 self._collective_scoped_allocator_enabled_ops = None 

483 self._collective_use_nccl_communication = None 

484 self._collective_device_filters = None 

485 self._coordination_service_config = None 

486 

487 self._device_lock = threading.Lock() 

488 self._physical_devices = None 

489 self._physical_device_to_index = None 

490 self._pluggable_devices = None 

491 self._visible_device_list = [] 

492 self._memory_growth_map = None 

493 self._virtual_device_map = {} 

494 

495 # Values set after construction 

496 self._optimizer_jit = None 

497 self._intra_op_parallelism_threads = None 

498 self._inter_op_parallelism_threads = None 

499 self._soft_device_placement = None 

500 self._log_device_placement = None 

501 self._operation_timeout_in_ms = None 

502 self._enable_mlir_graph_optimization = None 

503 self._optimizer_experimental_options = {} 

504 

505 _python_eager_context_create_counter.get_cell().increase_by(1) 

506 

507 self._is_global_context = False 

508 

509 # pylint: enable=redefined-outer-name 

510 

511 def _set_global_seed(self, seed): 

512 """Set a global eager mode seed for random ops.""" 

513 self._seed = seed 

514 # `random.Random(seed)` needs `seed` to be hashable, while values of type 

515 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them 

516 # to int. 

517 try: 

518 hash(seed) 

519 self._rng = random.Random(seed) 

520 except TypeError: 

521 seed = int(np.array(seed)) 

522 self._rng = random.Random(seed) 

523 # Also clear the kernel cache, to reset any existing seeds 

524 if self._context_handle is not None: 

525 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 

526 

527 def _internal_operation_seed(self): 

528 """Returns a fake operation seed. 

529 

530 In eager mode, user shouldn't set or depend on operation seed. 

531 Here, we generate a random seed based on global seed to make 

532 operation's randomness different and depend on the global seed. 

533 

534 Returns: 

535 A fake operation seed based on global seed. 

536 """ 

537 return self._rng.randint(0, _MAXINT32) 

538 

539 def _initialize_logical_devices(self): 

540 """Helper to initialize devices.""" 

541 # Store list of devices 

542 logical_devices = [] 

543 context_devices = [] 

544 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) 

545 try: 

546 self._num_gpus = 0 

547 current_job, current_task = None, None 

548 server_def = self._server_def or self._collective_ops_server_def 

549 if server_def is not None: 

550 current_job, current_task = server_def.job_name, server_def.task_index 

551 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): 

552 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) 

553 context_devices.append(pydev.canonical_name(dev_name)) 

554 spec = pydev.DeviceSpec.from_string(dev_name) 

555 # If the job is localhost, we assume that the cluster has not yet been 

556 # configured and thus clear the job, replica & task. 

557 if spec.job == "localhost": 

558 spec = spec.replace(job=None, replica=None, task=None) 

559 logical_devices.append( 

560 LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) 

561 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) 

562 if (dev_type == "GPU" and spec.job == current_job and 

563 spec.task == current_task): 

564 self._num_gpus += 1 

565 

566 finally: 

567 self._logical_devices = logical_devices 

568 self._context_devices = context_devices 

569 pywrap_tfe.TF_DeleteDeviceList(device_list) 

570 

571 def ensure_initialized(self): 

572 """Initialize handle and devices if not already done so.""" 

573 if self._initialized: 

574 return 

575 with self._initialize_lock: 

576 if self._initialized: 

577 return 

578 assert self._context_devices is None 

579 opts = pywrap_tfe.TFE_NewContextOptions() 

580 try: 

581 config_str = self.config.SerializeToString() 

582 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) 

583 if self._device_policy is not None: 

584 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( 

585 opts, self._device_policy) 

586 if self._mirroring_policy is not None: 

587 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( 

588 opts, self._mirroring_policy) 

589 if self._default_is_async == ASYNC: 

590 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) 

591 if self._use_tfrt is not None: 

592 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) 

593 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True) 

594 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite( 

595 opts, self._jit_compile_rewrite) 

596 context_handle = pywrap_tfe.TFE_NewContext(opts) 

597 finally: 

598 pywrap_tfe.TFE_DeleteContextOptions(opts) 

599 assert not (self._server_def and self._collective_ops_server_def), ( 

600 "Cannot enable remote execution as well as collective ops at the " 

601 "moment. If this is important to you, please file an issue.") 

602 if self._server_def is not None: 

603 server_def_str = self._server_def.SerializeToString() 

604 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, 

605 server_def_str) 

606 elif self._collective_ops_server_def is not None: 

607 server_def_str = self._collective_ops_server_def.SerializeToString() 

608 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) 

609 

610 self._context_handle = context_handle 

611 self._initialize_logical_devices() 

612 self._initialized = True 

613 

614 if self._is_global_context: 

615 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 

616 

617 def ensure_uninitialized(self): 

618 """Uninitialize handle and devices if not already done so.""" 

619 with self._initialize_lock: 

620 if not self._initialized: 

621 return 

622 self._context_devices = None 

623 self._logical_devices = None 

624 self._server_def = None 

625 self._initialized = False 

626 

627 if self._is_global_context: 

628 pywrap_tfe.TFE_Py_SetCEagerContext(None) 

629 

630 self._context_handle = None 

631 

632 def mark_as_global_context(self): 

633 # If the context was already initialized, publish it. Otherwise wait with 

634 # publication until it's initialized. 

635 if self._initialized: 

636 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 

637 self._is_global_context = True 

638 

639 def _clear_caches(self): 

640 self.ones_rank_cache().flush() 

641 self.zeros_cache().flush() 

642 pywrap_tfe.TFE_ClearScalarCache() 

643 

644 def get_server_def(self): 

645 return self._server_def 

646 

647 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 

648 """Allow setting a server_def on the context. 

649 

650 When a server def is replaced, it effectively clears a bunch of caches 

651 within the context. If you attempt to use a tensor object that was pointing 

652 to a tensor on the remote device, it will raise an error. 

653 

654 Args: 

655 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

656 devices. 

657 keep_alive_secs: Num. seconds after which the remote end will hang up. As 

658 long as the client is still alive, the server state for the context will 

659 be kept alive. If the client is killed (or there is some failure), the 

660 server will clean up its context keep_alive_secs after the final RPC it 

661 receives. 

662 

663 Raises: 

664 ValueError: if server_def is None. 

665 """ 

666 if not server_def: 

667 raise ValueError("server_def is None.") 

668 

669 self._server_def = server_def 

670 

671 if self._context_handle: 

672 server_def_str = server_def.SerializeToString() 

673 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, 

674 server_def_str) 

675 self._initialize_logical_devices() 

676 

677 # Clear all the caches in case there are remote tensors in them. 

678 self._clear_caches() 

679 

680 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 

681 """Update a server_def on the context. 

682 

683 Args: 

684 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

685 devices. 

686 keep_alive_secs: Num. seconds after which the remote end will hang up. As 

687 long as the client is still alive, the server state for the context will 

688 be kept alive. If the client is killed (or there is some failure), the 

689 server will clean up its context keep_alive_secs after the final RPC it 

690 receives. 

691 

692 Raises: 

693 ValueError: if server_def is None. 

694 """ 

695 if not server_def: 

696 raise ValueError("server_def is None.") 

697 

698 self._server_def = server_def 

699 

700 if self._context_handle: 

701 server_def_str = server_def.SerializeToString() 

702 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, 

703 keep_alive_secs, server_def_str) 

704 self._initialize_logical_devices() 

705 

706 self._clear_caches() 

707 

708 def check_alive(self, worker_name): 

709 """Checks whether a remote worker is alive or not. 

710 

711 Args: 

712 worker_name: a string representing the remote worker. It must be a fully 

713 specified name like "/job:worker/replica:0/task:0". 

714 

715 Returns: 

716 a boolean indicating whether the remote worker is alive or not. 

717 

718 Raises: 

719 ValueError: if context is not initialized. 

720 """ 

721 # TODO(yuefengz): support checking multiple workers. 

722 if self._context_handle: 

723 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) 

724 else: 

725 raise ValueError("Context is not initialized.") 

726 

727 def sync_executors(self): 

728 """Sync both local executors and the ones on remote workers. 

729 

730 In async execution mode, local function calls can return before the 

731 corresponding remote op/function execution requests are completed. Calling 

732 this method creates a synchronization barrier for remote executors. It only 

733 returns when all remote pending nodes are finished, potentially with errors 

734 if any remote executors are in error state. 

735 

736 Raises: 

737 ValueError: if context is not initialized. 

738 """ 

739 if self._context_handle: 

740 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle) 

741 else: 

742 raise ValueError("Context is not initialized.") 

743 

744 def clear_executor_errors(self): 

745 """Clear errors in both local executors and remote workers. 

746 

747 After receiving errors from remote workers, additional requests on the fly 

748 could further taint the status on the remote workers due to the async nature 

749 of remote execution. Calling this method block on waiting for all pending 

750 nodes in remote executors to finish and clear their error statuses. 

751 

752 Raises: 

753 ValueError: if context is not initialized. 

754 """ 

755 if self._context_handle: 

756 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle) 

757 else: 

758 raise ValueError("Context is not initialized.") 

759 

760 def configure_coordination_service(self, 

761 service_type, 

762 service_leader="", 

763 enable_health_check=True, 

764 cluster_register_timeout_in_ms=0, 

765 heartbeat_timeout_in_ms=0, 

766 shutdown_barrier_timeout_in_ms=0, 

767 coordinated_jobs=None, 

768 allow_new_incarnation_to_reconnect=False): 

769 """Enable distributed coordination service with specified configs.""" 

770 if self._context_handle: 

771 logging.warning("Configuring coordination service type may not be " 

772 "effective because the context is already initialized.") 

773 config = coordination_config_pb2.CoordinationServiceConfig() 

774 config.service_type = service_type 

775 if service_leader: 

776 config.service_leader = pydev.canonical_name(service_leader) 

777 config.enable_health_check = enable_health_check 

778 config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms 

779 config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms 

780 config.shutdown_barrier_timeout_in_ms = shutdown_barrier_timeout_in_ms 

781 config.allow_new_incarnation_to_reconnect = ( 

782 allow_new_incarnation_to_reconnect) 

783 if coordinated_jobs is not None: 

784 if isinstance(coordinated_jobs, list): 

785 config.coordinated_job_list.extend(coordinated_jobs) 

786 else: 

787 raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or " 

788 "None, but got: %s" % (coordinated_jobs,)) 

789 self._coordination_service_config = config 

790 

791 @property 

792 def coordination_service(self): 

793 return self._coordination_service_config 

794 

795 def set_config_key_value(self, key, value): 

796 ensure_initialized() 

797 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value) 

798 

799 # If `timeout_in_ms=0`, this will block until the key-value is set or the 

800 # worker shuts down. 

801 def get_config_key_value(self, key, timeout_in_ms=0): 

802 ensure_initialized() 

803 with c_api_util.tf_buffer() as buffer_: 

804 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, 

805 timeout_in_ms, buffer_) 

806 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 

807 return value 

808 

809 def delete_config_key_value(self, key): 

810 ensure_initialized() 

811 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key) 

812 

813 def report_error_to_cluster(self, error_code, error_message): 

814 """Report error to other members in a multi-client cluster. 

815 

816 Args: 

817 error_code: a `tf.errors` error code. 

818 error_message: a string. The error message. 

819 """ 

820 if self._context_handle: 

821 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code, 

822 error_message) 

823 else: 

824 raise ValueError("Context is not initialized.") 

825 

826 def get_task_states(self, job_configs): 

827 """Get task states from the Coordination Service. 

828 

829 Args: 

830 job_configs: A list of tuples of job name and task number. 

831 

832 Returns: 

833 A list of TF_Status. 

834 """ 

835 if self._context_handle: 

836 job_names, task_nums = zip(*job_configs) 

837 return pywrap_tfe.TFE_GetTaskStates(self._context_handle, job_names, 

838 task_nums) 

839 else: 

840 raise ValueError("Context is not initialized.") 

841 

842 def wait_at_barrier(self, barrier_id, timeout_in_ms): 

843 """Blocks until all coordinated tasks are at the barrier. 

844 

845 The barrier may fail if it times out or if one of the tasks is unhealthy. 

846 

847 Args: 

848 barrier_id: Unique string identifying the barrier. 

849 timeout_in_ms: Duration before the barrier times out and fails. 

850 """ 

851 ensure_initialized() 

852 pywrap_tfe.TFE_WaitAtBarrier(self._context_handle, barrier_id, 

853 timeout_in_ms) 

854 

855 def clear_kernel_cache(self): 

856 """Clear kernel cache and reset all stateful kernels.""" 

857 if self._context_handle is not None: 

858 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 

859 

860 def enable_collective_ops(self, server_def): 

861 """Enable distributed collective ops with an appropriate server_def. 

862 

863 Args: 

864 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

865 devices. 

866 

867 Raises: 

868 ValueError: if server_def is None. 

869 RuntimeError: if this method is not called at program startup. 

870 """ 

871 if not server_def: 

872 raise ValueError("server_def is None.") 

873 

874 self._collective_ops_server_def = server_def 

875 

876 # TODO(b/129298253): Allow creating datasets/tensors before enabling 

877 # collective ops. 

878 if self._context_handle is not None: 

879 logging.warning("Enabling collective ops after program startup may cause " 

880 "error when accessing previously created tensors.") 

881 with self._initialize_lock: 

882 assert self._initialized 

883 server_def_str = self._collective_ops_server_def.SerializeToString() 

884 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) 

885 self._initialize_logical_devices() 

886 self._clear_caches() 

887 

888 def configure_collective_ops( 

889 self, 

890 collective_leader="", 

891 scoped_allocator_enabled_ops=("CollectiveReduce",), 

892 use_nccl_communication=False, 

893 device_filters=None): 

894 """Configure collective ops. 

895 

896 Collective group leader is necessary for collective ops to run, other 

897 configurations are mainly for the purpose of performance. 

898 

899 Args: 

900 collective_leader: a device string for collective leader, e.g. 

901 "/job:worker/replica:0/task:0"; empty string means local execution of 

902 collective ops. 

903 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped 

904 allocator to run with. 

905 use_nccl_communication: whether to use nccl communication for collective 

906 ops. 

907 device_filters: a tuple or a list of device strings. If set, corresponding 

908 task can only see the devices filtered by these device filters. 

909 

910 Raises: 

911 RuntimeError: if this method is not called at program startup. 

912 """ 

913 if self._collective_leader is not None: 

914 if (self._collective_leader != collective_leader or 

915 self._collective_scoped_allocator_enabled_ops != 

916 scoped_allocator_enabled_ops or 

917 self._collective_use_nccl_communication != use_nccl_communication or 

918 self._collective_device_filters != device_filters): 

919 raise ValueError("Collective ops are already configured.") 

920 else: 

921 return 

922 

923 if self._context_handle is not None: 

924 raise RuntimeError("Collective ops must be configured at program startup") 

925 

926 self._collective_leader = collective_leader 

927 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops 

928 self._collective_use_nccl_communication = use_nccl_communication 

929 self._collective_device_filters = device_filters 

930 

931 def abort_collective_ops(self, code, message): 

932 """Abort the collective ops. 

933 

934 This is intended to be used when a peer failure is detected, which allows 

935 the user to handle the case instead of hanging. This aborts all on-going 

936 collectives. After all subsequent collectives error immediately, and you 

937 need to reset_context() to use collectives again. 

938 

939 Args: 

940 code: a `tf.errors` error code. 

941 message: a string. The error message. 

942 """ 

943 self.ensure_initialized() 

944 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) 

945 

946 def check_collective_ops_peer_health(self, task, timeout_in_ms): 

947 """Check collective peer health. 

948 

949 This probes each task to see if they're still alive. Note that restarted 

950 tasks are considered a different one, and they're considered not healthy. 

951 

952 This should only be used in multi client multi worker training. 

953 

954 Args: 

955 task: a task string, must be in the format of /job:xxx/replica:0/task:N. 

956 timeout_in_ms: an integer, the timeout. If zero, there's no timeout. 

957 

958 Raises: 

959 tf.errors.UnavailableError: when a peer is down. 

960 tf.errors.FailedPreconditionError: when a peer is a different one from the 

961 one this task has talked to, e.g. the peer has restarted. 

962 tf.errors.InvalidArgumentError: when the task string is invalid. 

963 """ 

964 self.ensure_initialized() 

965 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, 

966 timeout_in_ms) 

967 

968 @property 

969 def _handle(self): 

970 if self._context_handle is None: 

971 raise AssertionError("Context must be initialized first.") 

972 

973 return self._context_handle 

974 

975 @property 

976 def _devices(self): 

977 if self._context_devices is None: 

978 raise AssertionError("Context must be initialized first.") 

979 

980 return self._context_devices 

981 

982 def __str__(self): 

983 if self._context_handle is None: 

984 return "Eager TensorFlow Context. Devices currently uninitialized." 

985 else: 

986 devices = self._devices 

987 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 

988 for i, d in enumerate(devices): 

989 lines.append(" Device %d: %s" % (i, d)) 

990 return "\n".join(lines) 

991 

992 @tf_contextlib.contextmanager 

993 def _mode(self, mode): 

994 """A context manager to allow setting the mode to EAGER/GRAPH.""" 

995 ctx = self._thread_local_data 

996 old_is_eager = ctx.is_eager 

997 ctx.is_eager = mode == EAGER_MODE 

998 if mode == EAGER_MODE: 

999 # Entering graph mode does not provide us with sufficient information to 

1000 # record a context switch; graph-based context switches are only logged 

1001 # when a graph is registered as the default graph. 

1002 self.context_switches.push(False, eager_mode, None) 

1003 try: 

1004 yield 

1005 finally: 

1006 ctx.is_eager = old_is_eager 

1007 if mode == EAGER_MODE: 

1008 self.context_switches.pop() 

1009 

1010 def executing_eagerly(self): 

1011 """Returns True if current thread has eager executing enabled.""" 

1012 return self._thread_local_data.is_eager 

1013 

1014 def ones_rank_cache(self): 

1015 """Per-device cache for scalars.""" 

1016 return _tensor_caches_map[self._id].ones_rank_cache 

1017 

1018 def zeros_cache(self): 

1019 """Per-device cache for scalars.""" 

1020 return _tensor_caches_map[self._id].zeros_cache 

1021 

1022 @property 

1023 def scope_name(self): 

1024 """Returns scope name for the current thread.""" 

1025 return self._thread_local_data.scope_name 

1026 

1027 @scope_name.setter 

1028 def scope_name(self, s): 

1029 """Sets scope name for the current thread.""" 

1030 self._thread_local_data.scope_name = s 

1031 

1032 @property 

1033 def device_name(self): 

1034 """Returns the device name for the current thread.""" 

1035 return self._thread_local_data.device_name 

1036 

1037 @property 

1038 def device_spec(self): 

1039 """Returns the device spec for the current thread.""" 

1040 return self._thread_local_data.device_spec 

1041 

1042 def _set_device(self, device_name, device_spec): 

1043 self._thread_local_data.device_name = device_name 

1044 self._thread_local_data.device_spec = device_spec 

1045 

1046 def device(self, name): 

1047 """Context-manager to force placement of operations and Tensors on a device. 

1048 

1049 Args: 

1050 name: Name of the device or None to get default placement. 

1051 

1052 Returns: 

1053 Context manager that forces device placement. 

1054 

1055 Raises: 

1056 ValueError: If name is not a string or is an invalid device name. 

1057 RuntimeError: If device scopes are not properly nested. 

1058 """ 

1059 if isinstance(name, LogicalDevice): 

1060 name = name.name 

1061 elif pydev.is_device_spec(name): 

1062 name = name.to_string() 

1063 return _EagerDeviceContext(self, name) 

1064 

1065 def devices(self): 

1066 """List of the names of devices available to execute operations.""" 

1067 return self._devices 

1068 

1069 def host_address_space(self): 

1070 self.ensure_initialized() 

1071 with c_api_util.tf_buffer() as buffer_: 

1072 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_) 

1073 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 

1074 return address_space 

1075 

1076 # TODO(fishx): remove this property. 

1077 @property 

1078 def execution_mode(self): 

1079 """Gets execution mode for current thread.""" 

1080 return ASYNC if self.is_async() else SYNC 

1081 

1082 @execution_mode.setter 

1083 def execution_mode(self, mode): 

1084 """Sets execution mode for current thread.""" 

1085 if mode not in (None, SYNC, ASYNC): 

1086 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" % 

1087 mode) 

1088 

1089 if mode is None: 

1090 mode = SYNC 

1091 

1092 enable_async = (mode == ASYNC) 

1093 if self.is_async() != enable_async: 

1094 # Only set the execution mode if the context has already been initialized 

1095 if self._context_handle is not None: 

1096 self.executor.wait() 

1097 executor_new = executor.new_executor(enable_async) 

1098 self._thread_local_data.executor = executor_new 

1099 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, 

1100 executor_new.handle()) 

1101 else: 

1102 self._default_is_async = enable_async 

1103 

1104 def is_async(self): 

1105 if self._context_handle is not None: 

1106 return self.executor.is_async() 

1107 else: 

1108 return self._default_is_async 

1109 

1110 @property 

1111 def executor(self): 

1112 self.ensure_initialized() 

1113 return executor.Executor( 

1114 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) 

1115 

1116 @executor.setter 

1117 def executor(self, e): 

1118 self.ensure_initialized() 

1119 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) 

1120 

1121 @property 

1122 def config(self): 

1123 """Return the ConfigProto with all runtime deltas applied.""" 

1124 # Ensure physical devices have been discovered and config has been imported 

1125 self._initialize_physical_devices() 

1126 

1127 config = config_pb2.ConfigProto() 

1128 if self._config is not None: 

1129 config.CopyFrom(self._config) 

1130 

1131 if self._optimizer_jit is not None: 

1132 config.graph_options.optimizer_options.global_jit_level = ( 

1133 config_pb2.OptimizerOptions.ON_1 

1134 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) 

1135 if self._intra_op_parallelism_threads is not None: 

1136 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads 

1137 if self._inter_op_parallelism_threads is not None: 

1138 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads 

1139 

1140 if self._soft_device_placement is not None: 

1141 config.allow_soft_placement = self._soft_device_placement 

1142 else: 

1143 config.allow_soft_placement = self.executing_eagerly() 

1144 

1145 if self._log_device_placement is not None: 

1146 config.log_device_placement = self._log_device_placement 

1147 

1148 if self._operation_timeout_in_ms is not None: 

1149 config.operation_timeout_in_ms = self._operation_timeout_in_ms 

1150 

1151 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() 

1152 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled 

1153 if (is_mlir_bridge_enabled == 

1154 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): 

1155 config.experimental.enable_mlir_bridge = True 

1156 

1157 if self._enable_mlir_graph_optimization is not None: 

1158 config.experimental.enable_mlir_graph_optimization = ( 

1159 self._enable_mlir_graph_optimization) 

1160 

1161 def rewriter_toggle(option): 

1162 toggle = self._optimizer_experimental_options.get(option, None) 

1163 if toggle is None: 

1164 return 

1165 

1166 setattr(config.graph_options.rewrite_options, option, 

1167 (rewriter_config_pb2.RewriterConfig.ON 

1168 if toggle else rewriter_config_pb2.RewriterConfig.OFF)) 

1169 

1170 def rewriter_bool(option): 

1171 toggle = self._optimizer_experimental_options.get(option, None) 

1172 if toggle is None: 

1173 return 

1174 

1175 setattr(config.graph_options.rewrite_options, option, toggle) 

1176 

1177 rewriter_toggle("layout_optimizer") 

1178 rewriter_toggle("constant_folding") 

1179 rewriter_toggle("shape_optimization") 

1180 rewriter_toggle("remapping") 

1181 rewriter_toggle("arithmetic_optimization") 

1182 rewriter_toggle("dependency_optimization") 

1183 rewriter_toggle("loop_optimization") 

1184 rewriter_toggle("function_optimization") 

1185 rewriter_toggle("debug_stripper") 

1186 rewriter_bool("disable_model_pruning") 

1187 rewriter_toggle("scoped_allocator_optimization") 

1188 rewriter_toggle("pin_to_host_optimization") 

1189 rewriter_toggle("implementation_selector") 

1190 rewriter_toggle("auto_mixed_precision") 

1191 rewriter_toggle("use_plugin_optimizers") 

1192 rewriter_bool("disable_meta_optimizer") 

1193 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 

1194 rewriter_toggle("auto_mixed_precision_mkl") 

1195 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None) 

1196 if nodes is not None: 

1197 config.graph_options.rewrite_options.min_graph_nodes = nodes 

1198 

1199 # Compute device counts 

1200 config.device_count["CPU"] = 0 

1201 config.device_count["GPU"] = 0 

1202 for dev in self._physical_devices: 

1203 if dev not in self._visible_device_list: 

1204 continue 

1205 

1206 virtual_devices = self._virtual_device_map.get(dev) 

1207 if virtual_devices is None: 

1208 config.device_count[dev.device_type] += 1 

1209 else: 

1210 config.device_count[dev.device_type] += len(virtual_devices) 

1211 

1212 # Configure gpu_options 

1213 gpu_options = self._compute_gpu_options() 

1214 config.gpu_options.MergeFrom(gpu_options) 

1215 

1216 # Configure collective ops 

1217 if self._collective_leader: 

1218 config.experimental.collective_group_leader = self._collective_leader 

1219 if self._collective_scoped_allocator_enabled_ops: 

1220 rewrite_options = config.graph_options.rewrite_options 

1221 rewrite_options.scoped_allocator_optimization = ( 

1222 rewriter_config_pb2.RewriterConfig.ON) 

1223 del rewrite_options.scoped_allocator_opts.enable_op[:] 

1224 for op in self._collective_scoped_allocator_enabled_ops: 

1225 rewrite_options.scoped_allocator_opts.enable_op.append(op) 

1226 if self._collective_use_nccl_communication: 

1227 config.experimental.collective_nccl = True 

1228 if self._collective_device_filters: 

1229 del config.device_filters[:] 

1230 for f in self._collective_device_filters: 

1231 config.device_filters.append(f) 

1232 

1233 # Configure coordination service 

1234 if self._coordination_service_config: 

1235 config.experimental.coordination_config.CopyFrom( 

1236 self._coordination_service_config) 

1237 

1238 return config 

1239 

1240 def _compute_gpu_options(self): 

1241 """Build the GPUOptions proto.""" 

1242 visible_device_list = [] 

1243 virtual_devices = [] 

1244 gpu_index = -1 

1245 memory_growths = set() 

1246 gpu_devices = self.list_physical_devices("GPU") 

1247 pluggable_devices = self._pluggable_devices 

1248 compatible_devices = gpu_devices 

1249 for dev in pluggable_devices: 

1250 if dev not in gpu_devices: 

1251 compatible_devices.append(dev) 

1252 for dev in compatible_devices: 

1253 gpu_index += 1 

1254 

1255 if dev not in self._visible_device_list: 

1256 continue 

1257 

1258 growth = self._memory_growth_map[dev] 

1259 memory_growths.add(growth) 

1260 visible_device_list.append(str(gpu_index)) 

1261 

1262 if self._virtual_device_map: 

1263 vdevs = self._virtual_device_map.get(dev, []) 

1264 device_ordinals = [] 

1265 device_limits = [] 

1266 priority = [] 

1267 for virt_dev in vdevs: 

1268 if virt_dev.experimental_device_ordinal is not None: 

1269 device_ordinals.append(virt_dev.experimental_device_ordinal) 

1270 device_limits.append(virt_dev.memory_limit) 

1271 if virt_dev.experimental_priority is not None: 

1272 priority.append(virt_dev.experimental_priority) 

1273 # If priority is specified, it must be specified for all virtual 

1274 # devices. 

1275 if priority and len(device_limits) != len(priority): 

1276 raise ValueError("priority must be specified for all virtual devices") 

1277 # If device_ordinals is specified, it must be specified for all virtual 

1278 # devices. 

1279 if device_ordinals and len(device_limits) != len(device_ordinals): 

1280 raise ValueError( 

1281 "device_ordinals must be specified for all virtual devices") 

1282 

1283 virtual_devices.append( 

1284 config_pb2.GPUOptions.Experimental.VirtualDevices( 

1285 memory_limit_mb=device_limits, 

1286 priority=priority, 

1287 device_ordinal=device_ordinals)) 

1288 

1289 # Only compute growth if virtual devices have not been configured and we 

1290 # have GPUs 

1291 if not virtual_devices and memory_growths: 

1292 if len(memory_growths) > 1: 

1293 raise ValueError("Memory growth cannot differ between GPU devices") 

1294 allow_growth = memory_growths.pop() 

1295 else: 

1296 allow_growth = None 

1297 

1298 return config_pb2.GPUOptions( 

1299 allow_growth=allow_growth, 

1300 visible_device_list=",".join(visible_device_list), 

1301 experimental=config_pb2.GPUOptions.Experimental( 

1302 virtual_devices=virtual_devices)) 

1303 

1304 @property 

1305 def function_call_options(self): 

1306 """Returns function call options for current thread. 

1307 

1308 Note that the returned object is still referenced by the eager context. 

1309 

1310 Returns: the FunctionCallOptions for current thread. 

1311 """ 

1312 if self._thread_local_data.function_call_options is None: 

1313 config = self.config 

1314 

1315 # Default to soft placement for functions unless specified 

1316 if self._soft_device_placement is None: 

1317 config.allow_soft_placement = True 

1318 self._thread_local_data.function_call_options = FunctionCallOptions( 

1319 config_proto=config) 

1320 

1321 return self._thread_local_data.function_call_options 

1322 

1323 @function_call_options.setter 

1324 def function_call_options(self, options): 

1325 """Returns function call options for current thread.""" 

1326 self._thread_local_data.function_call_options = options 

1327 

1328 def num_gpus(self): 

1329 """The number of GPUs available to execute operations.""" 

1330 self.ensure_initialized() 

1331 return self._num_gpus 

1332 

1333 def add_c_function(self, c_func): 

1334 """Add a C API TF_Function to the context. 

1335 

1336 Once added, the function (identified by its name) can be executed like any 

1337 other operation. 

1338 

1339 Args: 

1340 c_func: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 

1341 """ 

1342 self.ensure_initialized() 

1343 pywrap_tfe.TFE_ContextAddFunction(self._handle, c_func) 

1344 

1345 def get_c_function(self, name): 

1346 """Get a C API TF_Function from the context. 

1347 

1348 Args: 

1349 name: Name of the function to get. 

1350 

1351 Returns: 

1352 A ScopedTFFunction wrapping the C API TF_Function. 

1353 """ 

1354 self.ensure_initialized() 

1355 return c_api_util.ScopedTFFunction( 

1356 pywrap_tfe.TFE_ContextGetFunction(self._handle, name), name 

1357 ) 

1358 

1359 def add_function_def(self, fdef): 

1360 """Add a function definition to the context. 

1361 

1362 Once added, the function (identified by its name) can be executed like any 

1363 other operation. 

1364 

1365 Args: 

1366 fdef: A FunctionDef protocol buffer message. 

1367 """ 

1368 self.ensure_initialized() 

1369 fdef_string = fdef.SerializeToString() 

1370 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, 

1371 len(fdef_string)) 

1372 

1373 def get_function_def(self, name): 

1374 """Get a function definition from the context. 

1375 

1376 Args: 

1377 name: function signature name. 

1378 

1379 Returns: 

1380 The requested FunctionDef. 

1381 

1382 Raises: 

1383 tf.errors.NotFoundError: if name is not the name of a registered function. 

1384 """ 

1385 with c_api_util.tf_buffer() as buffer_: 

1386 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_) 

1387 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 

1388 function_def = function_pb2.FunctionDef() 

1389 function_def.ParseFromString(proto_data) 

1390 

1391 return function_def 

1392 

1393 def is_custom_device(self, device_name): 

1394 """Calls TFE_IsCustomDevice. See the non-member function.""" 

1395 self.ensure_initialized() 

1396 return pywrap_tfe.TFE_Py_IsCustomDevice(self._handle, device_name) 

1397 

1398 def register_custom_device(self, device_capsule, device_name, 

1399 device_info_capsule): 

1400 """Calls TFE_RegisterCustomDevice. See the non-member function.""" 

1401 self.ensure_initialized() 

1402 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, 

1403 device_name, device_info_capsule) 

1404 

1405 def pack_eager_tensors(self, tensors): 

1406 """Pack multiple `EagerTensor`s of the same dtype and shape. 

1407 

1408 Args: 

1409 tensors: a list of EagerTensors to pack. 

1410 

1411 Returns: 

1412 A packed EagerTensor. 

1413 """ 

1414 self.ensure_initialized() 

1415 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) 

1416 

1417 def list_function_names(self): 

1418 """Get a list of names of registered functions. 

1419 

1420 Returns: 

1421 A set of names of all registered functions for the context. 

1422 """ 

1423 self.ensure_initialized() 

1424 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle)) 

1425 

1426 def remove_function(self, name): 

1427 """Remove a function from the context. 

1428 

1429 Once removed, the function cannot be executed anymore. 

1430 

1431 Args: 

1432 name: function signature name. 

1433 """ 

1434 self.ensure_initialized() 

1435 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) 

1436 

1437 def has_function(self, name): 

1438 """Check if a function `name` is registered.""" 

1439 self.ensure_initialized() 

1440 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) 

1441 

1442 @property 

1443 def function_scope_id(self): 

1444 """Returns an id that is unique to each scope holding functions.""" 

1445 return id(self._context_handle) 

1446 

1447 def call_function(self, name, tensor_inputs, num_outputs): 

1448 """Calls the function associated with the given name.""" 

1449 attrs = tuple( 

1450 itertools.chain( 

1451 *self.function_call_options.as_attrs().items() 

1452 ) 

1453 ) 

1454 

1455 cancellation_context = cancellation.context() 

1456 if cancellation_context is None: 

1457 outputs = execute.execute( 

1458 name.decode("utf-8"), 

1459 num_outputs=num_outputs, 

1460 inputs=tensor_inputs, 

1461 attrs=attrs, 

1462 ctx=self, 

1463 ) 

1464 else: 

1465 outputs = execute.execute_with_cancellation( 

1466 name.decode("utf-8"), 

1467 num_outputs=num_outputs, 

1468 inputs=tensor_inputs, 

1469 attrs=attrs, 

1470 ctx=self, 

1471 cancellation_manager=cancellation_context, 

1472 ) 

1473 # Empty list means no function outputs so return None 

1474 outputs = outputs or None 

1475 

1476 return outputs 

1477 

1478 def add_op_callback(self, callback): 

1479 """Add a post-op callback to the context. 

1480 

1481 A post-op callback is invoked immediately after an eager operation or 

1482 function has finished execution or after a op has been added to a graph, 

1483 providing access to the op's type, name input and output tensors. Multiple 

1484 op callbacks can be added, in which case the callbacks will be invoked in 

1485 the order in which they are added. 

1486 

1487 Args: 

1488 callback: a callable of the signature `f(op_type, inputs, attrs, outputs, 

1489 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for 

1490 details on the function signature and its semantics. 

1491 """ 

1492 if callback not in self._thread_local_data.op_callbacks: 

1493 self._thread_local_data.op_callbacks.append(callback) 

1494 

1495 def remove_op_callback(self, callback): 

1496 """Remove an already-registered op callback. 

1497 

1498 Args: 

1499 callback: The op callback to be removed. 

1500 

1501 Raises: 

1502 KeyError: If `callback` is not already registered. 

1503 """ 

1504 if callback not in self._thread_local_data.op_callbacks: 

1505 raise KeyError("The specified op callback has not been registered, " 

1506 "and hence cannot be removed.") 

1507 del self._thread_local_data.op_callbacks[ 

1508 self._thread_local_data.op_callbacks.index(callback)] 

1509 

1510 @property 

1511 def op_callbacks(self): 

1512 return self._thread_local_data.op_callbacks 

1513 

1514 @property 

1515 def invoking_op_callbacks(self): 

1516 return self._thread_local_data.invoking_op_callbacks 

1517 

1518 @invoking_op_callbacks.setter 

1519 def invoking_op_callbacks(self, value): 

1520 self._thread_local_data.invoking_op_callbacks = value 

1521 

1522 def _initialize_physical_devices(self, reinitialize=False): 

1523 """Gets local devices visible to the system. 

1524 

1525 Args: 

1526 reinitialize: If True, reinitializes self._physical_devices so that 

1527 dynamic registered devices will also be visible to the python front-end. 

1528 """ 

1529 # We lazy initialize self._physical_devices since we do not want to do this 

1530 # the constructor since the backend may not be initialized yet. 

1531 with self._device_lock: 

1532 if not reinitialize and self._physical_devices is not None: 

1533 return 

1534 

1535 devs = pywrap_tfe.TF_ListPhysicalDevices() 

1536 self._physical_devices = [ 

1537 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 

1538 for d in devs 

1539 ] 

1540 self._physical_device_to_index = { 

1541 p: i for i, p in enumerate(self._physical_devices) 

1542 } 

1543 # We maintain a separate list just so we can check whether the device in 

1544 # _physical_devices is a PluggableDevice. 

1545 pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices() 

1546 self._pluggable_devices = [ 

1547 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 

1548 for d in pluggable_devs 

1549 ] 

1550 

1551 self._visible_device_list = list(self._physical_devices) 

1552 self._memory_growth_map = { 

1553 d: None 

1554 for d in self._physical_devices 

1555 if d.device_type == "GPU" or d in self._pluggable_devices 

1556 } 

1557 

1558 # Import device settings that may have been passed into the constructor 

1559 self._import_config() 

1560 

1561 def reinitialize_physical_devices(self): 

1562 """Gets local devices visible to the system.""" 

1563 # Reinitialize the physical device list after registering 

1564 # the pluggable device. 

1565 self._initialize_physical_devices(True) 

1566 

1567 def list_physical_devices(self, device_type=None): 

1568 """List local devices visible to the system. 

1569 

1570 This API allows a client to query the devices before they have been 

1571 initialized by the eager runtime. Additionally a user can filter by device 

1572 type, to get only CPUs or GPUs. 

1573 

1574 Args: 

1575 device_type: Optional device type to limit results to 

1576 

1577 Returns: 

1578 List of PhysicalDevice objects. 

1579 """ 

1580 self._initialize_physical_devices() 

1581 

1582 if device_type is None: 

1583 return list(self._physical_devices) 

1584 

1585 return [d for d in self._physical_devices if d.device_type == device_type] 

1586 

1587 def get_device_details(self, device): # pylint: disable=redefined-outer-name 

1588 """Returns details about a physical devices. 

1589 

1590 Args: 

1591 device: A `tf.config.PhysicalDevice` returned by 

1592 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. 

1593 

1594 Returns: 

1595 A dict with string keys. 

1596 """ 

1597 if not isinstance(device, PhysicalDevice): 

1598 raise ValueError("device must be a tf.config.PhysicalDevice, but got: " 

1599 "%s" % (device,)) 

1600 if (self._physical_device_to_index is None or 

1601 device not in self._physical_device_to_index): 

1602 raise ValueError("The PhysicalDevice must be one obtained from " 

1603 "calling `tf.config.list_physical_devices`, but got: " 

1604 "%s" % (device,)) 

1605 index = self._physical_device_to_index[device] 

1606 details = pywrap_tfe.TF_GetDeviceDetails(index) 

1607 

1608 # Change compute_capability from a string to a tuple 

1609 if "compute_capability" in details: 

1610 try: 

1611 major, minor = details["compute_capability"].split(".") 

1612 details["compute_capability"] = (int(major), int(minor)) 

1613 except ValueError: 

1614 raise RuntimeError("Device returned compute capability an in invalid " 

1615 "format: %s" % details["compute_capability"]) 

1616 return details 

1617 

1618 def _import_config(self): 

1619 """Import config if passed in during construction. 

1620 

1621 If Context was created with a ConfigProto such as when calling 

1622 tf.compat.v1.enable_eager_execution(), then we need to pull out the 

1623 various pieces we might be replacing and import then into our internal 

1624 class representation. 

1625 """ 

1626 if self._config is None: 

1627 return 

1628 

1629 num_cpus = self._config.device_count.get("CPU", 1) 

1630 if num_cpus != 1: 

1631 cpus = [d for d in self._physical_devices if d.device_type == "CPU"] 

1632 if num_cpus == 0: 

1633 self.set_visible_devices([], "CPU") 

1634 elif num_cpus > 1: 

1635 self.set_logical_device_configuration( 

1636 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) 

1637 

1638 # Parse GPU options 

1639 gpus = [d for d in self._physical_devices if d.device_type == "GPU"] 

1640 

1641 # If there are no GPUs detected, simply ignore all the GPU options passed in 

1642 # rather than doing any validation checks. 

1643 if not gpus: 

1644 return 

1645 

1646 gpu_count = self._config.device_count.get("GPU", None) 

1647 

1648 visible_gpus = [] 

1649 # TODO(gjn): Handle importing existing virtual GPU configuration 

1650 visible_indices = self._config.gpu_options.visible_device_list 

1651 if visible_indices: 

1652 for index in visible_indices.split(","): 

1653 if int(index) >= len(gpus): 

1654 raise ValueError("Invalid visible device index: %s" % index) 

1655 visible_gpus.append(gpus[int(index)]) 

1656 else: 

1657 visible_gpus = gpus 

1658 

1659 if gpu_count is not None: 

1660 visible_gpus = visible_gpus[:gpu_count] 

1661 

1662 self.set_visible_devices(visible_gpus, "GPU") 

1663 

1664 def list_logical_devices(self, device_type=None): 

1665 """Return logical devices.""" 

1666 self.ensure_initialized() 

1667 if device_type is None: 

1668 return list(self._logical_devices) 

1669 

1670 return [d for d in self._logical_devices if d.device_type == device_type] 

1671 

1672 def get_visible_devices(self, device_type=None): 

1673 """Get the list of visible devices.""" 

1674 self._initialize_physical_devices() 

1675 

1676 if device_type is None: 

1677 return list(self._visible_device_list) 

1678 

1679 return [ 

1680 d for d in self._visible_device_list if d.device_type == device_type 

1681 ] 

1682 

1683 def set_visible_devices(self, devices, device_type=None): 

1684 """Set the list of visible devices.""" 

1685 self._initialize_physical_devices() 

1686 

1687 if not isinstance(devices, list): 

1688 devices = [devices] 

1689 

1690 for d in devices: 

1691 if d not in self._physical_devices: 

1692 raise ValueError("Unrecognized device: %s" % repr(d)) 

1693 if device_type is not None and d.device_type != device_type: 

1694 raise ValueError("Unrecognized device: %s" % repr(d)) 

1695 

1696 visible_device_list = [] 

1697 if device_type is not None: 

1698 visible_device_list = [ 

1699 d for d in self._visible_device_list if d.device_type != device_type 

1700 ] 

1701 

1702 visible_device_list += devices 

1703 

1704 if self._visible_device_list == visible_device_list: 

1705 return 

1706 

1707 if self._context_handle is not None: 

1708 raise RuntimeError( 

1709 "Visible devices cannot be modified after being initialized") 

1710 

1711 self._visible_device_list = visible_device_list 

1712 

1713 def get_memory_info(self, dev): 

1714 """Returns a dict of memory info for the device.""" 

1715 self._initialize_physical_devices() 

1716 self.ensure_initialized() 

1717 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev) 

1718 

1719 def reset_memory_stats(self, dev): 

1720 """Resets the tracked memory stats for the device.""" 

1721 self._initialize_physical_devices() 

1722 self.ensure_initialized() 

1723 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev) 

1724 

1725 def get_memory_growth(self, dev): 

1726 """Get if memory growth is enabled for a PhysicalDevice.""" 

1727 self._initialize_physical_devices() 

1728 

1729 if dev not in self._physical_devices: 

1730 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1731 

1732 return self._memory_growth_map[dev] 

1733 

1734 def set_memory_growth(self, dev, enable): 

1735 """Set if memory growth should be enabled for a PhysicalDevice.""" 

1736 self._initialize_physical_devices() 

1737 

1738 if dev not in self._physical_devices: 

1739 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1740 

1741 if dev in self._virtual_device_map: 

1742 raise ValueError( 

1743 "Cannot set memory growth on device when virtual devices configured") 

1744 

1745 if dev.device_type != "GPU" and dev not in self._pluggable_devices: 

1746 raise ValueError( 

1747 "Cannot set memory growth on non-GPU and non-Pluggable devices") 

1748 

1749 if self._memory_growth_map.get(dev) == enable: 

1750 return 

1751 

1752 if self._context_handle is not None: 

1753 raise RuntimeError( 

1754 "Physical devices cannot be modified after being initialized") 

1755 

1756 self._memory_growth_map[dev] = enable 

1757 

1758 def get_logical_device_configuration(self, dev): 

1759 """Get the virtual device configuration for a PhysicalDevice.""" 

1760 self._initialize_physical_devices() 

1761 

1762 if dev not in self._physical_devices: 

1763 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1764 

1765 return self._virtual_device_map.get(dev) 

1766 

1767 def set_logical_device_configuration(self, dev, virtual_devices): 

1768 """Set the virtual device configuration for a PhysicalDevice.""" 

1769 self._initialize_physical_devices() 

1770 

1771 if dev not in self._physical_devices: 

1772 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1773 

1774 if dev.device_type == "CPU": 

1775 for vdev in virtual_devices: 

1776 if vdev.memory_limit is not None: 

1777 raise ValueError("Setting memory limit on CPU virtual devices is " 

1778 "currently not supported") 

1779 if vdev.experimental_priority is not None: 

1780 raise ValueError("Setting experimental_priority on CPU virtual " 

1781 " devices is currently not supported") 

1782 if vdev.experimental_device_ordinal is not None: 

1783 raise ValueError("Setting experimental_device_ordinal on CPU virtual " 

1784 " devices is currently not supported") 

1785 elif dev.device_type == "GPU": 

1786 for vdev in virtual_devices: 

1787 if vdev.memory_limit is None: 

1788 raise ValueError( 

1789 "Setting memory limit is required for GPU virtual devices") 

1790 else: 

1791 raise ValueError("Virtual devices are not supported for %s" % 

1792 dev.device_type) 

1793 

1794 if self._virtual_device_map.get(dev) == virtual_devices: 

1795 return 

1796 

1797 if self._context_handle is not None: 

1798 raise RuntimeError( 

1799 "Virtual devices cannot be modified after being initialized") 

1800 

1801 self._virtual_device_map[dev] = virtual_devices 

1802 

1803 def set_logical_cpu_devices(self, num_cpus, prefix=""): 

1804 """Set virtual CPU devices in context. 

1805 

1806 If virtual CPU devices are already configured at context initialization 

1807 by tf.config.set_logical_device_configuration(), this method should not be 

1808 called. 

1809 

1810 Args: 

1811 num_cpus: Number of virtual CPUs. 

1812 prefix: Device name prefix. 

1813 

1814 Raises: 

1815 RuntimeError: If virtual CPUs are already configured at context 

1816 initialization. 

1817 """ 

1818 server_def = self._server_def or self._collective_ops_server_def 

1819 local_prefix = ["/device"] 

1820 if server_def is not None: 

1821 local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name, 

1822 server_def.task_index)) 

1823 logical_local_devices = [d for d in self.list_logical_devices("CPU") if 

1824 d.name.startswith(tuple(local_prefix))] 

1825 self.ensure_initialized() 

1826 # Error out if there are already multiple logical CPU in the context. 

1827 if len(logical_local_devices) > 1: 

1828 raise RuntimeError("Virtual CPUs already set, cannot modify again.") 

1829 

1830 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix) 

1831 self._initialize_logical_devices() 

1832 

1833 def get_compiler_ir( 

1834 self, 

1835 device_name, 

1836 function_name, 

1837 flat_args, 

1838 captured_inputs, 

1839 stage="hlo", 

1840 ): 

1841 return pywrap_tfe.TF_GetCompilerIr( 

1842 self._context_handle, 

1843 function_name, 

1844 stage, 

1845 device_name, 

1846 flat_args, 

1847 captured_inputs, 

1848 ) 

1849 

1850 @deprecated( 

1851 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) 

1852 def enable_xla_devices(self): 

1853 """Enables XLA:CPU and XLA:GPU devices registration.""" 

1854 pywrap_tfe.TF_EnableXlaDevices() 

1855 

1856 @property 

1857 def enable_mlir_bridge(self): 

1858 return pywrap_tfe.TF_IsMlirBridgeEnabled() 

1859 

1860 @property 

1861 def enable_mlir_graph_optimization(self): 

1862 return self._enable_mlir_graph_optimization 

1863 

1864 @enable_mlir_bridge.setter 

1865 def enable_mlir_bridge(self, enabled): 

1866 pywrap_tfe.TF_EnableMlirBridge(enabled) 

1867 self._thread_local_data.function_call_options = None 

1868 

1869 @enable_mlir_graph_optimization.setter 

1870 def enable_mlir_graph_optimization(self, enabled): 

1871 self._enable_mlir_graph_optimization = enabled 

1872 self._thread_local_data.function_call_options = None 

1873 

1874 @property 

1875 def optimizer_jit(self): 

1876 level = self.config.graph_options.optimizer_options.global_jit_level 

1877 return (level == config_pb2.OptimizerOptions.ON_1 or 

1878 level == config_pb2.OptimizerOptions.ON_2) 

1879 

1880 @optimizer_jit.setter 

1881 def optimizer_jit(self, enabled): 

1882 self._optimizer_jit = enabled 

1883 

1884 self._thread_local_data.function_call_options = None 

1885 

1886 def get_optimizer_experimental_options(self): 

1887 """Get experimental options for the optimizer. 

1888 

1889 Returns: 

1890 Dictionary of current option values 

1891 """ 

1892 rewrite_options = self.config.graph_options.rewrite_options 

1893 options = {} 

1894 

1895 def rewriter_toggle(option): 

1896 attr = getattr(rewrite_options, option) 

1897 if attr != 0: 

1898 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) 

1899 

1900 def rewriter_bool(option): 

1901 options[option] = getattr(rewrite_options, option) 

1902 

1903 rewriter_toggle("layout_optimizer") 

1904 rewriter_toggle("constant_folding") 

1905 rewriter_toggle("shape_optimization") 

1906 rewriter_toggle("remapping") 

1907 rewriter_toggle("arithmetic_optimization") 

1908 rewriter_toggle("dependency_optimization") 

1909 rewriter_toggle("loop_optimization") 

1910 rewriter_toggle("function_optimization") 

1911 rewriter_toggle("debug_stripper") 

1912 rewriter_bool("disable_model_pruning") 

1913 rewriter_toggle("scoped_allocator_optimization") 

1914 rewriter_toggle("pin_to_host_optimization") 

1915 rewriter_toggle("implementation_selector") 

1916 rewriter_toggle("auto_mixed_precision") 

1917 rewriter_toggle("use_plugin_optimizers") 

1918 rewriter_bool("disable_meta_optimizer") 

1919 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 

1920 rewriter_toggle("auto_mixed_precision_mkl") 

1921 

1922 if rewrite_options.min_graph_nodes != 0: 

1923 options["min_graph_nodes"] = rewrite_options.min_graph_nodes 

1924 

1925 return options 

1926 

1927 def set_optimizer_experimental_options(self, options): 

1928 """Set experimental options for the optimizer. 

1929 

1930 Args: 

1931 options: Dictionary of options to modify 

1932 """ 

1933 self._optimizer_experimental_options.update(options) 

1934 

1935 self._thread_local_data.function_call_options = None 

1936 

1937 @property 

1938 def intra_op_parallelism_threads(self): 

1939 return self.config.intra_op_parallelism_threads 

1940 

1941 @intra_op_parallelism_threads.setter 

1942 def intra_op_parallelism_threads(self, num_threads): 

1943 if self._intra_op_parallelism_threads == num_threads: 

1944 return 

1945 

1946 if self._context_handle is not None: 

1947 raise RuntimeError( 

1948 "Intra op parallelism cannot be modified after initialization.") 

1949 

1950 self._intra_op_parallelism_threads = num_threads 

1951 

1952 @property 

1953 def inter_op_parallelism_threads(self): 

1954 return self.config.inter_op_parallelism_threads 

1955 

1956 @inter_op_parallelism_threads.setter 

1957 def inter_op_parallelism_threads(self, num_threads): 

1958 if self._inter_op_parallelism_threads == num_threads: 

1959 return 

1960 

1961 if self._context_handle is not None: 

1962 raise RuntimeError( 

1963 "Inter op parallelism cannot be modified after initialization.") 

1964 

1965 self._inter_op_parallelism_threads = num_threads 

1966 

1967 @property 

1968 def soft_device_placement(self): 

1969 return self.config.allow_soft_placement 

1970 

1971 @soft_device_placement.setter 

1972 def soft_device_placement(self, enable): 

1973 if self._context_handle is not None: 

1974 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) 

1975 

1976 self._soft_device_placement = enable 

1977 self._thread_local_data.function_call_options = None 

1978 

1979 @property 

1980 def log_device_placement(self): 

1981 return self.config.log_device_placement 

1982 

1983 @log_device_placement.setter 

1984 def log_device_placement(self, enable): 

1985 if self._context_handle is not None: 

1986 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) 

1987 

1988 self._log_device_placement = enable 

1989 self._thread_local_data.function_call_options = None 

1990 

1991 @property 

1992 def jit_compile_rewrite(self): 

1993 return self._jit_compile_rewrite 

1994 

1995 @jit_compile_rewrite.setter 

1996 def jit_compile_rewrite(self, enable): 

1997 if self._context_handle is not None: 

1998 pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable) 

1999 self._jit_compile_rewrite = enable 

2000 

2001 @property 

2002 def device_policy(self): 

2003 # Only get the policy from the context if it has already been initialized 

2004 if self._context_handle is not None: 

2005 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) 

2006 

2007 return self._device_policy 

2008 

2009 @device_policy.setter 

2010 def device_policy(self, policy): 

2011 if policy is None: 

2012 policy = DEVICE_PLACEMENT_SILENT 

2013 

2014 if self._device_policy != policy: 

2015 self._device_policy = policy 

2016 

2017 # Only set the policy if the context has already been initialized 

2018 if self._context_handle is not None: 

2019 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( 

2020 self._handle, self._device_policy) 

2021 

2022 @property 

2023 def use_tfrt(self): 

2024 return self._use_tfrt 

2025 

2026 @use_tfrt.setter 

2027 def use_tfrt(self, tfrt): 

2028 """Sets whether to use TFRT.""" 

2029 if not isinstance(tfrt, bool): 

2030 raise ValueError("Expecting a boolean but got %s" % type(tfrt)) 

2031 

2032 if self._use_tfrt != tfrt: 

2033 if self._initialized: 

2034 raise ValueError("use_tfrt should be set before being initialized.") 

2035 self._use_tfrt = tfrt 

2036 

2037 @property 

2038 def operation_timeout_in_ms(self): 

2039 return self.config.operation_timeout_in_ms 

2040 

2041 @operation_timeout_in_ms.setter 

2042 def operation_timeout_in_ms(self, timeout_in_ms): 

2043 if self._operation_timeout_in_ms == timeout_in_ms: 

2044 return 

2045 

2046 if self._context_handle is not None: 

2047 raise RuntimeError( 

2048 "Operation timeout cannot be modified after initialization.") 

2049 

2050 self._operation_timeout_in_ms = timeout_in_ms 

2051 

2052 def enable_run_metadata(self): 

2053 """Enables tracing of op execution via RunMetadata. 

2054 

2055 To retrieve the accumulated metadata call context.export_run_metadata() 

2056 and to stop tracing call context.disable_run_metadata(). 

2057 """ 

2058 self.ensure_initialized() 

2059 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) 

2060 

2061 def disable_run_metadata(self): 

2062 """Disables tracing of op execution via RunMetadata.""" 

2063 if not self._context_handle: 

2064 return 

2065 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) 

2066 

2067 def enable_graph_collection(self): 

2068 """Enables graph collection of executed functions. 

2069 

2070 To retrieve the accumulated graphs call context.export_run_metadata() 

2071 and to stop collecting graphs call context.disable_graph_collection(). 

2072 """ 

2073 self.ensure_initialized() 

2074 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) 

2075 

2076 def disable_graph_collection(self): 

2077 """Disables graph collection of executed functions.""" 

2078 if not self._context_handle: 

2079 return 

2080 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) 

2081 

2082 def export_run_metadata(self): 

2083 """Returns a RunMetadata proto with accumulated information. 

2084 

2085 The returned protocol buffer contains information since the most recent call 

2086 to either enable_run_metadata or export_run_metadata. 

2087 

2088 Returns: 

2089 A RunMetadata protocol buffer. Or None if not enabled. 

2090 """ 

2091 if not self._context_handle: 

2092 return None 

2093 with c_api_util.tf_buffer() as buffer_: 

2094 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) 

2095 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 

2096 run_metadata = config_pb2.RunMetadata() 

2097 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 

2098 return run_metadata 

2099 

2100 @property 

2101 def context_switches(self): 

2102 """Returns a stack of context switches.""" 

2103 return self._context_switches 

2104 

2105 

2106class _EagerDeviceContext(object): 

2107 """Context-manager forcing placement of ops and Tensors on a device.""" 

2108 

2109 __slots__ = ["_device_name", "_ctx", "_stack"] 

2110 

2111 def __init__(self, ctx, device_name): 

2112 self._device_name = device_name 

2113 self._ctx = ctx 

2114 self._stack = [] 

2115 

2116 # TODO(b/189233748): Consolidate the device string parsing logic with 

2117 # tensorflow/core/util/device_name_utils.cc. 

2118 def __enter__(self): 

2119 ctx = self._ctx 

2120 old_device_name = ctx.device_name 

2121 old_device_spec = ctx.device_spec 

2122 new_device_name = self._device_name 

2123 cache_key = (old_device_name, new_device_name) 

2124 try: 

2125 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 

2126 except TypeError: 

2127 # Error while trying to compute the cache key. 

2128 raise ValueError("Expecting a string device name. Got %s(%s)" % 

2129 (type(new_device_name), new_device_name)) 

2130 except KeyError: 

2131 # Handle a cache miss. 

2132 if new_device_name is not None: 

2133 if not isinstance(new_device_name, str): 

2134 raise ValueError("Expecting a string device name. Got %s(%s)" % 

2135 (type(new_device_name), new_device_name)) 

2136 device_spec = pydev.DeviceSpec.from_string(new_device_name) 

2137 if old_device_name: 

2138 new_device_spec = copy.copy(old_device_spec) 

2139 else: 

2140 ctx.ensure_initialized() 

2141 new_device_spec = pydev.DeviceSpec.from_string( 

2142 ctx._context_devices[0]) # pylint: disable=protected-access 

2143 new_device_spec = new_device_spec.make_merged_spec(device_spec) 

2144 else: 

2145 new_device_spec = pydev.DeviceSpec.from_string("") 

2146 new_device_name = new_device_spec.to_string() 

2147 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 

2148 

2149 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access 

2150 self._stack.append((old_device_name, old_device_spec, new_device_spec)) 

2151 

2152 def __exit__(self, *ex_info): 

2153 ctx = self._ctx 

2154 old_device_name, old_device_spec, new_device_spec = self._stack[-1] 

2155 if ctx.device_spec is not new_device_spec: 

2156 raise RuntimeError("Exiting device scope without proper scope nesting") 

2157 del self._stack[-1] 

2158 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access 

2159 

2160 

2161# Do not change directly. 

2162_context = None 

2163_context_lock = threading.Lock() 

2164 

2165 

2166def _set_context_locked(ctx): 

2167 global _context 

2168 pywrap_tfe.TFE_Py_SetEagerContext(ctx) 

2169 ctx.mark_as_global_context() 

2170 _context = ctx 

2171 

2172 

2173def _set_context(ctx): 

2174 with _context_lock: 

2175 _set_context_locked(ctx) 

2176 

2177 

2178def _create_context(): 

2179 with _context_lock: 

2180 if _context is None: 

2181 ctx = Context() 

2182 _set_context_locked(ctx) 

2183 

2184 

2185def _reset_context(): 

2186 """Clears and re-initializes the singleton context. 

2187 

2188 Should only be used for testing. 

2189 """ 

2190 global _context 

2191 global _device_parsing_cache 

2192 

2193 # Garbage collect and clear scalar cache to avoid Tensor from current context 

2194 # polluting next context. 

2195 gc.collect() 

2196 pywrap_tfe.TFE_ClearScalarCache() 

2197 with _context_lock: 

2198 if _context is not None: 

2199 _context._clear_caches() 

2200 _context = None 

2201 _create_context() 

2202 _device_parsing_cache = {} 

2203 

2204 

2205def _reset_jit_compiler_flags(): 

2206 """Clears and re-initializes the TF JIT compiler flags. 

2207 

2208 Should only be used for testing. 

2209 """ 

2210 pywrap_tfe.TF_ResetJitCompilerFlags() 

2211 

2212 

2213def context(): 

2214 """Returns a singleton context object.""" 

2215 if _context is None: 

2216 _create_context() 

2217 return _context 

2218 

2219 

2220def context_safe(): 

2221 """Returns current context (or None if one hasn't been initialized).""" 

2222 return _context 

2223 

2224 

2225def ensure_initialized(): 

2226 """Initialize the context.""" 

2227 context().ensure_initialized() 

2228 

2229 

2230def initialize_logical_devices(): 

2231 """Initialize the virtual devices.""" 

2232 context()._initialize_logical_devices() # pylint: disable=protected-access 

2233 

2234 

2235def set_global_seed(seed): 

2236 """Sets the eager mode seed.""" 

2237 context()._set_global_seed(seed) # pylint: disable=protected-access 

2238 

2239 

2240def global_seed(): 

2241 """Returns the eager mode seed.""" 

2242 return context()._seed # pylint: disable=protected-access 

2243 

2244 

2245def internal_operation_seed(): 

2246 """Returns the operation seed generated based on global seed.""" 

2247 return context()._internal_operation_seed() # pylint: disable=protected-access 

2248 

2249 

2250@tf_export("executing_eagerly", v1=[]) 

2251def executing_eagerly(): 

2252 """Checks whether the current thread has eager execution enabled. 

2253 

2254 Eager execution is enabled by default and this API returns `True` 

2255 in most of cases. However, this API might return `False` in the following use 

2256 cases. 

2257 

2258 * Executing inside `tf.function`, unless under `tf.init_scope` or 

2259 `tf.config.run_functions_eagerly(True)` is previously called. 

2260 * Executing inside a transformation function for `tf.dataset`. 

2261 * `tf.compat.v1.disable_eager_execution()` is called. 

2262 

2263 General case: 

2264 

2265 >>> print(tf.executing_eagerly()) 

2266 True 

2267 

2268 Inside `tf.function`: 

2269 

2270 >>> @tf.function 

2271 ... def fn(): 

2272 ... with tf.init_scope(): 

2273 ... print(tf.executing_eagerly()) 

2274 ... print(tf.executing_eagerly()) 

2275 >>> fn() 

2276 True 

2277 False 

2278 

2279 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called: 

2280 

2281 >>> tf.config.run_functions_eagerly(True) 

2282 >>> @tf.function 

2283 ... def fn(): 

2284 ... with tf.init_scope(): 

2285 ... print(tf.executing_eagerly()) 

2286 ... print(tf.executing_eagerly()) 

2287 >>> fn() 

2288 True 

2289 True 

2290 >>> tf.config.run_functions_eagerly(False) 

2291 

2292 Inside a transformation function for `tf.dataset`: 

2293 

2294 >>> def data_fn(x): 

2295 ... print(tf.executing_eagerly()) 

2296 ... return x 

2297 >>> dataset = tf.data.Dataset.range(100) 

2298 >>> dataset = dataset.map(data_fn) 

2299 False 

2300 

2301 Returns: 

2302 `True` if the current thread has eager execution enabled. 

2303 """ 

2304 ctx = context_safe() 

2305 if ctx is None: 

2306 return default_execution_mode == EAGER_MODE 

2307 

2308 return ctx.executing_eagerly() 

2309 

2310 

2311@tf_export(v1=["executing_eagerly"]) 

2312def executing_eagerly_v1(): 

2313 """Checks whether the current thread has eager execution enabled. 

2314 

2315 Eager execution is typically enabled via 

2316 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the 

2317 context of a Python function via tf.contrib.eager.py_func. 

2318 

2319 When eager execution is enabled, returns `True` in most cases. However, 

2320 this API might return `False` in the following use cases. 

2321 

2322 * Executing inside `tf.function`, unless under `tf.init_scope` or 

2323 `tf.config.run_functions_eagerly(True)` is previously called. 

2324 * Executing inside a transformation function for `tf.dataset`. 

2325 * `tf.compat.v1.disable_eager_execution()` is called. 

2326 

2327 >>> tf.compat.v1.enable_eager_execution() 

2328 

2329 General case: 

2330 

2331 >>> print(tf.executing_eagerly()) 

2332 True 

2333 

2334 Inside `tf.function`: 

2335 

2336 >>> @tf.function 

2337 ... def fn(): 

2338 ... with tf.init_scope(): 

2339 ... print(tf.executing_eagerly()) 

2340 ... print(tf.executing_eagerly()) 

2341 >>> fn() 

2342 True 

2343 False 

2344 

2345 Inside `tf.function` 

2346 after `tf.config.run_functions_eagerly(True)` is called: 

2347 

2348 >>> tf.config.run_functions_eagerly(True) 

2349 >>> @tf.function 

2350 ... def fn(): 

2351 ... with tf.init_scope(): 

2352 ... print(tf.executing_eagerly()) 

2353 ... print(tf.executing_eagerly()) 

2354 >>> fn() 

2355 True 

2356 True 

2357 >>> tf.config.run_functions_eagerly(False) 

2358 

2359 Inside a transformation function for `tf.dataset`: 

2360 

2361 >>> def data_fn(x): 

2362 ... print(tf.executing_eagerly()) 

2363 ... return x 

2364 >>> dataset = tf.data.Dataset.range(100) 

2365 >>> dataset = dataset.map(data_fn) 

2366 False 

2367 

2368 Returns: 

2369 `True` if the current thread has eager execution enabled. 

2370 """ 

2371 return executing_eagerly() 

2372 

2373 

2374def in_eager_mode(): 

2375 """Use executing_eagerly() instead. This function will be removed.""" 

2376 return executing_eagerly() 

2377 

2378 

2379def anonymous_name(): 

2380 """Returns the anonymous shared name. 

2381 

2382 In eager mode we create anonymous resources to avoid spurious sharing issues. 

2383 The runtime generates a unique name on our behalf when the reserved 

2384 anonymous shared name is used as a shared name. 

2385 

2386 Returns: 

2387 The anonymous shared name. 

2388 """ 

2389 

2390 # The magic value is defined as 

2391 # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++. 

2392 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 

2393 

2394 

2395def graph_mode(): 

2396 """Context-manager to disable eager execution for the current thread.""" 

2397 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 

2398 

2399 

2400# Used by b/167638505 for keras backend API and Lambda layer. 

2401@tf_export("__internal__.eager_context.eager_mode", v1=[]) 

2402def eager_mode(): 

2403 """Context-manager to enable eager execution for the current thread.""" 

2404 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 

2405 

2406 

2407def scope_name(): 

2408 """Name of the current scope.""" 

2409 return context().scope_name 

2410 

2411 

2412def device(name): 

2413 """Context-manager to force placement of operations and Tensors on a device. 

2414 

2415 Example: 

2416 ```python 

2417 with tf.device('gpu:0'): 

2418 with tf.device('cpu:0'): 

2419 shape = tf.constant([], dtype=tf.int32) 

2420 x = tf.random.truncated_normal(shape, tf.float32) 

2421 ``` 

2422 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 

2423 operation runs on GPU 0. 

2424 

2425 Args: 

2426 name: Name of the device (see context().devices()), or None to perform 

2427 automatic placement. 

2428 

2429 Returns: 

2430 Context manager for setting the device. 

2431 """ 

2432 ensure_initialized() 

2433 return context().device(name) 

2434 

2435 

2436# Expose some properties of Context as internally public APIs (b/160348781). 

2437@tf_export("__internal__.eager_context.get_config", v1=[]) 

2438def get_config(): 

2439 """Get the ConfigProto of Context. 

2440 

2441 Returns: 

2442 The ConfigProto of Context. 

2443 """ 

2444 return context().config 

2445 

2446 

2447@tf_export("__internal__.eager_context.get_device_name", v1=[]) 

2448def get_device_name(): 

2449 """Get the device name for the current thread. 

2450 

2451 Returns: 

2452 The device name for the current thread. 

2453 """ 

2454 return context().device_name 

2455 

2456 

2457@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) 

2458def set_soft_device_placement(enabled): 

2459 """Set if soft device placements should be allowed. 

2460 

2461 Args: 

2462 enabled: Whether to enable soft device placement. 

2463 """ 

2464 context().soft_device_placement = enabled 

2465 

2466 

2467@tf_export("__internal__.eager_context.get_executor", v1=[]) 

2468def get_executor(): 

2469 """Get the Executor of the current thread. 

2470 

2471 Returns: 

2472 The Executor of the current thread. 

2473 """ 

2474 return context().executor 

2475 

2476 

2477@tf_export("debugging.get_log_device_placement") 

2478def get_log_device_placement(): 

2479 """Get if device placements are logged. 

2480 

2481 Returns: 

2482 If device placements are logged. 

2483 """ 

2484 return context().log_device_placement 

2485 

2486 

2487@tf_export("debugging.set_log_device_placement") 

2488def set_log_device_placement(enabled): 

2489 """Turns logging for device placement decisions on or off. 

2490 

2491 Operations execute on a particular device, producing and consuming tensors on 

2492 that device. This may change the performance of the operation or require 

2493 TensorFlow to copy data to or from an accelerator, so knowing where operations 

2494 execute is useful for debugging performance issues. 

2495 

2496 For more advanced profiling, use the [TensorFlow 

2497 profiler](https://www.tensorflow.org/guide/profiler). 

2498 

2499 Device placement for operations is typically controlled by a `tf.device` 

2500 scope, but there are exceptions, for example operations on a `tf.Variable` 

2501 which follow the initial placement of the variable. Turning off soft device 

2502 placement (with `tf.config.set_soft_device_placement`) provides more explicit 

2503 control. 

2504 

2505 >>> tf.debugging.set_log_device_placement(True) 

2506 >>> tf.ones([]) 

2507 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0 

2508 >>> with tf.device("CPU"): 

2509 ... tf.ones([]) 

2510 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0 

2511 >>> tf.debugging.set_log_device_placement(False) 

2512 

2513 Turning on `tf.debugging.set_log_device_placement` also logs the placement of 

2514 ops inside `tf.function` when the function is called. 

2515 

2516 Args: 

2517 enabled: Whether to enabled device placement logging. 

2518 """ 

2519 context().log_device_placement = enabled 

2520 

2521 

2522@tf_contextlib.contextmanager 

2523def device_policy(policy): 

2524 """Context manager for setting device placement policy for current thread.""" 

2525 ctx = context() 

2526 old_policy = ctx.device_policy 

2527 try: 

2528 ctx.device_policy = policy 

2529 yield 

2530 finally: 

2531 ctx.device_policy = old_policy 

2532 

2533 

2534def set_execution_mode(mode): 

2535 """Sets execution mode for the current thread.""" 

2536 context().execution_mode = mode 

2537 

2538 

2539# TODO(fishx): remove this method. 

2540@tf_contextlib.contextmanager 

2541def execution_mode(mode): 

2542 """Context manager for setting execution mode for current thread.""" 

2543 if mode is None: 

2544 yield 

2545 else: 

2546 ctx = context() 

2547 executor_new = executor.new_executor(mode == ASYNC) 

2548 executor_old = ctx.executor 

2549 try: 

2550 executor_old.wait() 

2551 ctx.executor = executor_new 

2552 yield 

2553 finally: 

2554 ctx.executor = executor_old 

2555 executor_new.wait() 

2556 

2557 

2558@tf_contextlib.contextmanager 

2559def executor_scope(e): 

2560 """Context manager for changing executor for current thread. 

2561 

2562 Args: 

2563 e: A Executor to execute eager ops under this scope. Setting it to None will 

2564 switch back to use the default executor for the context. 

2565 

2566 Yields: 

2567 Context manager for setting the executor for current thread. 

2568 """ 

2569 ctx = context() 

2570 executor_old = ctx.executor 

2571 try: 

2572 ctx.executor = e 

2573 yield 

2574 finally: 

2575 ctx.executor = executor_old 

2576 

2577 

2578@tf_export("experimental.function_executor_type") 

2579@tf_contextlib.contextmanager 

2580def function_executor_type(executor_type): 

2581 """Context manager for setting the executor of eager defined functions. 

2582 

2583 Eager defined functions are functions decorated by tf.contrib.eager.defun. 

2584 

2585 Args: 

2586 executor_type: a string for the name of the executor to be used to execute 

2587 functions defined by tf.contrib.eager.defun. 

2588 

2589 Yields: 

2590 Context manager for setting the executor of eager defined functions. 

2591 """ 

2592 current_options = context().function_call_options 

2593 old_options = copy.copy(current_options) 

2594 try: 

2595 current_options.executor_type = executor_type 

2596 yield 

2597 finally: 

2598 context().function_call_options = old_options 

2599 

2600 

2601def is_async(): 

2602 """Returns true if current thread is in async mode.""" 

2603 return context().is_async() 

2604 

2605 

2606def num_gpus(): 

2607 """Get the number of available GPU devices. 

2608 

2609 Returns: 

2610 The number of available GPU devices. 

2611 """ 

2612 return context().num_gpus() 

2613 

2614 

2615def enable_run_metadata(): 

2616 """Enables tracing of op execution via RunMetadata. 

2617 

2618 To retrieve the accumulated metadata call context.export_run_metadata() 

2619 and to stop tracing call context.disable_run_metadata(). 

2620 """ 

2621 context().enable_run_metadata() 

2622 

2623 

2624def disable_run_metadata(): 

2625 """Disables tracing of op execution via RunMetadata.""" 

2626 context().disable_run_metadata() 

2627 

2628 

2629def enable_graph_collection(): 

2630 """Enables graph collection of executed functions. 

2631 

2632 To retrieve the accumulated graphs call context.export_run_metadata() 

2633 and to stop collecting graphs call context.disable_graph_collection(). 

2634 """ 

2635 context().enable_graph_collection() 

2636 

2637 

2638def disable_graph_collection(): 

2639 """Disables graph collection of executed functions.""" 

2640 context().disable_graph_collection() 

2641 

2642 

2643def export_run_metadata(): 

2644 """Returns a RunMetadata proto with accumulated information. 

2645 

2646 The returned protocol buffer contains information since the most recent call 

2647 to either enable_run_metadata or export_run_metadata. 

2648 

2649 Returns: 

2650 A RunMetadata protocol buffer. 

2651 """ 

2652 return context().export_run_metadata() 

2653 

2654 

2655@contextlib.contextmanager 

2656def collect_graphs(optimized=True): 

2657 """Collects a flat list of pre- or post-optimization graphs. 

2658 

2659 The collected graphs include device placements, which can be useful for 

2660 testing. 

2661 

2662 Usage: 

2663 

2664 ``` 

2665 @def_function.function 

2666 def f(x): 

2667 return x + constant_op.constant(1.) 

2668 

2669 with context.collect_graphs() as graphs: 

2670 with ops.device("CPU:0"): 

2671 f(constant_op.constant(1.)) 

2672 

2673 graph, = graphs # `graph` contains a single GraphDef for inspection 

2674 ``` 

2675 

2676 Args: 

2677 optimized: whether to collect optimized graphs or non-optimized graphs 

2678 

2679 Yields: 

2680 A list of GraphDefs, populated when the context manager exits. 

2681 """ 

2682 ctx = context() 

2683 ctx.enable_graph_collection() 

2684 try: 

2685 graphs = [] 

2686 yield graphs 

2687 metadata = ctx.export_run_metadata() 

2688 finally: 

2689 ctx.disable_graph_collection() 

2690 for graph in metadata.function_graphs: 

2691 if optimized: 

2692 graphs.append(graph.post_optimization_graph) 

2693 else: 

2694 graphs.append(graph.pre_optimization_graph) 

2695 

2696 

2697def get_server_def(): 

2698 return context().get_server_def() 

2699 

2700 

2701def set_server_def(server_def): 

2702 context().set_server_def(server_def) 

2703 

2704 

2705def update_server_def(server_def): 

2706 context().update_server_def(server_def) 

2707 

2708 

2709def check_alive(worker_name): 

2710 return context().check_alive(worker_name) 

2711 

2712 

2713@tf_export("experimental.async_scope") 

2714@tf_contextlib.contextmanager 

2715def async_scope(): 

2716 """Context manager for grouping async operations. 

2717 

2718 Ops/function calls inside the scope can return before finishing the actual 

2719 execution. When exiting the async scope, a synchronization barrier will be 

2720 automatically added to ensure the completion of all async op and function 

2721 execution, potentially raising exceptions if async execution results in 

2722 an error state. 

2723 

2724 Users may write the following code to asynchronously invoke `train_step_fn` 

2725 and log the `loss` metric for every `num_steps` steps in a training loop. 

2726 `train_step_fn` internally consumes data using `iterator.get_next()`, and may 

2727 throw OutOfRangeError when running out of data. In the case: 

2728 

2729 ``` 

2730 try: 

2731 with tf.experimental.async_scope(): 

2732 for _ in range(num_steps): 

2733 # Step function updates the metric `loss` internally 

2734 train_step_fn() 

2735 except tf.errors.OutOfRangeError: 

2736 tf.experimental.async_clear_error() 

2737 logging.info('loss = %s', loss.numpy()) 

2738 ``` 

2739 

2740 Yields: 

2741 Context manager for grouping async operations. 

2742 """ 

2743 # TODO(haoyuzhang): replace env var once we have a config method to turn on 

2744 # and off async streaming RPC 

2745 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" 

2746 old_policy = os.environ.get(remote_async_env_var) 

2747 try: 

2748 os.environ[remote_async_env_var] = str(True) 

2749 yield 

2750 # Note: sync local and remote executors iff the async block does not raise 

2751 # an exception. Triggering sync after an exception may lead to derived 

2752 # runtime errors and unexpected exception types. 

2753 context().sync_executors() 

2754 finally: 

2755 if old_policy is None: 

2756 del os.environ[remote_async_env_var] 

2757 else: 

2758 os.environ[remote_async_env_var] = old_policy 

2759 

2760 

2761def async_wait(): 

2762 """Sync all async operations and raise any errors during execution. 

2763 

2764 In async execution mode, an op/function call can return before finishing the 

2765 actual execution. Calling this method creates a synchronization barrier for 

2766 all async op and function execution. It only returns when all pending nodes 

2767 are finished, potentially raising exceptions if async execution results in 

2768 an error state. It is a no-op if the context is not initialized. 

2769 """ 

2770 disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY" 

2771 if os.environ.get(disable_async_executor_env_var) == str(True): 

2772 return 

2773 if context()._context_handle is not None: # pylint: disable=protected-access 

2774 context().sync_executors() 

2775 

2776 

2777@tf_export("experimental.async_clear_error") 

2778def async_clear_error(): 

2779 """Clear pending operations and error statuses in async execution. 

2780 

2781 In async execution mode, an error in op/function execution can lead to errors 

2782 in subsequent ops/functions that are scheduled but not yet executed. Calling 

2783 this method clears all pending operations and reset the async execution state. 

2784 

2785 Example: 

2786 

2787 ``` 

2788 while True: 

2789 try: 

2790 # Step function updates the metric `loss` internally 

2791 train_step_fn() 

2792 except tf.errors.OutOfRangeError: 

2793 tf.experimental.async_clear_error() 

2794 break 

2795 logging.info('loss = %s', loss.numpy()) 

2796 ``` 

2797 """ 

2798 context().clear_executor_errors() 

2799 

2800 

2801def add_c_function(c_func): 

2802 """Add a C API TF_Function to the context.""" 

2803 context().add_c_function(c_func) 

2804 

2805 

2806def get_c_function(name): 

2807 """Get a C API TF_Function from the context.""" 

2808 return context().get_c_function(name) 

2809 

2810 

2811def remove_function(name): 

2812 """Remove a function from the context.""" 

2813 context().remove_function(name) 

2814 

2815 

2816def get_function_def(name): 

2817 return context().get_function_def(name) 

2818 

2819 

2820def is_custom_device(device_name): 

2821 """Calls TFE_IsCustomDevice. 

2822 

2823 Enables using C extensions specifying a custom device from Python. See the 

2824 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 

2825 details. 

2826 

2827 Args: 

2828 device_name: A string indicating the name to check whether it is a 

2829 registered custom device. 

2830 

2831 Returns: 

2832 A boolean. 

2833 """ 

2834 return context().is_custom_device(device_name) 

2835 

2836 

2837def register_custom_device(device_capsule, device_name, device_info_capsule): 

2838 """Calls TFE_RegisterCustomDevice to register a custom device with Python. 

2839 

2840 Enables using C extensions specifying a custom device from Python. See the 

2841 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 

2842 details. 

2843 

2844 Note that custom devices are not currently supported inside `tf.function`s. 

2845 

2846 Args: 

2847 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' 

2848 containing a pointer to a TFE_CustomDevice struct. The capsule retains 

2849 ownership of the memory. 

2850 device_name: A string indicating the name to register the custom device 

2851 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may 

2852 subsequently be passed to `with tf.device(...):`. 

2853 device_info_capsule: A PyCapsule with the name set to 

2854 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific 

2855 struct with the initial state of the custom device (the void* device_info 

2856 argument to TFE_RegisterCustomDevice). This method takes ownership of the 

2857 memory and clears the capsule destructor. 

2858 """ 

2859 context().register_custom_device(device_capsule, device_name, 

2860 device_info_capsule) 

2861 

2862 

2863# Not every user creates a Context via context.context() 

2864# (for example, enable_eager_execution in python/framework/ops.py), 

2865# but they do all import this file. Note that IS_IN_GRAPH_MODE and 

2866# in_graph_mode are both parameterless functions. 

2867def _tmp_in_graph_mode(): 

2868 if context_safe() is None: 

2869 # Context not yet initialized. Assume graph mode following the 

2870 # default implementation in `is_in_graph_mode`. 

2871 return True 

2872 return not executing_eagerly() 

2873 

2874 

2875is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode