Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/context.py: 2%

1260 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""State management for eager execution.""" 

16 

17import collections 

18import contextlib 

19import copy 

20import gc 

21import itertools 

22import os 

23import random 

24import threading 

25 

26from absl import logging 

27import numpy as np 

28 

29from tensorflow.core.framework import function_pb2 

30from tensorflow.core.framework import graph_debug_info_pb2 

31from tensorflow.core.protobuf import config_pb2 

32from tensorflow.core.protobuf import rewriter_config_pb2 

33from tensorflow.python import pywrap_tfe 

34from tensorflow.python import tf2 

35from tensorflow.python.client import pywrap_tf_session 

36from tensorflow.python.eager import cancellation 

37from tensorflow.python.eager import execute 

38from tensorflow.python.eager import executor 

39from tensorflow.python.eager import monitoring 

40from tensorflow.python.framework import c_api_util 

41from tensorflow.python.framework import device as pydev 

42from tensorflow.python.framework import tfrt_utils 

43from tensorflow.python.util import compat 

44from tensorflow.python.util import function_utils 

45from tensorflow.python.util import is_in_graph_mode 

46from tensorflow.python.util import tf_contextlib 

47from tensorflow.python.util.deprecation import deprecated 

48from tensorflow.python.util.tf_export import tf_export 

49from tensorflow.tsl.protobuf import coordination_config_pb2 

50 

51 

52GRAPH_MODE = 0 

53EAGER_MODE = 1 

54 

55default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE 

56 

57# Cache from (old_device_name, partial_new_device_name) -> (new_device_name, 

58# new_device_spec). 

59# Note that we do not protect this with a lock and instead rely on python's GIL 

60# and the idempotent nature of writes to provide thread safety. 

61_device_parsing_cache = {} 

62_starting_device_spec = pydev.DeviceSpec.from_string("") 

63 

64_MAXINT32 = 2**31 - 1 

65 

66DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT 

67DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN 

68DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT 

69DEVICE_PLACEMENT_SILENT_FOR_INT32 = ( 

70 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) 

71 

72SYNC = 0 

73ASYNC = 1 

74 

75_KEEP_ALIVE_SECS = 600 

76 

77_python_eager_context_create_counter = monitoring.Counter( 

78 "/tensorflow/api/python/eager_context_create_counter", 

79 "Counter for number of eager contexts created in Python.") 

80 

81# Re-exporting through context. 

82is_tfrt_enabled = tfrt_utils.enabled 

83 

84# This flag and the associated environment var are transient and will eventually 

85# be removed, once this experiment is enabled by default. 

86_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1" 

87 

88 

89def run_eager_op_as_function_enabled(): 

90 return True 

91 

92 

93# This method should only be called after the context has beein initialized. 

94def enable_jit_compile_rewrite(): 

95 """Run jit_compile functions through rewrite pass. 

96 

97 This runs jit_compile functions through all of the multidevice function 

98 rewrite passes. 

99 """ 

100 global _JIT_COMPILE_REWRITE_ENABLED 

101 _JIT_COMPILE_REWRITE_ENABLED = True 

102 if context_safe() is not None: 

103 context_safe().jit_compile_rewrite = True 

104 

105 

106# This method should only be called after the context has been initialized. 

107def disable_jit_compile_rewrite(): 

108 global _JIT_COMPILE_REWRITE_ENABLED 

109 _JIT_COMPILE_REWRITE_ENABLED = False 

110 if context_safe() is not None: 

111 context_safe().jit_compile_rewrite = False 

112 

113 

114def jit_compile_rewrite_enabled(): 

115 if context_safe() is not None: 

116 return context_safe().jit_compile_rewrite 

117 return _JIT_COMPILE_REWRITE_ENABLED 

118 

119 

120# Expose it as internally public APIs for Keras use cases in b/171080602. 

121tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled) 

122 

123 

124class _EagerTensorCache(object): 

125 """Simple cache which evicts items based on length in a FIFO manner.""" 

126 

127 __slots__ = ["_data", "_max_items", "_max_tensor_size"] 

128 

129 def __init__(self, max_items=256, max_tensor_size=10000): 

130 self._data = collections.OrderedDict() 

131 self._max_items = max_items 

132 self._max_tensor_size = max_tensor_size 

133 

134 def put(self, key, value): 

135 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access 

136 return 

137 

138 self._data[key] = value 

139 

140 if len(self._data) > self._max_items: 

141 self._data.popitem(last=False) 

142 

143 def get(self, key): 

144 return self._data.get(key, None) 

145 

146 def flush(self): 

147 self._data.clear() 

148 

149 

150class FunctionCallOptions: 

151 """Options applied at call sites of eager functions. 

152 

153 Eager functions are functions decorated with tf.contrib.eager.defun. 

154 """ 

155 

156 __slots__ = ["_config_proto_serialized", "_executor_type"] 

157 

158 def __init__(self, executor_type=None, config_proto=None): 

159 """Constructor. 

160 

161 Args: 

162 executor_type: (optional) name of the executor to be used to execute the 

163 eager function. If None or an empty string, the default Tensorflow 

164 executor will be used. 

165 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized 

166 string of that proto. The config used by Grappler when optimizing the 

167 function graph. Each concrete function is optimized the first time is 

168 called. Changing config_proto after the first call has no effect. If 

169 config_proto is None, an empty RewriterConfig will be used. 

170 """ 

171 self.config_proto_serialized = config_proto 

172 self.executor_type = executor_type 

173 

174 @property 

175 def executor_type(self): 

176 return self._executor_type 

177 

178 @executor_type.setter 

179 def executor_type(self, executor_type): 

180 self._executor_type = executor_type 

181 

182 @property 

183 def config_proto_serialized(self): 

184 return self._config_proto_serialized 

185 

186 @config_proto_serialized.setter 

187 def config_proto_serialized(self, config): 

188 if isinstance(config, config_pb2.ConfigProto): 

189 self._config_proto_serialized = config.SerializeToString( 

190 deterministic=True) 

191 elif isinstance(config, str): 

192 self._config_proto_serialized = config 

193 elif config is None: 

194 self._config_proto_serialized = ( 

195 config_pb2.ConfigProto().SerializeToString()) 

196 else: 

197 raise ValueError("the rewriter config must be either a " 

198 "config_pb2.ConfigProto, or a serialized string of that " 

199 "proto or None. got: {}".format(type(config))) 

200 

201 def as_attrs(self): 

202 if self.config_proto_serialized is None: 

203 config = function_utils.get_disabled_rewriter_config() 

204 else: 

205 config = self.config_proto_serialized 

206 executor_type = self.executor_type or "" 

207 

208 return {"executor_type": executor_type, "config_proto": config} 

209 

210 

211# Map from context_id (an int) to _TensorCaches. 

212# Dicts are thread safe in CPython. 

213# TODO(iga): Remove this once TensorCaches are moved to C++. 

214_tensor_caches_map = {} 

215 

216 

217class _TensorCaches(threading.local): 

218 """Thread local tensor caches.""" 

219 

220 __slots__ = ["_ones_rank_cache", "_zeros_cache"] 

221 

222 def __init__(self): 

223 super().__init__() 

224 self._ones_rank_cache = None 

225 self._zeros_cache = None 

226 

227 @property 

228 def ones_rank_cache(self): 

229 if not self._ones_rank_cache: 

230 self._ones_rank_cache = _EagerTensorCache() 

231 return self._ones_rank_cache 

232 

233 @property 

234 def zeros_cache(self): 

235 if not self._zeros_cache: 

236 self._zeros_cache = _EagerTensorCache() 

237 return self._zeros_cache 

238 

239 

240ContextSwitch = collections.namedtuple( 

241 "ContextSwitch", 

242 ["is_building_function", "enter_context_fn", "device_stack"]) 

243 

244 

245# `_ContextSwitchStack` is a `threading.local` to match the semantics of 

246# ``DefaultGraphStack`, which is also a `threading.local`. 

247class _ContextSwitchStack(threading.local): 

248 """A thread-local stack of context switches.""" 

249 

250 def __init__(self, eager): 

251 super().__init__() 

252 self.stack = [] 

253 if eager: 

254 # Initialize the stack with a pointer to enter the eager context; this 

255 # ensures that the fact that eager execution was enabled is propagated 

256 # across threads, since (1) `enable_eager_execution` modifies a 

257 # process-level flag (`default_execution_mode`) and (2) `__init__` is 

258 # called each time a threading.local object is used in a separate thread. 

259 self.push( 

260 is_building_function=False, 

261 enter_context_fn=eager_mode, 

262 device_stack=None) 

263 

264 def push(self, is_building_function, enter_context_fn, device_stack): 

265 """Push metadata about a context switch onto the stack. 

266 

267 A context switch can take any one of the two forms: installing a graph as 

268 the default graph, or entering the eager context. For each context switch, 

269 we record whether or not the entered context is building a function. 

270 

271 Args: 

272 is_building_function: (bool.) Whether the context is building a function. 

273 enter_context_fn: (function.) A callable that executes the context switch. 

274 For example, `graph.as_default` or `eager_mode`. 

275 device_stack: If applicable, the device function stack for this graph. 

276 When breaking out of graphs in init_scope, the innermost nonempty device 

277 stack is used. Eager contexts put `None` here and the value is never 

278 used. 

279 """ 

280 

281 self.stack.append( 

282 ContextSwitch(is_building_function, enter_context_fn, device_stack)) 

283 

284 def pop(self): 

285 """Pop the stack.""" 

286 

287 self.stack.pop() 

288 

289 

290@tf_export("config.LogicalDevice") 

291class LogicalDevice( 

292 collections.namedtuple("LogicalDevice", ["name", "device_type"])): 

293 """Abstraction for a logical device initialized by the runtime. 

294 

295 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a 

296 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors 

297 and operations can be placed on a specific logical device by calling 

298 `tf.device` with a specified `tf.config.LogicalDevice`. 

299 

300 Fields: 

301 name: The fully qualified name of the device. Can be used for Op or function 

302 placement. 

303 device_type: String declaring the type of device such as "CPU" or "GPU". 

304 """ 

305 pass 

306 

307 

308@tf_export("config.LogicalDeviceConfiguration", 

309 "config.experimental.VirtualDeviceConfiguration") 

310class LogicalDeviceConfiguration( 

311 collections.namedtuple("LogicalDeviceConfiguration", [ 

312 "memory_limit", "experimental_priority", "experimental_device_ordinal" 

313 ])): 

314 """Configuration class for a logical devices. 

315 

316 The class specifies the parameters to configure a `tf.config.PhysicalDevice` 

317 as it is initialized to a `tf.config.LogicalDevice` during runtime 

318 initialization. Not all fields are valid for all device types. 

319 

320 See `tf.config.get_logical_device_configuration` and 

321 `tf.config.set_logical_device_configuration` for usage examples. 

322 

323 Fields: 

324 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual 

325 device. Currently only supported for GPUs. 

326 experimental_priority: (optional) Priority to assign to a virtual device. 

327 Lower values have higher priorities and 0 is the default. 

328 Within a physical GPU, the GPU scheduler will prioritize ops on virtual 

329 devices with higher priority. Currently only supported for Nvidia GPUs. 

330 experimental_device_ordinal: (optional) Ordinal number to order the virtual 

331 device. 

332 LogicalDevice with lower ordinal number will receive a lower device id. 

333 Physical device id and location in the list is used to break ties. 

334 Currently only supported for Nvidia GPUs. 

335 """ 

336 

337 def __new__(cls, 

338 memory_limit=None, 

339 experimental_priority=None, 

340 experimental_device_ordinal=None): 

341 return super().__new__(cls, memory_limit, experimental_priority, 

342 experimental_device_ordinal) 

343 

344 

345@tf_export("config.PhysicalDevice") 

346class PhysicalDevice( 

347 collections.namedtuple("PhysicalDevice", ["name", "device_type"])): 

348 """Abstraction for a locally visible physical device. 

349 

350 TensorFlow can utilize various devices such as the CPU or multiple GPUs 

351 for computation. Before initializing a local device for use, the user can 

352 customize certain properties of the device such as it's visibility or memory 

353 configuration. 

354 

355 Once a visible `tf.config.PhysicalDevice` is initialized one or more 

356 `tf.config.LogicalDevice` objects are created. Use 

357 `tf.config.set_visible_devices` to configure the visibility of a physical 

358 device and `tf.config.set_logical_device_configuration` to configure multiple 

359 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is 

360 useful when separation between models is needed or to simulate a multi-device 

361 environment. 

362 

363 Fields: 

364 name: Unique identifier for device. 

365 device_type: String declaring the type of device such as "CPU" or "GPU". 

366 """ 

367 pass 

368 

369 

370class _AtomicCounter(object): 

371 """A simple atomic counter.""" 

372 

373 __slots__ = ["_value", "_lock"] 

374 

375 def __init__(self): 

376 self._value = 0 

377 self._lock = threading.Lock() 

378 

379 def increment_and_get(self): 

380 with self._lock: 

381 self._value += 1 

382 return self._value 

383 

384 

385_context_id_counter = _AtomicCounter() 

386 

387 

388class _TensorCacheDeleter(object): 

389 """Deletes tensor caches for a given context.""" 

390 

391 __slots__ = ["_context_id"] 

392 

393 def __init__(self, context_id): 

394 self._context_id = context_id 

395 

396 def __del__(self): 

397 if _tensor_caches_map is None: 

398 return 

399 if self._context_id in _tensor_caches_map: 

400 del _tensor_caches_map[self._context_id] 

401 

402 

403# TODO(agarwal): rename to EagerContext / EagerRuntime ? 

404# TODO(agarwal): consider keeping the corresponding Graph here. 

405class Context: 

406 """Environment in which eager operations execute.""" 

407 

408 # TODO(agarwal): create and link in some documentation for `execution_mode`. 

409 # pylint: disable=redefined-outer-name 

410 def __init__(self, 

411 config=None, 

412 device_policy=None, 

413 execution_mode=None, 

414 server_def=None): 

415 """Creates a new Context. 

416 

417 Args: 

418 config: (Optional.) A `ConfigProto` protocol buffer with configuration 

419 options for the Context. Note that a lot of these options may be 

420 currently unimplemented or irrelevant when eager execution is enabled. 

421 device_policy: (Optional.) What policy to use when trying to run an 

422 operation on a device with inputs which are not on that device. When set 

423 to None, an appropriate value will be picked automatically. The value 

424 picked may change between TensorFlow releases. Defaults to 

425 DEVICE_PLACEMENT_SILENT. 

426 Valid values: 

427 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not 

428 correct. 

429 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right 

430 device but raises a warning. 

431 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide 

432 performance problems. 

433 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors, 

434 raising errors on the other ones. 

435 execution_mode: (Optional.) Policy controlling how operations dispatched 

436 are actually executed. When set to None, an appropriate value will be 

437 picked automatically. The value picked may change between TensorFlow 

438 releases. 

439 Valid values: 

440 - SYNC: executes each operation synchronously. 

441 - ASYNC: executes each operation asynchronously. These operations may 

442 return "non-ready" handles. 

443 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution 

444 on remote devices. GrpcServers need to be started by creating an 

445 identical server_def to this, and setting the appropriate task_indexes, 

446 so that the servers can communicate. It will then be possible to execute 

447 operations on remote devices. 

448 

449 Raises: 

450 ValueError: If execution_mode is not valid. 

451 """ 

452 # This _id is used only to index the tensor caches. 

453 # TODO(iga): Remove this when tensor caches are moved to C++. 

454 self._id = _context_id_counter.increment_and_get() 

455 self._tensor_cache_deleter = _TensorCacheDeleter(self._id) 

456 _tensor_caches_map[self._id] = _TensorCaches() 

457 

458 self._config = config 

459 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData( 

460 self, 

461 is_eager=lambda: default_execution_mode == EAGER_MODE, 

462 device_spec=_starting_device_spec) 

463 self._context_switches = _ContextSwitchStack(self.executing_eagerly()) 

464 self._context_handle = None 

465 self._context_devices = None 

466 self._seed = None 

467 self._initialize_lock = threading.Lock() 

468 self._initialized = False 

469 if device_policy is None: 

470 device_policy = DEVICE_PLACEMENT_SILENT 

471 self._device_policy = device_policy 

472 self._mirroring_policy = None 

473 if execution_mode not in (None, SYNC, ASYNC): 

474 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" % 

475 execution_mode) 

476 if execution_mode is None: 

477 execution_mode = SYNC 

478 self._default_is_async = execution_mode == ASYNC 

479 self._use_tfrt = is_tfrt_enabled() 

480 self._jit_compile_rewrite = jit_compile_rewrite_enabled() 

481 self._server_def = server_def 

482 self._collective_ops_server_def = None 

483 self._collective_leader = None 

484 self._collective_scoped_allocator_enabled_ops = None 

485 self._collective_use_nccl_communication = None 

486 self._collective_device_filters = None 

487 self._coordination_service_config = None 

488 

489 self._device_lock = threading.Lock() 

490 self._physical_devices = None 

491 self._physical_device_to_index = None 

492 self._pluggable_devices = None 

493 self._visible_device_list = [] 

494 self._memory_growth_map = None 

495 self._virtual_device_map = {} 

496 

497 # Values set after construction 

498 self._optimizer_jit = None 

499 self._intra_op_parallelism_threads = None 

500 self._inter_op_parallelism_threads = None 

501 self._soft_device_placement = None 

502 self._log_device_placement = None 

503 self._operation_timeout_in_ms = None 

504 self._enable_mlir_graph_optimization = None 

505 self._optimizer_experimental_options = {} 

506 

507 _python_eager_context_create_counter.get_cell().increase_by(1) 

508 

509 self._is_global_context = False 

510 

511 # pylint: enable=redefined-outer-name 

512 

513 def _set_global_seed(self, seed): 

514 """Set a global eager mode seed for random ops.""" 

515 self._seed = seed 

516 # `random.Random(seed)` needs `seed` to be hashable, while values of type 

517 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them 

518 # to int. 

519 try: 

520 hash(seed) 

521 self._rng = random.Random(seed) 

522 except TypeError: 

523 seed = int(np.array(seed)) 

524 self._rng = random.Random(seed) 

525 # Also clear the kernel cache, to reset any existing seeds 

526 if self._context_handle is not None: 

527 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 

528 

529 def _internal_operation_seed(self): 

530 """Returns a fake operation seed. 

531 

532 In eager mode, user shouldn't set or depend on operation seed. 

533 Here, we generate a random seed based on global seed to make 

534 operation's randomness different and depend on the global seed. 

535 

536 Returns: 

537 A fake operation seed based on global seed. 

538 """ 

539 return self._rng.randint(0, _MAXINT32) 

540 

541 def _initialize_logical_devices(self): 

542 """Helper to initialize devices.""" 

543 # Store list of devices 

544 logical_devices = [] 

545 context_devices = [] 

546 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle) 

547 try: 

548 self._num_gpus = 0 

549 current_job, current_task = None, None 

550 server_def = self._server_def or self._collective_ops_server_def 

551 if server_def is not None: 

552 current_job, current_task = server_def.job_name, server_def.task_index 

553 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)): 

554 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i) 

555 context_devices.append(pydev.canonical_name(dev_name)) 

556 spec = pydev.DeviceSpec.from_string(dev_name) 

557 # If the job is localhost, we assume that the cluster has not yet been 

558 # configured and thus clear the job, replica & task. 

559 if spec.job == "localhost": 

560 spec = spec.replace(job=None, replica=None, task=None) 

561 logical_devices.append( 

562 LogicalDevice(name=spec.to_string(), device_type=spec.device_type)) 

563 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i) 

564 if (dev_type == "GPU" and spec.job == current_job and 

565 spec.task == current_task): 

566 self._num_gpus += 1 

567 

568 finally: 

569 self._logical_devices = logical_devices 

570 self._context_devices = context_devices 

571 pywrap_tfe.TF_DeleteDeviceList(device_list) 

572 

573 def ensure_initialized(self): 

574 """Initialize handle and devices if not already done so.""" 

575 if self._initialized: 

576 return 

577 with self._initialize_lock: 

578 if self._initialized: 

579 return 

580 assert self._context_devices is None 

581 opts = pywrap_tfe.TFE_NewContextOptions() 

582 try: 

583 config_str = self.config.SerializeToString() 

584 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str) 

585 if self._device_policy is not None: 

586 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy( 

587 opts, self._device_policy) 

588 if self._mirroring_policy is not None: 

589 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy( 

590 opts, self._mirroring_policy) 

591 if self._default_is_async == ASYNC: 

592 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True) 

593 if self._use_tfrt is not None: 

594 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt) 

595 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True) 

596 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite( 

597 opts, self._jit_compile_rewrite) 

598 context_handle = pywrap_tfe.TFE_NewContext(opts) 

599 finally: 

600 pywrap_tfe.TFE_DeleteContextOptions(opts) 

601 assert not (self._server_def and self._collective_ops_server_def), ( 

602 "Cannot enable remote execution as well as collective ops at the " 

603 "moment. If this is important to you, please file an issue.") 

604 if self._server_def is not None: 

605 server_def_str = self._server_def.SerializeToString() 

606 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS, 

607 server_def_str) 

608 elif self._collective_ops_server_def is not None: 

609 server_def_str = self._collective_ops_server_def.SerializeToString() 

610 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str) 

611 

612 self._context_handle = context_handle 

613 self._initialize_logical_devices() 

614 self._initialized = True 

615 

616 if self._is_global_context: 

617 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 

618 

619 def ensure_uninitialized(self): 

620 """Uninitialize handle and devices if not already done so.""" 

621 with self._initialize_lock: 

622 if not self._initialized: 

623 return 

624 self._context_devices = None 

625 self._logical_devices = None 

626 self._server_def = None 

627 self._initialized = False 

628 

629 if self._is_global_context: 

630 pywrap_tfe.TFE_Py_SetCEagerContext(None) 

631 

632 self._context_handle = None 

633 

634 def mark_as_global_context(self): 

635 # If the context was already initialized, publish it. Otherwise wait with 

636 # publication until it's initialized. 

637 if self._initialized: 

638 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle) 

639 self._is_global_context = True 

640 

641 def _clear_caches(self): 

642 self.ones_rank_cache().flush() 

643 self.zeros_cache().flush() 

644 pywrap_tfe.TFE_ClearScalarCache() 

645 

646 def get_server_def(self): 

647 return self._server_def 

648 

649 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 

650 """Allow setting a server_def on the context. 

651 

652 When a server def is replaced, it effectively clears a bunch of caches 

653 within the context. If you attempt to use a tensor object that was pointing 

654 to a tensor on the remote device, it will raise an error. 

655 

656 Args: 

657 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

658 devices. 

659 keep_alive_secs: Num. seconds after which the remote end will hang up. As 

660 long as the client is still alive, the server state for the context will 

661 be kept alive. If the client is killed (or there is some failure), the 

662 server will clean up its context keep_alive_secs after the final RPC it 

663 receives. 

664 

665 Raises: 

666 ValueError: if server_def is None. 

667 """ 

668 if not server_def: 

669 raise ValueError("server_def is None.") 

670 

671 self._server_def = server_def 

672 

673 if self._context_handle: 

674 server_def_str = server_def.SerializeToString() 

675 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs, 

676 server_def_str) 

677 self._initialize_logical_devices() 

678 

679 # Clear all the caches in case there are remote tensors in them. 

680 self._clear_caches() 

681 

682 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS): 

683 """Update a server_def on the context. 

684 

685 Args: 

686 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

687 devices. 

688 keep_alive_secs: Num. seconds after which the remote end will hang up. As 

689 long as the client is still alive, the server state for the context will 

690 be kept alive. If the client is killed (or there is some failure), the 

691 server will clean up its context keep_alive_secs after the final RPC it 

692 receives. 

693 

694 Raises: 

695 ValueError: if server_def is None. 

696 """ 

697 if not server_def: 

698 raise ValueError("server_def is None.") 

699 

700 self._server_def = server_def 

701 

702 if self._context_handle: 

703 server_def_str = server_def.SerializeToString() 

704 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle, 

705 keep_alive_secs, server_def_str) 

706 self._initialize_logical_devices() 

707 

708 self._clear_caches() 

709 

710 def check_alive(self, worker_name): 

711 """Checks whether a remote worker is alive or not. 

712 

713 Args: 

714 worker_name: a string representing the remote worker. It must be a fully 

715 specified name like "/job:worker/replica:0/task:0". 

716 

717 Returns: 

718 a boolean indicating whether the remote worker is alive or not. 

719 

720 Raises: 

721 ValueError: if context is not initialized. 

722 """ 

723 # TODO(yuefengz): support checking multiple workers. 

724 if self._context_handle: 

725 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name) 

726 else: 

727 raise ValueError("Context is not initialized.") 

728 

729 def sync_executors(self): 

730 """Sync both local executors and the ones on remote workers. 

731 

732 In async execution mode, local function calls can return before the 

733 corresponding remote op/function execution requests are completed. Calling 

734 this method creates a synchronization barrier for remote executors. It only 

735 returns when all remote pending nodes are finished, potentially with errors 

736 if any remote executors are in error state. 

737 

738 Raises: 

739 ValueError: if context is not initialized. 

740 """ 

741 if self._context_handle: 

742 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle) 

743 else: 

744 raise ValueError("Context is not initialized.") 

745 

746 def clear_executor_errors(self): 

747 """Clear errors in both local executors and remote workers. 

748 

749 After receiving errors from remote workers, additional requests on the fly 

750 could further taint the status on the remote workers due to the async nature 

751 of remote execution. Calling this method block on waiting for all pending 

752 nodes in remote executors to finish and clear their error statuses. 

753 

754 Raises: 

755 ValueError: if context is not initialized. 

756 """ 

757 if self._context_handle: 

758 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle) 

759 else: 

760 raise ValueError("Context is not initialized.") 

761 

762 def configure_coordination_service(self, 

763 service_type, 

764 service_leader="", 

765 enable_health_check=True, 

766 cluster_register_timeout_in_ms=0, 

767 heartbeat_timeout_in_ms=0, 

768 shutdown_barrier_timeout_in_ms=0, 

769 coordinated_jobs=None, 

770 allow_new_incarnation_to_reconnect=False): 

771 """Enable distributed coordination service with specified configs.""" 

772 if self._context_handle: 

773 logging.warning("Configuring coordination service type may not be " 

774 "effective because the context is already initialized.") 

775 config = coordination_config_pb2.CoordinationServiceConfig() 

776 config.service_type = service_type 

777 if service_leader: 

778 config.service_leader = pydev.canonical_name(service_leader) 

779 config.enable_health_check = enable_health_check 

780 config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms 

781 config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms 

782 config.shutdown_barrier_timeout_in_ms = shutdown_barrier_timeout_in_ms 

783 config.allow_new_incarnation_to_reconnect = ( 

784 allow_new_incarnation_to_reconnect) 

785 if coordinated_jobs is not None: 

786 if isinstance(coordinated_jobs, list): 

787 config.coordinated_job_list.extend(coordinated_jobs) 

788 else: 

789 raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or " 

790 "None, but got: %s" % (coordinated_jobs,)) 

791 self._coordination_service_config = config 

792 

793 @property 

794 def coordination_service(self): 

795 return self._coordination_service_config 

796 

797 def set_config_key_value(self, key, value): 

798 ensure_initialized() 

799 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value) 

800 

801 # If `timeout_in_ms=0`, this will block until the key-value is set or the 

802 # worker shuts down. 

803 def get_config_key_value(self, key, timeout_in_ms=0): 

804 ensure_initialized() 

805 with c_api_util.tf_buffer() as buffer_: 

806 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key, 

807 timeout_in_ms, buffer_) 

808 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 

809 return value 

810 

811 def delete_config_key_value(self, key): 

812 ensure_initialized() 

813 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key) 

814 

815 def report_error_to_cluster(self, error_code, error_message): 

816 """Report error to other members in a multi-client cluster. 

817 

818 Args: 

819 error_code: a `tf.errors` error code. 

820 error_message: a string. The error message. 

821 """ 

822 if self._context_handle: 

823 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code, 

824 error_message) 

825 else: 

826 raise ValueError("Context is not initialized.") 

827 

828 def get_task_states(self, job_configs): 

829 """Get task states from the Coordination Service. 

830 

831 Args: 

832 job_configs: A list of tuples of job name and task number. 

833 

834 Returns: 

835 A list of TF_Status. 

836 """ 

837 if self._context_handle: 

838 job_names, task_nums = zip(*job_configs) 

839 return pywrap_tfe.TFE_GetTaskStates(self._context_handle, job_names, 

840 task_nums) 

841 else: 

842 raise ValueError("Context is not initialized.") 

843 

844 def wait_at_barrier(self, barrier_id, timeout_in_ms): 

845 """Blocks until all coordinated tasks are at the barrier. 

846 

847 The barrier may fail if it times out or if one of the tasks is unhealthy. 

848 

849 Args: 

850 barrier_id: Unique string identifying the barrier. 

851 timeout_in_ms: Duration before the barrier times out and fails. 

852 """ 

853 ensure_initialized() 

854 pywrap_tfe.TFE_WaitAtBarrier(self._context_handle, barrier_id, 

855 timeout_in_ms) 

856 

857 def clear_kernel_cache(self): 

858 """Clear kernel cache and reset all stateful kernels.""" 

859 if self._context_handle is not None: 

860 pywrap_tfe.TFE_ContextClearCaches(self._context_handle) 

861 

862 def enable_collective_ops(self, server_def): 

863 """Enable distributed collective ops with an appropriate server_def. 

864 

865 Args: 

866 server_def: A tensorflow::ServerDef proto. Enables execution on remote 

867 devices. 

868 

869 Raises: 

870 ValueError: if server_def is None. 

871 RuntimeError: if this method is not called at program startup. 

872 """ 

873 if not server_def: 

874 raise ValueError("server_def is None.") 

875 

876 self._collective_ops_server_def = server_def 

877 

878 # TODO(b/129298253): Allow creating datasets/tensors before enabling 

879 # collective ops. 

880 if self._context_handle is not None: 

881 logging.warning("Enabling collective ops after program startup may cause " 

882 "error when accessing previously created tensors.") 

883 with self._initialize_lock: 

884 assert self._initialized 

885 server_def_str = self._collective_ops_server_def.SerializeToString() 

886 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str) 

887 self._initialize_logical_devices() 

888 self._clear_caches() 

889 

890 def configure_collective_ops( 

891 self, 

892 collective_leader="", 

893 scoped_allocator_enabled_ops=("CollectiveReduce",), 

894 use_nccl_communication=False, 

895 device_filters=None): 

896 """Configure collective ops. 

897 

898 Collective group leader is necessary for collective ops to run, other 

899 configurations are mainly for the purpose of performance. 

900 

901 Args: 

902 collective_leader: a device string for collective leader, e.g. 

903 "/job:worker/replica:0/task:0"; empty string means local execution of 

904 collective ops. 

905 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped 

906 allocator to run with. 

907 use_nccl_communication: whether to use nccl communication for collective 

908 ops. 

909 device_filters: a tuple or a list of device strings. If set, corresponding 

910 task can only see the devices filtered by these device filters. 

911 

912 Raises: 

913 RuntimeError: if this method is not called at program startup. 

914 """ 

915 if self._collective_leader is not None: 

916 if (self._collective_leader != collective_leader or 

917 self._collective_scoped_allocator_enabled_ops != 

918 scoped_allocator_enabled_ops or 

919 self._collective_use_nccl_communication != use_nccl_communication or 

920 self._collective_device_filters != device_filters): 

921 raise ValueError("Collective ops are already configured.") 

922 else: 

923 return 

924 

925 if self._context_handle is not None: 

926 raise RuntimeError("Collective ops must be configured at program startup") 

927 

928 self._collective_leader = collective_leader 

929 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops 

930 self._collective_use_nccl_communication = use_nccl_communication 

931 self._collective_device_filters = device_filters 

932 

933 def abort_collective_ops(self, code, message): 

934 """Abort the collective ops. 

935 

936 This is intended to be used when a peer failure is detected, which allows 

937 the user to handle the case instead of hanging. This aborts all on-going 

938 collectives. After all subsequent collectives error immediately, and you 

939 need to reset_context() to use collectives again. 

940 

941 Args: 

942 code: a `tf.errors` error code. 

943 message: a string. The error message. 

944 """ 

945 self.ensure_initialized() 

946 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message) 

947 

948 def check_collective_ops_peer_health(self, task, timeout_in_ms): 

949 """Check collective peer health. 

950 

951 This probes each task to see if they're still alive. Note that restarted 

952 tasks are considered a different one, and they're considered not healthy. 

953 

954 This should only be used in multi client multi worker training. 

955 

956 Args: 

957 task: a task string, must be in the format of /job:xxx/replica:0/task:N. 

958 timeout_in_ms: an integer, the timeout. If zero, there's no timeout. 

959 

960 Raises: 

961 tf.errors.UnavailableError: when a peer is down. 

962 tf.errors.FailedPreconditionError: when a peer is a different one from the 

963 one this task has talked to, e.g. the peer has restarted. 

964 tf.errors.InvalidArgumentError: when the task string is invalid. 

965 """ 

966 self.ensure_initialized() 

967 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task, 

968 timeout_in_ms) 

969 

970 @property 

971 def _handle(self): 

972 if self._context_handle is None: 

973 raise AssertionError("Context must be initialized first.") 

974 

975 return self._context_handle 

976 

977 @property 

978 def _devices(self): 

979 if self._context_devices is None: 

980 raise AssertionError("Context must be initialized first.") 

981 

982 return self._context_devices 

983 

984 def __str__(self): 

985 if self._context_handle is None: 

986 return "Eager TensorFlow Context. Devices currently uninitialized." 

987 else: 

988 devices = self._devices 

989 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))] 

990 for i, d in enumerate(devices): 

991 lines.append(" Device %d: %s" % (i, d)) 

992 return "\n".join(lines) 

993 

994 @tf_contextlib.contextmanager 

995 def _mode(self, mode): 

996 """A context manager to allow setting the mode to EAGER/GRAPH.""" 

997 ctx = self._thread_local_data 

998 old_is_eager = ctx.is_eager 

999 ctx.is_eager = mode == EAGER_MODE 

1000 if mode == EAGER_MODE: 

1001 # Entering graph mode does not provide us with sufficient information to 

1002 # record a context switch; graph-based context switches are only logged 

1003 # when a graph is registered as the default graph. 

1004 self.context_switches.push(False, eager_mode, None) 

1005 try: 

1006 yield 

1007 finally: 

1008 ctx.is_eager = old_is_eager 

1009 if mode == EAGER_MODE: 

1010 self.context_switches.pop() 

1011 

1012 def executing_eagerly(self): 

1013 """Returns True if current thread has eager executing enabled.""" 

1014 return self._thread_local_data.is_eager 

1015 

1016 def ones_rank_cache(self): 

1017 """Per-device cache for scalars.""" 

1018 return _tensor_caches_map[self._id].ones_rank_cache 

1019 

1020 def zeros_cache(self): 

1021 """Per-device cache for scalars.""" 

1022 return _tensor_caches_map[self._id].zeros_cache 

1023 

1024 @property 

1025 def scope_name(self): 

1026 """Returns scope name for the current thread.""" 

1027 return self._thread_local_data.scope_name 

1028 

1029 @scope_name.setter 

1030 def scope_name(self, s): 

1031 """Sets scope name for the current thread.""" 

1032 self._thread_local_data.scope_name = s 

1033 

1034 @property 

1035 def device_name(self): 

1036 """Returns the device name for the current thread.""" 

1037 return self._thread_local_data.device_name 

1038 

1039 @property 

1040 def device_spec(self): 

1041 """Returns the device spec for the current thread.""" 

1042 return self._thread_local_data.device_spec 

1043 

1044 def _set_device(self, device_name, device_spec): 

1045 self._thread_local_data.device_name = device_name 

1046 self._thread_local_data.device_spec = device_spec 

1047 

1048 def device(self, name): 

1049 """Context-manager to force placement of operations and Tensors on a device. 

1050 

1051 Args: 

1052 name: Name of the device or None to get default placement. 

1053 

1054 Returns: 

1055 Context manager that forces device placement. 

1056 

1057 Raises: 

1058 ValueError: If name is not a string or is an invalid device name. 

1059 RuntimeError: If device scopes are not properly nested. 

1060 """ 

1061 if isinstance(name, LogicalDevice): 

1062 name = name.name 

1063 elif pydev.is_device_spec(name): 

1064 name = name.to_string() 

1065 return _EagerDeviceContext(self, name) 

1066 

1067 def devices(self): 

1068 """List of the names of devices available to execute operations.""" 

1069 return self._devices 

1070 

1071 def host_address_space(self): 

1072 self.ensure_initialized() 

1073 with c_api_util.tf_buffer() as buffer_: 

1074 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_) 

1075 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8") 

1076 return address_space 

1077 

1078 # TODO(fishx): remove this property. 

1079 @property 

1080 def execution_mode(self): 

1081 """Gets execution mode for current thread.""" 

1082 return ASYNC if self.is_async() else SYNC 

1083 

1084 @execution_mode.setter 

1085 def execution_mode(self, mode): 

1086 """Sets execution mode for current thread.""" 

1087 if mode not in (None, SYNC, ASYNC): 

1088 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" % 

1089 mode) 

1090 

1091 if mode is None: 

1092 mode = SYNC 

1093 

1094 enable_async = (mode == ASYNC) 

1095 if self.is_async() != enable_async: 

1096 # Only set the execution mode if the context has already been initialized 

1097 if self._context_handle is not None: 

1098 self.executor.wait() 

1099 executor_new = executor.new_executor(enable_async) 

1100 self._thread_local_data.executor = executor_new 

1101 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, 

1102 executor_new.handle()) 

1103 else: 

1104 self._default_is_async = enable_async 

1105 

1106 def is_async(self): 

1107 if self._context_handle is not None: 

1108 return self.executor.is_async() 

1109 else: 

1110 return self._default_is_async 

1111 

1112 @property 

1113 def executor(self): 

1114 self.ensure_initialized() 

1115 return executor.Executor( 

1116 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) 

1117 

1118 @executor.setter 

1119 def executor(self, e): 

1120 self.ensure_initialized() 

1121 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) 

1122 

1123 @property 

1124 def config(self): 

1125 """Return the ConfigProto with all runtime deltas applied.""" 

1126 # Ensure physical devices have been discovered and config has been imported 

1127 self._initialize_physical_devices() 

1128 

1129 config = config_pb2.ConfigProto() 

1130 if self._config is not None: 

1131 config.CopyFrom(self._config) 

1132 

1133 if self._optimizer_jit is not None: 

1134 config.graph_options.optimizer_options.global_jit_level = ( 

1135 config_pb2.OptimizerOptions.ON_1 

1136 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF) 

1137 if self._intra_op_parallelism_threads is not None: 

1138 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads 

1139 if self._inter_op_parallelism_threads is not None: 

1140 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads 

1141 

1142 if self._soft_device_placement is not None: 

1143 config.allow_soft_placement = self._soft_device_placement 

1144 else: 

1145 config.allow_soft_placement = self.executing_eagerly() 

1146 

1147 if self._log_device_placement is not None: 

1148 config.log_device_placement = self._log_device_placement 

1149 

1150 if self._operation_timeout_in_ms is not None: 

1151 config.operation_timeout_in_ms = self._operation_timeout_in_ms 

1152 

1153 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled() 

1154 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled 

1155 if (is_mlir_bridge_enabled == 

1156 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED): 

1157 config.experimental.enable_mlir_bridge = True 

1158 

1159 if self._enable_mlir_graph_optimization is not None: 

1160 config.experimental.enable_mlir_graph_optimization = ( 

1161 self._enable_mlir_graph_optimization) 

1162 

1163 def rewriter_toggle(option): 

1164 toggle = self._optimizer_experimental_options.get(option, None) 

1165 if toggle is None: 

1166 return 

1167 

1168 setattr(config.graph_options.rewrite_options, option, 

1169 (rewriter_config_pb2.RewriterConfig.ON 

1170 if toggle else rewriter_config_pb2.RewriterConfig.OFF)) 

1171 

1172 def rewriter_bool(option): 

1173 toggle = self._optimizer_experimental_options.get(option, None) 

1174 if toggle is None: 

1175 return 

1176 

1177 setattr(config.graph_options.rewrite_options, option, toggle) 

1178 

1179 rewriter_toggle("layout_optimizer") 

1180 rewriter_toggle("constant_folding") 

1181 rewriter_toggle("shape_optimization") 

1182 rewriter_toggle("remapping") 

1183 rewriter_toggle("arithmetic_optimization") 

1184 rewriter_toggle("dependency_optimization") 

1185 rewriter_toggle("loop_optimization") 

1186 rewriter_toggle("function_optimization") 

1187 rewriter_toggle("debug_stripper") 

1188 rewriter_bool("disable_model_pruning") 

1189 rewriter_toggle("scoped_allocator_optimization") 

1190 rewriter_toggle("pin_to_host_optimization") 

1191 rewriter_toggle("implementation_selector") 

1192 rewriter_toggle("auto_mixed_precision") 

1193 rewriter_toggle("use_plugin_optimizers") 

1194 rewriter_bool("disable_meta_optimizer") 

1195 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 

1196 rewriter_toggle("auto_mixed_precision_mkl") 

1197 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None) 

1198 if nodes is not None: 

1199 config.graph_options.rewrite_options.min_graph_nodes = nodes 

1200 

1201 # Compute device counts 

1202 config.device_count["CPU"] = 0 

1203 config.device_count["GPU"] = 0 

1204 for dev in self._physical_devices: 

1205 if dev not in self._visible_device_list: 

1206 continue 

1207 

1208 virtual_devices = self._virtual_device_map.get(dev) 

1209 if virtual_devices is None: 

1210 config.device_count[dev.device_type] += 1 

1211 else: 

1212 config.device_count[dev.device_type] += len(virtual_devices) 

1213 

1214 # Configure gpu_options 

1215 gpu_options = self._compute_gpu_options() 

1216 config.gpu_options.MergeFrom(gpu_options) 

1217 

1218 # Configure collective ops 

1219 if self._collective_leader: 

1220 config.experimental.collective_group_leader = self._collective_leader 

1221 if self._collective_scoped_allocator_enabled_ops: 

1222 rewrite_options = config.graph_options.rewrite_options 

1223 rewrite_options.scoped_allocator_optimization = ( 

1224 rewriter_config_pb2.RewriterConfig.ON) 

1225 del rewrite_options.scoped_allocator_opts.enable_op[:] 

1226 for op in self._collective_scoped_allocator_enabled_ops: 

1227 rewrite_options.scoped_allocator_opts.enable_op.append(op) 

1228 if self._collective_use_nccl_communication: 

1229 config.experimental.collective_nccl = True 

1230 if self._collective_device_filters: 

1231 del config.device_filters[:] 

1232 for f in self._collective_device_filters: 

1233 config.device_filters.append(f) 

1234 

1235 # Configure coordination service 

1236 if self._coordination_service_config: 

1237 config.experimental.coordination_config.CopyFrom( 

1238 self._coordination_service_config) 

1239 

1240 return config 

1241 

1242 def _compute_gpu_options(self): 

1243 """Build the GPUOptions proto.""" 

1244 visible_device_list = [] 

1245 virtual_devices = [] 

1246 gpu_index = -1 

1247 memory_growths = set() 

1248 gpu_devices = self.list_physical_devices("GPU") 

1249 pluggable_devices = self._pluggable_devices 

1250 compatible_devices = gpu_devices 

1251 for dev in pluggable_devices: 

1252 if dev not in gpu_devices: 

1253 compatible_devices.append(dev) 

1254 for dev in compatible_devices: 

1255 gpu_index += 1 

1256 

1257 if dev not in self._visible_device_list: 

1258 continue 

1259 

1260 growth = self._memory_growth_map[dev] 

1261 memory_growths.add(growth) 

1262 visible_device_list.append(str(gpu_index)) 

1263 

1264 if self._virtual_device_map: 

1265 vdevs = self._virtual_device_map.get(dev, []) 

1266 device_ordinals = [] 

1267 device_limits = [] 

1268 priority = [] 

1269 for virt_dev in vdevs: 

1270 if virt_dev.experimental_device_ordinal is not None: 

1271 device_ordinals.append(virt_dev.experimental_device_ordinal) 

1272 device_limits.append(virt_dev.memory_limit) 

1273 if virt_dev.experimental_priority is not None: 

1274 priority.append(virt_dev.experimental_priority) 

1275 # If priority is specified, it must be specified for all virtual 

1276 # devices. 

1277 if priority and len(device_limits) != len(priority): 

1278 raise ValueError("priority must be specified for all virtual devices") 

1279 # If device_ordinals is specified, it must be specified for all virtual 

1280 # devices. 

1281 if device_ordinals and len(device_limits) != len(device_ordinals): 

1282 raise ValueError( 

1283 "device_ordinals must be specified for all virtual devices") 

1284 

1285 virtual_devices.append( 

1286 config_pb2.GPUOptions.Experimental.VirtualDevices( 

1287 memory_limit_mb=device_limits, 

1288 priority=priority, 

1289 device_ordinal=device_ordinals)) 

1290 

1291 # Only compute growth if virtual devices have not been configured and we 

1292 # have GPUs 

1293 if not virtual_devices and memory_growths: 

1294 if len(memory_growths) > 1: 

1295 raise ValueError("Memory growth cannot differ between GPU devices") 

1296 allow_growth = memory_growths.pop() 

1297 else: 

1298 allow_growth = None 

1299 

1300 return config_pb2.GPUOptions( 

1301 allow_growth=allow_growth, 

1302 visible_device_list=",".join(visible_device_list), 

1303 experimental=config_pb2.GPUOptions.Experimental( 

1304 virtual_devices=virtual_devices)) 

1305 

1306 @property 

1307 def function_call_options(self): 

1308 """Returns function call options for current thread. 

1309 

1310 Note that the returned object is still referenced by the eager context. 

1311 

1312 Returns: the FunctionCallOptions for current thread. 

1313 """ 

1314 if self._thread_local_data.function_call_options is None: 

1315 config = self.config 

1316 

1317 # Default to soft placement for functions unless specified 

1318 if self._soft_device_placement is None: 

1319 config.allow_soft_placement = True 

1320 self._thread_local_data.function_call_options = FunctionCallOptions( 

1321 config_proto=config) 

1322 

1323 return self._thread_local_data.function_call_options 

1324 

1325 @function_call_options.setter 

1326 def function_call_options(self, options): 

1327 """Returns function call options for current thread.""" 

1328 self._thread_local_data.function_call_options = options 

1329 

1330 def num_gpus(self): 

1331 """The number of GPUs available to execute operations.""" 

1332 self.ensure_initialized() 

1333 return self._num_gpus 

1334 

1335 def add_c_function(self, c_func): 

1336 """Add a C API TF_Function to the context. 

1337 

1338 Once added, the function (identified by its name) can be executed like any 

1339 other operation. 

1340 

1341 Args: 

1342 c_func: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper). 

1343 """ 

1344 self.ensure_initialized() 

1345 pywrap_tfe.TFE_ContextAddFunction(self._handle, c_func) 

1346 

1347 def get_c_function(self, name): 

1348 """Get a C API TF_Function from the context. 

1349 

1350 Args: 

1351 name: Name of the function to get. 

1352 

1353 Returns: 

1354 A ScopedTFFunction wrapping the C API TF_Function. 

1355 """ 

1356 self.ensure_initialized() 

1357 return c_api_util.ScopedTFFunction( 

1358 pywrap_tfe.TFE_ContextGetFunction(self._handle, name), name 

1359 ) 

1360 

1361 def add_function_def(self, fdef): 

1362 """Add a function definition to the context. 

1363 

1364 Once added, the function (identified by its name) can be executed like any 

1365 other operation. 

1366 

1367 Args: 

1368 fdef: A FunctionDef protocol buffer message. 

1369 """ 

1370 self.ensure_initialized() 

1371 fdef_string = fdef.SerializeToString() 

1372 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string, 

1373 len(fdef_string)) 

1374 

1375 def get_function_def(self, name): 

1376 """Get a function definition from the context. 

1377 

1378 Args: 

1379 name: function signature name. 

1380 

1381 Returns: 

1382 The requested FunctionDef. 

1383 

1384 Raises: 

1385 tf.errors.NotFoundError: if name is not the name of a registered function. 

1386 """ 

1387 with c_api_util.tf_buffer() as buffer_: 

1388 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_) 

1389 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 

1390 function_def = function_pb2.FunctionDef() 

1391 function_def.ParseFromString(proto_data) 

1392 

1393 return function_def 

1394 

1395 def get_graph_debug_info(self, name): 

1396 """Get GraphDebugInfo associated with a function from the context. 

1397 

1398 Args: 

1399 name: function signature name. 

1400 

1401 Returns: 

1402 The requested GraphDebugInfo. 

1403 

1404 Raises: 

1405 tf.errors.NotFoundError: if name is not the name of a registered function. 

1406 """ 

1407 with c_api_util.tf_buffer() as buffer_: 

1408 pywrap_tfe.TFE_ContextGetGraphDebugInfo(self._handle, name, buffer_) 

1409 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 

1410 graph_debug_info = graph_debug_info_pb2.GraphDebugInfo() 

1411 graph_debug_info.ParseFromString(proto_data) 

1412 

1413 return graph_debug_info 

1414 

1415 def is_custom_device(self, device_name): 

1416 """Calls TFE_IsCustomDevice. See the non-member function.""" 

1417 self.ensure_initialized() 

1418 return pywrap_tfe.TFE_Py_IsCustomDevice(self._handle, device_name) 

1419 

1420 def register_custom_device(self, device_capsule, device_name, 

1421 device_info_capsule): 

1422 """Calls TFE_RegisterCustomDevice. See the non-member function.""" 

1423 self.ensure_initialized() 

1424 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule, 

1425 device_name, device_info_capsule) 

1426 

1427 def pack_eager_tensors(self, tensors): 

1428 """Pack multiple `EagerTensor`s of the same dtype and shape. 

1429 

1430 Args: 

1431 tensors: a list of EagerTensors to pack. 

1432 

1433 Returns: 

1434 A packed EagerTensor. 

1435 """ 

1436 self.ensure_initialized() 

1437 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors) 

1438 

1439 def list_function_names(self): 

1440 """Get a list of names of registered functions. 

1441 

1442 Returns: 

1443 A set of names of all registered functions for the context. 

1444 """ 

1445 self.ensure_initialized() 

1446 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle)) 

1447 

1448 def remove_function(self, name): 

1449 """Remove a function from the context. 

1450 

1451 Once removed, the function cannot be executed anymore. 

1452 

1453 Args: 

1454 name: function signature name. 

1455 """ 

1456 self.ensure_initialized() 

1457 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name) 

1458 

1459 def has_function(self, name): 

1460 """Check if a function `name` is registered.""" 

1461 self.ensure_initialized() 

1462 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name)) 

1463 

1464 @property 

1465 def function_scope_id(self): 

1466 """Returns an id that is unique to each scope holding functions.""" 

1467 return id(self._context_handle) 

1468 

1469 def call_function(self, name, tensor_inputs, num_outputs): 

1470 """Calls the function associated with the given name.""" 

1471 attrs = tuple( 

1472 itertools.chain( 

1473 *self.function_call_options.as_attrs().items() 

1474 ) 

1475 ) 

1476 

1477 cancellation_context = cancellation.context() 

1478 if cancellation_context is None: 

1479 outputs = execute.execute( 

1480 name.decode("utf-8"), 

1481 num_outputs=num_outputs, 

1482 inputs=tensor_inputs, 

1483 attrs=attrs, 

1484 ctx=self, 

1485 ) 

1486 else: 

1487 outputs = execute.execute_with_cancellation( 

1488 name.decode("utf-8"), 

1489 num_outputs=num_outputs, 

1490 inputs=tensor_inputs, 

1491 attrs=attrs, 

1492 ctx=self, 

1493 cancellation_manager=cancellation_context, 

1494 ) 

1495 # Empty list means no function outputs so return None 

1496 outputs = outputs or None 

1497 

1498 return outputs 

1499 

1500 def add_op_callback(self, callback): 

1501 """Add a post-op callback to the context. 

1502 

1503 A post-op callback is invoked immediately after an eager operation or 

1504 function has finished execution or after a op has been added to a graph, 

1505 providing access to the op's type, name input and output tensors. Multiple 

1506 op callbacks can be added, in which case the callbacks will be invoked in 

1507 the order in which they are added. 

1508 

1509 Args: 

1510 callback: a callable of the signature `f(op_type, inputs, attrs, outputs, 

1511 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for 

1512 details on the function signature and its semantics. 

1513 """ 

1514 if callback not in self._thread_local_data.op_callbacks: 

1515 self._thread_local_data.op_callbacks.append(callback) 

1516 

1517 def remove_op_callback(self, callback): 

1518 """Remove an already-registered op callback. 

1519 

1520 Args: 

1521 callback: The op callback to be removed. 

1522 

1523 Raises: 

1524 KeyError: If `callback` is not already registered. 

1525 """ 

1526 if callback not in self._thread_local_data.op_callbacks: 

1527 raise KeyError("The specified op callback has not been registered, " 

1528 "and hence cannot be removed.") 

1529 del self._thread_local_data.op_callbacks[ 

1530 self._thread_local_data.op_callbacks.index(callback)] 

1531 

1532 @property 

1533 def op_callbacks(self): 

1534 return self._thread_local_data.op_callbacks 

1535 

1536 @property 

1537 def invoking_op_callbacks(self): 

1538 return self._thread_local_data.invoking_op_callbacks 

1539 

1540 @invoking_op_callbacks.setter 

1541 def invoking_op_callbacks(self, value): 

1542 self._thread_local_data.invoking_op_callbacks = value 

1543 

1544 def _initialize_physical_devices(self, reinitialize=False): 

1545 """Gets local devices visible to the system. 

1546 

1547 Args: 

1548 reinitialize: If True, reinitializes self._physical_devices so that 

1549 dynamic registered devices will also be visible to the python front-end. 

1550 """ 

1551 # We lazy initialize self._physical_devices since we do not want to do this 

1552 # the constructor since the backend may not be initialized yet. 

1553 with self._device_lock: 

1554 if not reinitialize and self._physical_devices is not None: 

1555 return 

1556 

1557 devs = pywrap_tfe.TF_ListPhysicalDevices() 

1558 self._physical_devices = [ 

1559 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 

1560 for d in devs 

1561 ] 

1562 self._physical_device_to_index = { 

1563 p: i for i, p in enumerate(self._physical_devices) 

1564 } 

1565 # We maintain a separate list just so we can check whether the device in 

1566 # _physical_devices is a PluggableDevice. 

1567 pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices() 

1568 self._pluggable_devices = [ 

1569 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1]) 

1570 for d in pluggable_devs 

1571 ] 

1572 

1573 self._visible_device_list = list(self._physical_devices) 

1574 self._memory_growth_map = { 

1575 d: None 

1576 for d in self._physical_devices 

1577 if d.device_type == "GPU" or d in self._pluggable_devices 

1578 } 

1579 

1580 # Import device settings that may have been passed into the constructor 

1581 self._import_config() 

1582 

1583 def reinitialize_physical_devices(self): 

1584 """Gets local devices visible to the system.""" 

1585 # Reinitialize the physical device list after registering 

1586 # the pluggable device. 

1587 self._initialize_physical_devices(True) 

1588 

1589 def list_physical_devices(self, device_type=None): 

1590 """List local devices visible to the system. 

1591 

1592 This API allows a client to query the devices before they have been 

1593 initialized by the eager runtime. Additionally a user can filter by device 

1594 type, to get only CPUs or GPUs. 

1595 

1596 Args: 

1597 device_type: Optional device type to limit results to 

1598 

1599 Returns: 

1600 List of PhysicalDevice objects. 

1601 """ 

1602 self._initialize_physical_devices() 

1603 

1604 if device_type is None: 

1605 return list(self._physical_devices) 

1606 

1607 return [d for d in self._physical_devices if d.device_type == device_type] 

1608 

1609 def get_device_details(self, device): # pylint: disable=redefined-outer-name 

1610 """Returns details about a physical devices. 

1611 

1612 Args: 

1613 device: A `tf.config.PhysicalDevice` returned by 

1614 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`. 

1615 

1616 Returns: 

1617 A dict with string keys. 

1618 """ 

1619 if not isinstance(device, PhysicalDevice): 

1620 raise ValueError("device must be a tf.config.PhysicalDevice, but got: " 

1621 "%s" % (device,)) 

1622 if (self._physical_device_to_index is None or 

1623 device not in self._physical_device_to_index): 

1624 raise ValueError("The PhysicalDevice must be one obtained from " 

1625 "calling `tf.config.list_physical_devices`, but got: " 

1626 "%s" % (device,)) 

1627 index = self._physical_device_to_index[device] 

1628 details = pywrap_tfe.TF_GetDeviceDetails(index) 

1629 

1630 # Change compute_capability from a string to a tuple 

1631 if "compute_capability" in details: 

1632 try: 

1633 major, minor = details["compute_capability"].split(".") 

1634 details["compute_capability"] = (int(major), int(minor)) 

1635 except ValueError: 

1636 raise RuntimeError("Device returned compute capability an in invalid " 

1637 "format: %s" % details["compute_capability"]) 

1638 return details 

1639 

1640 def _import_config(self): 

1641 """Import config if passed in during construction. 

1642 

1643 If Context was created with a ConfigProto such as when calling 

1644 tf.compat.v1.enable_eager_execution(), then we need to pull out the 

1645 various pieces we might be replacing and import then into our internal 

1646 class representation. 

1647 """ 

1648 if self._config is None: 

1649 return 

1650 

1651 num_cpus = self._config.device_count.get("CPU", 1) 

1652 if num_cpus != 1: 

1653 cpus = [d for d in self._physical_devices if d.device_type == "CPU"] 

1654 if num_cpus == 0: 

1655 self.set_visible_devices([], "CPU") 

1656 elif num_cpus > 1: 

1657 self.set_logical_device_configuration( 

1658 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)]) 

1659 

1660 # Parse GPU options 

1661 gpus = [d for d in self._physical_devices if d.device_type == "GPU"] 

1662 

1663 # If there are no GPUs detected, simply ignore all the GPU options passed in 

1664 # rather than doing any validation checks. 

1665 if not gpus: 

1666 return 

1667 

1668 gpu_count = self._config.device_count.get("GPU", None) 

1669 

1670 visible_gpus = [] 

1671 # TODO(gjn): Handle importing existing virtual GPU configuration 

1672 visible_indices = self._config.gpu_options.visible_device_list 

1673 if visible_indices: 

1674 for index in visible_indices.split(","): 

1675 if int(index) >= len(gpus): 

1676 raise ValueError("Invalid visible device index: %s" % index) 

1677 visible_gpus.append(gpus[int(index)]) 

1678 else: 

1679 visible_gpus = gpus 

1680 

1681 if gpu_count is not None: 

1682 visible_gpus = visible_gpus[:gpu_count] 

1683 

1684 self.set_visible_devices(visible_gpus, "GPU") 

1685 

1686 def list_logical_devices(self, device_type=None): 

1687 """Return logical devices.""" 

1688 self.ensure_initialized() 

1689 if device_type is None: 

1690 return list(self._logical_devices) 

1691 

1692 return [d for d in self._logical_devices if d.device_type == device_type] 

1693 

1694 def get_visible_devices(self, device_type=None): 

1695 """Get the list of visible devices.""" 

1696 self._initialize_physical_devices() 

1697 

1698 if device_type is None: 

1699 return list(self._visible_device_list) 

1700 

1701 return [ 

1702 d for d in self._visible_device_list if d.device_type == device_type 

1703 ] 

1704 

1705 def set_visible_devices(self, devices, device_type=None): 

1706 """Set the list of visible devices.""" 

1707 self._initialize_physical_devices() 

1708 

1709 if not isinstance(devices, list): 

1710 devices = [devices] 

1711 

1712 for d in devices: 

1713 if d not in self._physical_devices: 

1714 raise ValueError("Unrecognized device: %s" % repr(d)) 

1715 if device_type is not None and d.device_type != device_type: 

1716 raise ValueError("Unrecognized device: %s" % repr(d)) 

1717 

1718 visible_device_list = [] 

1719 if device_type is not None: 

1720 visible_device_list = [ 

1721 d for d in self._visible_device_list if d.device_type != device_type 

1722 ] 

1723 

1724 visible_device_list += devices 

1725 

1726 if self._visible_device_list == visible_device_list: 

1727 return 

1728 

1729 if self._context_handle is not None: 

1730 raise RuntimeError( 

1731 "Visible devices cannot be modified after being initialized") 

1732 

1733 self._visible_device_list = visible_device_list 

1734 

1735 def get_memory_info(self, dev): 

1736 """Returns a dict of memory info for the device.""" 

1737 self._initialize_physical_devices() 

1738 self.ensure_initialized() 

1739 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev) 

1740 

1741 def reset_memory_stats(self, dev): 

1742 """Resets the tracked memory stats for the device.""" 

1743 self._initialize_physical_devices() 

1744 self.ensure_initialized() 

1745 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev) 

1746 

1747 def get_memory_growth(self, dev): 

1748 """Get if memory growth is enabled for a PhysicalDevice.""" 

1749 self._initialize_physical_devices() 

1750 

1751 if dev not in self._physical_devices: 

1752 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1753 

1754 return self._memory_growth_map[dev] 

1755 

1756 def set_memory_growth(self, dev, enable): 

1757 """Set if memory growth should be enabled for a PhysicalDevice.""" 

1758 self._initialize_physical_devices() 

1759 

1760 if dev not in self._physical_devices: 

1761 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1762 

1763 if dev in self._virtual_device_map: 

1764 raise ValueError( 

1765 "Cannot set memory growth on device when virtual devices configured") 

1766 

1767 if dev.device_type != "GPU" and dev not in self._pluggable_devices: 

1768 raise ValueError( 

1769 "Cannot set memory growth on non-GPU and non-Pluggable devices") 

1770 

1771 if self._memory_growth_map.get(dev) == enable: 

1772 return 

1773 

1774 if self._context_handle is not None: 

1775 raise RuntimeError( 

1776 "Physical devices cannot be modified after being initialized") 

1777 

1778 self._memory_growth_map[dev] = enable 

1779 

1780 def get_logical_device_configuration(self, dev): 

1781 """Get the virtual device configuration for a PhysicalDevice.""" 

1782 self._initialize_physical_devices() 

1783 

1784 if dev not in self._physical_devices: 

1785 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1786 

1787 return self._virtual_device_map.get(dev) 

1788 

1789 def set_logical_device_configuration(self, dev, virtual_devices): 

1790 """Set the virtual device configuration for a PhysicalDevice.""" 

1791 self._initialize_physical_devices() 

1792 

1793 if dev not in self._physical_devices: 

1794 raise ValueError("Unrecognized device: %s" % repr(dev)) 

1795 

1796 if dev.device_type == "CPU": 

1797 for vdev in virtual_devices: 

1798 if vdev.memory_limit is not None: 

1799 raise ValueError("Setting memory limit on CPU virtual devices is " 

1800 "currently not supported") 

1801 if vdev.experimental_priority is not None: 

1802 raise ValueError("Setting experimental_priority on CPU virtual " 

1803 " devices is currently not supported") 

1804 if vdev.experimental_device_ordinal is not None: 

1805 raise ValueError("Setting experimental_device_ordinal on CPU virtual " 

1806 " devices is currently not supported") 

1807 elif dev.device_type == "GPU": 

1808 for vdev in virtual_devices: 

1809 if vdev.memory_limit is None: 

1810 raise ValueError( 

1811 "Setting memory limit is required for GPU virtual devices") 

1812 else: 

1813 raise ValueError("Virtual devices are not supported for %s" % 

1814 dev.device_type) 

1815 

1816 if self._virtual_device_map.get(dev) == virtual_devices: 

1817 return 

1818 

1819 if self._context_handle is not None: 

1820 raise RuntimeError( 

1821 "Virtual devices cannot be modified after being initialized") 

1822 

1823 self._virtual_device_map[dev] = virtual_devices 

1824 

1825 def set_logical_cpu_devices(self, num_cpus, prefix=""): 

1826 """Set virtual CPU devices in context. 

1827 

1828 If virtual CPU devices are already configured at context initialization 

1829 by tf.config.set_logical_device_configuration(), this method should not be 

1830 called. 

1831 

1832 Args: 

1833 num_cpus: Number of virtual CPUs. 

1834 prefix: Device name prefix. 

1835 

1836 Raises: 

1837 RuntimeError: If virtual CPUs are already configured at context 

1838 initialization. 

1839 """ 

1840 server_def = self._server_def or self._collective_ops_server_def 

1841 local_prefix = ["/device"] 

1842 if server_def is not None: 

1843 local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name, 

1844 server_def.task_index)) 

1845 logical_local_devices = [d for d in self.list_logical_devices("CPU") if 

1846 d.name.startswith(tuple(local_prefix))] 

1847 self.ensure_initialized() 

1848 # Error out if there are already multiple logical CPU in the context. 

1849 if len(logical_local_devices) > 1: 

1850 raise RuntimeError("Virtual CPUs already set, cannot modify again.") 

1851 

1852 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix) 

1853 self._initialize_logical_devices() 

1854 

1855 def get_compiler_ir( 

1856 self, 

1857 device_name, 

1858 function_name, 

1859 flat_args, 

1860 captured_inputs, 

1861 stage="hlo", 

1862 ): 

1863 return pywrap_tfe.TF_GetCompilerIr( 

1864 self._context_handle, 

1865 function_name, 

1866 stage, 

1867 device_name, 

1868 flat_args, 

1869 captured_inputs, 

1870 ) 

1871 

1872 @deprecated( 

1873 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True) 

1874 def enable_xla_devices(self): 

1875 """Enables XLA:CPU and XLA:GPU devices registration.""" 

1876 pywrap_tfe.TF_EnableXlaDevices() 

1877 

1878 @property 

1879 def enable_mlir_bridge(self): 

1880 return pywrap_tfe.TF_IsMlirBridgeEnabled() 

1881 

1882 @property 

1883 def enable_mlir_graph_optimization(self): 

1884 return self._enable_mlir_graph_optimization 

1885 

1886 @enable_mlir_bridge.setter 

1887 def enable_mlir_bridge(self, enabled): 

1888 pywrap_tfe.TF_EnableMlirBridge(enabled) 

1889 self._thread_local_data.function_call_options = None 

1890 

1891 @enable_mlir_graph_optimization.setter 

1892 def enable_mlir_graph_optimization(self, enabled): 

1893 self._enable_mlir_graph_optimization = enabled 

1894 self._thread_local_data.function_call_options = None 

1895 

1896 @property 

1897 def optimizer_jit(self): 

1898 level = self.config.graph_options.optimizer_options.global_jit_level 

1899 return (level == config_pb2.OptimizerOptions.ON_1 or 

1900 level == config_pb2.OptimizerOptions.ON_2) 

1901 

1902 @optimizer_jit.setter 

1903 def optimizer_jit(self, enabled): 

1904 self._optimizer_jit = enabled 

1905 

1906 self._thread_local_data.function_call_options = None 

1907 

1908 def get_optimizer_experimental_options(self): 

1909 """Get experimental options for the optimizer. 

1910 

1911 Returns: 

1912 Dictionary of current option values 

1913 """ 

1914 rewrite_options = self.config.graph_options.rewrite_options 

1915 options = {} 

1916 

1917 def rewriter_toggle(option): 

1918 attr = getattr(rewrite_options, option) 

1919 if attr != 0: 

1920 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON) 

1921 

1922 def rewriter_bool(option): 

1923 options[option] = getattr(rewrite_options, option) 

1924 

1925 rewriter_toggle("layout_optimizer") 

1926 rewriter_toggle("constant_folding") 

1927 rewriter_toggle("shape_optimization") 

1928 rewriter_toggle("remapping") 

1929 rewriter_toggle("arithmetic_optimization") 

1930 rewriter_toggle("dependency_optimization") 

1931 rewriter_toggle("loop_optimization") 

1932 rewriter_toggle("function_optimization") 

1933 rewriter_toggle("debug_stripper") 

1934 rewriter_bool("disable_model_pruning") 

1935 rewriter_toggle("scoped_allocator_optimization") 

1936 rewriter_toggle("pin_to_host_optimization") 

1937 rewriter_toggle("implementation_selector") 

1938 rewriter_toggle("auto_mixed_precision") 

1939 rewriter_toggle("use_plugin_optimizers") 

1940 rewriter_bool("disable_meta_optimizer") 

1941 rewriter_toggle("auto_mixed_precision_onednn_bfloat16") 

1942 rewriter_toggle("auto_mixed_precision_mkl") 

1943 

1944 if rewrite_options.min_graph_nodes != 0: 

1945 options["min_graph_nodes"] = rewrite_options.min_graph_nodes 

1946 

1947 return options 

1948 

1949 def set_optimizer_experimental_options(self, options): 

1950 """Set experimental options for the optimizer. 

1951 

1952 Args: 

1953 options: Dictionary of options to modify 

1954 """ 

1955 self._optimizer_experimental_options.update(options) 

1956 

1957 self._thread_local_data.function_call_options = None 

1958 

1959 @property 

1960 def intra_op_parallelism_threads(self): 

1961 return self.config.intra_op_parallelism_threads 

1962 

1963 @intra_op_parallelism_threads.setter 

1964 def intra_op_parallelism_threads(self, num_threads): 

1965 if self._intra_op_parallelism_threads == num_threads: 

1966 return 

1967 

1968 if self._context_handle is not None: 

1969 raise RuntimeError( 

1970 "Intra op parallelism cannot be modified after initialization.") 

1971 

1972 self._intra_op_parallelism_threads = num_threads 

1973 

1974 @property 

1975 def inter_op_parallelism_threads(self): 

1976 return self.config.inter_op_parallelism_threads 

1977 

1978 @inter_op_parallelism_threads.setter 

1979 def inter_op_parallelism_threads(self, num_threads): 

1980 if self._inter_op_parallelism_threads == num_threads: 

1981 return 

1982 

1983 if self._context_handle is not None: 

1984 raise RuntimeError( 

1985 "Inter op parallelism cannot be modified after initialization.") 

1986 

1987 self._inter_op_parallelism_threads = num_threads 

1988 

1989 @property 

1990 def soft_device_placement(self): 

1991 return self.config.allow_soft_placement 

1992 

1993 @soft_device_placement.setter 

1994 def soft_device_placement(self, enable): 

1995 if self._context_handle is not None: 

1996 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable) 

1997 

1998 self._soft_device_placement = enable 

1999 self._thread_local_data.function_call_options = None 

2000 

2001 @property 

2002 def log_device_placement(self): 

2003 return self.config.log_device_placement 

2004 

2005 @log_device_placement.setter 

2006 def log_device_placement(self, enable): 

2007 if self._context_handle is not None: 

2008 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable) 

2009 

2010 self._log_device_placement = enable 

2011 self._thread_local_data.function_call_options = None 

2012 

2013 @property 

2014 def jit_compile_rewrite(self): 

2015 return self._jit_compile_rewrite 

2016 

2017 @jit_compile_rewrite.setter 

2018 def jit_compile_rewrite(self, enable): 

2019 if self._context_handle is not None: 

2020 pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable) 

2021 self._jit_compile_rewrite = enable 

2022 

2023 @property 

2024 def device_policy(self): 

2025 # Only get the policy from the context if it has already been initialized 

2026 if self._context_handle is not None: 

2027 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle) 

2028 

2029 return self._device_policy 

2030 

2031 @device_policy.setter 

2032 def device_policy(self, policy): 

2033 if policy is None: 

2034 policy = DEVICE_PLACEMENT_SILENT 

2035 

2036 if self._device_policy != policy: 

2037 self._device_policy = policy 

2038 

2039 # Only set the policy if the context has already been initialized 

2040 if self._context_handle is not None: 

2041 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy( 

2042 self._handle, self._device_policy) 

2043 

2044 @property 

2045 def use_tfrt(self): 

2046 return self._use_tfrt 

2047 

2048 @use_tfrt.setter 

2049 def use_tfrt(self, tfrt): 

2050 """Sets whether to use TFRT.""" 

2051 if not isinstance(tfrt, bool): 

2052 raise ValueError("Expecting a boolean but got %s" % type(tfrt)) 

2053 

2054 if self._use_tfrt != tfrt: 

2055 if self._initialized: 

2056 raise ValueError("use_tfrt should be set before being initialized.") 

2057 self._use_tfrt = tfrt 

2058 

2059 @property 

2060 def operation_timeout_in_ms(self): 

2061 return self.config.operation_timeout_in_ms 

2062 

2063 @operation_timeout_in_ms.setter 

2064 def operation_timeout_in_ms(self, timeout_in_ms): 

2065 if self._operation_timeout_in_ms == timeout_in_ms: 

2066 return 

2067 

2068 if self._context_handle is not None: 

2069 raise RuntimeError( 

2070 "Operation timeout cannot be modified after initialization.") 

2071 

2072 self._operation_timeout_in_ms = timeout_in_ms 

2073 

2074 def enable_run_metadata(self): 

2075 """Enables tracing of op execution via RunMetadata. 

2076 

2077 To retrieve the accumulated metadata call context.export_run_metadata() 

2078 and to stop tracing call context.disable_run_metadata(). 

2079 """ 

2080 self.ensure_initialized() 

2081 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle) 

2082 

2083 def disable_run_metadata(self): 

2084 """Disables tracing of op execution via RunMetadata.""" 

2085 if not self._context_handle: 

2086 return 

2087 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle) 

2088 

2089 def enable_graph_collection(self): 

2090 """Enables graph collection of executed functions. 

2091 

2092 To retrieve the accumulated graphs call context.export_run_metadata() 

2093 and to stop collecting graphs call context.disable_graph_collection(). 

2094 """ 

2095 self.ensure_initialized() 

2096 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle) 

2097 

2098 def disable_graph_collection(self): 

2099 """Disables graph collection of executed functions.""" 

2100 if not self._context_handle: 

2101 return 

2102 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle) 

2103 

2104 def export_run_metadata(self): 

2105 """Returns a RunMetadata proto with accumulated information. 

2106 

2107 The returned protocol buffer contains information since the most recent call 

2108 to either enable_run_metadata or export_run_metadata. 

2109 

2110 Returns: 

2111 A RunMetadata protocol buffer. Or None if not enabled. 

2112 """ 

2113 if not self._context_handle: 

2114 return None 

2115 with c_api_util.tf_buffer() as buffer_: 

2116 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_) 

2117 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_) 

2118 run_metadata = config_pb2.RunMetadata() 

2119 run_metadata.ParseFromString(compat.as_bytes(proto_data)) 

2120 return run_metadata 

2121 

2122 @property 

2123 def context_switches(self): 

2124 """Returns a stack of context switches.""" 

2125 return self._context_switches 

2126 

2127 

2128class _EagerDeviceContext(object): 

2129 """Context-manager forcing placement of ops and Tensors on a device.""" 

2130 

2131 __slots__ = ["_device_name", "_ctx", "_stack"] 

2132 

2133 def __init__(self, ctx, device_name): 

2134 self._device_name = device_name 

2135 self._ctx = ctx 

2136 self._stack = [] 

2137 

2138 # TODO(b/189233748): Consolidate the device string parsing logic with 

2139 # tensorflow/core/util/device_name_utils.cc. 

2140 def __enter__(self): 

2141 ctx = self._ctx 

2142 old_device_name = ctx.device_name 

2143 old_device_spec = ctx.device_spec 

2144 new_device_name = self._device_name 

2145 cache_key = (old_device_name, new_device_name) 

2146 try: 

2147 new_device_name, new_device_spec = _device_parsing_cache[cache_key] 

2148 except TypeError: 

2149 # Error while trying to compute the cache key. 

2150 raise ValueError("Expecting a string device name. Got %s(%s)" % 

2151 (type(new_device_name), new_device_name)) 

2152 except KeyError: 

2153 # Handle a cache miss. 

2154 if new_device_name is not None: 

2155 if not isinstance(new_device_name, str): 

2156 raise ValueError("Expecting a string device name. Got %s(%s)" % 

2157 (type(new_device_name), new_device_name)) 

2158 device_spec = pydev.DeviceSpec.from_string(new_device_name) 

2159 if old_device_name: 

2160 new_device_spec = copy.copy(old_device_spec) 

2161 else: 

2162 ctx.ensure_initialized() 

2163 new_device_spec = pydev.DeviceSpec.from_string( 

2164 ctx._context_devices[0]) # pylint: disable=protected-access 

2165 new_device_spec = new_device_spec.make_merged_spec(device_spec) 

2166 else: 

2167 new_device_spec = pydev.DeviceSpec.from_string("") 

2168 new_device_name = new_device_spec.to_string() 

2169 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec) 

2170 

2171 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access 

2172 self._stack.append((old_device_name, old_device_spec, new_device_spec)) 

2173 

2174 def __exit__(self, *ex_info): 

2175 ctx = self._ctx 

2176 old_device_name, old_device_spec, new_device_spec = self._stack[-1] 

2177 if ctx.device_spec is not new_device_spec: 

2178 raise RuntimeError("Exiting device scope without proper scope nesting") 

2179 del self._stack[-1] 

2180 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access 

2181 

2182 

2183# Do not change directly. 

2184_context = None 

2185_context_lock = threading.Lock() 

2186 

2187 

2188def _set_context_locked(ctx): 

2189 global _context 

2190 pywrap_tfe.TFE_Py_SetEagerContext(ctx) 

2191 ctx.mark_as_global_context() 

2192 _context = ctx 

2193 

2194 

2195def _set_context(ctx): 

2196 with _context_lock: 

2197 _set_context_locked(ctx) 

2198 

2199 

2200def _create_context(): 

2201 with _context_lock: 

2202 if _context is None: 

2203 ctx = Context() 

2204 _set_context_locked(ctx) 

2205 

2206 

2207def _reset_context(): 

2208 """Clears and re-initializes the singleton context. 

2209 

2210 Should only be used for testing. 

2211 """ 

2212 global _context 

2213 global _device_parsing_cache 

2214 

2215 # Garbage collect and clear scalar cache to avoid Tensor from current context 

2216 # polluting next context. 

2217 gc.collect() 

2218 pywrap_tfe.TFE_ClearScalarCache() 

2219 with _context_lock: 

2220 if _context is not None: 

2221 _context._clear_caches() 

2222 _context = None 

2223 _create_context() 

2224 _device_parsing_cache = {} 

2225 

2226 

2227def _reset_jit_compiler_flags(): 

2228 """Clears and re-initializes the TF JIT compiler flags. 

2229 

2230 Should only be used for testing. 

2231 """ 

2232 pywrap_tfe.TF_ResetJitCompilerFlags() 

2233 

2234 

2235def context(): 

2236 """Returns a singleton context object.""" 

2237 if _context is None: 

2238 _create_context() 

2239 return _context 

2240 

2241 

2242def context_safe(): 

2243 """Returns current context (or None if one hasn't been initialized).""" 

2244 return _context 

2245 

2246 

2247def ensure_initialized(): 

2248 """Initialize the context.""" 

2249 context().ensure_initialized() 

2250 

2251 

2252def initialize_logical_devices(): 

2253 """Initialize the virtual devices.""" 

2254 context()._initialize_logical_devices() # pylint: disable=protected-access 

2255 

2256 

2257def set_global_seed(seed): 

2258 """Sets the eager mode seed.""" 

2259 context()._set_global_seed(seed) # pylint: disable=protected-access 

2260 

2261 

2262def global_seed(): 

2263 """Returns the eager mode seed.""" 

2264 return context()._seed # pylint: disable=protected-access 

2265 

2266 

2267def internal_operation_seed(): 

2268 """Returns the operation seed generated based on global seed.""" 

2269 return context()._internal_operation_seed() # pylint: disable=protected-access 

2270 

2271 

2272@tf_export("executing_eagerly", v1=[]) 

2273def executing_eagerly(): 

2274 """Checks whether the current thread has eager execution enabled. 

2275 

2276 Eager execution is enabled by default and this API returns `True` 

2277 in most of cases. However, this API might return `False` in the following use 

2278 cases. 

2279 

2280 * Executing inside `tf.function`, unless under `tf.init_scope` or 

2281 `tf.config.run_functions_eagerly(True)` is previously called. 

2282 * Executing inside a transformation function for `tf.dataset`. 

2283 * `tf.compat.v1.disable_eager_execution()` is called. 

2284 

2285 General case: 

2286 

2287 >>> print(tf.executing_eagerly()) 

2288 True 

2289 

2290 Inside `tf.function`: 

2291 

2292 >>> @tf.function 

2293 ... def fn(): 

2294 ... with tf.init_scope(): 

2295 ... print(tf.executing_eagerly()) 

2296 ... print(tf.executing_eagerly()) 

2297 >>> fn() 

2298 True 

2299 False 

2300 

2301 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called: 

2302 

2303 >>> tf.config.run_functions_eagerly(True) 

2304 >>> @tf.function 

2305 ... def fn(): 

2306 ... with tf.init_scope(): 

2307 ... print(tf.executing_eagerly()) 

2308 ... print(tf.executing_eagerly()) 

2309 >>> fn() 

2310 True 

2311 True 

2312 >>> tf.config.run_functions_eagerly(False) 

2313 

2314 Inside a transformation function for `tf.dataset`: 

2315 

2316 >>> def data_fn(x): 

2317 ... print(tf.executing_eagerly()) 

2318 ... return x 

2319 >>> dataset = tf.data.Dataset.range(100) 

2320 >>> dataset = dataset.map(data_fn) 

2321 False 

2322 

2323 Returns: 

2324 `True` if the current thread has eager execution enabled. 

2325 """ 

2326 ctx = context_safe() 

2327 if ctx is None: 

2328 return default_execution_mode == EAGER_MODE 

2329 

2330 return ctx.executing_eagerly() 

2331 

2332 

2333@tf_export(v1=["executing_eagerly"]) 

2334def executing_eagerly_v1(): 

2335 """Checks whether the current thread has eager execution enabled. 

2336 

2337 Eager execution is typically enabled via 

2338 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the 

2339 context of a Python function via tf.contrib.eager.py_func. 

2340 

2341 When eager execution is enabled, returns `True` in most cases. However, 

2342 this API might return `False` in the following use cases. 

2343 

2344 * Executing inside `tf.function`, unless under `tf.init_scope` or 

2345 `tf.config.run_functions_eagerly(True)` is previously called. 

2346 * Executing inside a transformation function for `tf.dataset`. 

2347 * `tf.compat.v1.disable_eager_execution()` is called. 

2348 

2349 >>> tf.compat.v1.enable_eager_execution() 

2350 

2351 General case: 

2352 

2353 >>> print(tf.executing_eagerly()) 

2354 True 

2355 

2356 Inside `tf.function`: 

2357 

2358 >>> @tf.function 

2359 ... def fn(): 

2360 ... with tf.init_scope(): 

2361 ... print(tf.executing_eagerly()) 

2362 ... print(tf.executing_eagerly()) 

2363 >>> fn() 

2364 True 

2365 False 

2366 

2367 Inside `tf.function` 

2368 after `tf.config.run_functions_eagerly(True)` is called: 

2369 

2370 >>> tf.config.run_functions_eagerly(True) 

2371 >>> @tf.function 

2372 ... def fn(): 

2373 ... with tf.init_scope(): 

2374 ... print(tf.executing_eagerly()) 

2375 ... print(tf.executing_eagerly()) 

2376 >>> fn() 

2377 True 

2378 True 

2379 >>> tf.config.run_functions_eagerly(False) 

2380 

2381 Inside a transformation function for `tf.dataset`: 

2382 

2383 >>> def data_fn(x): 

2384 ... print(tf.executing_eagerly()) 

2385 ... return x 

2386 >>> dataset = tf.data.Dataset.range(100) 

2387 >>> dataset = dataset.map(data_fn) 

2388 False 

2389 

2390 Returns: 

2391 `True` if the current thread has eager execution enabled. 

2392 """ 

2393 return executing_eagerly() 

2394 

2395 

2396def in_eager_mode(): 

2397 """Use executing_eagerly() instead. This function will be removed.""" 

2398 return executing_eagerly() 

2399 

2400 

2401def anonymous_name(): 

2402 """Returns the anonymous shared name. 

2403 

2404 In eager mode we create anonymous resources to avoid spurious sharing issues. 

2405 The runtime generates a unique name on our behalf when the reserved 

2406 anonymous shared name is used as a shared name. 

2407 

2408 Returns: 

2409 The anonymous shared name. 

2410 """ 

2411 

2412 # The magic value is defined as 

2413 # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++. 

2414 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347" 

2415 

2416 

2417def graph_mode(): 

2418 """Context-manager to disable eager execution for the current thread.""" 

2419 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access 

2420 

2421 

2422# Used by b/167638505 for keras backend API and Lambda layer. 

2423@tf_export("__internal__.eager_context.eager_mode", v1=[]) 

2424def eager_mode(): 

2425 """Context-manager to enable eager execution for the current thread.""" 

2426 return context()._mode(EAGER_MODE) # pylint: disable=protected-access 

2427 

2428 

2429def scope_name(): 

2430 """Name of the current scope.""" 

2431 return context().scope_name 

2432 

2433 

2434def device(name): 

2435 """Context-manager to force placement of operations and Tensors on a device. 

2436 

2437 Example: 

2438 ```python 

2439 with tf.device('gpu:0'): 

2440 with tf.device('cpu:0'): 

2441 shape = tf.constant([], dtype=tf.int32) 

2442 x = tf.random.truncated_normal(shape, tf.float32) 

2443 ``` 

2444 will ensure that the `shape` Tensor is on CPU but the `truncated_normal` 

2445 operation runs on GPU 0. 

2446 

2447 Args: 

2448 name: Name of the device (see context().devices()), or None to perform 

2449 automatic placement. 

2450 

2451 Returns: 

2452 Context manager for setting the device. 

2453 """ 

2454 ensure_initialized() 

2455 return context().device(name) 

2456 

2457 

2458# Expose some properties of Context as internally public APIs (b/160348781). 

2459@tf_export("__internal__.eager_context.get_config", v1=[]) 

2460def get_config(): 

2461 """Get the ConfigProto of Context. 

2462 

2463 Returns: 

2464 The ConfigProto of Context. 

2465 """ 

2466 return context().config 

2467 

2468 

2469@tf_export("__internal__.eager_context.get_device_name", v1=[]) 

2470def get_device_name(): 

2471 """Get the device name for the current thread. 

2472 

2473 Returns: 

2474 The device name for the current thread. 

2475 """ 

2476 return context().device_name 

2477 

2478 

2479@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[]) 

2480def set_soft_device_placement(enabled): 

2481 """Set if soft device placements should be allowed. 

2482 

2483 Args: 

2484 enabled: Whether to enable soft device placement. 

2485 """ 

2486 context().soft_device_placement = enabled 

2487 

2488 

2489@tf_export("__internal__.eager_context.get_executor", v1=[]) 

2490def get_executor(): 

2491 """Get the Executor of the current thread. 

2492 

2493 Returns: 

2494 The Executor of the current thread. 

2495 """ 

2496 return context().executor 

2497 

2498 

2499@tf_export("debugging.get_log_device_placement") 

2500def get_log_device_placement(): 

2501 """Get if device placements are logged. 

2502 

2503 Returns: 

2504 If device placements are logged. 

2505 """ 

2506 return context().log_device_placement 

2507 

2508 

2509@tf_export("debugging.set_log_device_placement") 

2510def set_log_device_placement(enabled): 

2511 """Turns logging for device placement decisions on or off. 

2512 

2513 Operations execute on a particular device, producing and consuming tensors on 

2514 that device. This may change the performance of the operation or require 

2515 TensorFlow to copy data to or from an accelerator, so knowing where operations 

2516 execute is useful for debugging performance issues. 

2517 

2518 For more advanced profiling, use the [TensorFlow 

2519 profiler](https://www.tensorflow.org/guide/profiler). 

2520 

2521 Device placement for operations is typically controlled by a `tf.device` 

2522 scope, but there are exceptions, for example operations on a `tf.Variable` 

2523 which follow the initial placement of the variable. Turning off soft device 

2524 placement (with `tf.config.set_soft_device_placement`) provides more explicit 

2525 control. 

2526 

2527 >>> tf.debugging.set_log_device_placement(True) 

2528 >>> tf.ones([]) 

2529 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0 

2530 >>> with tf.device("CPU"): 

2531 ... tf.ones([]) 

2532 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0 

2533 >>> tf.debugging.set_log_device_placement(False) 

2534 

2535 Turning on `tf.debugging.set_log_device_placement` also logs the placement of 

2536 ops inside `tf.function` when the function is called. 

2537 

2538 Args: 

2539 enabled: Whether to enabled device placement logging. 

2540 """ 

2541 context().log_device_placement = enabled 

2542 

2543 

2544@tf_contextlib.contextmanager 

2545def device_policy(policy): 

2546 """Context manager for setting device placement policy for current thread.""" 

2547 ctx = context() 

2548 old_policy = ctx.device_policy 

2549 try: 

2550 ctx.device_policy = policy 

2551 yield 

2552 finally: 

2553 ctx.device_policy = old_policy 

2554 

2555 

2556def set_execution_mode(mode): 

2557 """Sets execution mode for the current thread.""" 

2558 context().execution_mode = mode 

2559 

2560 

2561# TODO(fishx): remove this method. 

2562@tf_contextlib.contextmanager 

2563def execution_mode(mode): 

2564 """Context manager for setting execution mode for current thread.""" 

2565 if mode is None: 

2566 yield 

2567 else: 

2568 ctx = context() 

2569 executor_new = executor.new_executor(mode == ASYNC) 

2570 executor_old = ctx.executor 

2571 try: 

2572 executor_old.wait() 

2573 ctx.executor = executor_new 

2574 yield 

2575 finally: 

2576 ctx.executor = executor_old 

2577 executor_new.wait() 

2578 

2579 

2580@tf_contextlib.contextmanager 

2581def executor_scope(e): 

2582 """Context manager for changing executor for current thread. 

2583 

2584 Args: 

2585 e: A Executor to execute eager ops under this scope. Setting it to None will 

2586 switch back to use the default executor for the context. 

2587 

2588 Yields: 

2589 Context manager for setting the executor for current thread. 

2590 """ 

2591 ctx = context() 

2592 executor_old = ctx.executor 

2593 try: 

2594 ctx.executor = e 

2595 yield 

2596 finally: 

2597 ctx.executor = executor_old 

2598 

2599 

2600@tf_export("experimental.function_executor_type") 

2601@tf_contextlib.contextmanager 

2602def function_executor_type(executor_type): 

2603 """Context manager for setting the executor of eager defined functions. 

2604 

2605 Eager defined functions are functions decorated by tf.contrib.eager.defun. 

2606 

2607 Args: 

2608 executor_type: a string for the name of the executor to be used to execute 

2609 functions defined by tf.contrib.eager.defun. 

2610 

2611 Yields: 

2612 Context manager for setting the executor of eager defined functions. 

2613 """ 

2614 current_options = context().function_call_options 

2615 old_options = copy.copy(current_options) 

2616 try: 

2617 current_options.executor_type = executor_type 

2618 yield 

2619 finally: 

2620 context().function_call_options = old_options 

2621 

2622 

2623def is_async(): 

2624 """Returns true if current thread is in async mode.""" 

2625 return context().is_async() 

2626 

2627 

2628def num_gpus(): 

2629 """Get the number of available GPU devices. 

2630 

2631 Returns: 

2632 The number of available GPU devices. 

2633 """ 

2634 return context().num_gpus() 

2635 

2636 

2637def enable_run_metadata(): 

2638 """Enables tracing of op execution via RunMetadata. 

2639 

2640 To retrieve the accumulated metadata call context.export_run_metadata() 

2641 and to stop tracing call context.disable_run_metadata(). 

2642 """ 

2643 context().enable_run_metadata() 

2644 

2645 

2646def disable_run_metadata(): 

2647 """Disables tracing of op execution via RunMetadata.""" 

2648 context().disable_run_metadata() 

2649 

2650 

2651def enable_graph_collection(): 

2652 """Enables graph collection of executed functions. 

2653 

2654 To retrieve the accumulated graphs call context.export_run_metadata() 

2655 and to stop collecting graphs call context.disable_graph_collection(). 

2656 """ 

2657 context().enable_graph_collection() 

2658 

2659 

2660def disable_graph_collection(): 

2661 """Disables graph collection of executed functions.""" 

2662 context().disable_graph_collection() 

2663 

2664 

2665def export_run_metadata(): 

2666 """Returns a RunMetadata proto with accumulated information. 

2667 

2668 The returned protocol buffer contains information since the most recent call 

2669 to either enable_run_metadata or export_run_metadata. 

2670 

2671 Returns: 

2672 A RunMetadata protocol buffer. 

2673 """ 

2674 return context().export_run_metadata() 

2675 

2676 

2677@contextlib.contextmanager 

2678def collect_graphs(optimized=True): 

2679 """Collects a flat list of pre- or post-optimization graphs. 

2680 

2681 The collected graphs include device placements, which can be useful for 

2682 testing. 

2683 

2684 Usage: 

2685 

2686 ``` 

2687 @def_function.function 

2688 def f(x): 

2689 return x + constant_op.constant(1.) 

2690 

2691 with context.collect_graphs() as graphs: 

2692 with ops.device("CPU:0"): 

2693 f(constant_op.constant(1.)) 

2694 

2695 graph, = graphs # `graph` contains a single GraphDef for inspection 

2696 ``` 

2697 

2698 Args: 

2699 optimized: whether to collect optimized graphs or non-optimized graphs 

2700 

2701 Yields: 

2702 A list of GraphDefs, populated when the context manager exits. 

2703 """ 

2704 ctx = context() 

2705 ctx.enable_graph_collection() 

2706 try: 

2707 graphs = [] 

2708 yield graphs 

2709 metadata = ctx.export_run_metadata() 

2710 finally: 

2711 ctx.disable_graph_collection() 

2712 for graph in metadata.function_graphs: 

2713 if optimized: 

2714 graphs.append(graph.post_optimization_graph) 

2715 else: 

2716 graphs.append(graph.pre_optimization_graph) 

2717 

2718 

2719def get_server_def(): 

2720 return context().get_server_def() 

2721 

2722 

2723def set_server_def(server_def): 

2724 context().set_server_def(server_def) 

2725 

2726 

2727def update_server_def(server_def): 

2728 context().update_server_def(server_def) 

2729 

2730 

2731def check_alive(worker_name): 

2732 return context().check_alive(worker_name) 

2733 

2734 

2735@tf_export("experimental.async_scope") 

2736@tf_contextlib.contextmanager 

2737def async_scope(): 

2738 """Context manager for grouping async operations. 

2739 

2740 Ops/function calls inside the scope can return before finishing the actual 

2741 execution. When exiting the async scope, a synchronization barrier will be 

2742 automatically added to ensure the completion of all async op and function 

2743 execution, potentially raising exceptions if async execution results in 

2744 an error state. 

2745 

2746 Users may write the following code to asynchronously invoke `train_step_fn` 

2747 and log the `loss` metric for every `num_steps` steps in a training loop. 

2748 `train_step_fn` internally consumes data using `iterator.get_next()`, and may 

2749 throw OutOfRangeError when running out of data. In the case: 

2750 

2751 ``` 

2752 try: 

2753 with tf.experimental.async_scope(): 

2754 for _ in range(num_steps): 

2755 # Step function updates the metric `loss` internally 

2756 train_step_fn() 

2757 except tf.errors.OutOfRangeError: 

2758 tf.experimental.async_clear_error() 

2759 logging.info('loss = %s', loss.numpy()) 

2760 ``` 

2761 

2762 Yields: 

2763 Context manager for grouping async operations. 

2764 """ 

2765 # TODO(haoyuzhang): replace env var once we have a config method to turn on 

2766 # and off async streaming RPC 

2767 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" 

2768 old_policy = os.environ.get(remote_async_env_var) 

2769 try: 

2770 os.environ[remote_async_env_var] = str(True) 

2771 yield 

2772 # Note: sync local and remote executors iff the async block does not raise 

2773 # an exception. Triggering sync after an exception may lead to derived 

2774 # runtime errors and unexpected exception types. 

2775 context().sync_executors() 

2776 finally: 

2777 if old_policy is None: 

2778 del os.environ[remote_async_env_var] 

2779 else: 

2780 os.environ[remote_async_env_var] = old_policy 

2781 

2782 

2783def async_wait(): 

2784 """Sync all async operations and raise any errors during execution. 

2785 

2786 In async execution mode, an op/function call can return before finishing the 

2787 actual execution. Calling this method creates a synchronization barrier for 

2788 all async op and function execution. It only returns when all pending nodes 

2789 are finished, potentially raising exceptions if async execution results in 

2790 an error state. It is a no-op if the context is not initialized. 

2791 """ 

2792 disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY" 

2793 if os.environ.get(disable_async_executor_env_var) == str(True): 

2794 return 

2795 if context()._context_handle is not None: # pylint: disable=protected-access 

2796 context().sync_executors() 

2797 

2798 

2799@tf_export("experimental.async_clear_error") 

2800def async_clear_error(): 

2801 """Clear pending operations and error statuses in async execution. 

2802 

2803 In async execution mode, an error in op/function execution can lead to errors 

2804 in subsequent ops/functions that are scheduled but not yet executed. Calling 

2805 this method clears all pending operations and reset the async execution state. 

2806 

2807 Example: 

2808 

2809 ``` 

2810 while True: 

2811 try: 

2812 # Step function updates the metric `loss` internally 

2813 train_step_fn() 

2814 except tf.errors.OutOfRangeError: 

2815 tf.experimental.async_clear_error() 

2816 break 

2817 logging.info('loss = %s', loss.numpy()) 

2818 ``` 

2819 """ 

2820 context().clear_executor_errors() 

2821 

2822 

2823def add_c_function(c_func): 

2824 """Add a C API TF_Function to the context.""" 

2825 context().add_c_function(c_func) 

2826 

2827 

2828def get_c_function(name): 

2829 """Get a C API TF_Function from the context.""" 

2830 return context().get_c_function(name) 

2831 

2832 

2833def remove_function(name): 

2834 """Remove a function from the context.""" 

2835 context().remove_function(name) 

2836 

2837 

2838def get_function_def(name): 

2839 return context().get_function_def(name) 

2840 

2841 

2842def is_custom_device(device_name): 

2843 """Calls TFE_IsCustomDevice. 

2844 

2845 Enables using C extensions specifying a custom device from Python. See the 

2846 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 

2847 details. 

2848 

2849 Args: 

2850 device_name: A string indicating the name to check whether it is a 

2851 registered custom device. 

2852 

2853 Returns: 

2854 A boolean. 

2855 """ 

2856 return context().is_custom_device(device_name) 

2857 

2858 

2859def register_custom_device(device_capsule, device_name, device_info_capsule): 

2860 """Calls TFE_RegisterCustomDevice to register a custom device with Python. 

2861 

2862 Enables using C extensions specifying a custom device from Python. See the 

2863 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for 

2864 details. 

2865 

2866 Note that custom devices are not currently supported inside `tf.function`s. 

2867 

2868 Args: 

2869 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice' 

2870 containing a pointer to a TFE_CustomDevice struct. The capsule retains 

2871 ownership of the memory. 

2872 device_name: A string indicating the name to register the custom device 

2873 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may 

2874 subsequently be passed to `with tf.device(...):`. 

2875 device_info_capsule: A PyCapsule with the name set to 

2876 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific 

2877 struct with the initial state of the custom device (the void* device_info 

2878 argument to TFE_RegisterCustomDevice). This method takes ownership of the 

2879 memory and clears the capsule destructor. 

2880 """ 

2881 context().register_custom_device(device_capsule, device_name, 

2882 device_info_capsule) 

2883 

2884 

2885# Not every user creates a Context via context.context() 

2886# (for example, enable_eager_execution in python/framework/ops.py), 

2887# but they do all import this file. Note that IS_IN_GRAPH_MODE and 

2888# in_graph_mode are both parameterless functions. 

2889def _tmp_in_graph_mode(): 

2890 if context_safe() is None: 

2891 # Context not yet initialized. Assume graph mode following the 

2892 # default implementation in `is_in_graph_mode`. 

2893 return True 

2894 return not executing_eagerly() 

2895 

2896 

2897is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode