Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/context.py: 45%
1252 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""State management for eager execution."""
17import collections
18import contextlib
19import copy
20import gc
21import itertools
22import os
23import random
24import threading
26from absl import logging
27import numpy as np
29from tensorflow.core.framework import function_pb2
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.core.protobuf import rewriter_config_pb2
32from tensorflow.python import pywrap_tfe
33from tensorflow.python import tf2
34from tensorflow.python.client import pywrap_tf_session
35from tensorflow.python.eager import cancellation
36from tensorflow.python.eager import execute
37from tensorflow.python.eager import executor
38from tensorflow.python.eager import monitoring
39from tensorflow.python.framework import c_api_util
40from tensorflow.python.framework import device as pydev
41from tensorflow.python.framework import tfrt_utils
42from tensorflow.python.util import compat
43from tensorflow.python.util import function_utils
44from tensorflow.python.util import is_in_graph_mode
45from tensorflow.python.util import tf_contextlib
46from tensorflow.python.util.deprecation import deprecated
47from tensorflow.python.util.tf_export import tf_export
48from tensorflow.tsl.protobuf import coordination_config_pb2
50GRAPH_MODE = 0
51EAGER_MODE = 1
53default_execution_mode = EAGER_MODE if tf2.enabled() else GRAPH_MODE
55# Cache from (old_device_name, partial_new_device_name) -> (new_device_name,
56# new_device_spec).
57# Note that we do not protect this with a lock and instead rely on python's GIL
58# and the idempotent nature of writes to provide thread safety.
59_device_parsing_cache = {}
60_starting_device_spec = pydev.DeviceSpec.from_string("")
62_MAXINT32 = 2**31 - 1
64DEVICE_PLACEMENT_EXPLICIT = pywrap_tfe.TFE_DEVICE_PLACEMENT_EXPLICIT
65DEVICE_PLACEMENT_WARN = pywrap_tfe.TFE_DEVICE_PLACEMENT_WARN
66DEVICE_PLACEMENT_SILENT = pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT
67DEVICE_PLACEMENT_SILENT_FOR_INT32 = (
68 pywrap_tfe.TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
70SYNC = 0
71ASYNC = 1
73_KEEP_ALIVE_SECS = 600
75_python_eager_context_create_counter = monitoring.Counter(
76 "/tensorflow/api/python/eager_context_create_counter",
77 "Counter for number of eager contexts created in Python.")
79# Re-exporting through context.
80is_tfrt_enabled = tfrt_utils.enabled
82# This flag and the associated environment var are transient and will eventually
83# be removed, once this experiment is enabled by default.
84_JIT_COMPILE_REWRITE_ENABLED = os.getenv("TF_JIT_COMPILE_REWRITE") == "1"
87def run_eager_op_as_function_enabled():
88 return True
91# This method should only be called after the context has beein initialized.
92def enable_jit_compile_rewrite():
93 """Run jit_compile functions through rewrite pass.
95 This runs jit_compile functions through all of the multidevice function
96 rewrite passes.
97 """
98 global _JIT_COMPILE_REWRITE_ENABLED
99 _JIT_COMPILE_REWRITE_ENABLED = True
100 if context_safe() is not None:
101 context_safe().jit_compile_rewrite = True
104# This method should only be called after the context has been initialized.
105def disable_jit_compile_rewrite():
106 global _JIT_COMPILE_REWRITE_ENABLED
107 _JIT_COMPILE_REWRITE_ENABLED = False
108 if context_safe() is not None:
109 context_safe().jit_compile_rewrite = False
112def jit_compile_rewrite_enabled():
113 if context_safe() is not None:
114 return context_safe().jit_compile_rewrite
115 return _JIT_COMPILE_REWRITE_ENABLED
118# Expose it as internally public APIs for Keras use cases in b/171080602.
119tf_export("__internal__.is_tfrt_enabled", v1=[])(is_tfrt_enabled)
122class _EagerTensorCache(object):
123 """Simple cache which evicts items based on length in a FIFO manner."""
125 __slots__ = ["_data", "_max_items", "_max_tensor_size"]
127 def __init__(self, max_items=256, max_tensor_size=10000):
128 self._data = collections.OrderedDict()
129 self._max_items = max_items
130 self._max_tensor_size = max_tensor_size
132 def put(self, key, value):
133 if value._num_elements() > self._max_tensor_size: # pylint: disable=protected-access
134 return
136 self._data[key] = value
138 if len(self._data) > self._max_items:
139 self._data.popitem(last=False)
141 def get(self, key):
142 return self._data.get(key, None)
144 def flush(self):
145 self._data.clear()
148class FunctionCallOptions:
149 """Options applied at call sites of eager functions.
151 Eager functions are functions decorated with tf.contrib.eager.defun.
152 """
154 __slots__ = ["_config_proto_serialized", "_executor_type"]
156 def __init__(self, executor_type=None, config_proto=None):
157 """Constructor.
159 Args:
160 executor_type: (optional) name of the executor to be used to execute the
161 eager function. If None or an empty string, the default Tensorflow
162 executor will be used.
163 config_proto: (optional) a `config_pb2.ConfigProto` proto or a serialized
164 string of that proto. The config used by Grappler when optimizing the
165 function graph. Each concrete function is optimized the first time is
166 called. Changing config_proto after the first call has no effect. If
167 config_proto is None, an empty RewriterConfig will be used.
168 """
169 self.config_proto_serialized = config_proto
170 self.executor_type = executor_type
172 @property
173 def executor_type(self):
174 return self._executor_type
176 @executor_type.setter
177 def executor_type(self, executor_type):
178 self._executor_type = executor_type
180 @property
181 def config_proto_serialized(self):
182 return self._config_proto_serialized
184 @config_proto_serialized.setter
185 def config_proto_serialized(self, config):
186 if isinstance(config, config_pb2.ConfigProto):
187 self._config_proto_serialized = config.SerializeToString(
188 deterministic=True)
189 elif isinstance(config, str):
190 self._config_proto_serialized = config
191 elif config is None:
192 self._config_proto_serialized = (
193 config_pb2.ConfigProto().SerializeToString())
194 else:
195 raise ValueError("the rewriter config must be either a "
196 "config_pb2.ConfigProto, or a serialized string of that "
197 "proto or None. got: {}".format(type(config)))
199 def as_attrs(self):
200 if self.config_proto_serialized is None:
201 config = function_utils.get_disabled_rewriter_config()
202 else:
203 config = self.config_proto_serialized
204 executor_type = self.executor_type or ""
206 return {"executor_type": executor_type, "config_proto": config}
209# Map from context_id (an int) to _TensorCaches.
210# Dicts are thread safe in CPython.
211# TODO(iga): Remove this once TensorCaches are moved to C++.
212_tensor_caches_map = {}
215class _TensorCaches(threading.local):
216 """Thread local tensor caches."""
218 __slots__ = ["_ones_rank_cache", "_zeros_cache"]
220 def __init__(self):
221 super().__init__()
222 self._ones_rank_cache = None
223 self._zeros_cache = None
225 @property
226 def ones_rank_cache(self):
227 if not self._ones_rank_cache:
228 self._ones_rank_cache = _EagerTensorCache()
229 return self._ones_rank_cache
231 @property
232 def zeros_cache(self):
233 if not self._zeros_cache:
234 self._zeros_cache = _EagerTensorCache()
235 return self._zeros_cache
238ContextSwitch = collections.namedtuple(
239 "ContextSwitch",
240 ["is_building_function", "enter_context_fn", "device_stack"])
243# `_ContextSwitchStack` is a `threading.local` to match the semantics of
244# ``DefaultGraphStack`, which is also a `threading.local`.
245class _ContextSwitchStack(threading.local):
246 """A thread-local stack of context switches."""
248 def __init__(self, eager):
249 super().__init__()
250 self.stack = []
251 if eager:
252 # Initialize the stack with a pointer to enter the eager context; this
253 # ensures that the fact that eager execution was enabled is propagated
254 # across threads, since (1) `enable_eager_execution` modifies a
255 # process-level flag (`default_execution_mode`) and (2) `__init__` is
256 # called each time a threading.local object is used in a separate thread.
257 self.push(
258 is_building_function=False,
259 enter_context_fn=eager_mode,
260 device_stack=None)
262 def push(self, is_building_function, enter_context_fn, device_stack):
263 """Push metadata about a context switch onto the stack.
265 A context switch can take any one of the two forms: installing a graph as
266 the default graph, or entering the eager context. For each context switch,
267 we record whether or not the entered context is building a function.
269 Args:
270 is_building_function: (bool.) Whether the context is building a function.
271 enter_context_fn: (function.) A callable that executes the context switch.
272 For example, `graph.as_default` or `eager_mode`.
273 device_stack: If applicable, the device function stack for this graph.
274 When breaking out of graphs in init_scope, the innermost nonempty device
275 stack is used. Eager contexts put `None` here and the value is never
276 used.
277 """
279 self.stack.append(
280 ContextSwitch(is_building_function, enter_context_fn, device_stack))
282 def pop(self):
283 """Pop the stack."""
285 self.stack.pop()
288@tf_export("config.LogicalDevice")
289class LogicalDevice(
290 collections.namedtuple("LogicalDevice", ["name", "device_type"])):
291 """Abstraction for a logical device initialized by the runtime.
293 A `tf.config.LogicalDevice` corresponds to an initialized logical device on a
294 `tf.config.PhysicalDevice` or a remote device visible to the cluster. Tensors
295 and operations can be placed on a specific logical device by calling
296 `tf.device` with a specified `tf.config.LogicalDevice`.
298 Fields:
299 name: The fully qualified name of the device. Can be used for Op or function
300 placement.
301 device_type: String declaring the type of device such as "CPU" or "GPU".
302 """
303 pass
306@tf_export("config.LogicalDeviceConfiguration",
307 "config.experimental.VirtualDeviceConfiguration")
308class LogicalDeviceConfiguration(
309 collections.namedtuple("LogicalDeviceConfiguration", [
310 "memory_limit", "experimental_priority", "experimental_device_ordinal"
311 ])):
312 """Configuration class for a logical devices.
314 The class specifies the parameters to configure a `tf.config.PhysicalDevice`
315 as it is initialized to a `tf.config.LogicalDevice` during runtime
316 initialization. Not all fields are valid for all device types.
318 See `tf.config.get_logical_device_configuration` and
319 `tf.config.set_logical_device_configuration` for usage examples.
321 Fields:
322 memory_limit: (optional) Maximum memory (in MB) to allocate on the virtual
323 device. Currently only supported for GPUs.
324 experimental_priority: (optional) Priority to assign to a virtual device.
325 Lower values have higher priorities and 0 is the default.
326 Within a physical GPU, the GPU scheduler will prioritize ops on virtual
327 devices with higher priority. Currently only supported for Nvidia GPUs.
328 experimental_device_ordinal: (optional) Ordinal number to order the virtual
329 device.
330 LogicalDevice with lower ordinal number will receive a lower device id.
331 Physical device id and location in the list is used to break ties.
332 Currently only supported for Nvidia GPUs.
333 """
335 def __new__(cls,
336 memory_limit=None,
337 experimental_priority=None,
338 experimental_device_ordinal=None):
339 return super().__new__(cls, memory_limit, experimental_priority,
340 experimental_device_ordinal)
343@tf_export("config.PhysicalDevice")
344class PhysicalDevice(
345 collections.namedtuple("PhysicalDevice", ["name", "device_type"])):
346 """Abstraction for a locally visible physical device.
348 TensorFlow can utilize various devices such as the CPU or multiple GPUs
349 for computation. Before initializing a local device for use, the user can
350 customize certain properties of the device such as it's visibility or memory
351 configuration.
353 Once a visible `tf.config.PhysicalDevice` is initialized one or more
354 `tf.config.LogicalDevice` objects are created. Use
355 `tf.config.set_visible_devices` to configure the visibility of a physical
356 device and `tf.config.set_logical_device_configuration` to configure multiple
357 `tf.config.LogicalDevice` objects for a `tf.config.PhysicalDevice`. This is
358 useful when separation between models is needed or to simulate a multi-device
359 environment.
361 Fields:
362 name: Unique identifier for device.
363 device_type: String declaring the type of device such as "CPU" or "GPU".
364 """
365 pass
368class _AtomicCounter(object):
369 """A simple atomic counter."""
371 __slots__ = ["_value", "_lock"]
373 def __init__(self):
374 self._value = 0
375 self._lock = threading.Lock()
377 def increment_and_get(self):
378 with self._lock:
379 self._value += 1
380 return self._value
383_context_id_counter = _AtomicCounter()
386class _TensorCacheDeleter(object):
387 """Deletes tensor caches for a given context."""
389 __slots__ = ["_context_id"]
391 def __init__(self, context_id):
392 self._context_id = context_id
394 def __del__(self):
395 if _tensor_caches_map is None:
396 return
397 if self._context_id in _tensor_caches_map:
398 del _tensor_caches_map[self._context_id]
401# TODO(agarwal): rename to EagerContext / EagerRuntime ?
402# TODO(agarwal): consider keeping the corresponding Graph here.
403class Context:
404 """Environment in which eager operations execute."""
406 # TODO(agarwal): create and link in some documentation for `execution_mode`.
407 # pylint: disable=redefined-outer-name
408 def __init__(self,
409 config=None,
410 device_policy=None,
411 execution_mode=None,
412 server_def=None):
413 """Creates a new Context.
415 Args:
416 config: (Optional.) A `ConfigProto` protocol buffer with configuration
417 options for the Context. Note that a lot of these options may be
418 currently unimplemented or irrelevant when eager execution is enabled.
419 device_policy: (Optional.) What policy to use when trying to run an
420 operation on a device with inputs which are not on that device. When set
421 to None, an appropriate value will be picked automatically. The value
422 picked may change between TensorFlow releases. Defaults to
423 DEVICE_PLACEMENT_SILENT.
424 Valid values:
425 - DEVICE_PLACEMENT_EXPLICIT: raises an error if the placement is not
426 correct.
427 - DEVICE_PLACEMENT_WARN: copies the tensors which are not on the right
428 device but raises a warning.
429 - DEVICE_PLACEMENT_SILENT: silently copies the tensors. This might hide
430 performance problems.
431 - DEVICE_PLACEMENT_SILENT_FOR_INT32: silently copies int32 tensors,
432 raising errors on the other ones.
433 execution_mode: (Optional.) Policy controlling how operations dispatched
434 are actually executed. When set to None, an appropriate value will be
435 picked automatically. The value picked may change between TensorFlow
436 releases.
437 Valid values:
438 - SYNC: executes each operation synchronously.
439 - ASYNC: executes each operation asynchronously. These operations may
440 return "non-ready" handles.
441 server_def: (Optional.) A tensorflow::ServerDef proto. Enables execution
442 on remote devices. GrpcServers need to be started by creating an
443 identical server_def to this, and setting the appropriate task_indexes,
444 so that the servers can communicate. It will then be possible to execute
445 operations on remote devices.
447 Raises:
448 ValueError: If execution_mode is not valid.
449 """
450 # This _id is used only to index the tensor caches.
451 # TODO(iga): Remove this when tensor caches are moved to C++.
452 self._id = _context_id_counter.increment_and_get()
453 self._tensor_cache_deleter = _TensorCacheDeleter(self._id)
454 _tensor_caches_map[self._id] = _TensorCaches()
456 self._config = config
457 self._thread_local_data = pywrap_tfe.EagerContextThreadLocalData(
458 self,
459 is_eager=lambda: default_execution_mode == EAGER_MODE,
460 device_spec=_starting_device_spec)
461 self._context_switches = _ContextSwitchStack(self.executing_eagerly())
462 self._context_handle = None
463 self._context_devices = None
464 self._seed = None
465 self._initialize_lock = threading.Lock()
466 self._initialized = False
467 if device_policy is None:
468 device_policy = DEVICE_PLACEMENT_SILENT
469 self._device_policy = device_policy
470 self._mirroring_policy = None
471 if execution_mode not in (None, SYNC, ASYNC):
472 raise ValueError("execution_mode should be None/SYNC/ASYNC. Got %s" %
473 execution_mode)
474 if execution_mode is None:
475 execution_mode = SYNC
476 self._default_is_async = execution_mode == ASYNC
477 self._use_tfrt = is_tfrt_enabled()
478 self._jit_compile_rewrite = jit_compile_rewrite_enabled()
479 self._server_def = server_def
480 self._collective_ops_server_def = None
481 self._collective_leader = None
482 self._collective_scoped_allocator_enabled_ops = None
483 self._collective_use_nccl_communication = None
484 self._collective_device_filters = None
485 self._coordination_service_config = None
487 self._device_lock = threading.Lock()
488 self._physical_devices = None
489 self._physical_device_to_index = None
490 self._pluggable_devices = None
491 self._visible_device_list = []
492 self._memory_growth_map = None
493 self._virtual_device_map = {}
495 # Values set after construction
496 self._optimizer_jit = None
497 self._intra_op_parallelism_threads = None
498 self._inter_op_parallelism_threads = None
499 self._soft_device_placement = None
500 self._log_device_placement = None
501 self._operation_timeout_in_ms = None
502 self._enable_mlir_graph_optimization = None
503 self._optimizer_experimental_options = {}
505 _python_eager_context_create_counter.get_cell().increase_by(1)
507 self._is_global_context = False
509 # pylint: enable=redefined-outer-name
511 def _set_global_seed(self, seed):
512 """Set a global eager mode seed for random ops."""
513 self._seed = seed
514 # `random.Random(seed)` needs `seed` to be hashable, while values of type
515 # e.g. `np.int64` or `np.ndarray` are not. We use `int(...)` to convert them
516 # to int.
517 try:
518 hash(seed)
519 self._rng = random.Random(seed)
520 except TypeError:
521 seed = int(np.array(seed))
522 self._rng = random.Random(seed)
523 # Also clear the kernel cache, to reset any existing seeds
524 if self._context_handle is not None:
525 pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
527 def _internal_operation_seed(self):
528 """Returns a fake operation seed.
530 In eager mode, user shouldn't set or depend on operation seed.
531 Here, we generate a random seed based on global seed to make
532 operation's randomness different and depend on the global seed.
534 Returns:
535 A fake operation seed based on global seed.
536 """
537 return self._rng.randint(0, _MAXINT32)
539 def _initialize_logical_devices(self):
540 """Helper to initialize devices."""
541 # Store list of devices
542 logical_devices = []
543 context_devices = []
544 device_list = pywrap_tfe.TFE_ContextListDevices(self._context_handle)
545 try:
546 self._num_gpus = 0
547 current_job, current_task = None, None
548 server_def = self._server_def or self._collective_ops_server_def
549 if server_def is not None:
550 current_job, current_task = server_def.job_name, server_def.task_index
551 for i in range(pywrap_tfe.TF_DeviceListCount(device_list)):
552 dev_name = pywrap_tfe.TF_DeviceListName(device_list, i)
553 context_devices.append(pydev.canonical_name(dev_name))
554 spec = pydev.DeviceSpec.from_string(dev_name)
555 # If the job is localhost, we assume that the cluster has not yet been
556 # configured and thus clear the job, replica & task.
557 if spec.job == "localhost":
558 spec = spec.replace(job=None, replica=None, task=None)
559 logical_devices.append(
560 LogicalDevice(name=spec.to_string(), device_type=spec.device_type))
561 dev_type = pywrap_tfe.TF_DeviceListType(device_list, i)
562 if (dev_type == "GPU" and spec.job == current_job and
563 spec.task == current_task):
564 self._num_gpus += 1
566 finally:
567 self._logical_devices = logical_devices
568 self._context_devices = context_devices
569 pywrap_tfe.TF_DeleteDeviceList(device_list)
571 def ensure_initialized(self):
572 """Initialize handle and devices if not already done so."""
573 if self._initialized:
574 return
575 with self._initialize_lock:
576 if self._initialized:
577 return
578 assert self._context_devices is None
579 opts = pywrap_tfe.TFE_NewContextOptions()
580 try:
581 config_str = self.config.SerializeToString()
582 pywrap_tfe.TFE_ContextOptionsSetConfig(opts, config_str)
583 if self._device_policy is not None:
584 pywrap_tfe.TFE_ContextOptionsSetDevicePlacementPolicy(
585 opts, self._device_policy)
586 if self._mirroring_policy is not None:
587 pywrap_tfe.TFE_ContextOptionsSetMirroringPolicy(
588 opts, self._mirroring_policy)
589 if self._default_is_async == ASYNC:
590 pywrap_tfe.TFE_ContextOptionsSetAsync(opts, True)
591 if self._use_tfrt is not None:
592 pywrap_tfe.TFE_ContextOptionsSetTfrt(opts, self._use_tfrt)
593 pywrap_tfe.TFE_ContextOptionsSetRunEagerOpAsFunction(opts, True)
594 pywrap_tfe.TFE_ContextOptionsSetJitCompileRewrite(
595 opts, self._jit_compile_rewrite)
596 context_handle = pywrap_tfe.TFE_NewContext(opts)
597 finally:
598 pywrap_tfe.TFE_DeleteContextOptions(opts)
599 assert not (self._server_def and self._collective_ops_server_def), (
600 "Cannot enable remote execution as well as collective ops at the "
601 "moment. If this is important to you, please file an issue.")
602 if self._server_def is not None:
603 server_def_str = self._server_def.SerializeToString()
604 pywrap_tfe.TFE_ContextSetServerDef(context_handle, _KEEP_ALIVE_SECS,
605 server_def_str)
606 elif self._collective_ops_server_def is not None:
607 server_def_str = self._collective_ops_server_def.SerializeToString()
608 pywrap_tfe.TFE_EnableCollectiveOps(context_handle, server_def_str)
610 self._context_handle = context_handle
611 self._initialize_logical_devices()
612 self._initialized = True
614 if self._is_global_context:
615 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
617 def ensure_uninitialized(self):
618 """Uninitialize handle and devices if not already done so."""
619 with self._initialize_lock:
620 if not self._initialized:
621 return
622 self._context_devices = None
623 self._logical_devices = None
624 self._server_def = None
625 self._initialized = False
627 if self._is_global_context:
628 pywrap_tfe.TFE_Py_SetCEagerContext(None)
630 self._context_handle = None
632 def mark_as_global_context(self):
633 # If the context was already initialized, publish it. Otherwise wait with
634 # publication until it's initialized.
635 if self._initialized:
636 pywrap_tfe.TFE_Py_SetCEagerContext(self._context_handle)
637 self._is_global_context = True
639 def _clear_caches(self):
640 self.ones_rank_cache().flush()
641 self.zeros_cache().flush()
642 pywrap_tfe.TFE_ClearScalarCache()
644 def get_server_def(self):
645 return self._server_def
647 def set_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
648 """Allow setting a server_def on the context.
650 When a server def is replaced, it effectively clears a bunch of caches
651 within the context. If you attempt to use a tensor object that was pointing
652 to a tensor on the remote device, it will raise an error.
654 Args:
655 server_def: A tensorflow::ServerDef proto. Enables execution on remote
656 devices.
657 keep_alive_secs: Num. seconds after which the remote end will hang up. As
658 long as the client is still alive, the server state for the context will
659 be kept alive. If the client is killed (or there is some failure), the
660 server will clean up its context keep_alive_secs after the final RPC it
661 receives.
663 Raises:
664 ValueError: if server_def is None.
665 """
666 if not server_def:
667 raise ValueError("server_def is None.")
669 self._server_def = server_def
671 if self._context_handle:
672 server_def_str = server_def.SerializeToString()
673 pywrap_tfe.TFE_ContextSetServerDef(self._context_handle, keep_alive_secs,
674 server_def_str)
675 self._initialize_logical_devices()
677 # Clear all the caches in case there are remote tensors in them.
678 self._clear_caches()
680 def update_server_def(self, server_def, keep_alive_secs=_KEEP_ALIVE_SECS):
681 """Update a server_def on the context.
683 Args:
684 server_def: A tensorflow::ServerDef proto. Enables execution on remote
685 devices.
686 keep_alive_secs: Num. seconds after which the remote end will hang up. As
687 long as the client is still alive, the server state for the context will
688 be kept alive. If the client is killed (or there is some failure), the
689 server will clean up its context keep_alive_secs after the final RPC it
690 receives.
692 Raises:
693 ValueError: if server_def is None.
694 """
695 if not server_def:
696 raise ValueError("server_def is None.")
698 self._server_def = server_def
700 if self._context_handle:
701 server_def_str = server_def.SerializeToString()
702 pywrap_tfe.TFE_ContextUpdateServerDef(self._context_handle,
703 keep_alive_secs, server_def_str)
704 self._initialize_logical_devices()
706 self._clear_caches()
708 def check_alive(self, worker_name):
709 """Checks whether a remote worker is alive or not.
711 Args:
712 worker_name: a string representing the remote worker. It must be a fully
713 specified name like "/job:worker/replica:0/task:0".
715 Returns:
716 a boolean indicating whether the remote worker is alive or not.
718 Raises:
719 ValueError: if context is not initialized.
720 """
721 # TODO(yuefengz): support checking multiple workers.
722 if self._context_handle:
723 return pywrap_tfe.TFE_ContextCheckAlive(self._context_handle, worker_name)
724 else:
725 raise ValueError("Context is not initialized.")
727 def sync_executors(self):
728 """Sync both local executors and the ones on remote workers.
730 In async execution mode, local function calls can return before the
731 corresponding remote op/function execution requests are completed. Calling
732 this method creates a synchronization barrier for remote executors. It only
733 returns when all remote pending nodes are finished, potentially with errors
734 if any remote executors are in error state.
736 Raises:
737 ValueError: if context is not initialized.
738 """
739 if self._context_handle:
740 pywrap_tfe.TFE_ContextSyncExecutors(self._context_handle)
741 else:
742 raise ValueError("Context is not initialized.")
744 def clear_executor_errors(self):
745 """Clear errors in both local executors and remote workers.
747 After receiving errors from remote workers, additional requests on the fly
748 could further taint the status on the remote workers due to the async nature
749 of remote execution. Calling this method block on waiting for all pending
750 nodes in remote executors to finish and clear their error statuses.
752 Raises:
753 ValueError: if context is not initialized.
754 """
755 if self._context_handle:
756 pywrap_tfe.TFE_ContextClearExecutors(self._context_handle)
757 else:
758 raise ValueError("Context is not initialized.")
760 def configure_coordination_service(self,
761 service_type,
762 service_leader="",
763 enable_health_check=True,
764 cluster_register_timeout_in_ms=0,
765 heartbeat_timeout_in_ms=0,
766 shutdown_barrier_timeout_in_ms=0,
767 coordinated_jobs=None,
768 allow_new_incarnation_to_reconnect=False):
769 """Enable distributed coordination service with specified configs."""
770 if self._context_handle:
771 logging.warning("Configuring coordination service type may not be "
772 "effective because the context is already initialized.")
773 config = coordination_config_pb2.CoordinationServiceConfig()
774 config.service_type = service_type
775 if service_leader:
776 config.service_leader = pydev.canonical_name(service_leader)
777 config.enable_health_check = enable_health_check
778 config.cluster_register_timeout_in_ms = cluster_register_timeout_in_ms
779 config.heartbeat_timeout_in_ms = heartbeat_timeout_in_ms
780 config.shutdown_barrier_timeout_in_ms = shutdown_barrier_timeout_in_ms
781 config.allow_new_incarnation_to_reconnect = (
782 allow_new_incarnation_to_reconnect)
783 if coordinated_jobs is not None:
784 if isinstance(coordinated_jobs, list):
785 config.coordinated_job_list.extend(coordinated_jobs)
786 else:
787 raise ValueError("`coordinated_jobs` must be list[CoordinatedJob] or "
788 "None, but got: %s" % (coordinated_jobs,))
789 self._coordination_service_config = config
791 @property
792 def coordination_service(self):
793 return self._coordination_service_config
795 def set_config_key_value(self, key, value):
796 ensure_initialized()
797 pywrap_tfe.TFE_InsertConfigKeyValue(self._context_handle, key, value)
799 # If `timeout_in_ms=0`, this will block until the key-value is set or the
800 # worker shuts down.
801 def get_config_key_value(self, key, timeout_in_ms=0):
802 ensure_initialized()
803 with c_api_util.tf_buffer() as buffer_:
804 pywrap_tfe.TFE_GetConfigKeyValue(self._context_handle, key,
805 timeout_in_ms, buffer_)
806 value = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
807 return value
809 def delete_config_key_value(self, key):
810 ensure_initialized()
811 pywrap_tfe.TFE_DeleteConfigKeyValue(self._context_handle, key)
813 def report_error_to_cluster(self, error_code, error_message):
814 """Report error to other members in a multi-client cluster.
816 Args:
817 error_code: a `tf.errors` error code.
818 error_message: a string. The error message.
819 """
820 if self._context_handle:
821 pywrap_tfe.TFE_ReportErrorToCluster(self._context_handle, error_code,
822 error_message)
823 else:
824 raise ValueError("Context is not initialized.")
826 def get_task_states(self, job_configs):
827 """Get task states from the Coordination Service.
829 Args:
830 job_configs: A list of tuples of job name and task number.
832 Returns:
833 A list of TF_Status.
834 """
835 if self._context_handle:
836 job_names, task_nums = zip(*job_configs)
837 return pywrap_tfe.TFE_GetTaskStates(self._context_handle, job_names,
838 task_nums)
839 else:
840 raise ValueError("Context is not initialized.")
842 def wait_at_barrier(self, barrier_id, timeout_in_ms):
843 """Blocks until all coordinated tasks are at the barrier.
845 The barrier may fail if it times out or if one of the tasks is unhealthy.
847 Args:
848 barrier_id: Unique string identifying the barrier.
849 timeout_in_ms: Duration before the barrier times out and fails.
850 """
851 ensure_initialized()
852 pywrap_tfe.TFE_WaitAtBarrier(self._context_handle, barrier_id,
853 timeout_in_ms)
855 def clear_kernel_cache(self):
856 """Clear kernel cache and reset all stateful kernels."""
857 if self._context_handle is not None:
858 pywrap_tfe.TFE_ContextClearCaches(self._context_handle)
860 def enable_collective_ops(self, server_def):
861 """Enable distributed collective ops with an appropriate server_def.
863 Args:
864 server_def: A tensorflow::ServerDef proto. Enables execution on remote
865 devices.
867 Raises:
868 ValueError: if server_def is None.
869 RuntimeError: if this method is not called at program startup.
870 """
871 if not server_def:
872 raise ValueError("server_def is None.")
874 self._collective_ops_server_def = server_def
876 # TODO(b/129298253): Allow creating datasets/tensors before enabling
877 # collective ops.
878 if self._context_handle is not None:
879 logging.warning("Enabling collective ops after program startup may cause "
880 "error when accessing previously created tensors.")
881 with self._initialize_lock:
882 assert self._initialized
883 server_def_str = self._collective_ops_server_def.SerializeToString()
884 pywrap_tfe.TFE_EnableCollectiveOps(self._context_handle, server_def_str)
885 self._initialize_logical_devices()
886 self._clear_caches()
888 def configure_collective_ops(
889 self,
890 collective_leader="",
891 scoped_allocator_enabled_ops=("CollectiveReduce",),
892 use_nccl_communication=False,
893 device_filters=None):
894 """Configure collective ops.
896 Collective group leader is necessary for collective ops to run, other
897 configurations are mainly for the purpose of performance.
899 Args:
900 collective_leader: a device string for collective leader, e.g.
901 "/job:worker/replica:0/task:0"; empty string means local execution of
902 collective ops.
903 scoped_allocator_enabled_ops: a tuple or a list of op names for scoped
904 allocator to run with.
905 use_nccl_communication: whether to use nccl communication for collective
906 ops.
907 device_filters: a tuple or a list of device strings. If set, corresponding
908 task can only see the devices filtered by these device filters.
910 Raises:
911 RuntimeError: if this method is not called at program startup.
912 """
913 if self._collective_leader is not None:
914 if (self._collective_leader != collective_leader or
915 self._collective_scoped_allocator_enabled_ops !=
916 scoped_allocator_enabled_ops or
917 self._collective_use_nccl_communication != use_nccl_communication or
918 self._collective_device_filters != device_filters):
919 raise ValueError("Collective ops are already configured.")
920 else:
921 return
923 if self._context_handle is not None:
924 raise RuntimeError("Collective ops must be configured at program startup")
926 self._collective_leader = collective_leader
927 self._collective_scoped_allocator_enabled_ops = scoped_allocator_enabled_ops
928 self._collective_use_nccl_communication = use_nccl_communication
929 self._collective_device_filters = device_filters
931 def abort_collective_ops(self, code, message):
932 """Abort the collective ops.
934 This is intended to be used when a peer failure is detected, which allows
935 the user to handle the case instead of hanging. This aborts all on-going
936 collectives. After all subsequent collectives error immediately, and you
937 need to reset_context() to use collectives again.
939 Args:
940 code: a `tf.errors` error code.
941 message: a string. The error message.
942 """
943 self.ensure_initialized()
944 pywrap_tfe.TFE_AbortCollectiveOps(self._handle, code, message)
946 def check_collective_ops_peer_health(self, task, timeout_in_ms):
947 """Check collective peer health.
949 This probes each task to see if they're still alive. Note that restarted
950 tasks are considered a different one, and they're considered not healthy.
952 This should only be used in multi client multi worker training.
954 Args:
955 task: a task string, must be in the format of /job:xxx/replica:0/task:N.
956 timeout_in_ms: an integer, the timeout. If zero, there's no timeout.
958 Raises:
959 tf.errors.UnavailableError: when a peer is down.
960 tf.errors.FailedPreconditionError: when a peer is a different one from the
961 one this task has talked to, e.g. the peer has restarted.
962 tf.errors.InvalidArgumentError: when the task string is invalid.
963 """
964 self.ensure_initialized()
965 pywrap_tfe.TFE_CollectiveOpsCheckPeerHealth(self._handle, task,
966 timeout_in_ms)
968 @property
969 def _handle(self):
970 if self._context_handle is None:
971 raise AssertionError("Context must be initialized first.")
973 return self._context_handle
975 @property
976 def _devices(self):
977 if self._context_devices is None:
978 raise AssertionError("Context must be initialized first.")
980 return self._context_devices
982 def __str__(self):
983 if self._context_handle is None:
984 return "Eager TensorFlow Context. Devices currently uninitialized."
985 else:
986 devices = self._devices
987 lines = ["Eager TensorFlow Context with %d devices" % (len(devices))]
988 for i, d in enumerate(devices):
989 lines.append(" Device %d: %s" % (i, d))
990 return "\n".join(lines)
992 @tf_contextlib.contextmanager
993 def _mode(self, mode):
994 """A context manager to allow setting the mode to EAGER/GRAPH."""
995 ctx = self._thread_local_data
996 old_is_eager = ctx.is_eager
997 ctx.is_eager = mode == EAGER_MODE
998 if mode == EAGER_MODE:
999 # Entering graph mode does not provide us with sufficient information to
1000 # record a context switch; graph-based context switches are only logged
1001 # when a graph is registered as the default graph.
1002 self.context_switches.push(False, eager_mode, None)
1003 try:
1004 yield
1005 finally:
1006 ctx.is_eager = old_is_eager
1007 if mode == EAGER_MODE:
1008 self.context_switches.pop()
1010 def executing_eagerly(self):
1011 """Returns True if current thread has eager executing enabled."""
1012 return self._thread_local_data.is_eager
1014 def ones_rank_cache(self):
1015 """Per-device cache for scalars."""
1016 return _tensor_caches_map[self._id].ones_rank_cache
1018 def zeros_cache(self):
1019 """Per-device cache for scalars."""
1020 return _tensor_caches_map[self._id].zeros_cache
1022 @property
1023 def scope_name(self):
1024 """Returns scope name for the current thread."""
1025 return self._thread_local_data.scope_name
1027 @scope_name.setter
1028 def scope_name(self, s):
1029 """Sets scope name for the current thread."""
1030 self._thread_local_data.scope_name = s
1032 @property
1033 def device_name(self):
1034 """Returns the device name for the current thread."""
1035 return self._thread_local_data.device_name
1037 @property
1038 def device_spec(self):
1039 """Returns the device spec for the current thread."""
1040 return self._thread_local_data.device_spec
1042 def _set_device(self, device_name, device_spec):
1043 self._thread_local_data.device_name = device_name
1044 self._thread_local_data.device_spec = device_spec
1046 def device(self, name):
1047 """Context-manager to force placement of operations and Tensors on a device.
1049 Args:
1050 name: Name of the device or None to get default placement.
1052 Returns:
1053 Context manager that forces device placement.
1055 Raises:
1056 ValueError: If name is not a string or is an invalid device name.
1057 RuntimeError: If device scopes are not properly nested.
1058 """
1059 if isinstance(name, LogicalDevice):
1060 name = name.name
1061 elif pydev.is_device_spec(name):
1062 name = name.to_string()
1063 return _EagerDeviceContext(self, name)
1065 def devices(self):
1066 """List of the names of devices available to execute operations."""
1067 return self._devices
1069 def host_address_space(self):
1070 self.ensure_initialized()
1071 with c_api_util.tf_buffer() as buffer_:
1072 pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
1073 address_space = pywrap_tf_session.TF_GetBuffer(buffer_).decode("utf-8")
1074 return address_space
1076 # TODO(fishx): remove this property.
1077 @property
1078 def execution_mode(self):
1079 """Gets execution mode for current thread."""
1080 return ASYNC if self.is_async() else SYNC
1082 @execution_mode.setter
1083 def execution_mode(self, mode):
1084 """Sets execution mode for current thread."""
1085 if mode not in (None, SYNC, ASYNC):
1086 raise ValueError("Execution mode should be None/SYNC/ASYNC. Got %s" %
1087 mode)
1089 if mode is None:
1090 mode = SYNC
1092 enable_async = (mode == ASYNC)
1093 if self.is_async() != enable_async:
1094 # Only set the execution mode if the context has already been initialized
1095 if self._context_handle is not None:
1096 self.executor.wait()
1097 executor_new = executor.new_executor(enable_async)
1098 self._thread_local_data.executor = executor_new
1099 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle,
1100 executor_new.handle())
1101 else:
1102 self._default_is_async = enable_async
1104 def is_async(self):
1105 if self._context_handle is not None:
1106 return self.executor.is_async()
1107 else:
1108 return self._default_is_async
1110 @property
1111 def executor(self):
1112 self.ensure_initialized()
1113 return executor.Executor(
1114 pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
1116 @executor.setter
1117 def executor(self, e):
1118 self.ensure_initialized()
1119 pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
1121 @property
1122 def config(self):
1123 """Return the ConfigProto with all runtime deltas applied."""
1124 # Ensure physical devices have been discovered and config has been imported
1125 self._initialize_physical_devices()
1127 config = config_pb2.ConfigProto()
1128 if self._config is not None:
1129 config.CopyFrom(self._config)
1131 if self._optimizer_jit is not None:
1132 config.graph_options.optimizer_options.global_jit_level = (
1133 config_pb2.OptimizerOptions.ON_1
1134 if self._optimizer_jit else config_pb2.OptimizerOptions.OFF)
1135 if self._intra_op_parallelism_threads is not None:
1136 config.intra_op_parallelism_threads = self._intra_op_parallelism_threads
1137 if self._inter_op_parallelism_threads is not None:
1138 config.inter_op_parallelism_threads = self._inter_op_parallelism_threads
1140 if self._soft_device_placement is not None:
1141 config.allow_soft_placement = self._soft_device_placement
1142 else:
1143 config.allow_soft_placement = self.executing_eagerly()
1145 if self._log_device_placement is not None:
1146 config.log_device_placement = self._log_device_placement
1148 if self._operation_timeout_in_ms is not None:
1149 config.operation_timeout_in_ms = self._operation_timeout_in_ms
1151 is_mlir_bridge_enabled = pywrap_tfe.TF_IsMlirBridgeEnabled()
1152 config.experimental.mlir_bridge_rollout = is_mlir_bridge_enabled
1153 if (is_mlir_bridge_enabled ==
1154 config_pb2.ConfigProto.Experimental.MLIR_BRIDGE_ROLLOUT_ENABLED):
1155 config.experimental.enable_mlir_bridge = True
1157 if self._enable_mlir_graph_optimization is not None:
1158 config.experimental.enable_mlir_graph_optimization = (
1159 self._enable_mlir_graph_optimization)
1161 def rewriter_toggle(option):
1162 toggle = self._optimizer_experimental_options.get(option, None)
1163 if toggle is None:
1164 return
1166 setattr(config.graph_options.rewrite_options, option,
1167 (rewriter_config_pb2.RewriterConfig.ON
1168 if toggle else rewriter_config_pb2.RewriterConfig.OFF))
1170 def rewriter_bool(option):
1171 toggle = self._optimizer_experimental_options.get(option, None)
1172 if toggle is None:
1173 return
1175 setattr(config.graph_options.rewrite_options, option, toggle)
1177 rewriter_toggle("layout_optimizer")
1178 rewriter_toggle("constant_folding")
1179 rewriter_toggle("shape_optimization")
1180 rewriter_toggle("remapping")
1181 rewriter_toggle("arithmetic_optimization")
1182 rewriter_toggle("dependency_optimization")
1183 rewriter_toggle("loop_optimization")
1184 rewriter_toggle("function_optimization")
1185 rewriter_toggle("debug_stripper")
1186 rewriter_bool("disable_model_pruning")
1187 rewriter_toggle("scoped_allocator_optimization")
1188 rewriter_toggle("pin_to_host_optimization")
1189 rewriter_toggle("implementation_selector")
1190 rewriter_toggle("auto_mixed_precision")
1191 rewriter_toggle("use_plugin_optimizers")
1192 rewriter_bool("disable_meta_optimizer")
1193 rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1194 rewriter_toggle("auto_mixed_precision_mkl")
1195 nodes = self._optimizer_experimental_options.get("min_graph_nodes", None)
1196 if nodes is not None:
1197 config.graph_options.rewrite_options.min_graph_nodes = nodes
1199 # Compute device counts
1200 config.device_count["CPU"] = 0
1201 config.device_count["GPU"] = 0
1202 for dev in self._physical_devices:
1203 if dev not in self._visible_device_list:
1204 continue
1206 virtual_devices = self._virtual_device_map.get(dev)
1207 if virtual_devices is None:
1208 config.device_count[dev.device_type] += 1
1209 else:
1210 config.device_count[dev.device_type] += len(virtual_devices)
1212 # Configure gpu_options
1213 gpu_options = self._compute_gpu_options()
1214 config.gpu_options.MergeFrom(gpu_options)
1216 # Configure collective ops
1217 if self._collective_leader:
1218 config.experimental.collective_group_leader = self._collective_leader
1219 if self._collective_scoped_allocator_enabled_ops:
1220 rewrite_options = config.graph_options.rewrite_options
1221 rewrite_options.scoped_allocator_optimization = (
1222 rewriter_config_pb2.RewriterConfig.ON)
1223 del rewrite_options.scoped_allocator_opts.enable_op[:]
1224 for op in self._collective_scoped_allocator_enabled_ops:
1225 rewrite_options.scoped_allocator_opts.enable_op.append(op)
1226 if self._collective_use_nccl_communication:
1227 config.experimental.collective_nccl = True
1228 if self._collective_device_filters:
1229 del config.device_filters[:]
1230 for f in self._collective_device_filters:
1231 config.device_filters.append(f)
1233 # Configure coordination service
1234 if self._coordination_service_config:
1235 config.experimental.coordination_config.CopyFrom(
1236 self._coordination_service_config)
1238 return config
1240 def _compute_gpu_options(self):
1241 """Build the GPUOptions proto."""
1242 visible_device_list = []
1243 virtual_devices = []
1244 gpu_index = -1
1245 memory_growths = set()
1246 gpu_devices = self.list_physical_devices("GPU")
1247 pluggable_devices = self._pluggable_devices
1248 compatible_devices = gpu_devices
1249 for dev in pluggable_devices:
1250 if dev not in gpu_devices:
1251 compatible_devices.append(dev)
1252 for dev in compatible_devices:
1253 gpu_index += 1
1255 if dev not in self._visible_device_list:
1256 continue
1258 growth = self._memory_growth_map[dev]
1259 memory_growths.add(growth)
1260 visible_device_list.append(str(gpu_index))
1262 if self._virtual_device_map:
1263 vdevs = self._virtual_device_map.get(dev, [])
1264 device_ordinals = []
1265 device_limits = []
1266 priority = []
1267 for virt_dev in vdevs:
1268 if virt_dev.experimental_device_ordinal is not None:
1269 device_ordinals.append(virt_dev.experimental_device_ordinal)
1270 device_limits.append(virt_dev.memory_limit)
1271 if virt_dev.experimental_priority is not None:
1272 priority.append(virt_dev.experimental_priority)
1273 # If priority is specified, it must be specified for all virtual
1274 # devices.
1275 if priority and len(device_limits) != len(priority):
1276 raise ValueError("priority must be specified for all virtual devices")
1277 # If device_ordinals is specified, it must be specified for all virtual
1278 # devices.
1279 if device_ordinals and len(device_limits) != len(device_ordinals):
1280 raise ValueError(
1281 "device_ordinals must be specified for all virtual devices")
1283 virtual_devices.append(
1284 config_pb2.GPUOptions.Experimental.VirtualDevices(
1285 memory_limit_mb=device_limits,
1286 priority=priority,
1287 device_ordinal=device_ordinals))
1289 # Only compute growth if virtual devices have not been configured and we
1290 # have GPUs
1291 if not virtual_devices and memory_growths:
1292 if len(memory_growths) > 1:
1293 raise ValueError("Memory growth cannot differ between GPU devices")
1294 allow_growth = memory_growths.pop()
1295 else:
1296 allow_growth = None
1298 return config_pb2.GPUOptions(
1299 allow_growth=allow_growth,
1300 visible_device_list=",".join(visible_device_list),
1301 experimental=config_pb2.GPUOptions.Experimental(
1302 virtual_devices=virtual_devices))
1304 @property
1305 def function_call_options(self):
1306 """Returns function call options for current thread.
1308 Note that the returned object is still referenced by the eager context.
1310 Returns: the FunctionCallOptions for current thread.
1311 """
1312 if self._thread_local_data.function_call_options is None:
1313 config = self.config
1315 # Default to soft placement for functions unless specified
1316 if self._soft_device_placement is None:
1317 config.allow_soft_placement = True
1318 self._thread_local_data.function_call_options = FunctionCallOptions(
1319 config_proto=config)
1321 return self._thread_local_data.function_call_options
1323 @function_call_options.setter
1324 def function_call_options(self, options):
1325 """Returns function call options for current thread."""
1326 self._thread_local_data.function_call_options = options
1328 def num_gpus(self):
1329 """The number of GPUs available to execute operations."""
1330 self.ensure_initialized()
1331 return self._num_gpus
1333 def add_c_function(self, c_func):
1334 """Add a C API TF_Function to the context.
1336 Once added, the function (identified by its name) can be executed like any
1337 other operation.
1339 Args:
1340 c_func: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
1341 """
1342 self.ensure_initialized()
1343 pywrap_tfe.TFE_ContextAddFunction(self._handle, c_func)
1345 def get_c_function(self, name):
1346 """Get a C API TF_Function from the context.
1348 Args:
1349 name: Name of the function to get.
1351 Returns:
1352 A ScopedTFFunction wrapping the C API TF_Function.
1353 """
1354 self.ensure_initialized()
1355 return c_api_util.ScopedTFFunction(
1356 pywrap_tfe.TFE_ContextGetFunction(self._handle, name), name
1357 )
1359 def add_function_def(self, fdef):
1360 """Add a function definition to the context.
1362 Once added, the function (identified by its name) can be executed like any
1363 other operation.
1365 Args:
1366 fdef: A FunctionDef protocol buffer message.
1367 """
1368 self.ensure_initialized()
1369 fdef_string = fdef.SerializeToString()
1370 pywrap_tfe.TFE_ContextAddFunctionDef(self._handle, fdef_string,
1371 len(fdef_string))
1373 def get_function_def(self, name):
1374 """Get a function definition from the context.
1376 Args:
1377 name: function signature name.
1379 Returns:
1380 The requested FunctionDef.
1382 Raises:
1383 tf.errors.NotFoundError: if name is not the name of a registered function.
1384 """
1385 with c_api_util.tf_buffer() as buffer_:
1386 pywrap_tfe.TFE_ContextGetFunctionDef(self._handle, name, buffer_)
1387 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
1388 function_def = function_pb2.FunctionDef()
1389 function_def.ParseFromString(proto_data)
1391 return function_def
1393 def is_custom_device(self, device_name):
1394 """Calls TFE_IsCustomDevice. See the non-member function."""
1395 self.ensure_initialized()
1396 return pywrap_tfe.TFE_Py_IsCustomDevice(self._handle, device_name)
1398 def register_custom_device(self, device_capsule, device_name,
1399 device_info_capsule):
1400 """Calls TFE_RegisterCustomDevice. See the non-member function."""
1401 self.ensure_initialized()
1402 pywrap_tfe.TFE_Py_RegisterCustomDevice(self._handle, device_capsule,
1403 device_name, device_info_capsule)
1405 def pack_eager_tensors(self, tensors):
1406 """Pack multiple `EagerTensor`s of the same dtype and shape.
1408 Args:
1409 tensors: a list of EagerTensors to pack.
1411 Returns:
1412 A packed EagerTensor.
1413 """
1414 self.ensure_initialized()
1415 return pywrap_tfe.TFE_Py_PackEagerTensors(self._handle, tensors)
1417 def list_function_names(self):
1418 """Get a list of names of registered functions.
1420 Returns:
1421 A set of names of all registered functions for the context.
1422 """
1423 self.ensure_initialized()
1424 return set(pywrap_tfe.TFE_ContextListFunctionNames(self._handle))
1426 def remove_function(self, name):
1427 """Remove a function from the context.
1429 Once removed, the function cannot be executed anymore.
1431 Args:
1432 name: function signature name.
1433 """
1434 self.ensure_initialized()
1435 pywrap_tfe.TFE_ContextRemoveFunction(self._handle, name)
1437 def has_function(self, name):
1438 """Check if a function `name` is registered."""
1439 self.ensure_initialized()
1440 return bool(pywrap_tfe.TFE_ContextHasFunction(self._handle, name))
1442 @property
1443 def function_scope_id(self):
1444 """Returns an id that is unique to each scope holding functions."""
1445 return id(self._context_handle)
1447 def call_function(self, name, tensor_inputs, num_outputs):
1448 """Calls the function associated with the given name."""
1449 attrs = tuple(
1450 itertools.chain(
1451 *self.function_call_options.as_attrs().items()
1452 )
1453 )
1455 cancellation_context = cancellation.context()
1456 if cancellation_context is None:
1457 outputs = execute.execute(
1458 name.decode("utf-8"),
1459 num_outputs=num_outputs,
1460 inputs=tensor_inputs,
1461 attrs=attrs,
1462 ctx=self,
1463 )
1464 else:
1465 outputs = execute.execute_with_cancellation(
1466 name.decode("utf-8"),
1467 num_outputs=num_outputs,
1468 inputs=tensor_inputs,
1469 attrs=attrs,
1470 ctx=self,
1471 cancellation_manager=cancellation_context,
1472 )
1473 # Empty list means no function outputs so return None
1474 outputs = outputs or None
1476 return outputs
1478 def add_op_callback(self, callback):
1479 """Add a post-op callback to the context.
1481 A post-op callback is invoked immediately after an eager operation or
1482 function has finished execution or after a op has been added to a graph,
1483 providing access to the op's type, name input and output tensors. Multiple
1484 op callbacks can be added, in which case the callbacks will be invoked in
1485 the order in which they are added.
1487 Args:
1488 callback: a callable of the signature `f(op_type, inputs, attrs, outputs,
1489 op_name=None, graph=None)`. See doc strings in `op_callbacks.py` for
1490 details on the function signature and its semantics.
1491 """
1492 if callback not in self._thread_local_data.op_callbacks:
1493 self._thread_local_data.op_callbacks.append(callback)
1495 def remove_op_callback(self, callback):
1496 """Remove an already-registered op callback.
1498 Args:
1499 callback: The op callback to be removed.
1501 Raises:
1502 KeyError: If `callback` is not already registered.
1503 """
1504 if callback not in self._thread_local_data.op_callbacks:
1505 raise KeyError("The specified op callback has not been registered, "
1506 "and hence cannot be removed.")
1507 del self._thread_local_data.op_callbacks[
1508 self._thread_local_data.op_callbacks.index(callback)]
1510 @property
1511 def op_callbacks(self):
1512 return self._thread_local_data.op_callbacks
1514 @property
1515 def invoking_op_callbacks(self):
1516 return self._thread_local_data.invoking_op_callbacks
1518 @invoking_op_callbacks.setter
1519 def invoking_op_callbacks(self, value):
1520 self._thread_local_data.invoking_op_callbacks = value
1522 def _initialize_physical_devices(self, reinitialize=False):
1523 """Gets local devices visible to the system.
1525 Args:
1526 reinitialize: If True, reinitializes self._physical_devices so that
1527 dynamic registered devices will also be visible to the python front-end.
1528 """
1529 # We lazy initialize self._physical_devices since we do not want to do this
1530 # the constructor since the backend may not be initialized yet.
1531 with self._device_lock:
1532 if not reinitialize and self._physical_devices is not None:
1533 return
1535 devs = pywrap_tfe.TF_ListPhysicalDevices()
1536 self._physical_devices = [
1537 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1538 for d in devs
1539 ]
1540 self._physical_device_to_index = {
1541 p: i for i, p in enumerate(self._physical_devices)
1542 }
1543 # We maintain a separate list just so we can check whether the device in
1544 # _physical_devices is a PluggableDevice.
1545 pluggable_devs = pywrap_tfe.TF_ListPluggablePhysicalDevices()
1546 self._pluggable_devices = [
1547 PhysicalDevice(name=d.decode(), device_type=d.decode().split(":")[1])
1548 for d in pluggable_devs
1549 ]
1551 self._visible_device_list = list(self._physical_devices)
1552 self._memory_growth_map = {
1553 d: None
1554 for d in self._physical_devices
1555 if d.device_type == "GPU" or d in self._pluggable_devices
1556 }
1558 # Import device settings that may have been passed into the constructor
1559 self._import_config()
1561 def reinitialize_physical_devices(self):
1562 """Gets local devices visible to the system."""
1563 # Reinitialize the physical device list after registering
1564 # the pluggable device.
1565 self._initialize_physical_devices(True)
1567 def list_physical_devices(self, device_type=None):
1568 """List local devices visible to the system.
1570 This API allows a client to query the devices before they have been
1571 initialized by the eager runtime. Additionally a user can filter by device
1572 type, to get only CPUs or GPUs.
1574 Args:
1575 device_type: Optional device type to limit results to
1577 Returns:
1578 List of PhysicalDevice objects.
1579 """
1580 self._initialize_physical_devices()
1582 if device_type is None:
1583 return list(self._physical_devices)
1585 return [d for d in self._physical_devices if d.device_type == device_type]
1587 def get_device_details(self, device): # pylint: disable=redefined-outer-name
1588 """Returns details about a physical devices.
1590 Args:
1591 device: A `tf.config.PhysicalDevice` returned by
1592 `tf.config.list_physical_devices` or `tf.config.get_visible_devices`.
1594 Returns:
1595 A dict with string keys.
1596 """
1597 if not isinstance(device, PhysicalDevice):
1598 raise ValueError("device must be a tf.config.PhysicalDevice, but got: "
1599 "%s" % (device,))
1600 if (self._physical_device_to_index is None or
1601 device not in self._physical_device_to_index):
1602 raise ValueError("The PhysicalDevice must be one obtained from "
1603 "calling `tf.config.list_physical_devices`, but got: "
1604 "%s" % (device,))
1605 index = self._physical_device_to_index[device]
1606 details = pywrap_tfe.TF_GetDeviceDetails(index)
1608 # Change compute_capability from a string to a tuple
1609 if "compute_capability" in details:
1610 try:
1611 major, minor = details["compute_capability"].split(".")
1612 details["compute_capability"] = (int(major), int(minor))
1613 except ValueError:
1614 raise RuntimeError("Device returned compute capability an in invalid "
1615 "format: %s" % details["compute_capability"])
1616 return details
1618 def _import_config(self):
1619 """Import config if passed in during construction.
1621 If Context was created with a ConfigProto such as when calling
1622 tf.compat.v1.enable_eager_execution(), then we need to pull out the
1623 various pieces we might be replacing and import then into our internal
1624 class representation.
1625 """
1626 if self._config is None:
1627 return
1629 num_cpus = self._config.device_count.get("CPU", 1)
1630 if num_cpus != 1:
1631 cpus = [d for d in self._physical_devices if d.device_type == "CPU"]
1632 if num_cpus == 0:
1633 self.set_visible_devices([], "CPU")
1634 elif num_cpus > 1:
1635 self.set_logical_device_configuration(
1636 cpus[0], [LogicalDeviceConfiguration() for _ in range(num_cpus)])
1638 # Parse GPU options
1639 gpus = [d for d in self._physical_devices if d.device_type == "GPU"]
1641 # If there are no GPUs detected, simply ignore all the GPU options passed in
1642 # rather than doing any validation checks.
1643 if not gpus:
1644 return
1646 gpu_count = self._config.device_count.get("GPU", None)
1648 visible_gpus = []
1649 # TODO(gjn): Handle importing existing virtual GPU configuration
1650 visible_indices = self._config.gpu_options.visible_device_list
1651 if visible_indices:
1652 for index in visible_indices.split(","):
1653 if int(index) >= len(gpus):
1654 raise ValueError("Invalid visible device index: %s" % index)
1655 visible_gpus.append(gpus[int(index)])
1656 else:
1657 visible_gpus = gpus
1659 if gpu_count is not None:
1660 visible_gpus = visible_gpus[:gpu_count]
1662 self.set_visible_devices(visible_gpus, "GPU")
1664 def list_logical_devices(self, device_type=None):
1665 """Return logical devices."""
1666 self.ensure_initialized()
1667 if device_type is None:
1668 return list(self._logical_devices)
1670 return [d for d in self._logical_devices if d.device_type == device_type]
1672 def get_visible_devices(self, device_type=None):
1673 """Get the list of visible devices."""
1674 self._initialize_physical_devices()
1676 if device_type is None:
1677 return list(self._visible_device_list)
1679 return [
1680 d for d in self._visible_device_list if d.device_type == device_type
1681 ]
1683 def set_visible_devices(self, devices, device_type=None):
1684 """Set the list of visible devices."""
1685 self._initialize_physical_devices()
1687 if not isinstance(devices, list):
1688 devices = [devices]
1690 for d in devices:
1691 if d not in self._physical_devices:
1692 raise ValueError("Unrecognized device: %s" % repr(d))
1693 if device_type is not None and d.device_type != device_type:
1694 raise ValueError("Unrecognized device: %s" % repr(d))
1696 visible_device_list = []
1697 if device_type is not None:
1698 visible_device_list = [
1699 d for d in self._visible_device_list if d.device_type != device_type
1700 ]
1702 visible_device_list += devices
1704 if self._visible_device_list == visible_device_list:
1705 return
1707 if self._context_handle is not None:
1708 raise RuntimeError(
1709 "Visible devices cannot be modified after being initialized")
1711 self._visible_device_list = visible_device_list
1713 def get_memory_info(self, dev):
1714 """Returns a dict of memory info for the device."""
1715 self._initialize_physical_devices()
1716 self.ensure_initialized()
1717 return pywrap_tfe.TFE_GetMemoryInfo(self._context_handle, dev)
1719 def reset_memory_stats(self, dev):
1720 """Resets the tracked memory stats for the device."""
1721 self._initialize_physical_devices()
1722 self.ensure_initialized()
1723 pywrap_tfe.TFE_ResetMemoryStats(self._context_handle, dev)
1725 def get_memory_growth(self, dev):
1726 """Get if memory growth is enabled for a PhysicalDevice."""
1727 self._initialize_physical_devices()
1729 if dev not in self._physical_devices:
1730 raise ValueError("Unrecognized device: %s" % repr(dev))
1732 return self._memory_growth_map[dev]
1734 def set_memory_growth(self, dev, enable):
1735 """Set if memory growth should be enabled for a PhysicalDevice."""
1736 self._initialize_physical_devices()
1738 if dev not in self._physical_devices:
1739 raise ValueError("Unrecognized device: %s" % repr(dev))
1741 if dev in self._virtual_device_map:
1742 raise ValueError(
1743 "Cannot set memory growth on device when virtual devices configured")
1745 if dev.device_type != "GPU" and dev not in self._pluggable_devices:
1746 raise ValueError(
1747 "Cannot set memory growth on non-GPU and non-Pluggable devices")
1749 if self._memory_growth_map.get(dev) == enable:
1750 return
1752 if self._context_handle is not None:
1753 raise RuntimeError(
1754 "Physical devices cannot be modified after being initialized")
1756 self._memory_growth_map[dev] = enable
1758 def get_logical_device_configuration(self, dev):
1759 """Get the virtual device configuration for a PhysicalDevice."""
1760 self._initialize_physical_devices()
1762 if dev not in self._physical_devices:
1763 raise ValueError("Unrecognized device: %s" % repr(dev))
1765 return self._virtual_device_map.get(dev)
1767 def set_logical_device_configuration(self, dev, virtual_devices):
1768 """Set the virtual device configuration for a PhysicalDevice."""
1769 self._initialize_physical_devices()
1771 if dev not in self._physical_devices:
1772 raise ValueError("Unrecognized device: %s" % repr(dev))
1774 if dev.device_type == "CPU":
1775 for vdev in virtual_devices:
1776 if vdev.memory_limit is not None:
1777 raise ValueError("Setting memory limit on CPU virtual devices is "
1778 "currently not supported")
1779 if vdev.experimental_priority is not None:
1780 raise ValueError("Setting experimental_priority on CPU virtual "
1781 " devices is currently not supported")
1782 if vdev.experimental_device_ordinal is not None:
1783 raise ValueError("Setting experimental_device_ordinal on CPU virtual "
1784 " devices is currently not supported")
1785 elif dev.device_type == "GPU":
1786 for vdev in virtual_devices:
1787 if vdev.memory_limit is None:
1788 raise ValueError(
1789 "Setting memory limit is required for GPU virtual devices")
1790 else:
1791 raise ValueError("Virtual devices are not supported for %s" %
1792 dev.device_type)
1794 if self._virtual_device_map.get(dev) == virtual_devices:
1795 return
1797 if self._context_handle is not None:
1798 raise RuntimeError(
1799 "Virtual devices cannot be modified after being initialized")
1801 self._virtual_device_map[dev] = virtual_devices
1803 def set_logical_cpu_devices(self, num_cpus, prefix=""):
1804 """Set virtual CPU devices in context.
1806 If virtual CPU devices are already configured at context initialization
1807 by tf.config.set_logical_device_configuration(), this method should not be
1808 called.
1810 Args:
1811 num_cpus: Number of virtual CPUs.
1812 prefix: Device name prefix.
1814 Raises:
1815 RuntimeError: If virtual CPUs are already configured at context
1816 initialization.
1817 """
1818 server_def = self._server_def or self._collective_ops_server_def
1819 local_prefix = ["/device"]
1820 if server_def is not None:
1821 local_prefix.append("/job:%s/replica:0/task:%d" % (server_def.job_name,
1822 server_def.task_index))
1823 logical_local_devices = [d for d in self.list_logical_devices("CPU") if
1824 d.name.startswith(tuple(local_prefix))]
1825 self.ensure_initialized()
1826 # Error out if there are already multiple logical CPU in the context.
1827 if len(logical_local_devices) > 1:
1828 raise RuntimeError("Virtual CPUs already set, cannot modify again.")
1830 pywrap_tfe.TFE_SetLogicalCpuDevices(self._context_handle, num_cpus, prefix)
1831 self._initialize_logical_devices()
1833 def get_compiler_ir(
1834 self,
1835 device_name,
1836 function_name,
1837 flat_args,
1838 captured_inputs,
1839 stage="hlo",
1840 ):
1841 return pywrap_tfe.TF_GetCompilerIr(
1842 self._context_handle,
1843 function_name,
1844 stage,
1845 device_name,
1846 flat_args,
1847 captured_inputs,
1848 )
1850 @deprecated(
1851 None, "XLA:CPU and XLA:GPU devices are deprecated", warn_once=True)
1852 def enable_xla_devices(self):
1853 """Enables XLA:CPU and XLA:GPU devices registration."""
1854 pywrap_tfe.TF_EnableXlaDevices()
1856 @property
1857 def enable_mlir_bridge(self):
1858 return pywrap_tfe.TF_IsMlirBridgeEnabled()
1860 @property
1861 def enable_mlir_graph_optimization(self):
1862 return self._enable_mlir_graph_optimization
1864 @enable_mlir_bridge.setter
1865 def enable_mlir_bridge(self, enabled):
1866 pywrap_tfe.TF_EnableMlirBridge(enabled)
1867 self._thread_local_data.function_call_options = None
1869 @enable_mlir_graph_optimization.setter
1870 def enable_mlir_graph_optimization(self, enabled):
1871 self._enable_mlir_graph_optimization = enabled
1872 self._thread_local_data.function_call_options = None
1874 @property
1875 def optimizer_jit(self):
1876 level = self.config.graph_options.optimizer_options.global_jit_level
1877 return (level == config_pb2.OptimizerOptions.ON_1 or
1878 level == config_pb2.OptimizerOptions.ON_2)
1880 @optimizer_jit.setter
1881 def optimizer_jit(self, enabled):
1882 self._optimizer_jit = enabled
1884 self._thread_local_data.function_call_options = None
1886 def get_optimizer_experimental_options(self):
1887 """Get experimental options for the optimizer.
1889 Returns:
1890 Dictionary of current option values
1891 """
1892 rewrite_options = self.config.graph_options.rewrite_options
1893 options = {}
1895 def rewriter_toggle(option):
1896 attr = getattr(rewrite_options, option)
1897 if attr != 0:
1898 options[option] = (attr == rewriter_config_pb2.RewriterConfig.ON)
1900 def rewriter_bool(option):
1901 options[option] = getattr(rewrite_options, option)
1903 rewriter_toggle("layout_optimizer")
1904 rewriter_toggle("constant_folding")
1905 rewriter_toggle("shape_optimization")
1906 rewriter_toggle("remapping")
1907 rewriter_toggle("arithmetic_optimization")
1908 rewriter_toggle("dependency_optimization")
1909 rewriter_toggle("loop_optimization")
1910 rewriter_toggle("function_optimization")
1911 rewriter_toggle("debug_stripper")
1912 rewriter_bool("disable_model_pruning")
1913 rewriter_toggle("scoped_allocator_optimization")
1914 rewriter_toggle("pin_to_host_optimization")
1915 rewriter_toggle("implementation_selector")
1916 rewriter_toggle("auto_mixed_precision")
1917 rewriter_toggle("use_plugin_optimizers")
1918 rewriter_bool("disable_meta_optimizer")
1919 rewriter_toggle("auto_mixed_precision_onednn_bfloat16")
1920 rewriter_toggle("auto_mixed_precision_mkl")
1922 if rewrite_options.min_graph_nodes != 0:
1923 options["min_graph_nodes"] = rewrite_options.min_graph_nodes
1925 return options
1927 def set_optimizer_experimental_options(self, options):
1928 """Set experimental options for the optimizer.
1930 Args:
1931 options: Dictionary of options to modify
1932 """
1933 self._optimizer_experimental_options.update(options)
1935 self._thread_local_data.function_call_options = None
1937 @property
1938 def intra_op_parallelism_threads(self):
1939 return self.config.intra_op_parallelism_threads
1941 @intra_op_parallelism_threads.setter
1942 def intra_op_parallelism_threads(self, num_threads):
1943 if self._intra_op_parallelism_threads == num_threads:
1944 return
1946 if self._context_handle is not None:
1947 raise RuntimeError(
1948 "Intra op parallelism cannot be modified after initialization.")
1950 self._intra_op_parallelism_threads = num_threads
1952 @property
1953 def inter_op_parallelism_threads(self):
1954 return self.config.inter_op_parallelism_threads
1956 @inter_op_parallelism_threads.setter
1957 def inter_op_parallelism_threads(self, num_threads):
1958 if self._inter_op_parallelism_threads == num_threads:
1959 return
1961 if self._context_handle is not None:
1962 raise RuntimeError(
1963 "Inter op parallelism cannot be modified after initialization.")
1965 self._inter_op_parallelism_threads = num_threads
1967 @property
1968 def soft_device_placement(self):
1969 return self.config.allow_soft_placement
1971 @soft_device_placement.setter
1972 def soft_device_placement(self, enable):
1973 if self._context_handle is not None:
1974 pywrap_tfe.TFE_ContextSetSoftDevicePlacement(self._handle, enable)
1976 self._soft_device_placement = enable
1977 self._thread_local_data.function_call_options = None
1979 @property
1980 def log_device_placement(self):
1981 return self.config.log_device_placement
1983 @log_device_placement.setter
1984 def log_device_placement(self, enable):
1985 if self._context_handle is not None:
1986 pywrap_tfe.TFE_ContextSetLogDevicePlacement(self._handle, enable)
1988 self._log_device_placement = enable
1989 self._thread_local_data.function_call_options = None
1991 @property
1992 def jit_compile_rewrite(self):
1993 return self._jit_compile_rewrite
1995 @jit_compile_rewrite.setter
1996 def jit_compile_rewrite(self, enable):
1997 if self._context_handle is not None:
1998 pywrap_tfe.TFE_ContextSetJitCompileRewrite(self._handle, enable)
1999 self._jit_compile_rewrite = enable
2001 @property
2002 def device_policy(self):
2003 # Only get the policy from the context if it has already been initialized
2004 if self._context_handle is not None:
2005 return pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(self._handle)
2007 return self._device_policy
2009 @device_policy.setter
2010 def device_policy(self, policy):
2011 if policy is None:
2012 policy = DEVICE_PLACEMENT_SILENT
2014 if self._device_policy != policy:
2015 self._device_policy = policy
2017 # Only set the policy if the context has already been initialized
2018 if self._context_handle is not None:
2019 pywrap_tfe.TFE_ContextSetThreadLocalDevicePlacementPolicy(
2020 self._handle, self._device_policy)
2022 @property
2023 def use_tfrt(self):
2024 return self._use_tfrt
2026 @use_tfrt.setter
2027 def use_tfrt(self, tfrt):
2028 """Sets whether to use TFRT."""
2029 if not isinstance(tfrt, bool):
2030 raise ValueError("Expecting a boolean but got %s" % type(tfrt))
2032 if self._use_tfrt != tfrt:
2033 if self._initialized:
2034 raise ValueError("use_tfrt should be set before being initialized.")
2035 self._use_tfrt = tfrt
2037 @property
2038 def operation_timeout_in_ms(self):
2039 return self.config.operation_timeout_in_ms
2041 @operation_timeout_in_ms.setter
2042 def operation_timeout_in_ms(self, timeout_in_ms):
2043 if self._operation_timeout_in_ms == timeout_in_ms:
2044 return
2046 if self._context_handle is not None:
2047 raise RuntimeError(
2048 "Operation timeout cannot be modified after initialization.")
2050 self._operation_timeout_in_ms = timeout_in_ms
2052 def enable_run_metadata(self):
2053 """Enables tracing of op execution via RunMetadata.
2055 To retrieve the accumulated metadata call context.export_run_metadata()
2056 and to stop tracing call context.disable_run_metadata().
2057 """
2058 self.ensure_initialized()
2059 pywrap_tfe.TFE_ContextEnableRunMetadata(self._handle)
2061 def disable_run_metadata(self):
2062 """Disables tracing of op execution via RunMetadata."""
2063 if not self._context_handle:
2064 return
2065 pywrap_tfe.TFE_ContextDisableRunMetadata(self._context_handle)
2067 def enable_graph_collection(self):
2068 """Enables graph collection of executed functions.
2070 To retrieve the accumulated graphs call context.export_run_metadata()
2071 and to stop collecting graphs call context.disable_graph_collection().
2072 """
2073 self.ensure_initialized()
2074 pywrap_tfe.TFE_ContextEnableGraphCollection(self._handle)
2076 def disable_graph_collection(self):
2077 """Disables graph collection of executed functions."""
2078 if not self._context_handle:
2079 return
2080 pywrap_tfe.TFE_ContextDisableGraphCollection(self._context_handle)
2082 def export_run_metadata(self):
2083 """Returns a RunMetadata proto with accumulated information.
2085 The returned protocol buffer contains information since the most recent call
2086 to either enable_run_metadata or export_run_metadata.
2088 Returns:
2089 A RunMetadata protocol buffer. Or None if not enabled.
2090 """
2091 if not self._context_handle:
2092 return None
2093 with c_api_util.tf_buffer() as buffer_:
2094 pywrap_tfe.TFE_ContextExportRunMetadata(self._context_handle, buffer_)
2095 proto_data = pywrap_tf_session.TF_GetBuffer(buffer_)
2096 run_metadata = config_pb2.RunMetadata()
2097 run_metadata.ParseFromString(compat.as_bytes(proto_data))
2098 return run_metadata
2100 @property
2101 def context_switches(self):
2102 """Returns a stack of context switches."""
2103 return self._context_switches
2106class _EagerDeviceContext(object):
2107 """Context-manager forcing placement of ops and Tensors on a device."""
2109 __slots__ = ["_device_name", "_ctx", "_stack"]
2111 def __init__(self, ctx, device_name):
2112 self._device_name = device_name
2113 self._ctx = ctx
2114 self._stack = []
2116 # TODO(b/189233748): Consolidate the device string parsing logic with
2117 # tensorflow/core/util/device_name_utils.cc.
2118 def __enter__(self):
2119 ctx = self._ctx
2120 old_device_name = ctx.device_name
2121 old_device_spec = ctx.device_spec
2122 new_device_name = self._device_name
2123 cache_key = (old_device_name, new_device_name)
2124 try:
2125 new_device_name, new_device_spec = _device_parsing_cache[cache_key]
2126 except TypeError:
2127 # Error while trying to compute the cache key.
2128 raise ValueError("Expecting a string device name. Got %s(%s)" %
2129 (type(new_device_name), new_device_name))
2130 except KeyError:
2131 # Handle a cache miss.
2132 if new_device_name is not None:
2133 if not isinstance(new_device_name, str):
2134 raise ValueError("Expecting a string device name. Got %s(%s)" %
2135 (type(new_device_name), new_device_name))
2136 device_spec = pydev.DeviceSpec.from_string(new_device_name)
2137 if old_device_name:
2138 new_device_spec = copy.copy(old_device_spec)
2139 else:
2140 ctx.ensure_initialized()
2141 new_device_spec = pydev.DeviceSpec.from_string(
2142 ctx._context_devices[0]) # pylint: disable=protected-access
2143 new_device_spec = new_device_spec.make_merged_spec(device_spec)
2144 else:
2145 new_device_spec = pydev.DeviceSpec.from_string("")
2146 new_device_name = new_device_spec.to_string()
2147 _device_parsing_cache[cache_key] = (new_device_name, new_device_spec)
2149 ctx._set_device(new_device_name, new_device_spec) # pylint: disable=protected-access
2150 self._stack.append((old_device_name, old_device_spec, new_device_spec))
2152 def __exit__(self, *ex_info):
2153 ctx = self._ctx
2154 old_device_name, old_device_spec, new_device_spec = self._stack[-1]
2155 if ctx.device_spec is not new_device_spec:
2156 raise RuntimeError("Exiting device scope without proper scope nesting")
2157 del self._stack[-1]
2158 ctx._set_device(old_device_name, old_device_spec) # pylint: disable=protected-access
2161# Do not change directly.
2162_context = None
2163_context_lock = threading.Lock()
2166def _set_context_locked(ctx):
2167 global _context
2168 pywrap_tfe.TFE_Py_SetEagerContext(ctx)
2169 ctx.mark_as_global_context()
2170 _context = ctx
2173def _set_context(ctx):
2174 with _context_lock:
2175 _set_context_locked(ctx)
2178def _create_context():
2179 with _context_lock:
2180 if _context is None:
2181 ctx = Context()
2182 _set_context_locked(ctx)
2185def _reset_context():
2186 """Clears and re-initializes the singleton context.
2188 Should only be used for testing.
2189 """
2190 global _context
2191 global _device_parsing_cache
2193 # Garbage collect and clear scalar cache to avoid Tensor from current context
2194 # polluting next context.
2195 gc.collect()
2196 pywrap_tfe.TFE_ClearScalarCache()
2197 with _context_lock:
2198 if _context is not None:
2199 _context._clear_caches()
2200 _context = None
2201 _create_context()
2202 _device_parsing_cache = {}
2205def _reset_jit_compiler_flags():
2206 """Clears and re-initializes the TF JIT compiler flags.
2208 Should only be used for testing.
2209 """
2210 pywrap_tfe.TF_ResetJitCompilerFlags()
2213def context():
2214 """Returns a singleton context object."""
2215 if _context is None:
2216 _create_context()
2217 return _context
2220def context_safe():
2221 """Returns current context (or None if one hasn't been initialized)."""
2222 return _context
2225def ensure_initialized():
2226 """Initialize the context."""
2227 context().ensure_initialized()
2230def initialize_logical_devices():
2231 """Initialize the virtual devices."""
2232 context()._initialize_logical_devices() # pylint: disable=protected-access
2235def set_global_seed(seed):
2236 """Sets the eager mode seed."""
2237 context()._set_global_seed(seed) # pylint: disable=protected-access
2240def global_seed():
2241 """Returns the eager mode seed."""
2242 return context()._seed # pylint: disable=protected-access
2245def internal_operation_seed():
2246 """Returns the operation seed generated based on global seed."""
2247 return context()._internal_operation_seed() # pylint: disable=protected-access
2250@tf_export("executing_eagerly", v1=[])
2251def executing_eagerly():
2252 """Checks whether the current thread has eager execution enabled.
2254 Eager execution is enabled by default and this API returns `True`
2255 in most of cases. However, this API might return `False` in the following use
2256 cases.
2258 * Executing inside `tf.function`, unless under `tf.init_scope` or
2259 `tf.config.run_functions_eagerly(True)` is previously called.
2260 * Executing inside a transformation function for `tf.dataset`.
2261 * `tf.compat.v1.disable_eager_execution()` is called.
2263 General case:
2265 >>> print(tf.executing_eagerly())
2266 True
2268 Inside `tf.function`:
2270 >>> @tf.function
2271 ... def fn():
2272 ... with tf.init_scope():
2273 ... print(tf.executing_eagerly())
2274 ... print(tf.executing_eagerly())
2275 >>> fn()
2276 True
2277 False
2279 Inside `tf.function` after `tf.config.run_functions_eagerly(True)` is called:
2281 >>> tf.config.run_functions_eagerly(True)
2282 >>> @tf.function
2283 ... def fn():
2284 ... with tf.init_scope():
2285 ... print(tf.executing_eagerly())
2286 ... print(tf.executing_eagerly())
2287 >>> fn()
2288 True
2289 True
2290 >>> tf.config.run_functions_eagerly(False)
2292 Inside a transformation function for `tf.dataset`:
2294 >>> def data_fn(x):
2295 ... print(tf.executing_eagerly())
2296 ... return x
2297 >>> dataset = tf.data.Dataset.range(100)
2298 >>> dataset = dataset.map(data_fn)
2299 False
2301 Returns:
2302 `True` if the current thread has eager execution enabled.
2303 """
2304 ctx = context_safe()
2305 if ctx is None:
2306 return default_execution_mode == EAGER_MODE
2308 return ctx.executing_eagerly()
2311@tf_export(v1=["executing_eagerly"])
2312def executing_eagerly_v1():
2313 """Checks whether the current thread has eager execution enabled.
2315 Eager execution is typically enabled via
2316 `tf.compat.v1.enable_eager_execution`, but may also be enabled within the
2317 context of a Python function via tf.contrib.eager.py_func.
2319 When eager execution is enabled, returns `True` in most cases. However,
2320 this API might return `False` in the following use cases.
2322 * Executing inside `tf.function`, unless under `tf.init_scope` or
2323 `tf.config.run_functions_eagerly(True)` is previously called.
2324 * Executing inside a transformation function for `tf.dataset`.
2325 * `tf.compat.v1.disable_eager_execution()` is called.
2327 >>> tf.compat.v1.enable_eager_execution()
2329 General case:
2331 >>> print(tf.executing_eagerly())
2332 True
2334 Inside `tf.function`:
2336 >>> @tf.function
2337 ... def fn():
2338 ... with tf.init_scope():
2339 ... print(tf.executing_eagerly())
2340 ... print(tf.executing_eagerly())
2341 >>> fn()
2342 True
2343 False
2345 Inside `tf.function`
2346 after `tf.config.run_functions_eagerly(True)` is called:
2348 >>> tf.config.run_functions_eagerly(True)
2349 >>> @tf.function
2350 ... def fn():
2351 ... with tf.init_scope():
2352 ... print(tf.executing_eagerly())
2353 ... print(tf.executing_eagerly())
2354 >>> fn()
2355 True
2356 True
2357 >>> tf.config.run_functions_eagerly(False)
2359 Inside a transformation function for `tf.dataset`:
2361 >>> def data_fn(x):
2362 ... print(tf.executing_eagerly())
2363 ... return x
2364 >>> dataset = tf.data.Dataset.range(100)
2365 >>> dataset = dataset.map(data_fn)
2366 False
2368 Returns:
2369 `True` if the current thread has eager execution enabled.
2370 """
2371 return executing_eagerly()
2374def in_eager_mode():
2375 """Use executing_eagerly() instead. This function will be removed."""
2376 return executing_eagerly()
2379def anonymous_name():
2380 """Returns the anonymous shared name.
2382 In eager mode we create anonymous resources to avoid spurious sharing issues.
2383 The runtime generates a unique name on our behalf when the reserved
2384 anonymous shared name is used as a shared name.
2386 Returns:
2387 The anonymous shared name.
2388 """
2390 # The magic value is defined as
2391 # `tensorflow::ResourceHandle::ANONYMOUS_NAME` in C++.
2392 return "cd2c89b7-88b7-44c8-ad83-06c2a9158347"
2395def graph_mode():
2396 """Context-manager to disable eager execution for the current thread."""
2397 return context()._mode(GRAPH_MODE) # pylint: disable=protected-access
2400# Used by b/167638505 for keras backend API and Lambda layer.
2401@tf_export("__internal__.eager_context.eager_mode", v1=[])
2402def eager_mode():
2403 """Context-manager to enable eager execution for the current thread."""
2404 return context()._mode(EAGER_MODE) # pylint: disable=protected-access
2407def scope_name():
2408 """Name of the current scope."""
2409 return context().scope_name
2412def device(name):
2413 """Context-manager to force placement of operations and Tensors on a device.
2415 Example:
2416 ```python
2417 with tf.device('gpu:0'):
2418 with tf.device('cpu:0'):
2419 shape = tf.constant([], dtype=tf.int32)
2420 x = tf.random.truncated_normal(shape, tf.float32)
2421 ```
2422 will ensure that the `shape` Tensor is on CPU but the `truncated_normal`
2423 operation runs on GPU 0.
2425 Args:
2426 name: Name of the device (see context().devices()), or None to perform
2427 automatic placement.
2429 Returns:
2430 Context manager for setting the device.
2431 """
2432 ensure_initialized()
2433 return context().device(name)
2436# Expose some properties of Context as internally public APIs (b/160348781).
2437@tf_export("__internal__.eager_context.get_config", v1=[])
2438def get_config():
2439 """Get the ConfigProto of Context.
2441 Returns:
2442 The ConfigProto of Context.
2443 """
2444 return context().config
2447@tf_export("__internal__.eager_context.get_device_name", v1=[])
2448def get_device_name():
2449 """Get the device name for the current thread.
2451 Returns:
2452 The device name for the current thread.
2453 """
2454 return context().device_name
2457@tf_export("__internal__.eager_context.set_soft_device_placement", v1=[])
2458def set_soft_device_placement(enabled):
2459 """Set if soft device placements should be allowed.
2461 Args:
2462 enabled: Whether to enable soft device placement.
2463 """
2464 context().soft_device_placement = enabled
2467@tf_export("__internal__.eager_context.get_executor", v1=[])
2468def get_executor():
2469 """Get the Executor of the current thread.
2471 Returns:
2472 The Executor of the current thread.
2473 """
2474 return context().executor
2477@tf_export("debugging.get_log_device_placement")
2478def get_log_device_placement():
2479 """Get if device placements are logged.
2481 Returns:
2482 If device placements are logged.
2483 """
2484 return context().log_device_placement
2487@tf_export("debugging.set_log_device_placement")
2488def set_log_device_placement(enabled):
2489 """Turns logging for device placement decisions on or off.
2491 Operations execute on a particular device, producing and consuming tensors on
2492 that device. This may change the performance of the operation or require
2493 TensorFlow to copy data to or from an accelerator, so knowing where operations
2494 execute is useful for debugging performance issues.
2496 For more advanced profiling, use the [TensorFlow
2497 profiler](https://www.tensorflow.org/guide/profiler).
2499 Device placement for operations is typically controlled by a `tf.device`
2500 scope, but there are exceptions, for example operations on a `tf.Variable`
2501 which follow the initial placement of the variable. Turning off soft device
2502 placement (with `tf.config.set_soft_device_placement`) provides more explicit
2503 control.
2505 >>> tf.debugging.set_log_device_placement(True)
2506 >>> tf.ones([])
2507 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:GPU:0
2508 >>> with tf.device("CPU"):
2509 ... tf.ones([])
2510 >>> # [...] op Fill in device /job:localhost/replica:0/task:0/device:CPU:0
2511 >>> tf.debugging.set_log_device_placement(False)
2513 Turning on `tf.debugging.set_log_device_placement` also logs the placement of
2514 ops inside `tf.function` when the function is called.
2516 Args:
2517 enabled: Whether to enabled device placement logging.
2518 """
2519 context().log_device_placement = enabled
2522@tf_contextlib.contextmanager
2523def device_policy(policy):
2524 """Context manager for setting device placement policy for current thread."""
2525 ctx = context()
2526 old_policy = ctx.device_policy
2527 try:
2528 ctx.device_policy = policy
2529 yield
2530 finally:
2531 ctx.device_policy = old_policy
2534def set_execution_mode(mode):
2535 """Sets execution mode for the current thread."""
2536 context().execution_mode = mode
2539# TODO(fishx): remove this method.
2540@tf_contextlib.contextmanager
2541def execution_mode(mode):
2542 """Context manager for setting execution mode for current thread."""
2543 if mode is None:
2544 yield
2545 else:
2546 ctx = context()
2547 executor_new = executor.new_executor(mode == ASYNC)
2548 executor_old = ctx.executor
2549 try:
2550 executor_old.wait()
2551 ctx.executor = executor_new
2552 yield
2553 finally:
2554 ctx.executor = executor_old
2555 executor_new.wait()
2558@tf_contextlib.contextmanager
2559def executor_scope(e):
2560 """Context manager for changing executor for current thread.
2562 Args:
2563 e: A Executor to execute eager ops under this scope. Setting it to None will
2564 switch back to use the default executor for the context.
2566 Yields:
2567 Context manager for setting the executor for current thread.
2568 """
2569 ctx = context()
2570 executor_old = ctx.executor
2571 try:
2572 ctx.executor = e
2573 yield
2574 finally:
2575 ctx.executor = executor_old
2578@tf_export("experimental.function_executor_type")
2579@tf_contextlib.contextmanager
2580def function_executor_type(executor_type):
2581 """Context manager for setting the executor of eager defined functions.
2583 Eager defined functions are functions decorated by tf.contrib.eager.defun.
2585 Args:
2586 executor_type: a string for the name of the executor to be used to execute
2587 functions defined by tf.contrib.eager.defun.
2589 Yields:
2590 Context manager for setting the executor of eager defined functions.
2591 """
2592 current_options = context().function_call_options
2593 old_options = copy.copy(current_options)
2594 try:
2595 current_options.executor_type = executor_type
2596 yield
2597 finally:
2598 context().function_call_options = old_options
2601def is_async():
2602 """Returns true if current thread is in async mode."""
2603 return context().is_async()
2606def num_gpus():
2607 """Get the number of available GPU devices.
2609 Returns:
2610 The number of available GPU devices.
2611 """
2612 return context().num_gpus()
2615def enable_run_metadata():
2616 """Enables tracing of op execution via RunMetadata.
2618 To retrieve the accumulated metadata call context.export_run_metadata()
2619 and to stop tracing call context.disable_run_metadata().
2620 """
2621 context().enable_run_metadata()
2624def disable_run_metadata():
2625 """Disables tracing of op execution via RunMetadata."""
2626 context().disable_run_metadata()
2629def enable_graph_collection():
2630 """Enables graph collection of executed functions.
2632 To retrieve the accumulated graphs call context.export_run_metadata()
2633 and to stop collecting graphs call context.disable_graph_collection().
2634 """
2635 context().enable_graph_collection()
2638def disable_graph_collection():
2639 """Disables graph collection of executed functions."""
2640 context().disable_graph_collection()
2643def export_run_metadata():
2644 """Returns a RunMetadata proto with accumulated information.
2646 The returned protocol buffer contains information since the most recent call
2647 to either enable_run_metadata or export_run_metadata.
2649 Returns:
2650 A RunMetadata protocol buffer.
2651 """
2652 return context().export_run_metadata()
2655@contextlib.contextmanager
2656def collect_graphs(optimized=True):
2657 """Collects a flat list of pre- or post-optimization graphs.
2659 The collected graphs include device placements, which can be useful for
2660 testing.
2662 Usage:
2664 ```
2665 @def_function.function
2666 def f(x):
2667 return x + constant_op.constant(1.)
2669 with context.collect_graphs() as graphs:
2670 with ops.device("CPU:0"):
2671 f(constant_op.constant(1.))
2673 graph, = graphs # `graph` contains a single GraphDef for inspection
2674 ```
2676 Args:
2677 optimized: whether to collect optimized graphs or non-optimized graphs
2679 Yields:
2680 A list of GraphDefs, populated when the context manager exits.
2681 """
2682 ctx = context()
2683 ctx.enable_graph_collection()
2684 try:
2685 graphs = []
2686 yield graphs
2687 metadata = ctx.export_run_metadata()
2688 finally:
2689 ctx.disable_graph_collection()
2690 for graph in metadata.function_graphs:
2691 if optimized:
2692 graphs.append(graph.post_optimization_graph)
2693 else:
2694 graphs.append(graph.pre_optimization_graph)
2697def get_server_def():
2698 return context().get_server_def()
2701def set_server_def(server_def):
2702 context().set_server_def(server_def)
2705def update_server_def(server_def):
2706 context().update_server_def(server_def)
2709def check_alive(worker_name):
2710 return context().check_alive(worker_name)
2713@tf_export("experimental.async_scope")
2714@tf_contextlib.contextmanager
2715def async_scope():
2716 """Context manager for grouping async operations.
2718 Ops/function calls inside the scope can return before finishing the actual
2719 execution. When exiting the async scope, a synchronization barrier will be
2720 automatically added to ensure the completion of all async op and function
2721 execution, potentially raising exceptions if async execution results in
2722 an error state.
2724 Users may write the following code to asynchronously invoke `train_step_fn`
2725 and log the `loss` metric for every `num_steps` steps in a training loop.
2726 `train_step_fn` internally consumes data using `iterator.get_next()`, and may
2727 throw OutOfRangeError when running out of data. In the case:
2729 ```
2730 try:
2731 with tf.experimental.async_scope():
2732 for _ in range(num_steps):
2733 # Step function updates the metric `loss` internally
2734 train_step_fn()
2735 except tf.errors.OutOfRangeError:
2736 tf.experimental.async_clear_error()
2737 logging.info('loss = %s', loss.numpy())
2738 ```
2740 Yields:
2741 Context manager for grouping async operations.
2742 """
2743 # TODO(haoyuzhang): replace env var once we have a config method to turn on
2744 # and off async streaming RPC
2745 remote_async_env_var = "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"
2746 old_policy = os.environ.get(remote_async_env_var)
2747 try:
2748 os.environ[remote_async_env_var] = str(True)
2749 yield
2750 # Note: sync local and remote executors iff the async block does not raise
2751 # an exception. Triggering sync after an exception may lead to derived
2752 # runtime errors and unexpected exception types.
2753 context().sync_executors()
2754 finally:
2755 if old_policy is None:
2756 del os.environ[remote_async_env_var]
2757 else:
2758 os.environ[remote_async_env_var] = old_policy
2761def async_wait():
2762 """Sync all async operations and raise any errors during execution.
2764 In async execution mode, an op/function call can return before finishing the
2765 actual execution. Calling this method creates a synchronization barrier for
2766 all async op and function execution. It only returns when all pending nodes
2767 are finished, potentially raising exceptions if async execution results in
2768 an error state. It is a no-op if the context is not initialized.
2769 """
2770 disable_async_executor_env_var = "TF_PS_DISABLE_ASYNC_EXECUTOR_GLOBALLY"
2771 if os.environ.get(disable_async_executor_env_var) == str(True):
2772 return
2773 if context()._context_handle is not None: # pylint: disable=protected-access
2774 context().sync_executors()
2777@tf_export("experimental.async_clear_error")
2778def async_clear_error():
2779 """Clear pending operations and error statuses in async execution.
2781 In async execution mode, an error in op/function execution can lead to errors
2782 in subsequent ops/functions that are scheduled but not yet executed. Calling
2783 this method clears all pending operations and reset the async execution state.
2785 Example:
2787 ```
2788 while True:
2789 try:
2790 # Step function updates the metric `loss` internally
2791 train_step_fn()
2792 except tf.errors.OutOfRangeError:
2793 tf.experimental.async_clear_error()
2794 break
2795 logging.info('loss = %s', loss.numpy())
2796 ```
2797 """
2798 context().clear_executor_errors()
2801def add_c_function(c_func):
2802 """Add a C API TF_Function to the context."""
2803 context().add_c_function(c_func)
2806def get_c_function(name):
2807 """Get a C API TF_Function from the context."""
2808 return context().get_c_function(name)
2811def remove_function(name):
2812 """Remove a function from the context."""
2813 context().remove_function(name)
2816def get_function_def(name):
2817 return context().get_function_def(name)
2820def is_custom_device(device_name):
2821 """Calls TFE_IsCustomDevice.
2823 Enables using C extensions specifying a custom device from Python. See the
2824 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2825 details.
2827 Args:
2828 device_name: A string indicating the name to check whether it is a
2829 registered custom device.
2831 Returns:
2832 A boolean.
2833 """
2834 return context().is_custom_device(device_name)
2837def register_custom_device(device_capsule, device_name, device_info_capsule):
2838 """Calls TFE_RegisterCustomDevice to register a custom device with Python.
2840 Enables using C extensions specifying a custom device from Python. See the
2841 experimental eager C API in tensorflow/c/eager/c_api_experimental.h for
2842 details.
2844 Note that custom devices are not currently supported inside `tf.function`s.
2846 Args:
2847 device_capsule: A PyCapsule with the name set to 'TFE_CustomDevice'
2848 containing a pointer to a TFE_CustomDevice struct. The capsule retains
2849 ownership of the memory.
2850 device_name: A string indicating the name to register the custom device
2851 under, e.g. '/job:localhost/replica:0/task:0/device:CUSTOM:0'. It may
2852 subsequently be passed to `with tf.device(...):`.
2853 device_info_capsule: A PyCapsule with the name set to
2854 'TFE_CustomDevice_DeviceInfo' containing a pointer to a device-specific
2855 struct with the initial state of the custom device (the void* device_info
2856 argument to TFE_RegisterCustomDevice). This method takes ownership of the
2857 memory and clears the capsule destructor.
2858 """
2859 context().register_custom_device(device_capsule, device_name,
2860 device_info_capsule)
2863# Not every user creates a Context via context.context()
2864# (for example, enable_eager_execution in python/framework/ops.py),
2865# but they do all import this file. Note that IS_IN_GRAPH_MODE and
2866# in_graph_mode are both parameterless functions.
2867def _tmp_in_graph_mode():
2868 if context_safe() is None:
2869 # Context not yet initialized. Assume graph mode following the
2870 # default implementation in `is_in_graph_mode`.
2871 return True
2872 return not executing_eagerly()
2875is_in_graph_mode.IS_IN_GRAPH_MODE = _tmp_in_graph_mode