Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/context.py: 2%
1260 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
17import collections
18import contextlib
19import copy
20import gc
21import itertools
22import os
23import random
24import threading
26from absl import logging
27import numpy as np
29from tensorflow.core.framework import function_pb2
30from tensorflow.core.framework import graph_debug_info_pb2
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.core.protobuf import rewriter_config_pb2
33from tensorflow.python import pywrap_tfe
34from tensorflow.python import tf2
35from tensorflow.python.client import pywrap_tf_session
36from tensorflow.python.eager import cancellation
37from tensorflow.python.eager import execute
38from tensorflow.python.eager import executor
39from tensorflow.python.eager import monitoring
40from tensorflow.python.framework import c_api_util
41from tensorflow.python.framework import device as pydev
42from tensorflow.python.framework import tfrt_utils
43from tensorflow.python.util import compat
44from tensorflow.python.util import function_utils
45from tensorflow.python.util import is_in_graph_mode
46from tensorflow.python.util import tf_contextlib
47from tensorflow.python.util.deprecation import deprecated
48from tensorflow.python.util.tf_export import tf_export
49from tensorflow.tsl.protobuf import coordination_config_pb2
52GRAPH_MODE = 0
53EAGER_MODE = 1
55default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
57# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
58# new_device_spec).
59# Note that we do not protect this with a lock and instead rely on python's GIL
60# and the idempotent nature of writes to provide thread safety.
61_device_parsing_cache = {}
62_starting_device_spec = pydev.DeviceSpec.from_string("")
64_MAXINT32 = 2**31 - 1
66DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
67DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
68DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
69DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
70 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
72SYNC = 0
73ASYNC = 1
75_KEEP_ALIVE_SECS = 600
77_python_eager_context_create_counter = monitoring.Counter(
78 "/tensorflow/api/python/eager_context_create_counter",
79 "Counter for number of eager contexts created in Python.")
81# Re-exporting through context.
82is_tfrt_enabled = tfrt_utils.enabled
84# This flag and the associated environment var are transient and will eventually
85# be removed, once this experiment is enabled by default.
86_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1"
89def run_eager_op_as_function_enabled():
90 return True
93# This method should only be called after the context has beein initialized.
94def enable_jit_compile_rewrite():
95 """Run jit_compile functions through rewrite pass.
97 This runs jit_compile functions through all of the multidevice function
98 rewrite passes.
99 """
100 global _JIT_COMPILE_REWRITE_ENABLED
101 _JIT_COMPILE_REWRITE_ENABLED = True
102 if context_safe() is not None:
103 context_safe().jit_compile_rewrite = True
106# This method should only be called after the context has been initialized.
107def disable_jit_compile_rewrite():
108 global _JIT_COMPILE_REWRITE_ENABLED
109 _JIT_COMPILE_REWRITE_ENABLED = False
110 if context_safe() is not None:
111 context_safe().jit_compile_rewrite = False
114def jit_compile_rewrite_enabled():
115 if context_safe() is not None:
116 return context_safe().jit_compile_rewrite
117 return _JIT_COMPILE_REWRITE_ENABLED
120# Expose it as internally public APIs for Keras use cases in b/171080602.
121tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
124class _EagerTensorCache(object):
125 """Simple cache which evicts items based on length in a FIFO manner."""
127 __slots__ = ["_data", "_max_items", "_max_tensor_size"]
129 def __init__(self, max_items=256, max_tensor_size=10000):
130 self._data = collections.OrderedDict()
131 self._max_items = max_items
132 self._max_tensor_size = max_tensor_size
134 def put(self, key, value):
135 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
136 return
138 self._data[key] = value
140 if len(self._data) > self._max_items:
141 self._data.popitem(last=False)
143 def get(self, key):
144 return self._data.get(key, None)
146 def flush(self):
147 self._data.clear()
150class FunctionCallOptions:
151 """Options applied at call sites of eager functions.
153 Eager functions are functions decorated with tf.contrib.eager.defun.
154 """
156 __slots__ = ["_config_proto_serialized", "_executor_type"]
158 def __init__(self, executor_type=None, config_proto=None):
159 """Constructor.
161 Args:
162 executor_type: (optional) name of the executor to be used to execute the
163 eager function. If None or an empty string, the default Tensorflow
164 executor will be used.
165 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized
166 string of that proto. The config used by Grappler when optimizing the
167 function graph. Each concrete function is optimized the first time is
168 called. Changing config_proto after the first call has no effect. If
169 config_proto is None, an empty RewriterConfig will be used.
170 """
171 self.config_proto_serialized = config_proto
172 self.executor_type = executor_type
174 @property
175 def executor_type(self):
176 return self._executor_type
178 @executor_type.setter
179 def executor_type(self, executor_type):
180 self._executor_type = executor_type
182 @property
183 def config_proto_serialized(self):
184 return self._config_proto_serialized
186 @config_proto_serialized.setter
187 def config_proto_serialized(self, config):
188 if isinstance(config, config_pb2.ConfigProto):
189 self._config_proto_serialized = config.SerializeToString(
190 deterministic=True)
191 elif isinstance(config, str):
192 self._config_proto_serialized = config
193 elif config is None:
194 self._config_proto_serialized = (
195 config_pb2.ConfigProto().SerializeToString())
196 else:
197 raise ValueError("the rewriter config must be either a "
198 "config_pb2.ConfigProto, or a serialized string of that "
199 "proto or None. got: {}".format(type(config)))
201 def as_attrs(self):
202 if self.config_proto_serialized is None:
203 config = function_utils.get_disabled_rewriter_config()
204 else:
205 config = self.config_proto_serialized
206 executor_type = self.executor_type or ""
208 return {"executor_type": executor_type, "config_proto": config}
211# Map from context_id (an int) to _TensorCaches.
212# Dicts are thread safe in CPython.
213# TODO(iga): Remove this once TensorCaches are moved to C++.
214_tensor_caches_map = {}
217class _TensorCaches(threading.local):
218 """Thread local tensor caches."""
220 __slots__ = ["_ones_rank_cache", "_zeros_cache"]
222 def __init__(self):
223 super().__init__()
224 self._ones_rank_cache = None
225 self._zeros_cache = None
227 @property
228 def ones_rank_cache(self):
229 if not self._ones_rank_cache:
230 self._ones_rank_cache = _EagerTensorCache()
231 return self._ones_rank_cache
233 @property
234 def zeros_cache(self):
235 if not self._zeros_cache:
236 self._zeros_cache = _EagerTensorCache()
237 return self._zeros_cache
240ContextSwitch = collections.namedtuple(
241 "ContextSwitch",
242 ["is_building_function", "enter_context_fn", "device_stack"])
245# `_ContextSwitchStack` is a `threading.local` to match the semantics of
246# ``DefaultGraphStack`, which is also a `threading.local`.
247class _ContextSwitchStack(threading.local):
248 """A thread-local stack of context switches."""
250 def __init__(self, eager):
251 super().__init__()
252 self.stack = []
253 if eager:
254 # Initialize the stack with a pointer to enter the eager context; this
255 # ensures that the fact that eager execution was enabled is propagated
256 # across threads, since (1) `enable_eager_execution` modifies a
257 # process-level flag (`default_execution_mode`) and (2) `__init__` is
258 # called each time a threading.local object is used in a separate thread.
259 self.push(
260 is_building_function=False,
261 enter_context_fn=eager_mode,
262 device_stack=None)
264 def push(self, is_building_function, enter_context_fn, device_stack):
265 """Push metadata about a context switch onto the stack.
267 A context switch can take any one of the two forms: installing a graph as
268 the default graph, or entering the eager context. For each context switch,
269 we record whether or not the entered context is building a function.
271 Args:
272 is_building_function: (bool.) Whether the context is building a function.
273 enter_context_fn: (function.) A callable that executes the context switch.
274 For example, `graph.as_default` or `eager_mode`.
275 device_stack: If applicable, the device function stack for this graph.
276 When breaking out of graphs in init_scope, the innermost nonempty device
277 stack is used. Eager contexts put `None` here and the value is never
278 used.
279 """
281 self.stack.append(
282 ContextSwitch(is_building_function, enter_context_fn, device_stack))
284 def pop(self):
285 """Pop the stack."""
287 self.stack.pop()
290@tf_export("config.LogicalDevice")
291class LogicalDevice(
292 collections.namedtuple("LogicalDevice", ["name", "device_type"])):
293 """Abstraction for a logical device initialized by the runtime.
295 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
296 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
297 and operations can be placed on a specific logical device by calling
298 `tf.device` with a specified `tf.config.LogicalDevice`.
300 Fields:
301 name: The fully qualified name of the device. Can be used for Op or function
302 placement.
303 device_type: String declaring the type of device such as "CPU" or "GPU".
304 """
305 pass
308@tf_export("config.LogicalDeviceConfiguration",
309 "config.experimental.VirtualDeviceConfiguration")
310class LogicalDeviceConfiguration(
311 collections.namedtuple("LogicalDeviceConfiguration", [
312 "memory_limit", "experimental_priority", "experimental_device_ordinal"
313 ])):
314 """Configuration class for a logical devices.
316 The class specifies the parameters to configure a `tf.config.PhysicalDevice`
317 as it is initialized to a `tf.config.LogicalDevice` during runtime
318 initialization. Not all fields are valid for all device types.
320 See `tf.config.get_logical_device_configuration` and
321 `tf.config.set_logical_device_configuration` for usage examples.
323 Fields:
324 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
325 device. Currently only supported for GPUs.
326 experimental_priority: (optional) Priority to assign to a virtual device.
327 Lower values have higher priorities and 0 is the default.
328 Within a physical GPU, the GPU scheduler will prioritize ops on virtual
329 devices with higher priority. Currently only supported for Nvidia GPUs.
330 experimental_device_ordinal: (optional) Ordinal number to order the virtual
331 device.
332 LogicalDevice with lower ordinal number will receive a lower device id.
333 Physical device id and location in the list is used to break ties.
334 Currently only supported for Nvidia GPUs.
335 """
337 def __new__(cls,
338 memory_limit=None,
339 experimental_priority=None,
340 experimental_device_ordinal=None):
341 return super().__new__(cls, memory_limit, experimental_priority,
342 experimental_device_ordinal)
345@tf_export("config.PhysicalDevice")
346class PhysicalDevice(
347 collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
348 """Abstraction for a locally visible physical device.
350 TensorFlow can utilize various devices such as the CPU or multiple GPUs
351 for computation. Before initializing a local device for use, the user can
352 customize certain properties of the device such as it's visibility or memory
353 configuration.
355 Once a visible `tf.config.PhysicalDevice` is initialized one or more
356 `tf.config.LogicalDevice` objects are created. Use
357 `tf.config.set_visible_devices` to configure the visibility of a physical
358 device and `tf.config.set_logical_device_configuration` to configure multiple
359 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
360 useful when separation between models is needed or to simulate a multi-device
361 environment.
363 Fields:
364 name: Unique identifier for device.
365 device_type: String declaring the type of device such as "CPU" or "GPU".
366 """
367 pass
370class _AtomicCounter(object):
371 """A simple atomic counter."""
373 __slots__ = ["_value", "_lock"]
375 def __init__(self):
376 self._value = 0
377 self._lock = threading.Lock()
379 def increment_and_get(self):
380 with self._lock:
381 self._value += 1
382 return self._value
385_context_id_counter = _AtomicCounter()
388class _TensorCacheDeleter(object):
389 """Deletes tensor caches for a given context."""
391 __slots__ = ["_context_id"]
393 def __init__(self, context_id):
394 self._context_id = context_id
396 def __del__(self):
397 if _tensor_caches_map is None:
398 return
399 if self._context_id in _tensor_caches_map:
400 del _tensor_caches_map[self._context_id]
403# TODO(agarwal): rename to EagerContext / EagerRuntime ?
404# TODO(agarwal): consider keeping the corresponding Graph here.
405class Context:
406 """Environment in which eager operations execute."""
408 # TODO(agarwal): create and link in some documentation for `execution_mode`.
409 # pylint: disable=redefined-outer-name
410 def __init__(self,
411 config=None,
412 device_policy=None,
413 execution_mode=None,
414 server_def=None):
415 """Creates a new Context.
417 Args:
418 config: (Optional.) A `ConfigProto` protocol buffer with configuration
419 options for the Context. Note that a lot of these options may be
420 currently unimplemented or irrelevant when eager execution is enabled.
421 device_policy: (Optional.) What policy to use when trying to run an
422 operation on a device with inputs which are not on that device. When set
423 to None, an appropriate value will be picked automatically. The value
424 picked may change between TensorFlow releases. Defaults to
425 DEVICE_PLACEMENT_SILENT.
426 Valid values:
427 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
428 correct.
429 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
430 device but raises a warning.
431 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
432 performance problems.
433 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
434 raising errors on the other ones.
435 execution_mode: (Optional.) Policy controlling how operations dispatched
436 are actually executed. When set to None, an appropriate value will be
437 picked automatically. The value picked may change between TensorFlow
438 releases.
439 Valid values:
440 - SYNC: executes each operation synchronously.
441 - ASYNC: executes each operation asynchronously. These operations may
442 return "non-ready" handles.
443 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
444 on remote devices. GrpcServers need to be started by creating an
445 identical server_def to this, and setting the appropriate task_indexes,
446 so that the servers can communicate. It will then be possible to execute
447 operations on remote devices.
449 Raises:
450 ValueError: If execution_mode is not valid.
451 """
452 # This _id is used only to index the tensor caches.
453 # TODO(iga): Remove this when tensor caches are moved to C++.
454 self._id = _context_id_counter.increment_and_get()
455 self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
456 _tensor_caches_map[self._id] = _TensorCaches()
458 self._config = config
459 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
460 self,
461 is_eager=lambda: default_execution_mode == EAGER_MODE,
462 device_spec=_starting_device_spec)
463 self._context_switches = _ContextSwitchStack(self.executing_eagerly())
464 self._context_handle = None
465 self._context_devices = None
466 self._seed = None
467 self._initialize_lock = threading.Lock()
468 self._initialized = False
469 if device_policy is None:
470 device_policy = DEVICE_PLACEMENT_SILENT
471 self._device_policy = device_policy
472 self._mirroring_policy = None
473 if execution_mode not in (None, SYNC, ASYNC):
474 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
475 execution_mode)
476 if execution_mode is None:
477 execution_mode = SYNC
478 self._default_is_async = execution_mode == ASYNC
479 self._use_tfrt = is_tfrt_enabled()
480 self._jit_compile_rewrite = jit_compile_rewrite_enabled()
481 self._server_def = server_def
482 self._collective_ops_server_def = None
483 self._collective_leader = None
484 self._collective_scoped_allocator_enabled_ops = None
485 self._collective_use_nccl_communication = None
486 self._collective_device_filters = None
487 self._coordination_service_config = None
489 self._device_lock = threading.Lock()
490 self._physical_devices = None
491 self._physical_device_to_index = None
492 self._pluggable_devices = None
493 self._visible_device_list = []
494 self._memory_growth_map = None
495 self._virtual_device_map = {}
497 # Values set after construction
498 self._optimizer_jit = None
499 self._intra_op_parallelism_threads = None
500 self._inter_op_parallelism_threads = None
501 self._soft_device_placement = None
502 self._log_device_placement = None
503 self._operation_timeout_in_ms = None
504 self._enable_mlir_graph_optimization = None
505 self._optimizer_experimental_options = {}
507 _python_eager_context_create_counter.get_cell().increase_by(1)
509 self._is_global_context = False
511 # pylint: enable=redefined-outer-name
513 def _set_global_seed(self, seed):
514 """Set a global eager mode seed for random ops."""
515 self._seed = seed
516 # `random.Random(seed)` needs `seed` to be hashable, while values of type
517 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
518 # to int.
519 try:
520 hash(seed)
521 self._rng = random.Random(seed)
522 except TypeError:
523 seed = int(np.array(seed))
524 self._rng = random.Random(seed)
525 # Also clear the kernel cache, to reset any existing seeds
526 if self._context_handle is not None:
527 pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
529 def _internal_operation_seed(self):
530 """Returns a fake operation seed.
532 In eager mode, user shouldn't set or depend on operation seed.
533 Here, we generate a random seed based on global seed to make
534 operation's randomness different and depend on the global seed.
536 Returns:
537 A fake operation seed based on global seed.
538 """
539 return self._rng.randint(0, _MAXINT32)
541 def _initialize_logical_devices(self):
542 """Helper to initialize devices."""
543 # Store list of devices
544 logical_devices = []
545 context_devices = []
546 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
547 try:
548 self._num_gpus = 0
549 current_job, current_task = None, None
550 server_def = self._server_def or self._collective_ops_server_def
551 if server_def is not None:
552 current_job, current_task = server_def.job_name, server_def.task_index
553 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
554 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
555 context_devices.append(pydev.canonical_name(dev_name))
556 spec = pydev.DeviceSpec.from_string(dev_name)
557 # If the job is localhost, we assume that the cluster has not yet been
558 # configured and thus clear the job, replica & task.
559 if spec.job == "localhost":
560 spec = spec.replace(job=None, replica=None, task=None)
561 logical_devices.append(
562 LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
563 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
564 if (dev_type == "GPU" and spec.job == current_job and
565 spec.task == current_task):
566 self._num_gpus += 1
568 finally:
569 self._logical_devices = logical_devices
570 self._context_devices = context_devices
571 pywrap_tfe.TF_DeleteDeviceList(device_list)
573 def ensure_initialized(self):
574 """Initialize handle and devices if not already done so."""
575 if self._initialized:
576 return
577 with self._initialize_lock:
578 if self._initialized:
579 return
580 assert self._context_devices is None
581 opts = pywrap_tfe.TFE_NewContextOptions()
582 try:
583 config_str = self.config.SerializeToString()
584 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
585 if self._device_policy is not None:
586 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
587 opts, self._device_policy)
588 if self._mirroring_policy is not None:
589 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
590 opts, self._mirroring_policy)
591 if self._default_is_async == ASYNC:
592 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
593 if self._use_tfrt is not None:
594 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
595 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
596 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
597 opts, self._jit_compile_rewrite)
598 context_handle = pywrap_tfe.TFE_NewContext(opts)
599 finally:
600 pywrap_tfe.TFE_DeleteContextOptions(opts)
601 assert not (self._server_def and self._collective_ops_server_def), (
602 "Cannot enable remote execution as well as collective ops at the "
603 "moment. If this is important to you, please file an issue.")
604 if self._server_def is not None:
605 server_def_str = self._server_def.SerializeToString()
606 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
607 server_def_str)
608 elif self._collective_ops_server_def is not None:
609 server_def_str = self._collective_ops_server_def.SerializeToString()
610 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
612 self._context_handle = context_handle
613 self._initialize_logical_devices()
614 self._initialized = True
616 if self._is_global_context:
617 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
619 def ensure_uninitialized(self):
620 """Uninitialize handle and devices if not already done so."""
621 with self._initialize_lock:
622 if not self._initialized:
623 return
624 self._context_devices = None
625 self._logical_devices = None
626 self._server_def = None
627 self._initialized = False
629 if self._is_global_context:
630 pywrap_tfe.TFE_Py_SetCEagerContext(None)
632 self._context_handle = None
634 def mark_as_global_context(self):
635 # If the context was already initialized, publish it. Otherwise wait with
636 # publication until it's initialized.
637 if self._initialized:
638 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
639 self._is_global_context = True
641 def _clear_caches(self):
642 self.ones_rank_cache().flush()
643 self.zeros_cache().flush()
644 pywrap_tfe.TFE_ClearScalarCache()
646 def get_server_def(self):
647 return self._server_def
649 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
650 """Allow setting a server_def on the context.
652 When a server def is replaced, it effectively clears a bunch of caches
653 within the context. If you attempt to use a tensor object that was pointing
654 to a tensor on the remote device, it will raise an error.
656 Args:
657 server_def: A tensorflow::ServerDef proto. Enables execution on remote
658 devices.
659 keep_alive_secs: Num. seconds after which the remote end will hang up. As
660 long as the client is still alive, the server state for the context will
661 be kept alive. If the client is killed (or there is some failure), the
662 server will clean up its context keep_alive_secs after the final RPC it
663 receives.
665 Raises:
666 ValueError: if server_def is None.
667 """
668 if not server_def:
669 raise ValueError("server_def is None.")
671 self._server_def = server_def
673 if self._context_handle:
674 server_def_str = server_def.SerializeToString()
675 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
676 server_def_str)
677 self._initialize_logical_devices()
679 # Clear all the caches in case there are remote tensors in them.
680 self._clear_caches()
682 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
683 """Update a server_def on the context.
685 Args:
686 server_def: A tensorflow::ServerDef proto. Enables execution on remote
687 devices.
688 keep_alive_secs: Num. seconds after which the remote end will hang up. As
689 long as the client is still alive, the server state for the context will
690 be kept alive. If the client is killed (or there is some failure), the
691 server will clean up its context keep_alive_secs after the final RPC it
692 receives.
694 Raises:
695 ValueError: if server_def is None.
696 """
697 if not server_def:
698 raise ValueError("server_def is None.")
700 self._server_def = server_def
702 if self._context_handle:
703 server_def_str = server_def.SerializeToString()
704 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
705 keep_alive_secs, server_def_str)
706 self._initialize_logical_devices()
708 self._clear_caches()
710 def check_alive(self, worker_name):
711 """Checks whether a remote worker is alive or not.
713 Args:
714 worker_name: a string representing the remote worker. It must be a fully
715 specified name like "/job:worker/replica:0/task:0".
717 Returns:
718 a boolean indicating whether the remote worker is alive or not.
720 Raises:
721 ValueError: if context is not initialized.
722 """
723 # TODO(yuefengz): support checking multiple workers.
724 if self._context_handle:
725 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
726 else:
727 raise ValueError("Context is not initialized.")
729 def sync_executors(self):
730 """Sync both local executors and the ones on remote workers.
732 In async execution mode, local function calls can return before the
733 corresponding remote op/function execution requests are completed. Calling
734 this method creates a synchronization barrier for remote executors. It only
735 returns when all remote pending nodes are finished, potentially with errors
736 if any remote executors are in error state.
738 Raises:
739 ValueError: if context is not initialized.
740 """
741 if self._context_handle:
742 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
743 else:
744 raise ValueError("Context is not initialized.")
746 def clear_executor_errors(self):
747 """Clear errors in both local executors and remote workers.
749 After receiving errors from remote workers, additional requests on the fly
750 could further taint the status on the remote workers due to the async nature
751 of remote execution. Calling this method block on waiting for all pending
752 nodes in remote executors to finish and clear their error statuses.
754 Raises:
755 ValueError: if context is not initialized.
756 """
757 if self._context_handle:
758 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
759 else:
760 raise ValueError("Context is not initialized.")
762 def configure_coordination_service(self,
763 service_type,
764 service_leader="",
765 enable_health_check=True,
766 cluster_register_timeout_in_ms=0,
767 heartbeat_timeout_in_ms=0,
768 shutdown_barrier_timeout_in_ms=0,
769 coordinated_jobs=None,
770 allow_new_incarnation_to_reconnect=False):
771 """Enable distributed coordination service with specified configs."""
772 if self._context_handle:
773 logging.warning("Configuring coordination service type may not be "
774 "effective because the context is already initialized.")
775 config = coordination_config_pb2.CoordinationServiceConfig()
776 config.service_type = service_type
777 if service_leader:
778 config.service_leader = pydev.canonical_name(service_leader)
779 config.enable_health_check = enable_health_check
780 config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms
781 config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms
782 config.shutdown_barrier_timeout_in_ms = shutdown_barrier_timeout_in_ms
783 config.allow_new_incarnation_to_reconnect = (
784 allow_new_incarnation_to_reconnect)
785 if coordinated_jobs is not None:
786 if isinstance(coordinated_jobs, list):
787 config.coordinated_job_list.extend(coordinated_jobs)
788 else:
789 raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or "
790 "None, but got: %s" % (coordinated_jobs,))
791 self._coordination_service_config = config
793 @property
794 def coordination_service(self):
795 return self._coordination_service_config
797 def set_config_key_value(self, key, value):
798 ensure_initialized()
799 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value)
801 # If `timeout_in_ms=0`, this will block until the key-value is set or the
802 # worker shuts down.
803 def get_config_key_value(self, key, timeout_in_ms=0):
804 ensure_initialized()
805 with c_api_util.tf_buffer() as buffer_:
806 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key,
807 timeout_in_ms, buffer_)
808 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
809 return value
811 def delete_config_key_value(self, key):
812 ensure_initialized()
813 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key)
815 def report_error_to_cluster(self, error_code, error_message):
816 """Report error to other members in a multi-client cluster.
818 Args:
819 error_code: a `tf.errors` error code.
820 error_message: a string. The error message.
821 """
822 if self._context_handle:
823 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code,
824 error_message)
825 else:
826 raise ValueError("Context is not initialized.")
828 def get_task_states(self, job_configs):
829 """Get task states from the Coordination Service.
831 Args:
832 job_configs: A list of tuples of job name and task number.
834 Returns:
835 A list of TF_Status.
836 """
837 if self._context_handle:
838 job_names, task_nums = zip(*job_configs)
839 return pywrap_tfe.TFE_GetTaskStates(self._context_handle, job_names,
840 task_nums)
841 else:
842 raise ValueError("Context is not initialized.")
844 def wait_at_barrier(self, barrier_id, timeout_in_ms):
845 """Blocks until all coordinated tasks are at the barrier.
847 The barrier may fail if it times out or if one of the tasks is unhealthy.
849 Args:
850 barrier_id: Unique string identifying the barrier.
851 timeout_in_ms: Duration before the barrier times out and fails.
852 """
853 ensure_initialized()
854 pywrap_tfe.TFE_WaitAtBarrier(self._context_handle, barrier_id,
855 timeout_in_ms)
857 def clear_kernel_cache(self):
858 """Clear kernel cache and reset all stateful kernels."""
859 if self._context_handle is not None:
860 pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
862 def enable_collective_ops(self, server_def):
863 """Enable distributed collective ops with an appropriate server_def.
865 Args:
866 server_def: A tensorflow::ServerDef proto. Enables execution on remote
867 devices.
869 Raises:
870 ValueError: if server_def is None.
871 RuntimeError: if this method is not called at program startup.
872 """
873 if not server_def:
874 raise ValueError("server_def is None.")
876 self._collective_ops_server_def = server_def
878 # TODO(b/129298253): Allow creating datasets/tensors before enabling
879 # collective ops.
880 if self._context_handle is not None:
881 logging.warning("Enabling collective ops after program startup may cause "
882 "error when accessing previously created tensors.")
883 with self._initialize_lock:
884 assert self._initialized
885 server_def_str = self._collective_ops_server_def.SerializeToString()
886 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
887 self._initialize_logical_devices()
888 self._clear_caches()
890 def configure_collective_ops(
891 self,
892 collective_leader="",
893 scoped_allocator_enabled_ops=("CollectiveReduce",),
894 use_nccl_communication=False,
895 device_filters=None):
896 """Configure collective ops.
898 Collective group leader is necessary for collective ops to run, other
899 configurations are mainly for the purpose of performance.
901 Args:
902 collective_leader: a device string for collective leader, e.g.
903 "/job:worker/replica:0/task:0"; empty string means local execution of
904 collective ops.
905 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
906 allocator to run with.
907 use_nccl_communication: whether to use nccl communication for collective
908 ops.
909 device_filters: a tuple or a list of device strings. If set, corresponding
910 task can only see the devices filtered by these device filters.
912 Raises:
913 RuntimeError: if this method is not called at program startup.
914 """
915 if self._collective_leader is not None:
916 if (self._collective_leader != collective_leader or
917 self._collective_scoped_allocator_enabled_ops !=
918 scoped_allocator_enabled_ops or
919 self._collective_use_nccl_communication != use_nccl_communication or
920 self._collective_device_filters != device_filters):
921 raise ValueError("Collective ops are already configured.")
922 else:
923 return
925 if self._context_handle is not None:
926 raise RuntimeError("Collective ops must be configured at program startup")
928 self._collective_leader = collective_leader
929 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
930 self._collective_use_nccl_communication = use_nccl_communication
931 self._collective_device_filters = device_filters
933 def abort_collective_ops(self, code, message):
934 """Abort the collective ops.
936 This is intended to be used when a peer failure is detected, which allows
937 the user to handle the case instead of hanging. This aborts all on-going
938 collectives. After all subsequent collectives error immediately, and you
939 need to reset_context() to use collectives again.
941 Args:
942 code: a `tf.errors` error code.
943 message: a string. The error message.
944 """
945 self.ensure_initialized()
946 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
948 def check_collective_ops_peer_health(self, task, timeout_in_ms):
949 """Check collective peer health.
951 This probes each task to see if they're still alive. Note that restarted
952 tasks are considered a different one, and they're considered not healthy.
954 This should only be used in multi client multi worker training.
956 Args:
957 task: a task string, must be in the format of /job:xxx/replica:0/task:N.
958 timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
960 Raises:
961 tf.errors.UnavailableError: when a peer is down.
962 tf.errors.FailedPreconditionError: when a peer is a different one from the
963 one this task has talked to, e.g. the peer has restarted.
964 tf.errors.InvalidArgumentError: when the task string is invalid.
965 """
966 self.ensure_initialized()
967 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
968 timeout_in_ms)
970 @property
971 def _handle(self):
972 if self._context_handle is None:
973 raise AssertionError("Context must be initialized first.")
975 return self._context_handle
977 @property
978 def _devices(self):
979 if self._context_devices is None:
980 raise AssertionError("Context must be initialized first.")
982 return self._context_devices
984 def __str__(self):
985 if self._context_handle is None:
986 return "Eager TensorFlow Context. Devices currently uninitialized."
987 else:
988 devices = self._devices
989 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
990 for i, d in enumerate(devices):
991 lines.append(" Device %d: %s" % (i, d))
992 return "\n".join(lines)
994 @tf_contextlib.contextmanager
995 def _mode(self, mode):
996 """A context manager to allow setting the mode to EAGER/GRAPH."""
997 ctx = self._thread_local_data
998 old_is_eager = ctx.is_eager
999 ctx.is_eager = mode == EAGER_MODE
1000 if mode == EAGER_MODE:
1001 # Entering graph mode does not provide us with sufficient information to
1002 # record a context switch; graph-based context switches are only logged
1003 # when a graph is registered as the default graph.
1004 self.context_switches.push(False, eager_mode, None)
1005 try:
1006 yield
1007 finally:
1008 ctx.is_eager = old_is_eager
1009 if mode == EAGER_MODE:
1010 self.context_switches.pop()
1012 def executing_eagerly(self):
1013 """Returns True if current thread has eager executing enabled."""
1014 return self._thread_local_data.is_eager
1016 def ones_rank_cache(self):
1017 """Per-device cache for scalars."""
1018 return _tensor_caches_map[self._id].ones_rank_cache
1020 def zeros_cache(self):
1021 """Per-device cache for scalars."""
1022 return _tensor_caches_map[self._id].zeros_cache
1024 @property
1025 def scope_name(self):
1026 """Returns scope name for the current thread."""
1027 return self._thread_local_data.scope_name
1029 @scope_name.setter
1030 def scope_name(self, s):
1031 """Sets scope name for the current thread."""
1032 self._thread_local_data.scope_name = s
1034 @property
1035 def device_name(self):
1036 """Returns the device name for the current thread."""
1037 return self._thread_local_data.device_name
1039 @property
1040 def device_spec(self):
1041 """Returns the device spec for the current thread."""
1042 return self._thread_local_data.device_spec
1044 def _set_device(self, device_name, device_spec):
1045 self._thread_local_data.device_name = device_name
1046 self._thread_local_data.device_spec = device_spec
1048 def device(self, name):
1049 """Context-manager to force placement of operations and Tensors on a device.
1051 Args:
1052 name: Name of the device or None to get default placement.
1054 Returns:
1055 Context manager that forces device placement.
1057 Raises:
1058 ValueError: If name is not a string or is an invalid device name.
1059 RuntimeError: If device scopes are not properly nested.
1060 """
1061 if isinstance(name, LogicalDevice):
1062 name = name.name
1063 elif pydev.is_device_spec(name):
1064 name = name.to_string()
1065 return _EagerDeviceContext(self, name)
1067 def devices(self):
1068 """List of the names of devices available to execute operations."""
1069 return self._devices
1071 def host_address_space(self):
1072 self.ensure_initialized()
1073 with c_api_util.tf_buffer() as buffer_:
1074 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
1075 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
1076 return address_space
1078 # TODO(fishx): remove this property.
1079 @property
1080 def execution_mode(self):
1081 """Gets execution mode for current thread."""
1082 return ASYNC if self.is_async() else SYNC
1084 @execution_mode.setter
1085 def execution_mode(self, mode):
1086 """Sets execution mode for current thread."""
1087 if mode not in (None, SYNC, ASYNC):
1088 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" %
1089 mode)
1091 if mode is None:
1092 mode = SYNC
1094 enable_async = (mode == ASYNC)
1095 if self.is_async() != enable_async:
1096 # Only set the execution mode if the context has already been initialized
1097 if self._context_handle is not None:
1098 self.executor.wait()
1099 executor_new = executor.new_executor(enable_async)
1100 self._thread_local_data.executor = executor_new
1101 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
1102 executor_new.handle())
1103 else:
1104 self._default_is_async = enable_async
1106 def is_async(self):
1107 if self._context_handle is not None:
1108 return self.executor.is_async()
1109 else:
1110 return self._default_is_async
1112 @property
1113 def executor(self):
1114 self.ensure_initialized()
1115 return executor.Executor(
1116 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
1118 @executor.setter
1119 def executor(self, e):
1120 self.ensure_initialized()
1121 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
1123 @property
1124 def config(self):
1125 """Return the ConfigProto with all runtime deltas applied."""
1126 # Ensure physical devices have been discovered and config has been imported
1127 self._initialize_physical_devices()
1129 config = config_pb2.ConfigProto()
1130 if self._config is not None:
1131 config.CopyFrom(self._config)
1133 if self._optimizer_jit is not None:
1134 config.graph_options.optimizer_options.global_jit_level = (
1135 config_pb2.OptimizerOptions.ON_1
1136 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF)
1137 if self._intra_op_parallelism_threads is not None:
1138 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads
1139 if self._inter_op_parallelism_threads is not None:
1140 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads
1142 if self._soft_device_placement is not None:
1143 config.allow_soft_placement = self._soft_device_placement
1144 else:
1145 config.allow_soft_placement = self.executing_eagerly()
1147 if self._log_device_placement is not None:
1148 config.log_device_placement = self._log_device_placement
1150 if self._operation_timeout_in_ms is not None:
1151 config.operation_timeout_in_ms = self._operation_timeout_in_ms
1153 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
1154 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
1155 if (is_mlir_bridge_enabled ==
1156 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
1157 config.experimental.enable_mlir_bridge = True
1159 if self._enable_mlir_graph_optimization is not None:
1160 config.experimental.enable_mlir_graph_optimization = (
1161 self._enable_mlir_graph_optimization)
1163 def rewriter_toggle(option):
1164 toggle = self._optimizer_experimental_options.get(option, None)
1165 if toggle is None:
1166 return
1168 setattr(config.graph_options.rewrite_options, option,
1169 (rewriter_config_pb2.RewriterConfig.ON
1170 if toggle else rewriter_config_pb2.RewriterConfig.OFF))
1172 def rewriter_bool(option):
1173 toggle = self._optimizer_experimental_options.get(option, None)
1174 if toggle is None:
1175 return
1177 setattr(config.graph_options.rewrite_options, option, toggle)
1179 rewriter_toggle("layout_optimizer")
1180 rewriter_toggle("constant_folding")
1181 rewriter_toggle("shape_optimization")
1182 rewriter_toggle("remapping")
1183 rewriter_toggle("arithmetic_optimization")
1184 rewriter_toggle("dependency_optimization")
1185 rewriter_toggle("loop_optimization")
1186 rewriter_toggle("function_optimization")
1187 rewriter_toggle("debug_stripper")
1188 rewriter_bool("disable_model_pruning")
1189 rewriter_toggle("scoped_allocator_optimization")
1190 rewriter_toggle("pin_to_host_optimization")
1191 rewriter_toggle("implementation_selector")
1192 rewriter_toggle("auto_mixed_precision")
1193 rewriter_toggle("use_plugin_optimizers")
1194 rewriter_bool("disable_meta_optimizer")
1195 rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1196 rewriter_toggle("auto_mixed_precision_mkl")
1197 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None)
1198 if nodes is not None:
1199 config.graph_options.rewrite_options.min_graph_nodes = nodes
1201 # Compute device counts
1202 config.device_count["CPU"] = 0
1203 config.device_count["GPU"] = 0
1204 for dev in self._physical_devices:
1205 if dev not in self._visible_device_list:
1206 continue
1208 virtual_devices = self._virtual_device_map.get(dev)
1209 if virtual_devices is None:
1210 config.device_count[dev.device_type] += 1
1211 else:
1212 config.device_count[dev.device_type] += len(virtual_devices)
1214 # Configure gpu_options
1215 gpu_options = self._compute_gpu_options()
1216 config.gpu_options.MergeFrom(gpu_options)
1218 # Configure collective ops
1219 if self._collective_leader:
1220 config.experimental.collective_group_leader = self._collective_leader
1221 if self._collective_scoped_allocator_enabled_ops:
1222 rewrite_options = config.graph_options.rewrite_options
1223 rewrite_options.scoped_allocator_optimization = (
1224 rewriter_config_pb2.RewriterConfig.ON)
1225 del rewrite_options.scoped_allocator_opts.enable_op[:]
1226 for op in self._collective_scoped_allocator_enabled_ops:
1227 rewrite_options.scoped_allocator_opts.enable_op.append(op)
1228 if self._collective_use_nccl_communication:
1229 config.experimental.collective_nccl = True
1230 if self._collective_device_filters:
1231 del config.device_filters[:]
1232 for f in self._collective_device_filters:
1233 config.device_filters.append(f)
1235 # Configure coordination service
1236 if self._coordination_service_config:
1237 config.experimental.coordination_config.CopyFrom(
1238 self._coordination_service_config)
1240 return config
1242 def _compute_gpu_options(self):
1243 """Build the GPUOptions proto."""
1244 visible_device_list = []
1245 virtual_devices = []
1246 gpu_index = -1
1247 memory_growths = set()
1248 gpu_devices = self.list_physical_devices("GPU")
1249 pluggable_devices = self._pluggable_devices
1250 compatible_devices = gpu_devices
1251 for dev in pluggable_devices:
1252 if dev not in gpu_devices:
1253 compatible_devices.append(dev)
1254 for dev in compatible_devices:
1255 gpu_index += 1
1257 if dev not in self._visible_device_list:
1258 continue
1260 growth = self._memory_growth_map[dev]
1261 memory_growths.add(growth)
1262 visible_device_list.append(str(gpu_index))
1264 if self._virtual_device_map:
1265 vdevs = self._virtual_device_map.get(dev, [])
1266 device_ordinals = []
1267 device_limits = []
1268 priority = []
1269 for virt_dev in vdevs:
1270 if virt_dev.experimental_device_ordinal is not None:
1271 device_ordinals.append(virt_dev.experimental_device_ordinal)
1272 device_limits.append(virt_dev.memory_limit)
1273 if virt_dev.experimental_priority is not None:
1274 priority.append(virt_dev.experimental_priority)
1275 # If priority is specified, it must be specified for all virtual
1276 # devices.
1277 if priority and len(device_limits) != len(priority):
1278 raise ValueError("priority must be specified for all virtual devices")
1279 # If device_ordinals is specified, it must be specified for all virtual
1280 # devices.
1281 if device_ordinals and len(device_limits) != len(device_ordinals):
1282 raise ValueError(
1283 "device_ordinals must be specified for all virtual devices")
1285 virtual_devices.append(
1286 config_pb2.GPUOptions.Experimental.VirtualDevices(
1287 memory_limit_mb=device_limits,
1288 priority=priority,
1289 device_ordinal=device_ordinals))
1291 # Only compute growth if virtual devices have not been configured and we
1292 # have GPUs
1293 if not virtual_devices and memory_growths:
1294 if len(memory_growths) > 1:
1295 raise ValueError("Memory growth cannot differ between GPU devices")
1296 allow_growth = memory_growths.pop()
1297 else:
1298 allow_growth = None
1300 return config_pb2.GPUOptions(
1301 allow_growth=allow_growth,
1302 visible_device_list=",".join(visible_device_list),
1303 experimental=config_pb2.GPUOptions.Experimental(
1304 virtual_devices=virtual_devices))
1306 @property
1307 def function_call_options(self):
1308 """Returns function call options for current thread.
1310 Note that the returned object is still referenced by the eager context.
1312 Returns: the FunctionCallOptions for current thread.
1313 """
1314 if self._thread_local_data.function_call_options is None:
1315 config = self.config
1317 # Default to soft placement for functions unless specified
1318 if self._soft_device_placement is None:
1319 config.allow_soft_placement = True
1320 self._thread_local_data.function_call_options = FunctionCallOptions(
1321 config_proto=config)
1323 return self._thread_local_data.function_call_options
1325 @function_call_options.setter
1326 def function_call_options(self, options):
1327 """Returns function call options for current thread."""
1328 self._thread_local_data.function_call_options = options
1330 def num_gpus(self):
1331 """The number of GPUs available to execute operations."""
1332 self.ensure_initialized()
1333 return self._num_gpus
1335 def add_c_function(self, c_func):
1336 """Add a C API TF_Function to the context.
1338 Once added, the function (identified by its name) can be executed like any
1339 other operation.
1341 Args:
1342 c_func: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
1343 """
1344 self.ensure_initialized()
1345 pywrap_tfe.TFE_ContextAddFunction(self._handle, c_func)
1347 def get_c_function(self, name):
1348 """Get a C API TF_Function from the context.
1350 Args:
1351 name: Name of the function to get.
1353 Returns:
1354 A ScopedTFFunction wrapping the C API TF_Function.
1355 """
1356 self.ensure_initialized()
1357 return c_api_util.ScopedTFFunction(
1358 pywrap_tfe.TFE_ContextGetFunction(self._handle, name), name
1359 )
1361 def add_function_def(self, fdef):
1362 """Add a function definition to the context.
1364 Once added, the function (identified by its name) can be executed like any
1365 other operation.
1367 Args:
1368 fdef: A FunctionDef protocol buffer message.
1369 """
1370 self.ensure_initialized()
1371 fdef_string = fdef.SerializeToString()
1372 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
1373 len(fdef_string))
1375 def get_function_def(self, name):
1376 """Get a function definition from the context.
1378 Args:
1379 name: function signature name.
1381 Returns:
1382 The requested FunctionDef.
1384 Raises:
1385 tf.errors.NotFoundError: if name is not the name of a registered function.
1386 """
1387 with c_api_util.tf_buffer() as buffer_:
1388 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
1389 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1390 function_def = function_pb2.FunctionDef()
1391 function_def.ParseFromString(proto_data)
1393 return function_def
1395 def get_graph_debug_info(self, name):
1396 """Get GraphDebugInfo associated with a function from the context.
1398 Args:
1399 name: function signature name.
1401 Returns:
1402 The requested GraphDebugInfo.
1404 Raises:
1405 tf.errors.NotFoundError: if name is not the name of a registered function.
1406 """
1407 with c_api_util.tf_buffer() as buffer_:
1408 pywrap_tfe.TFE_ContextGetGraphDebugInfo(self._handle, name, buffer_)
1409 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1410 graph_debug_info = graph_debug_info_pb2.GraphDebugInfo()
1411 graph_debug_info.ParseFromString(proto_data)
1413 return graph_debug_info
1415 def is_custom_device(self, device_name):
1416 """Calls TFE_IsCustomDevice. See the non-member function."""
1417 self.ensure_initialized()
1418 return pywrap_tfe.TFE_Py_IsCustomDevice(self._handle, device_name)
1420 def register_custom_device(self, device_capsule, device_name,
1421 device_info_capsule):
1422 """Calls TFE_RegisterCustomDevice. See the non-member function."""
1423 self.ensure_initialized()
1424 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
1425 device_name, device_info_capsule)
1427 def pack_eager_tensors(self, tensors):
1428 """Pack multiple `EagerTensor`s of the same dtype and shape.
1430 Args:
1431 tensors: a list of EagerTensors to pack.
1433 Returns:
1434 A packed EagerTensor.
1435 """
1436 self.ensure_initialized()
1437 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
1439 def list_function_names(self):
1440 """Get a list of names of registered functions.
1442 Returns:
1443 A set of names of all registered functions for the context.
1444 """
1445 self.ensure_initialized()
1446 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
1448 def remove_function(self, name):
1449 """Remove a function from the context.
1451 Once removed, the function cannot be executed anymore.
1453 Args:
1454 name: function signature name.
1455 """
1456 self.ensure_initialized()
1457 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
1459 def has_function(self, name):
1460 """Check if a function `name` is registered."""
1461 self.ensure_initialized()
1462 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
1464 @property
1465 def function_scope_id(self):
1466 """Returns an id that is unique to each scope holding functions."""
1467 return id(self._context_handle)
1469 def call_function(self, name, tensor_inputs, num_outputs):
1470 """Calls the function associated with the given name."""
1471 attrs = tuple(
1472 itertools.chain(
1473 *self.function_call_options.as_attrs().items()
1474 )
1475 )
1477 cancellation_context = cancellation.context()
1478 if cancellation_context is None:
1479 outputs = execute.execute(
1480 name.decode("utf-8"),
1481 num_outputs=num_outputs,
1482 inputs=tensor_inputs,
1483 attrs=attrs,
1484 ctx=self,
1485 )
1486 else:
1487 outputs = execute.execute_with_cancellation(
1488 name.decode("utf-8"),
1489 num_outputs=num_outputs,
1490 inputs=tensor_inputs,
1491 attrs=attrs,
1492 ctx=self,
1493 cancellation_manager=cancellation_context,
1494 )
1495 # Empty list means no function outputs so return None
1496 outputs = outputs or None
1498 return outputs
1500 def add_op_callback(self, callback):
1501 """Add a post-op callback to the context.
1503 A post-op callback is invoked immediately after an eager operation or
1504 function has finished execution or after a op has been added to a graph,
1505 providing access to the op's type, name input and output tensors. Multiple
1506 op callbacks can be added, in which case the callbacks will be invoked in
1507 the order in which they are added.
1509 Args:
1510 callback: a callable of the signature `f(op_type, inputs, attrs, outputs,
1511 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for
1512 details on the function signature and its semantics.
1513 """
1514 if callback not in self._thread_local_data.op_callbacks:
1515 self._thread_local_data.op_callbacks.append(callback)
1517 def remove_op_callback(self, callback):
1518 """Remove an already-registered op callback.
1520 Args:
1521 callback: The op callback to be removed.
1523 Raises:
1524 KeyError: If `callback` is not already registered.
1525 """
1526 if callback not in self._thread_local_data.op_callbacks:
1527 raise KeyError("The specified op callback has not been registered, "
1528 "and hence cannot be removed.")
1529 del self._thread_local_data.op_callbacks[
1530 self._thread_local_data.op_callbacks.index(callback)]
1532 @property
1533 def op_callbacks(self):
1534 return self._thread_local_data.op_callbacks
1536 @property
1537 def invoking_op_callbacks(self):
1538 return self._thread_local_data.invoking_op_callbacks
1540 @invoking_op_callbacks.setter
1541 def invoking_op_callbacks(self, value):
1542 self._thread_local_data.invoking_op_callbacks = value
1544 def _initialize_physical_devices(self, reinitialize=False):
1545 """Gets local devices visible to the system.
1547 Args:
1548 reinitialize: If True, reinitializes self._physical_devices so that
1549 dynamic registered devices will also be visible to the python front-end.
1550 """
1551 # We lazy initialize self._physical_devices since we do not want to do this
1552 # the constructor since the backend may not be initialized yet.
1553 with self._device_lock:
1554 if not reinitialize and self._physical_devices is not None:
1555 return
1557 devs = pywrap_tfe.TF_ListPhysicalDevices()
1558 self._physical_devices = [
1559 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1560 for d in devs
1561 ]
1562 self._physical_device_to_index = {
1563 p: i for i, p in enumerate(self._physical_devices)
1564 }
1565 # We maintain a separate list just so we can check whether the device in
1566 # _physical_devices is a PluggableDevice.
1567 pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices()
1568 self._pluggable_devices = [
1569 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1570 for d in pluggable_devs
1571 ]
1573 self._visible_device_list = list(self._physical_devices)
1574 self._memory_growth_map = {
1575 d: None
1576 for d in self._physical_devices
1577 if d.device_type == "GPU" or d in self._pluggable_devices
1578 }
1580 # Import device settings that may have been passed into the constructor
1581 self._import_config()
1583 def reinitialize_physical_devices(self):
1584 """Gets local devices visible to the system."""
1585 # Reinitialize the physical device list after registering
1586 # the pluggable device.
1587 self._initialize_physical_devices(True)
1589 def list_physical_devices(self, device_type=None):
1590 """List local devices visible to the system.
1592 This API allows a client to query the devices before they have been
1593 initialized by the eager runtime. Additionally a user can filter by device
1594 type, to get only CPUs or GPUs.
1596 Args:
1597 device_type: Optional device type to limit results to
1599 Returns:
1600 List of PhysicalDevice objects.
1601 """
1602 self._initialize_physical_devices()
1604 if device_type is None:
1605 return list(self._physical_devices)
1607 return [d for d in self._physical_devices if d.device_type == device_type]
1609 def get_device_details(self, device): # pylint: disable=redefined-outer-name
1610 """Returns details about a physical devices.
1612 Args:
1613 device: A `tf.config.PhysicalDevice` returned by
1614 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
1616 Returns:
1617 A dict with string keys.
1618 """
1619 if not isinstance(device, PhysicalDevice):
1620 raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
1621 "%s" % (device,))
1622 if (self._physical_device_to_index is None or
1623 device not in self._physical_device_to_index):
1624 raise ValueError("The PhysicalDevice must be one obtained from "
1625 "calling `tf.config.list_physical_devices`, but got: "
1626 "%s" % (device,))
1627 index = self._physical_device_to_index[device]
1628 details = pywrap_tfe.TF_GetDeviceDetails(index)
1630 # Change compute_capability from a string to a tuple
1631 if "compute_capability" in details:
1632 try:
1633 major, minor = details["compute_capability"].split(".")
1634 details["compute_capability"] = (int(major), int(minor))
1635 except ValueError:
1636 raise RuntimeError("Device returned compute capability an in invalid "
1637 "format: %s" % details["compute_capability"])
1638 return details
1640 def _import_config(self):
1641 """Import config if passed in during construction.
1643 If Context was created with a ConfigProto such as when calling
1644 tf.compat.v1.enable_eager_execution(), then we need to pull out the
1645 various pieces we might be replacing and import then into our internal
1646 class representation.
1647 """
1648 if self._config is None:
1649 return
1651 num_cpus = self._config.device_count.get("CPU", 1)
1652 if num_cpus != 1:
1653 cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
1654 if num_cpus == 0:
1655 self.set_visible_devices([], "CPU")
1656 elif num_cpus > 1:
1657 self.set_logical_device_configuration(
1658 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)])
1660 # Parse GPU options
1661 gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
1663 # If there are no GPUs detected, simply ignore all the GPU options passed in
1664 # rather than doing any validation checks.
1665 if not gpus:
1666 return
1668 gpu_count = self._config.device_count.get("GPU", None)
1670 visible_gpus = []
1671 # TODO(gjn): Handle importing existing virtual GPU configuration
1672 visible_indices = self._config.gpu_options.visible_device_list
1673 if visible_indices:
1674 for index in visible_indices.split(","):
1675 if int(index) >= len(gpus):
1676 raise ValueError("Invalid visible device index: %s" % index)
1677 visible_gpus.append(gpus[int(index)])
1678 else:
1679 visible_gpus = gpus
1681 if gpu_count is not None:
1682 visible_gpus = visible_gpus[:gpu_count]
1684 self.set_visible_devices(visible_gpus, "GPU")
1686 def list_logical_devices(self, device_type=None):
1687 """Return logical devices."""
1688 self.ensure_initialized()
1689 if device_type is None:
1690 return list(self._logical_devices)
1692 return [d for d in self._logical_devices if d.device_type == device_type]
1694 def get_visible_devices(self, device_type=None):
1695 """Get the list of visible devices."""
1696 self._initialize_physical_devices()
1698 if device_type is None:
1699 return list(self._visible_device_list)
1701 return [
1702 d for d in self._visible_device_list if d.device_type == device_type
1703 ]
1705 def set_visible_devices(self, devices, device_type=None):
1706 """Set the list of visible devices."""
1707 self._initialize_physical_devices()
1709 if not isinstance(devices, list):
1710 devices = [devices]
1712 for d in devices:
1713 if d not in self._physical_devices:
1714 raise ValueError("Unrecognized device: %s" % repr(d))
1715 if device_type is not None and d.device_type != device_type:
1716 raise ValueError("Unrecognized device: %s" % repr(d))
1718 visible_device_list = []
1719 if device_type is not None:
1720 visible_device_list = [
1721 d for d in self._visible_device_list if d.device_type != device_type
1722 ]
1724 visible_device_list += devices
1726 if self._visible_device_list == visible_device_list:
1727 return
1729 if self._context_handle is not None:
1730 raise RuntimeError(
1731 "Visible devices cannot be modified after being initialized")
1733 self._visible_device_list = visible_device_list
1735 def get_memory_info(self, dev):
1736 """Returns a dict of memory info for the device."""
1737 self._initialize_physical_devices()
1738 self.ensure_initialized()
1739 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
1741 def reset_memory_stats(self, dev):
1742 """Resets the tracked memory stats for the device."""
1743 self._initialize_physical_devices()
1744 self.ensure_initialized()
1745 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev)
1747 def get_memory_growth(self, dev):
1748 """Get if memory growth is enabled for a PhysicalDevice."""
1749 self._initialize_physical_devices()
1751 if dev not in self._physical_devices:
1752 raise ValueError("Unrecognized device: %s" % repr(dev))
1754 return self._memory_growth_map[dev]
1756 def set_memory_growth(self, dev, enable):
1757 """Set if memory growth should be enabled for a PhysicalDevice."""
1758 self._initialize_physical_devices()
1760 if dev not in self._physical_devices:
1761 raise ValueError("Unrecognized device: %s" % repr(dev))
1763 if dev in self._virtual_device_map:
1764 raise ValueError(
1765 "Cannot set memory growth on device when virtual devices configured")
1767 if dev.device_type != "GPU" and dev not in self._pluggable_devices:
1768 raise ValueError(
1769 "Cannot set memory growth on non-GPU and non-Pluggable devices")
1771 if self._memory_growth_map.get(dev) == enable:
1772 return
1774 if self._context_handle is not None:
1775 raise RuntimeError(
1776 "Physical devices cannot be modified after being initialized")
1778 self._memory_growth_map[dev] = enable
1780 def get_logical_device_configuration(self, dev):
1781 """Get the virtual device configuration for a PhysicalDevice."""
1782 self._initialize_physical_devices()
1784 if dev not in self._physical_devices:
1785 raise ValueError("Unrecognized device: %s" % repr(dev))
1787 return self._virtual_device_map.get(dev)
1789 def set_logical_device_configuration(self, dev, virtual_devices):
1790 """Set the virtual device configuration for a PhysicalDevice."""
1791 self._initialize_physical_devices()
1793 if dev not in self._physical_devices:
1794 raise ValueError("Unrecognized device: %s" % repr(dev))
1796 if dev.device_type == "CPU":
1797 for vdev in virtual_devices:
1798 if vdev.memory_limit is not None:
1799 raise ValueError("Setting memory limit on CPU virtual devices is "
1800 "currently not supported")
1801 if vdev.experimental_priority is not None:
1802 raise ValueError("Setting experimental_priority on CPU virtual "
1803 " devices is currently not supported")
1804 if vdev.experimental_device_ordinal is not None:
1805 raise ValueError("Setting experimental_device_ordinal on CPU virtual "
1806 " devices is currently not supported")
1807 elif dev.device_type == "GPU":
1808 for vdev in virtual_devices:
1809 if vdev.memory_limit is None:
1810 raise ValueError(
1811 "Setting memory limit is required for GPU virtual devices")
1812 else:
1813 raise ValueError("Virtual devices are not supported for %s" %
1814 dev.device_type)
1816 if self._virtual_device_map.get(dev) == virtual_devices:
1817 return
1819 if self._context_handle is not None:
1820 raise RuntimeError(
1821 "Virtual devices cannot be modified after being initialized")
1823 self._virtual_device_map[dev] = virtual_devices
1825 def set_logical_cpu_devices(self, num_cpus, prefix=""):
1826 """Set virtual CPU devices in context.
1828 If virtual CPU devices are already configured at context initialization
1829 by tf.config.set_logical_device_configuration(), this method should not be
1830 called.
1832 Args:
1833 num_cpus: Number of virtual CPUs.
1834 prefix: Device name prefix.
1836 Raises:
1837 RuntimeError: If virtual CPUs are already configured at context
1838 initialization.
1839 """
1840 server_def = self._server_def or self._collective_ops_server_def
1841 local_prefix = ["/device"]
1842 if server_def is not None:
1843 local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name,
1844 server_def.task_index))
1845 logical_local_devices = [d for d in self.list_logical_devices("CPU") if
1846 d.name.startswith(tuple(local_prefix))]
1847 self.ensure_initialized()
1848 # Error out if there are already multiple logical CPU in the context.
1849 if len(logical_local_devices) > 1:
1850 raise RuntimeError("Virtual CPUs already set, cannot modify again.")
1852 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix)
1853 self._initialize_logical_devices()
1855 def get_compiler_ir(
1856 self,
1857 device_name,
1858 function_name,
1859 flat_args,
1860 captured_inputs,
1861 stage="hlo",
1862 ):
1863 return pywrap_tfe.TF_GetCompilerIr(
1864 self._context_handle,
1865 function_name,
1866 stage,
1867 device_name,
1868 flat_args,
1869 captured_inputs,
1870 )
1872 @deprecated(
1873 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
1874 def enable_xla_devices(self):
1875 """Enables XLA:CPU and XLA:GPU devices registration."""
1876 pywrap_tfe.TF_EnableXlaDevices()
1878 @property
1879 def enable_mlir_bridge(self):
1880 return pywrap_tfe.TF_IsMlirBridgeEnabled()
1882 @property
1883 def enable_mlir_graph_optimization(self):
1884 return self._enable_mlir_graph_optimization
1886 @enable_mlir_bridge.setter
1887 def enable_mlir_bridge(self, enabled):
1888 pywrap_tfe.TF_EnableMlirBridge(enabled)
1889 self._thread_local_data.function_call_options = None
1891 @enable_mlir_graph_optimization.setter
1892 def enable_mlir_graph_optimization(self, enabled):
1893 self._enable_mlir_graph_optimization = enabled
1894 self._thread_local_data.function_call_options = None
1896 @property
1897 def optimizer_jit(self):
1898 level = self.config.graph_options.optimizer_options.global_jit_level
1899 return (level == config_pb2.OptimizerOptions.ON_1 or
1900 level == config_pb2.OptimizerOptions.ON_2)
1902 @optimizer_jit.setter
1903 def optimizer_jit(self, enabled):
1904 self._optimizer_jit = enabled
1906 self._thread_local_data.function_call_options = None
1908 def get_optimizer_experimental_options(self):
1909 """Get experimental options for the optimizer.
1911 Returns:
1912 Dictionary of current option values
1913 """
1914 rewrite_options = self.config.graph_options.rewrite_options
1915 options = {}
1917 def rewriter_toggle(option):
1918 attr = getattr(rewrite_options, option)
1919 if attr != 0:
1920 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
1922 def rewriter_bool(option):
1923 options[option] = getattr(rewrite_options, option)
1925 rewriter_toggle("layout_optimizer")
1926 rewriter_toggle("constant_folding")
1927 rewriter_toggle("shape_optimization")
1928 rewriter_toggle("remapping")
1929 rewriter_toggle("arithmetic_optimization")
1930 rewriter_toggle("dependency_optimization")
1931 rewriter_toggle("loop_optimization")
1932 rewriter_toggle("function_optimization")
1933 rewriter_toggle("debug_stripper")
1934 rewriter_bool("disable_model_pruning")
1935 rewriter_toggle("scoped_allocator_optimization")
1936 rewriter_toggle("pin_to_host_optimization")
1937 rewriter_toggle("implementation_selector")
1938 rewriter_toggle("auto_mixed_precision")
1939 rewriter_toggle("use_plugin_optimizers")
1940 rewriter_bool("disable_meta_optimizer")
1941 rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1942 rewriter_toggle("auto_mixed_precision_mkl")
1944 if rewrite_options.min_graph_nodes != 0:
1945 options["min_graph_nodes"] = rewrite_options.min_graph_nodes
1947 return options
1949 def set_optimizer_experimental_options(self, options):
1950 """Set experimental options for the optimizer.
1952 Args:
1953 options: Dictionary of options to modify
1954 """
1955 self._optimizer_experimental_options.update(options)
1957 self._thread_local_data.function_call_options = None
1959 @property
1960 def intra_op_parallelism_threads(self):
1961 return self.config.intra_op_parallelism_threads
1963 @intra_op_parallelism_threads.setter
1964 def intra_op_parallelism_threads(self, num_threads):
1965 if self._intra_op_parallelism_threads == num_threads:
1966 return
1968 if self._context_handle is not None:
1969 raise RuntimeError(
1970 "Intra op parallelism cannot be modified after initialization.")
1972 self._intra_op_parallelism_threads = num_threads
1974 @property
1975 def inter_op_parallelism_threads(self):
1976 return self.config.inter_op_parallelism_threads
1978 @inter_op_parallelism_threads.setter
1979 def inter_op_parallelism_threads(self, num_threads):
1980 if self._inter_op_parallelism_threads == num_threads:
1981 return
1983 if self._context_handle is not None:
1984 raise RuntimeError(
1985 "Inter op parallelism cannot be modified after initialization.")
1987 self._inter_op_parallelism_threads = num_threads
1989 @property
1990 def soft_device_placement(self):
1991 return self.config.allow_soft_placement
1993 @soft_device_placement.setter
1994 def soft_device_placement(self, enable):
1995 if self._context_handle is not None:
1996 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
1998 self._soft_device_placement = enable
1999 self._thread_local_data.function_call_options = None
2001 @property
2002 def log_device_placement(self):
2003 return self.config.log_device_placement
2005 @log_device_placement.setter
2006 def log_device_placement(self, enable):
2007 if self._context_handle is not None:
2008 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
2010 self._log_device_placement = enable
2011 self._thread_local_data.function_call_options = None
2013 @property
2014 def jit_compile_rewrite(self):
2015 return self._jit_compile_rewrite
2017 @jit_compile_rewrite.setter
2018 def jit_compile_rewrite(self, enable):
2019 if self._context_handle is not None:
2020 pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable)
2021 self._jit_compile_rewrite = enable
2023 @property
2024 def device_policy(self):
2025 # Only get the policy from the context if it has already been initialized
2026 if self._context_handle is not None:
2027 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
2029 return self._device_policy
2031 @device_policy.setter
2032 def device_policy(self, policy):
2033 if policy is None:
2034 policy = DEVICE_PLACEMENT_SILENT
2036 if self._device_policy != policy:
2037 self._device_policy = policy
2039 # Only set the policy if the context has already been initialized
2040 if self._context_handle is not None:
2041 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
2042 self._handle, self._device_policy)
2044 @property
2045 def use_tfrt(self):
2046 return self._use_tfrt
2048 @use_tfrt.setter
2049 def use_tfrt(self, tfrt):
2050 """Sets whether to use TFRT."""
2051 if not isinstance(tfrt, bool):
2052 raise ValueError("Expecting a boolean but got %s" % type(tfrt))
2054 if self._use_tfrt != tfrt:
2055 if self._initialized:
2056 raise ValueError("use_tfrt should be set before being initialized.")
2057 self._use_tfrt = tfrt
2059 @property
2060 def operation_timeout_in_ms(self):
2061 return self.config.operation_timeout_in_ms
2063 @operation_timeout_in_ms.setter
2064 def operation_timeout_in_ms(self, timeout_in_ms):
2065 if self._operation_timeout_in_ms == timeout_in_ms:
2066 return
2068 if self._context_handle is not None:
2069 raise RuntimeError(
2070 "Operation timeout cannot be modified after initialization.")
2072 self._operation_timeout_in_ms = timeout_in_ms
2074 def enable_run_metadata(self):
2075 """Enables tracing of op execution via RunMetadata.
2077 To retrieve the accumulated metadata call context.export_run_metadata()
2078 and to stop tracing call context.disable_run_metadata().
2079 """
2080 self.ensure_initialized()
2081 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
2083 def disable_run_metadata(self):
2084 """Disables tracing of op execution via RunMetadata."""
2085 if not self._context_handle:
2086 return
2087 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
2089 def enable_graph_collection(self):
2090 """Enables graph collection of executed functions.
2092 To retrieve the accumulated graphs call context.export_run_metadata()
2093 and to stop collecting graphs call context.disable_graph_collection().
2094 """
2095 self.ensure_initialized()
2096 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
2098 def disable_graph_collection(self):
2099 """Disables graph collection of executed functions."""
2100 if not self._context_handle:
2101 return
2102 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
2104 def export_run_metadata(self):
2105 """Returns a RunMetadata proto with accumulated information.
2107 The returned protocol buffer contains information since the most recent call
2108 to either enable_run_metadata or export_run_metadata.
2110 Returns:
2111 A RunMetadata protocol buffer. Or None if not enabled.
2112 """
2113 if not self._context_handle:
2114 return None
2115 with c_api_util.tf_buffer() as buffer_:
2116 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
2117 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
2118 run_metadata = config_pb2.RunMetadata()
2119 run_metadata.ParseFromString(compat.as_bytes(proto_data))
2120 return run_metadata
2122 @property
2123 def context_switches(self):
2124 """Returns a stack of context switches."""
2125 return self._context_switches
2128class _EagerDeviceContext(object):
2129 """Context-manager forcing placement of ops and Tensors on a device."""
2131 __slots__ = ["_device_name", "_ctx", "_stack"]
2133 def __init__(self, ctx, device_name):
2134 self._device_name = device_name
2135 self._ctx = ctx
2136 self._stack = []
2138 # TODO(b/189233748): Consolidate the device string parsing logic with
2139 # tensorflow/core/util/device_name_utils.cc.
2140 def __enter__(self):
2141 ctx = self._ctx
2142 old_device_name = ctx.device_name
2143 old_device_spec = ctx.device_spec
2144 new_device_name = self._device_name
2145 cache_key = (old_device_name, new_device_name)
2146 try:
2147 new_device_name, new_device_spec = _device_parsing_cache[cache_key]
2148 except TypeError:
2149 # Error while trying to compute the cache key.
2150 raise ValueError("Expecting a string device name. Got %s(%s)" %
2151 (type(new_device_name), new_device_name))
2152 except KeyError:
2153 # Handle a cache miss.
2154 if new_device_name is not None:
2155 if not isinstance(new_device_name, str):
2156 raise ValueError("Expecting a string device name. Got %s(%s)" %
2157 (type(new_device_name), new_device_name))
2158 device_spec = pydev.DeviceSpec.from_string(new_device_name)
2159 if old_device_name:
2160 new_device_spec = copy.copy(old_device_spec)
2161 else:
2162 ctx.ensure_initialized()
2163 new_device_spec = pydev.DeviceSpec.from_string(
2164 ctx._context_devices[0]) # pylint: disable=protected-access
2165 new_device_spec = new_device_spec.make_merged_spec(device_spec)
2166 else:
2167 new_device_spec = pydev.DeviceSpec.from_string("")
2168 new_device_name = new_device_spec.to_string()
2169 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
2171 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access
2172 self._stack.append((old_device_name, old_device_spec, new_device_spec))
2174 def __exit__(self, *ex_info):
2175 ctx = self._ctx
2176 old_device_name, old_device_spec, new_device_spec = self._stack[-1]
2177 if ctx.device_spec is not new_device_spec:
2178 raise RuntimeError("Exiting device scope without proper scope nesting")
2179 del self._stack[-1]
2180 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access
2183# Do not change directly.
2184_context = None
2185_context_lock = threading.Lock()
2188def _set_context_locked(ctx):
2189 global _context
2190 pywrap_tfe.TFE_Py_SetEagerContext(ctx)
2191 ctx.mark_as_global_context()
2192 _context = ctx
2195def _set_context(ctx):
2196 with _context_lock:
2197 _set_context_locked(ctx)
2200def _create_context():
2201 with _context_lock:
2202 if _context is None:
2203 ctx = Context()
2204 _set_context_locked(ctx)
2207def _reset_context():
2208 """Clears and re-initializes the singleton context.
2210 Should only be used for testing.
2211 """
2212 global _context
2213 global _device_parsing_cache
2215 # Garbage collect and clear scalar cache to avoid Tensor from current context
2216 # polluting next context.
2217 gc.collect()
2218 pywrap_tfe.TFE_ClearScalarCache()
2219 with _context_lock:
2220 if _context is not None:
2221 _context._clear_caches()
2222 _context = None
2223 _create_context()
2224 _device_parsing_cache = {}
2227def _reset_jit_compiler_flags():
2228 """Clears and re-initializes the TF JIT compiler flags.
2230 Should only be used for testing.
2231 """
2232 pywrap_tfe.TF_ResetJitCompilerFlags()
2235def context():
2236 """Returns a singleton context object."""
2237 if _context is None:
2238 _create_context()
2239 return _context
2242def context_safe():
2243 """Returns current context (or None if one hasn't been initialized)."""
2244 return _context
2247def ensure_initialized():
2248 """Initialize the context."""
2249 context().ensure_initialized()
2252def initialize_logical_devices():
2253 """Initialize the virtual devices."""
2254 context()._initialize_logical_devices() # pylint: disable=protected-access
2257def set_global_seed(seed):
2258 """Sets the eager mode seed."""
2259 context()._set_global_seed(seed) # pylint: disable=protected-access
2262def global_seed():
2263 """Returns the eager mode seed."""
2264 return context()._seed # pylint: disable=protected-access
2267def internal_operation_seed():
2268 """Returns the operation seed generated based on global seed."""
2269 return context()._internal_operation_seed() # pylint: disable=protected-access
2272@tf_export("executing_eagerly", v1=[])
2273def executing_eagerly():
2274 """Checks whether the current thread has eager execution enabled.
2276 Eager execution is enabled by default and this API returns `True`
2277 in most of cases. However, this API might return `False` in the following use
2278 cases.
2280 * Executing inside `tf.function`, unless under `tf.init_scope` or
2281 `tf.config.run_functions_eagerly(True)` is previously called.
2282 * Executing inside a transformation function for `tf.dataset`.
2283 * `tf.compat.v1.disable_eager_execution()` is called.
2285 General case:
2287 >>> print(tf.executing_eagerly())
2288 True
2290 Inside `tf.function`:
2292 >>> @tf.function
2293 ... def fn():
2294 ... with tf.init_scope():
2295 ... print(tf.executing_eagerly())
2296 ... print(tf.executing_eagerly())
2297 >>> fn()
2298 True
2299 False
2301 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called:
2303 >>> tf.config.run_functions_eagerly(True)
2304 >>> @tf.function
2305 ... def fn():
2306 ... with tf.init_scope():
2307 ... print(tf.executing_eagerly())
2308 ... print(tf.executing_eagerly())
2309 >>> fn()
2310 True
2311 True
2312 >>> tf.config.run_functions_eagerly(False)
2314 Inside a transformation function for `tf.dataset`:
2316 >>> def data_fn(x):
2317 ... print(tf.executing_eagerly())
2318 ... return x
2319 >>> dataset = tf.data.Dataset.range(100)
2320 >>> dataset = dataset.map(data_fn)
2321 False
2323 Returns:
2324 `True` if the current thread has eager execution enabled.
2325 """
2326 ctx = context_safe()
2327 if ctx is None:
2328 return default_execution_mode == EAGER_MODE
2330 return ctx.executing_eagerly()
2333@tf_export(v1=["executing_eagerly"])
2334def executing_eagerly_v1():
2335 """Checks whether the current thread has eager execution enabled.
2337 Eager execution is typically enabled via
2338 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the
2339 context of a Python function via tf.contrib.eager.py_func.
2341 When eager execution is enabled, returns `True` in most cases. However,
2342 this API might return `False` in the following use cases.
2344 * Executing inside `tf.function`, unless under `tf.init_scope` or
2345 `tf.config.run_functions_eagerly(True)` is previously called.
2346 * Executing inside a transformation function for `tf.dataset`.
2347 * `tf.compat.v1.disable_eager_execution()` is called.
2349 >>> tf.compat.v1.enable_eager_execution()
2351 General case:
2353 >>> print(tf.executing_eagerly())
2354 True
2356 Inside `tf.function`:
2358 >>> @tf.function
2359 ... def fn():
2360 ... with tf.init_scope():
2361 ... print(tf.executing_eagerly())
2362 ... print(tf.executing_eagerly())
2363 >>> fn()
2364 True
2365 False
2367 Inside `tf.function`
2368 after `tf.config.run_functions_eagerly(True)` is called:
2370 >>> tf.config.run_functions_eagerly(True)
2371 >>> @tf.function
2372 ... def fn():
2373 ... with tf.init_scope():
2374 ... print(tf.executing_eagerly())
2375 ... print(tf.executing_eagerly())
2376 >>> fn()
2377 True
2378 True
2379 >>> tf.config.run_functions_eagerly(False)
2381 Inside a transformation function for `tf.dataset`:
2383 >>> def data_fn(x):
2384 ... print(tf.executing_eagerly())
2385 ... return x
2386 >>> dataset = tf.data.Dataset.range(100)
2387 >>> dataset = dataset.map(data_fn)
2388 False
2390 Returns:
2391 `True` if the current thread has eager execution enabled.
2392 """
2393 return executing_eagerly()
2396def in_eager_mode():
2397 """Use executing_eagerly() instead. This function will be removed."""
2398 return executing_eagerly()
2401def anonymous_name():
2402 """Returns the anonymous shared name.
2404 In eager mode we create anonymous resources to avoid spurious sharing issues.
2405 The runtime generates a unique name on our behalf when the reserved
2406 anonymous shared name is used as a shared name.
2408 Returns:
2409 The anonymous shared name.
2410 """
2412 # The magic value is defined as
2413 # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++.
2414 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
2417def graph_mode():
2418 """Context-manager to disable eager execution for the current thread."""
2419 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
2422# Used by b/167638505 for keras backend API and Lambda layer.
2423@tf_export("__internal__.eager_context.eager_mode", v1=[])
2424def eager_mode():
2425 """Context-manager to enable eager execution for the current thread."""
2426 return context()._mode(EAGER_MODE) # pylint: disable=protected-access
2429def scope_name():
2430 """Name of the current scope."""
2431 return context().scope_name
2434def device(name):
2435 """Context-manager to force placement of operations and Tensors on a device.
2437 Example:
2438 ```python
2439 with tf.device('gpu:0'):
2440 with tf.device('cpu:0'):
2441 shape = tf.constant([], dtype=tf.int32)
2442 x = tf.random.truncated_normal(shape, tf.float32)
2443 ```
2444 will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
2445 operation runs on GPU 0.
2447 Args:
2448 name: Name of the device (see context().devices()), or None to perform
2449 automatic placement.
2451 Returns:
2452 Context manager for setting the device.
2453 """
2454 ensure_initialized()
2455 return context().device(name)
2458# Expose some properties of Context as internally public APIs (b/160348781).
2459@tf_export("__internal__.eager_context.get_config", v1=[])
2460def get_config():
2461 """Get the ConfigProto of Context.
2463 Returns:
2464 The ConfigProto of Context.
2465 """
2466 return context().config
2469@tf_export("__internal__.eager_context.get_device_name", v1=[])
2470def get_device_name():
2471 """Get the device name for the current thread.
2473 Returns:
2474 The device name for the current thread.
2475 """
2476 return context().device_name
2479@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
2480def set_soft_device_placement(enabled):
2481 """Set if soft device placements should be allowed.
2483 Args:
2484 enabled: Whether to enable soft device placement.
2485 """
2486 context().soft_device_placement = enabled
2489@tf_export("__internal__.eager_context.get_executor", v1=[])
2490def get_executor():
2491 """Get the Executor of the current thread.
2493 Returns:
2494 The Executor of the current thread.
2495 """
2496 return context().executor
2499@tf_export("debugging.get_log_device_placement")
2500def get_log_device_placement():
2501 """Get if device placements are logged.
2503 Returns:
2504 If device placements are logged.
2505 """
2506 return context().log_device_placement
2509@tf_export("debugging.set_log_device_placement")
2510def set_log_device_placement(enabled):
2511 """Turns logging for device placement decisions on or off.
2513 Operations execute on a particular device, producing and consuming tensors on
2514 that device. This may change the performance of the operation or require
2515 TensorFlow to copy data to or from an accelerator, so knowing where operations
2516 execute is useful for debugging performance issues.
2518 For more advanced profiling, use the [TensorFlow
2519 profiler](https://www.tensorflow.org/guide/profiler).
2521 Device placement for operations is typically controlled by a `tf.device`
2522 scope, but there are exceptions, for example operations on a `tf.Variable`
2523 which follow the initial placement of the variable. Turning off soft device
2524 placement (with `tf.config.set_soft_device_placement`) provides more explicit
2525 control.
2527 >>> tf.debugging.set_log_device_placement(True)
2528 >>> tf.ones([])
2529 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
2530 >>> with tf.device("CPU"):
2531 ... tf.ones([])
2532 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0
2533 >>> tf.debugging.set_log_device_placement(False)
2535 Turning on `tf.debugging.set_log_device_placement` also logs the placement of
2536 ops inside `tf.function` when the function is called.
2538 Args:
2539 enabled: Whether to enabled device placement logging.
2540 """
2541 context().log_device_placement = enabled
2544@tf_contextlib.contextmanager
2545def device_policy(policy):
2546 """Context manager for setting device placement policy for current thread."""
2547 ctx = context()
2548 old_policy = ctx.device_policy
2549 try:
2550 ctx.device_policy = policy
2551 yield
2552 finally:
2553 ctx.device_policy = old_policy
2556def set_execution_mode(mode):
2557 """Sets execution mode for the current thread."""
2558 context().execution_mode = mode
2561# TODO(fishx): remove this method.
2562@tf_contextlib.contextmanager
2563def execution_mode(mode):
2564 """Context manager for setting execution mode for current thread."""
2565 if mode is None:
2566 yield
2567 else:
2568 ctx = context()
2569 executor_new = executor.new_executor(mode == ASYNC)
2570 executor_old = ctx.executor
2571 try:
2572 executor_old.wait()
2573 ctx.executor = executor_new
2574 yield
2575 finally:
2576 ctx.executor = executor_old
2577 executor_new.wait()
2580@tf_contextlib.contextmanager
2581def executor_scope(e):
2582 """Context manager for changing executor for current thread.
2584 Args:
2585 e: A Executor to execute eager ops under this scope. Setting it to None will
2586 switch back to use the default executor for the context.
2588 Yields:
2589 Context manager for setting the executor for current thread.
2590 """
2591 ctx = context()
2592 executor_old = ctx.executor
2593 try:
2594 ctx.executor = e
2595 yield
2596 finally:
2597 ctx.executor = executor_old
2600@tf_export("experimental.function_executor_type")
2601@tf_contextlib.contextmanager
2602def function_executor_type(executor_type):
2603 """Context manager for setting the executor of eager defined functions.
2605 Eager defined functions are functions decorated by tf.contrib.eager.defun.
2607 Args:
2608 executor_type: a string for the name of the executor to be used to execute
2609 functions defined by tf.contrib.eager.defun.
2611 Yields:
2612 Context manager for setting the executor of eager defined functions.
2613 """
2614 current_options = context().function_call_options
2615 old_options = copy.copy(current_options)
2616 try:
2617 current_options.executor_type = executor_type
2618 yield
2619 finally:
2620 context().function_call_options = old_options
2623def is_async():
2624 """Returns true if current thread is in async mode."""
2625 return context().is_async()
2628def num_gpus():
2629 """Get the number of available GPU devices.
2631 Returns:
2632 The number of available GPU devices.
2633 """
2634 return context().num_gpus()
2637def enable_run_metadata():
2638 """Enables tracing of op execution via RunMetadata.
2640 To retrieve the accumulated metadata call context.export_run_metadata()
2641 and to stop tracing call context.disable_run_metadata().
2642 """
2643 context().enable_run_metadata()
2646def disable_run_metadata():
2647 """Disables tracing of op execution via RunMetadata."""
2648 context().disable_run_metadata()
2651def enable_graph_collection():
2652 """Enables graph collection of executed functions.
2654 To retrieve the accumulated graphs call context.export_run_metadata()
2655 and to stop collecting graphs call context.disable_graph_collection().
2656 """
2657 context().enable_graph_collection()
2660def disable_graph_collection():
2661 """Disables graph collection of executed functions."""
2662 context().disable_graph_collection()
2665def export_run_metadata():
2666 """Returns a RunMetadata proto with accumulated information.
2668 The returned protocol buffer contains information since the most recent call
2669 to either enable_run_metadata or export_run_metadata.
2671 Returns:
2672 A RunMetadata protocol buffer.
2673 """
2674 return context().export_run_metadata()
2677@contextlib.contextmanager
2678def collect_graphs(optimized=True):
2679 """Collects a flat list of pre- or post-optimization graphs.
2681 The collected graphs include device placements, which can be useful for
2682 testing.
2684 Usage:
2686 ```
2687 @def_function.function
2688 def f(x):
2689 return x + constant_op.constant(1.)
2691 with context.collect_graphs() as graphs:
2692 with ops.device("CPU:0"):
2693 f(constant_op.constant(1.))
2695 graph, = graphs # `graph` contains a single GraphDef for inspection
2696 ```
2698 Args:
2699 optimized: whether to collect optimized graphs or non-optimized graphs
2701 Yields:
2702 A list of GraphDefs, populated when the context manager exits.
2703 """
2704 ctx = context()
2705 ctx.enable_graph_collection()
2706 try:
2707 graphs = []
2708 yield graphs
2709 metadata = ctx.export_run_metadata()
2710 finally:
2711 ctx.disable_graph_collection()
2712 for graph in metadata.function_graphs:
2713 if optimized:
2714 graphs.append(graph.post_optimization_graph)
2715 else:
2716 graphs.append(graph.pre_optimization_graph)
2719def get_server_def():
2720 return context().get_server_def()
2723def set_server_def(server_def):
2724 context().set_server_def(server_def)
2727def update_server_def(server_def):
2728 context().update_server_def(server_def)
2731def check_alive(worker_name):
2732 return context().check_alive(worker_name)
2735@tf_export("experimental.async_scope")
2736@tf_contextlib.contextmanager
2737def async_scope():
2738 """Context manager for grouping async operations.
2740 Ops/function calls inside the scope can return before finishing the actual
2741 execution. When exiting the async scope, a synchronization barrier will be
2742 automatically added to ensure the completion of all async op and function
2743 execution, potentially raising exceptions if async execution results in
2744 an error state.
2746 Users may write the following code to asynchronously invoke `train_step_fn`
2747 and log the `loss` metric for every `num_steps` steps in a training loop.
2748 `train_step_fn` internally consumes data using `iterator.get_next()`, and may
2749 throw OutOfRangeError when running out of data. In the case:
2751 ```
2752 try:
2753 with tf.experimental.async_scope():
2754 for _ in range(num_steps):
2755 # Step function updates the metric `loss` internally
2756 train_step_fn()
2757 except tf.errors.OutOfRangeError:
2758 tf.experimental.async_clear_error()
2759 logging.info('loss = %s', loss.numpy())
2760 ```
2762 Yields:
2763 Context manager for grouping async operations.
2764 """
2765 # TODO(haoyuzhang): replace env var once we have a config method to turn on
2766 # and off async streaming RPC
2767 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
2768 old_policy = os.environ.get(remote_async_env_var)
2769 try:
2770 os.environ[remote_async_env_var] = str(True)
2771 yield
2772 # Note: sync local and remote executors iff the async block does not raise
2773 # an exception. Triggering sync after an exception may lead to derived
2774 # runtime errors and unexpected exception types.
2775 context().sync_executors()
2776 finally:
2777 if old_policy is None:
2778 del os.environ[remote_async_env_var]
2779 else:
2780 os.environ[remote_async_env_var] = old_policy
2783def async_wait():
2784 """Sync all async operations and raise any errors during execution.
2786 In async execution mode, an op/function call can return before finishing the
2787 actual execution. Calling this method creates a synchronization barrier for
2788 all async op and function execution. It only returns when all pending nodes
2789 are finished, potentially raising exceptions if async execution results in
2790 an error state. It is a no-op if the context is not initialized.
2791 """
2792 disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"
2793 if os.environ.get(disable_async_executor_env_var) == str(True):
2794 return
2795 if context()._context_handle is not None: # pylint: disable=protected-access
2796 context().sync_executors()
2799@tf_export("experimental.async_clear_error")
2800def async_clear_error():
2801 """Clear pending operations and error statuses in async execution.
2803 In async execution mode, an error in op/function execution can lead to errors
2804 in subsequent ops/functions that are scheduled but not yet executed. Calling
2805 this method clears all pending operations and reset the async execution state.
2807 Example:
2809 ```
2810 while True:
2811 try:
2812 # Step function updates the metric `loss` internally
2813 train_step_fn()
2814 except tf.errors.OutOfRangeError:
2815 tf.experimental.async_clear_error()
2816 break
2817 logging.info('loss = %s', loss.numpy())
2818 ```
2819 """
2820 context().clear_executor_errors()
2823def add_c_function(c_func):
2824 """Add a C API TF_Function to the context."""
2825 context().add_c_function(c_func)
2828def get_c_function(name):
2829 """Get a C API TF_Function from the context."""
2830 return context().get_c_function(name)
2833def remove_function(name):
2834 """Remove a function from the context."""
2835 context().remove_function(name)
2838def get_function_def(name):
2839 return context().get_function_def(name)
2842def is_custom_device(device_name):
2843 """Calls TFE_IsCustomDevice.
2845 Enables using C extensions specifying a custom device from Python. See the
2846 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2847 details.
2849 Args:
2850 device_name: A string indicating the name to check whether it is a
2851 registered custom device.
2853 Returns:
2854 A boolean.
2855 """
2856 return context().is_custom_device(device_name)
2859def register_custom_device(device_capsule, device_name, device_info_capsule):
2860 """Calls TFE_RegisterCustomDevice to register a custom device with Python.
2862 Enables using C extensions specifying a custom device from Python. See the
2863 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2864 details.
2866 Note that custom devices are not currently supported inside `tf.function`s.
2868 Args:
2869 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
2870 containing a pointer to a TFE_CustomDevice struct. The capsule retains
2871 ownership of the memory.
2872 device_name: A string indicating the name to register the custom device
2873 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
2874 subsequently be passed to `with tf.device(...):`.
2875 device_info_capsule: A PyCapsule with the name set to
2876 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
2877 struct with the initial state of the custom device (the void* device_info
2878 argument to TFE_RegisterCustomDevice). This method takes ownership of the
2879 memory and clears the capsule destructor.
2880 """
2881 context().register_custom_device(device_capsule, device_name,
2882 device_info_capsule)
2885# Not every user creates a Context via context.context()
2886# (for example, enable_eager_execution in python/framework/ops.py),
2887# but they do all import this file. Note that IS_IN_GRAPH_MODE and
2888# in_graph_mode are both parameterless functions.
2889def _tmp_in_graph_mode():
2890 if context_safe() is None:
2891 # Context not yet initialized. Assume graph mode following the
2892 # default implementation in `is_in_graph_mode`.
2893 return True
2894 return not executing_eagerly()
2897is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode