Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/client/session.py: 21%
624 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A client interface for TensorFlow."""
17import collections
18import functools
19import re
20import threading
21import warnings
23import numpy as np
24import wrapt
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.protobuf import rewriter_config_pb2
28from tensorflow.python.client import pywrap_tf_session as tf_session
29from tensorflow.python.eager import context
30from tensorflow.python.eager import monitoring
31from tensorflow.python.framework import device
32from tensorflow.python.framework import error_interpolation
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import indexed_slices
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.framework import stack
38from tensorflow.python.ops import session_ops
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.training.experimental import mixed_precision_global_state
41from tensorflow.python.util import compat
42from tensorflow.python.util import nest
43from tensorflow.python.util.compat import collections_abc
44from tensorflow.python.util.tf_export import tf_export
46_python_session_create_counter = monitoring.Counter(
47 '/tensorflow/api/python/session_create_counter',
48 'Counter for number of sessions created in Python.')
51class SessionInterface(object):
52 """Base class for implementations of TensorFlow client sessions."""
54 @property
55 def graph(self):
56 """The underlying TensorFlow graph, to be used in building Operations."""
57 raise NotImplementedError('graph')
59 @property
60 def sess_str(self):
61 """The TensorFlow process to which this session will connect."""
62 raise NotImplementedError('sess_str')
64 def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
65 """Runs operations in the session. See `BaseSession.run()` for details."""
66 raise NotImplementedError('run')
68 def partial_run_setup(self, fetches, feeds=None):
69 """Sets up the feeds and fetches for partial runs in the session."""
70 raise NotImplementedError('partial_run_setup')
72 def partial_run(self, handle, fetches, feed_dict=None):
73 """Continues the execution with additional feeds and fetches."""
74 raise NotImplementedError('partial_run')
77def _get_indexed_slices_value_from_fetches(fetched_vals):
78 return indexed_slices.IndexedSlicesValue(
79 fetched_vals[0], fetched_vals[1],
80 fetched_vals[2] if len(fetched_vals) == 3 else None)
83def _get_feeds_for_indexed_slices(feed, feed_val):
84 return list(
85 zip([feed.values, feed.indices] if feed.dense_shape is None else
86 [feed.values, feed.indices, feed.dense_shape], feed_val))
89# List of extensions supported to convert run arguments into actual fetches and
90# feeds.
91#
92# Each element in the list is a tuple of (Type, fetch_fn, feed_fn1, feed_fn2),
93# where the function signatures are:
94# fetch_fn : Type -> (list of Tensors,
95# lambda: list of fetched np.ndarray -> TypeVal)
96# feed_fn1 : Type, TypeVal -> list of (Tensor, value)
97# feed_fn2 : Type -> list of Tensors
98#
99# `fetch_fn` describes how to expand fetch into its
100# component Tensors and how to contract the fetched results back into
101# a single return value.
102#
103# Each feed function describes how to unpack a single fed value and map it to
104# feeds of one or more tensors and their corresponding values: `feed_fn1` is
105# used to feed a run, `feed_fn2` to set up a partial run.
106#
107# TODO(touts): We could reimplement these as specialized _FeedMapper
108# implementations after we refactor the feed handling code to use them.
109#
110# Eventually, this registration could be opened up to support custom Tensor
111# expansions.
112# pylint: disable=g-long-lambda
113_REGISTERED_EXPANSIONS = [
114 # SparseTensors are fetched as SparseTensorValues. They can be fed
115 # SparseTensorValues or normal tuples.
116 (sparse_tensor.SparseTensor, lambda fetch: ([
117 fetch.indices, fetch.values, fetch.dense_shape
118 ], lambda fetched_vals: sparse_tensor.SparseTensorValue(*fetched_vals)),
119 lambda feed, feed_val: list(
120 zip([feed.indices, feed.values, feed.dense_shape], feed_val)),
121 lambda feed: [feed.indices, feed.values, feed.dense_shape]),
122 # IndexedSlices are fetched as IndexedSlicesValues. They can be fed
123 # IndexedSlicesValues or normal tuples.
124 (indexed_slices.IndexedSlices,
125 lambda fetch: ([fetch.values, fetch.indices] if fetch.dense_shape is None
126 else [fetch.values, fetch.indices, fetch.dense_shape
127 ], _get_indexed_slices_value_from_fetches),
128 _get_feeds_for_indexed_slices,
129 lambda feed: [feed.values, feed.indices] if feed.dense_shape is None else
130 [feed.values, feed.indices, feed.dense_shape]),
131 # The default catches all other types and performs no expansions.
132 (object, lambda fetch: ([fetch], lambda fetched_vals: fetched_vals[0]),
133 lambda feed, feed_val: [(feed, feed_val)], lambda feed: [feed])
134]
136# pylint: enable=g-long-lambda
139def _convert_to_numpy_obj(numpy_dtype, obj):
140 """Explicitly convert obj based on numpy type except for string type."""
141 return numpy_dtype(obj) if numpy_dtype is not object else str(obj)
144def register_session_run_conversion_functions(
145 tensor_type,
146 fetch_function,
147 feed_function=None,
148 feed_function_for_partial_run=None):
149 """Register fetch and feed conversion functions for `tf.Session.run()`.
151 This function registers a triple of conversion functions for fetching and/or
152 feeding values of user-defined types in a call to tf.Session.run().
154 An example
156 ```python
157 class SquaredTensor(object):
158 def __init__(self, tensor):
159 self.sq = tf.square(tensor)
160 #you can define conversion functions as follows:
161 fetch_function = lambda squared_tensor:([squared_tensor.sq],
162 lambda val: val[0])
163 feed_function = lambda feed, feed_val: [(feed.sq, feed_val)]
164 feed_function_for_partial_run = lambda feed: [feed.sq]
165 #then after invoking this register function, you can use as follows:
166 session.run(squared_tensor1,
167 feed_dict = {squared_tensor2 : some_numpy_array})
168 ```
170 Args:
171 tensor_type: The type for which you want to register a conversion function.
172 fetch_function: A callable that takes an object of type `tensor_type` and
173 returns a tuple, where the first element is a list of `tf.Tensor` objects,
174 and the second element is a callable that takes a list of ndarrays and
175 returns an object of some value type that corresponds to `tensor_type`.
176 fetch_function describes how to expand fetch into its component Tensors
177 and how to contract the fetched results back into a single return value.
178 feed_function: A callable that takes feed_key and feed_value as input, and
179 returns a list of tuples (feed_tensor, feed_val), feed_key must have type
180 `tensor_type`, and feed_tensor must have type `tf.Tensor`. Each feed
181 function describes how to unpack a single fed value and map it to feeds of
182 one or more tensors and their corresponding values.
183 feed_function_for_partial_run: A callable for specifying tensor values to
184 feed when setting up a partial run, which takes a `tensor_type` type
185 object as input, and returns a list of Tensors.
187 Raises:
188 ValueError: If `tensor_type` has already been registered.
189 """
190 for conversion_function in _REGISTERED_EXPANSIONS:
191 if issubclass(conversion_function[0], tensor_type):
192 raise ValueError(f'{tensor_type} has already been registered so ignore '
193 'it.')
195 _REGISTERED_EXPANSIONS.insert(0, (tensor_type, fetch_function, feed_function,
196 feed_function_for_partial_run))
199def _is_attrs_instance(obj):
200 """Returns True if the given obj is an instance of attrs-decorated class."""
201 return getattr(obj.__class__, '__attrs_attrs__', None) is not None
204def _get_attrs_values(obj):
205 """Returns the list of values from an attrs instance."""
206 attrs = getattr(obj.__class__, '__attrs_attrs__')
207 return [getattr(obj, a.name) for a in attrs]
210class _FetchMapper(object):
211 """Definition of the interface provided by fetch mappers.
213 Fetch mappers are utility classes used by the _FetchHandler to handle
214 arbitrary structures for the `fetch` argument to `Session.run()`.
216 The `fetch` argument can be of various shapes: single tensor or op, list of
217 fetches, tuple of fetches, namedtuple of fetches, or dict of fetches. The
218 structures can be arbitrarily nested.
220 The low level run() API only wants a list of tensor or op names. The various
221 `_FetchMapper` subclasses below take care of handling the different shapes:
222 uniquifying the fetches, and constructing results with the original shape.
223 """
225 def unique_fetches(self):
226 """Return the list of unique tensors or ops needed by this fetch mapper.
228 Returns:
229 A list of tensors or ops.
230 """
231 raise NotImplementedError(
232 'unique_fetches must be implemented by subclasses')
234 def build_results(self, values):
235 """Build results that match the original shape of the fetch.
237 Args:
238 values: List of values returned by run(). The values correspond exactly to
239 the list tensors or ops returned by unique_fetches().
241 Returns:
242 A struct of the same shape as the original fetch object handled by
243 this fetch mapper. In the returned struct, the original fetches are
244 replaced by their fetched values.
245 """
246 raise NotImplementedError('build_results must be implemented by subclasses')
248 @staticmethod
249 def for_fetch(fetch):
250 """Creates fetch mapper that handles the structure of `fetch`.
252 The default graph must be the one from which we want to fetch values when
253 this function is called.
255 Args:
256 fetch: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
257 or dict.
259 Returns:
260 An instance of a subclass of `_FetchMapper` that handles the shape.
261 """
262 if fetch is None:
263 raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
264 f'"{type(fetch).__name__}". Cannot be None')
265 elif isinstance(fetch, (list, tuple)):
266 # NOTE(touts): This is also the code path for namedtuples.
267 return _ListFetchMapper(fetch)
268 elif isinstance(fetch, collections_abc.Mapping):
269 return _DictFetchMapper(fetch)
270 elif _is_attrs_instance(fetch):
271 return _AttrsFetchMapper(fetch)
272 else:
273 # Look for a handler in the registered expansions.
274 for tensor_type, fetch_fn, _, _ in _REGISTERED_EXPANSIONS:
275 if isinstance(fetch, tensor_type):
276 fetches, contraction_fn = fetch_fn(fetch)
277 return _ElementFetchMapper(fetches, contraction_fn)
278 # Did not find anything.
279 raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
280 f'"{type(fetch).__name__}"')
283class _ElementFetchMapper(_FetchMapper):
284 """Fetch mapper for singleton tensors and ops."""
286 def __init__(self, fetches, contraction_fn):
287 """Creates an _ElementFetchMapper.
289 This is the fetch mapper used for leaves in the fetch struct. Because of
290 the expansions mechanism, a leaf can actually fetch more than one tensor.
292 Also note that the fetches here can be just strings (tensor or op names) or
293 any other object that the graph knows how to convert to a tensor, such as a
294 Variable. So we have to run each fetch through `as_graph_element()` to get
295 the corresponding tensor or op.
297 Args:
298 fetches: List of objects, as returned by a fetch_fn defined in
299 _REGISTERED_EXPANSIONS.
300 contraction_fn: Callable as returned by a fetch_fn.
301 """
302 self._unique_fetches = []
303 for fetch in fetches:
304 try:
305 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
306 fetch, allow_tensor=True, allow_operation=True))
307 except TypeError as e:
308 raise TypeError(f'Argument `fetch` = {fetch} has invalid type '
309 f'"{type(fetch).__name__}" must be a string or Tensor. '
310 f'({str(e)})')
311 except ValueError as e:
312 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as '
313 f'a Tensor. ({str(e)})')
314 except KeyError as e:
315 raise ValueError(f'Argument `fetch` = {fetch} cannot be interpreted as '
316 f'a Tensor. ({str(e)})')
317 self._contraction_fn = contraction_fn
319 def unique_fetches(self):
320 return self._unique_fetches
322 def build_results(self, values):
323 if not values:
324 # 'Operation' case
325 return None
326 else:
327 return self._contraction_fn(values)
330def _uniquify_fetches(fetch_mappers):
331 """Uniquifies fetches from a list of fetch_mappers.
333 This is a utility function used by _ListFetchMapper and _DictFetchMapper. It
334 gathers all the unique fetches from a list of mappers and builds a list
335 containing all of them but without duplicates (unique_fetches).
337 It also returns a 2-D list of integers (values_indices) indicating at which
338 index in unique_fetches the fetches of the mappers are located.
340 This list is as follows:
341 values_indices[mapper_index][mapper_fetch_index] = unique_fetches_index
343 Args:
344 fetch_mappers: list of fetch mappers.
346 Returns:
347 A list of fetches.
348 A 2-D list of integers.
349 """
350 unique_fetches = []
351 value_indices = []
352 seen_fetches = {}
353 for m in fetch_mappers:
354 m_value_indices = []
355 for f in m.unique_fetches():
356 j = seen_fetches.get(id(f))
357 if j is None:
358 j = len(seen_fetches)
359 seen_fetches[id(f)] = j
360 unique_fetches.append(f)
361 m_value_indices.append(j)
362 value_indices.append(m_value_indices)
363 return unique_fetches, value_indices
366class _ListFetchMapper(_FetchMapper):
367 """Fetch mapper for lists, tuples, and namedtuples."""
369 def __init__(self, fetches):
370 """Creates a _ListFetchMapper.
372 Args:
373 fetches: List, tuple, or namedtuple of fetches.
374 """
375 if isinstance(fetches, wrapt.ObjectProxy):
376 self._fetch_type = type(fetches.__wrapped__)
377 else:
378 self._fetch_type = type(fetches)
379 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in fetches]
380 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
382 def unique_fetches(self):
383 return self._unique_fetches
385 def build_results(self, values):
386 # Create the list of results for each mapper.
387 results = []
388 for m, vi in zip(self._mappers, self._value_indices):
389 results.append(m.build_results([values[j] for j in vi]))
390 # Return a value of the original type of the fetches.
391 if issubclass(self._fetch_type, list):
392 return results
393 elif self._fetch_type == tuple:
394 return tuple(results)
395 else:
396 # This is the code path for namedtuple.
397 return self._fetch_type(*results)
400class _DictFetchMapper(_FetchMapper):
401 """Fetch mapper for dicts."""
403 def __init__(self, fetches):
404 """Creates a _DictFetchMapper.
406 Args:
407 fetches: Dict of fetches.
408 """
409 self._fetch_type = type(fetches)
410 if isinstance(fetches, collections.defaultdict):
411 self._type_ctor = functools.partial(collections.defaultdict,
412 fetches.default_factory)
413 else:
414 self._type_ctor = self._fetch_type
416 self._keys = fetches.keys()
417 self._mappers = [
418 _FetchMapper.for_fetch(fetch) for fetch in fetches.values()
419 ]
420 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
422 def unique_fetches(self):
423 return self._unique_fetches
425 def build_results(self, values):
427 def _generator():
428 for k, m, vi in zip(self._keys, self._mappers, self._value_indices):
429 yield k, m.build_results([values[j] for j in vi])
431 return self._type_ctor(_generator())
434class _AttrsFetchMapper(_FetchMapper):
435 """Fetch mapper for attrs decorated classes."""
437 def __init__(self, fetches):
438 """Creates a _AttrsFetchMapper.
440 Args:
441 fetches: An instance of an attrs decorated class.
442 """
443 values = _get_attrs_values(fetches)
444 self._fetch_type = type(fetches)
445 self._mappers = [_FetchMapper.for_fetch(fetch) for fetch in values]
446 self._unique_fetches, self._value_indices = _uniquify_fetches(self._mappers)
448 def unique_fetches(self):
449 return self._unique_fetches
451 def build_results(self, values):
452 results = []
453 for m, vi in zip(self._mappers, self._value_indices):
454 results.append(m.build_results([values[j] for j in vi]))
455 return self._fetch_type(*results)
458class _FetchHandler(object):
459 """Handler for structured fetches.
461 Given a graph, a user-provided structure for fetches, and a feed dict, this
462 class takes care of generating a list of tensor names to fetch and op names
463 to run for a low level `run()` call.
465 Given the results of the low level run call, this class can also rebuild a
466 result structure matching the user-provided structure for fetches, but
467 containing the corresponding results.
468 """
470 # TODO(touts): Make this class also take care of destructuring the feed
471 # dict instead of doing it in the callers.
473 def __init__(self, graph, fetches, feeds, feed_handles=None):
474 """Creates a fetch handler.
476 Args:
477 graph: Graph of the fetches. Used to check for fetchability and to
478 convert all fetches to tensors or ops as needed.
479 fetches: An arbitrary fetch structure: singleton, list, tuple, namedtuple,
480 or dict.
481 feeds: A feed dict where keys are Tensors.
482 feed_handles: A dict from feed Tensors to TensorHandle objects used as
483 direct feeds.
484 """
485 with graph.as_default():
486 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
487 self._fetches = []
488 self._targets = []
489 self._feeds = feeds
490 self._feed_handles = feed_handles or {}
491 self._ops = []
492 self._fetch_handles = {}
493 for fetch in self._fetch_mapper.unique_fetches():
494 if isinstance(fetch, ops.Operation):
495 self._assert_fetchable(graph, fetch)
496 self._targets.append(fetch)
497 self._ops.append(True)
498 else:
499 self._assert_fetchable(graph, fetch.op)
500 self._fetches.append(fetch)
501 self._ops.append(False)
502 # Remember the fetch if it is for a tensor handle.
503 if (isinstance(fetch, ops.Tensor) and
504 (fetch.op.type == 'GetSessionHandle' or
505 fetch.op.type == 'GetSessionHandleV2')):
506 self._fetch_handles[fetch.ref()] = fetch.op.inputs[0].dtype
507 self._final_fetches = [x for x in self._fetches if x.ref() not in feeds]
509 def _assert_fetchable(self, graph, op):
510 if not graph.is_fetchable(op):
511 raise errors.InaccessibleTensorError(
512 f'Operation {op.name} has been marked as not fetchable. Typically '
513 'this happens when it is defined in another function or code block. '
514 'Use return values, explicit Python locals or TensorFlow collections '
515 'to access it.')
517 def fetches(self):
518 """Return the unique names of tensors to fetch.
520 Returns:
521 A list of strings.
522 """
523 return self._final_fetches
525 def targets(self):
526 """Return the unique names of ops to run.
528 Returns:
529 A list of strings.
530 """
531 return self._targets
533 def build_results(self, session, tensor_values):
534 """Build results matching the original fetch shape.
536 `tensor_values` must be a list of the same length as
537 the one returned by `fetches()`, and holding the requested
538 fetch values.
540 This method builds a struct with the same shape as the original `fetches`
541 passed to the constructor, in which the fetches are replaced by their
542 fetched value.
544 Args:
545 session: The enclosing session. Used for tensor handles.
546 tensor_values: List of values matching the list returned by fetches().
548 Returns:
549 A structure of the same shape as the original `fetches` argument but
550 containing tensors or None (for fetched ops).
551 """
552 full_values = []
553 assert len(self._final_fetches) == len(tensor_values)
554 i = 0
555 j = 0
556 for is_op in self._ops:
557 if is_op:
558 full_values.append(None)
559 else:
560 # If the fetch was in the feeds, use the fed value, otherwise
561 # use the returned value.
562 if self._fetches[i].ref() in self._feed_handles:
563 # A fetch had a corresponding direct TensorHandle feed. Call eval()
564 # to obtain the Tensor value from the TensorHandle.
565 value = self._feed_handles[self._fetches[i].ref()].eval()
566 else:
567 value = self._feeds.get(self._fetches[i].ref())
568 if value is None:
569 value = tensor_values[j]
570 j += 1
571 dtype = self._fetch_handles.get(self._fetches[i].ref())
572 if dtype:
573 full_values.append(session_ops.TensorHandle(value, dtype, session))
574 else:
575 full_values.append(value)
576 i += 1
577 assert j == len(tensor_values)
578 return self._fetch_mapper.build_results(full_values)
581def _name_list(tensor_list):
582 """Utility function for transitioning to the new session API.
584 Args:
585 tensor_list: a list of `Tensor`s.
587 Returns:
588 A list of each `Tensor`s name (as byte arrays).
589 """
590 return [compat.as_bytes(t.name) for t in tensor_list]
593class _DeviceAttributes(object):
594 """Struct-like object describing a device's attributes.
596 Each device has 3 key properties:
597 - name: the fully-qualified TensorFlow path to the device. For
598 example: /job:worker/replica:0/task:3/device:CPU:0
599 - device_type: the type of the device (e.g. CPU, GPU, TPU, etc.)
600 - memory_limit_bytes: the maximum amount of memory available on the device
601 (in bytes).
602 """
604 def __init__(self, name, device_type, memory_limit_bytes, incarnation):
605 self._name = device.canonical_name(name)
606 self._device_type = device_type
607 self._memory_limit_bytes = memory_limit_bytes
608 self._incarnation = incarnation
610 @property
611 def name(self):
612 return self._name
614 @property
615 def device_type(self):
616 return self._device_type
618 @property
619 def memory_limit_bytes(self):
620 return self._memory_limit_bytes
622 @property
623 def incarnation(self):
624 return self._incarnation
626 def __repr__(self):
627 return '_DeviceAttributes(%s, %s, %d, %d)' % (
628 self.name,
629 self.device_type,
630 self.memory_limit_bytes,
631 self.incarnation,
632 )
635class BaseSession(SessionInterface):
636 """A class for interacting with a TensorFlow computation.
638 The BaseSession enables incremental graph building with inline
639 execution of Operations and evaluation of Tensors.
640 """
642 def __init__(self, target='', graph=None, config=None):
643 """Constructs a new TensorFlow session.
645 Args:
646 target: (Optional) The TensorFlow execution engine to connect to.
647 graph: (Optional) The graph to be used. If this argument is None, the
648 default graph will be used.
649 config: (Optional) ConfigProto proto used to configure the session. If no
650 config is specified, the global default will be used. The global default
651 can be configured via the tf.config APIs.
653 Raises:
654 tf.errors.OpError: Or one of its subclasses if an error occurs while
655 creating the TensorFlow session.
656 TypeError: If one of the arguments has the wrong type.
657 """
658 _python_session_create_counter.get_cell().increase_by(1)
659 if graph is None:
660 self._graph = ops.get_default_graph()
661 else:
662 if not isinstance(graph, ops.Graph):
663 raise TypeError('Argument `graph` must be a tf.Graph, but got '
664 f'"{type(graph).__name__}"')
665 self._graph = graph
667 self._closed = False
669 if target is not None:
670 try:
671 self._target = compat.as_bytes(target)
672 except TypeError:
673 if isinstance(target, config_pb2.ConfigProto):
674 raise TypeError('Argument `target` must be a string, but got '
675 f'"{type(target).__name__}". Did you do '
676 '"Session(config)" instead of '
677 '"Session(config=config)"?')
678 raise TypeError('Argument `target` must be a string, but got '
679 f'"{type(target).__name__}"')
680 else:
681 self._target = None
683 self._delete_lock = threading.Lock()
684 self._dead_handles = []
686 if config is None:
687 config = context.context().config
689 if not isinstance(config, config_pb2.ConfigProto):
690 raise TypeError('Argument `config` must be a tf.ConfigProto, but got '
691 f'"{type(config).__name__}"')
693 if (mixed_precision_global_state.is_mixed_precision_graph_rewrite_enabled()
694 and config.graph_options.rewrite_options.auto_mixed_precision !=
695 rewriter_config_pb2.RewriterConfig.OFF):
696 new_config = config_pb2.ConfigProto()
697 new_config.CopyFrom(config)
698 new_config.graph_options.rewrite_options.auto_mixed_precision = (
699 rewriter_config_pb2.RewriterConfig.ON)
700 config = new_config
701 elif (config.graph_options.rewrite_options.auto_mixed_precision !=
702 rewriter_config_pb2.RewriterConfig.ON):
703 mixed_precision_global_state.set_non_mixed_precision_session_created(True)
705 self._config = config
706 self._add_shapes = config.graph_options.infer_shapes
708 self._session = None
709 opts = tf_session.TF_NewSessionOptions(target=self._target, config=config)
710 try:
711 # pylint: disable=protected-access
712 with self._graph._c_graph.get() as c_graph:
713 self._session = tf_session.TF_NewSessionRef(c_graph, opts)
714 # pylint: enable=protected-access
715 finally:
716 tf_session.TF_DeleteSessionOptions(opts)
718 def list_devices(self):
719 """Lists available devices in this session.
721 ```python
722 devices = sess.list_devices()
723 for d in devices:
724 print(d.name)
725 ```
727 Where:
728 Each element in the list has the following properties
729 name: A string with the full name of the device. ex:
730 `/job:worker/replica:0/task:3/device:CPU:0`
731 device_type: The type of the device (e.g. `CPU`, `GPU`, `TPU`.)
732 memory_limit: The maximum amount of memory available on the device.
733 Note: depending on the device, it is possible the usable memory could
734 be substantially less.
736 Raises:
737 tf.errors.OpError: If it encounters an error (e.g. session is in an
738 invalid state, or network errors occur).
740 Returns:
741 A list of devices in the session.
742 """
743 raw_device_list = tf_session.TF_SessionListDevices(self._session)
744 device_list = []
745 size = tf_session.TF_DeviceListCount(raw_device_list)
746 for i in range(size):
747 name = tf_session.TF_DeviceListName(raw_device_list, i)
748 device_type = tf_session.TF_DeviceListType(raw_device_list, i)
749 memory = tf_session.TF_DeviceListMemoryBytes(raw_device_list, i)
750 incarnation = tf_session.TF_DeviceListIncarnation(raw_device_list, i)
751 device_list.append(
752 _DeviceAttributes(name, device_type, memory, incarnation))
753 tf_session.TF_DeleteDeviceList(raw_device_list)
754 return device_list
756 def close(self):
757 """Closes this session.
759 Calling this method frees all resources associated with the session.
761 Raises:
762 tf.errors.OpError: Or one of its subclasses if an error occurs while
763 closing the TensorFlow session.
764 """
765 if self._session and not self._closed:
766 self._closed = True
767 tf_session.TF_CloseSession(self._session)
769 def __del__(self):
770 # cleanly ignore all exceptions
771 try:
772 self.close()
773 except Exception: # pylint: disable=broad-except
774 pass
775 if self._session is not None:
776 try:
777 tf_session.TF_DeleteSession(self._session)
778 except (AttributeError, TypeError):
779 # At shutdown, `c_api_util`, `tf_session`, or
780 # `tf_session.TF_DeleteSession` may have been garbage collected, causing
781 # the above method calls to fail. In this case, silently leak since the
782 # program is about to terminate anyway.
783 pass
784 self._session = None
786 @property
787 def graph(self):
788 """The graph that was launched in this session."""
789 return self._graph
791 @property
792 def graph_def(self):
793 """A serializable version of the underlying TensorFlow graph.
795 Returns:
796 A graph_pb2.GraphDef proto containing nodes for all of the Operations in
797 the underlying TensorFlow graph.
798 """
799 return self._graph.as_graph_def(add_shapes=self._add_shapes)
801 @property
802 def sess_str(self):
803 return self._target
805 def as_default(self):
806 """Returns a context manager that makes this object the default session.
808 Use with the `with` keyword to specify that calls to
809 `tf.Operation.run` or `tf.Tensor.eval` should be executed in
810 this session.
812 ```python
813 c = tf.constant(..)
814 sess = tf.compat.v1.Session()
816 with sess.as_default():
817 assert tf.compat.v1.get_default_session() is sess
818 print(c.eval())
819 ```
821 To get the current default session, use `tf.compat.v1.get_default_session`.
823 *N.B.* The `as_default` context manager *does not* close the
824 session when you exit the context, and you must close the session
825 explicitly.
827 ```python
828 c = tf.constant(...)
829 sess = tf.compat.v1.Session()
830 with sess.as_default():
831 print(c.eval())
832 # ...
833 with sess.as_default():
834 print(c.eval())
836 sess.close()
837 ```
839 Alternatively, you can use `with tf.compat.v1.Session():` to create a
840 session that is automatically closed on exiting the context,
841 including when an uncaught exception is raised.
843 *N.B.* The default session is a property of the current thread. If you
844 create a new thread, and wish to use the default session in that
845 thread, you must explicitly add a `with sess.as_default():` in that
846 thread's function.
848 *N.B.* Entering a `with sess.as_default():` block does not affect
849 the current default graph. If you are using multiple graphs, and
850 `sess.graph` is different from the value of
851 `tf.compat.v1.get_default_graph`, you must explicitly enter a
852 `with sess.graph.as_default():` block to make `sess.graph` the default
853 graph.
855 Returns:
856 A context manager using this session as the default session.
857 """
858 return stack.default_session(self)
860 def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
861 """Runs operations and evaluates tensors in `fetches`.
863 This method runs one "step" of TensorFlow computation, by
864 running the necessary graph fragment to execute every `Operation`
865 and evaluate every `Tensor` in `fetches`, substituting the values in
866 `feed_dict` for the corresponding input values.
868 The `fetches` argument may be a single graph element, or an arbitrarily
869 nested list, tuple, namedtuple, dict, or OrderedDict containing graph
870 elements at its leaves. A graph element can be one of the following types:
872 * A `tf.Operation`.
873 The corresponding fetched value will be `None`.
874 * A `tf.Tensor`.
875 The corresponding fetched value will be a numpy ndarray containing the
876 value of that tensor.
877 * A `tf.sparse.SparseTensor`.
878 The corresponding fetched value will be a
879 `tf.compat.v1.SparseTensorValue`
880 containing the value of that sparse tensor.
881 * A `get_tensor_handle` op. The corresponding fetched value will be a
882 numpy ndarray containing the handle of that tensor.
883 * A `string` which is the name of a tensor or operation in the graph.
885 The value returned by `run()` has the same shape as the `fetches` argument,
886 where the leaves are replaced by the corresponding values returned by
887 TensorFlow.
889 Example:
891 ```python
892 a = tf.constant([10, 20])
893 b = tf.constant([1.0, 2.0])
894 # 'fetches' can be a singleton
895 v = session.run(a)
896 # v is the numpy array [10, 20]
897 # 'fetches' can be a list.
898 v = session.run([a, b])
899 # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
900 # 1-D array [1.0, 2.0]
901 # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
902 MyData = collections.namedtuple('MyData', ['a', 'b'])
903 v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
904 # v is a dict with
905 # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
906 # 'b' (the numpy array [1.0, 2.0])
907 # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
908 # [10, 20].
909 ```
911 The optional `feed_dict` argument allows the caller to override
912 the value of tensors in the graph. Each key in `feed_dict` can be
913 one of the following types:
915 * If the key is a `tf.Tensor`, the
916 value may be a Python scalar, string, list, or numpy ndarray
917 that can be converted to the same `dtype` as that
918 tensor. Additionally, if the key is a
919 `tf.compat.v1.placeholder`, the shape of
920 the value will be checked for compatibility with the placeholder.
921 * If the key is a
922 `tf.sparse.SparseTensor`,
923 the value should be a
924 `tf.compat.v1.SparseTensorValue`.
925 * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
926 should be a nested tuple with the same structure that maps to their
927 corresponding values as above.
929 Each value in `feed_dict` must be convertible to a numpy array of the dtype
930 of the corresponding key.
932 The optional `options` argument expects a [`RunOptions`] proto. The options
933 allow controlling the behavior of this particular step (e.g. turning tracing
934 on).
936 The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
937 appropriate, the non-Tensor output of this step will be collected there. For
938 example, when users turn on tracing in `options`, the profiled info will be
939 collected into this argument and passed back.
941 Args:
942 fetches: A single graph element, a list of graph elements, or a dictionary
943 whose values are graph elements or lists of graph elements (described
944 above).
945 feed_dict: A dictionary that maps graph elements to values (described
946 above).
947 options: A [`RunOptions`] protocol buffer
948 run_metadata: A [`RunMetadata`] protocol buffer
950 Returns:
951 Either a single value if `fetches` is a single graph element, or
952 a list of values if `fetches` is a list, or a dictionary with the
953 same keys as `fetches` if that is a dictionary (described above).
954 Order in which `fetches` operations are evaluated inside the call
955 is undefined.
957 Raises:
958 RuntimeError: If this `Session` is in an invalid state (e.g. has been
959 closed).
960 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
961 ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
962 `Tensor` that doesn't exist.
963 """
964 options_ptr = tf_session.TF_NewBufferFromString(
965 compat.as_bytes(options.SerializeToString())) if options else None
966 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
968 try:
969 result = self._run(None, fetches, feed_dict, options_ptr,
970 run_metadata_ptr)
971 if run_metadata:
972 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
973 run_metadata.ParseFromString(compat.as_bytes(proto_data))
974 finally:
975 if run_metadata_ptr:
976 tf_session.TF_DeleteBuffer(run_metadata_ptr)
977 if options:
978 tf_session.TF_DeleteBuffer(options_ptr)
979 return result
981 def partial_run(self, handle, fetches, feed_dict=None):
982 """Continues the execution with more feeds and fetches.
984 This is EXPERIMENTAL and subject to change.
986 To use partial execution, a user first calls `partial_run_setup()` and
987 then a sequence of `partial_run()`. `partial_run_setup` specifies the
988 list of feeds and fetches that will be used in the subsequent
989 `partial_run` calls.
991 The optional `feed_dict` argument allows the caller to override
992 the value of tensors in the graph. See run() for more information.
994 Below is a simple example:
996 ```python
997 a = array_ops.placeholder(dtypes.float32, shape=[])
998 b = array_ops.placeholder(dtypes.float32, shape=[])
999 c = array_ops.placeholder(dtypes.float32, shape=[])
1000 r1 = math_ops.add(a, b)
1001 r2 = math_ops.multiply(r1, c)
1003 h = sess.partial_run_setup([r1, r2], [a, b, c])
1004 res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2})
1005 res = sess.partial_run(h, r2, feed_dict={c: res})
1006 ```
1008 Args:
1009 handle: A handle for a sequence of partial runs.
1010 fetches: A single graph element, a list of graph elements, or a dictionary
1011 whose values are graph elements or lists of graph elements (see
1012 documentation for `run`).
1013 feed_dict: A dictionary that maps graph elements to values (described
1014 above).
1016 Returns:
1017 Either a single value if `fetches` is a single graph element, or
1018 a list of values if `fetches` is a list, or a dictionary with the
1019 same keys as `fetches` if that is a dictionary
1020 (see documentation for `run`).
1022 Raises:
1023 tf.errors.OpError: Or one of its subclasses on error.
1024 """
1025 # TODO(touts): Support feeding and fetching the same tensor.
1026 return self._run(handle, fetches, feed_dict, None, None)
1028 def partial_run_setup(self, fetches, feeds=None):
1029 """Sets up a graph with feeds and fetches for partial run.
1031 This is EXPERIMENTAL and subject to change.
1033 Note that contrary to `run`, `feeds` only specifies the graph elements.
1034 The tensors will be supplied by the subsequent `partial_run` calls.
1036 Args:
1037 fetches: A single graph element, or a list of graph elements.
1038 feeds: A single graph element, or a list of graph elements.
1040 Returns:
1041 A handle for partial run.
1043 Raises:
1044 RuntimeError: If this `Session` is in an invalid state (e.g. has been
1045 closed).
1046 TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
1047 tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens.
1048 """
1050 def _feed_fn(feed):
1051 for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS:
1052 if isinstance(feed, tensor_type):
1053 return feed_fn(feed)
1054 raise TypeError(f'Feed argument {feed} has invalid type '
1055 f'"{type(feed).__name__}"')
1057 # Check session.
1058 if self._closed:
1059 raise RuntimeError('Attempted to use a closed Session.')
1060 if self.graph.version == 0:
1061 raise RuntimeError('The Session graph is empty. Add operations to the '
1062 'graph before calling run().')
1064 if feeds is None:
1065 feeds = []
1066 # Create request.
1067 feed_list = []
1069 # Validate and process feed_list.
1070 is_list_feed = isinstance(feeds, (list, tuple))
1071 if not is_list_feed:
1072 feeds = [feeds]
1073 for feed in feeds:
1074 for subfeed in _feed_fn(feed):
1075 try:
1076 subfeed_t = self.graph.as_graph_element(
1077 subfeed, allow_tensor=True, allow_operation=False)
1078 # pylint: disable=protected-access
1079 feed_list.append(subfeed_t._as_tf_output())
1080 # pylint: enable=protected-access
1081 except Exception as e:
1082 e.message = ('Cannot interpret argument `feed` key as Tensor: '
1083 f'{e.message}')
1084 e.args = (e.message,)
1085 raise e
1087 # Validate and process fetches.
1088 # TODO(touts): Support feeding and fetching the same tensor.
1089 fetch_handler = _FetchHandler(self._graph, fetches, {})
1091 # Set up a graph with feeds and fetches for partial run.
1092 def _setup_fn(session, feed_list, fetch_list, target_list):
1093 self._extend_graph()
1094 return tf_session.TF_SessionPRunSetup_wrapper(session, feed_list,
1095 fetch_list, target_list)
1097 # pylint: disable=protected-access
1098 final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()]
1099 final_targets = [op._c_op for op in fetch_handler.targets()]
1100 # pylint: enable=protected-access
1102 return self._do_call(_setup_fn, self._session, feed_list, final_fetches,
1103 final_targets)
1105 def _run(self, handle, fetches, feed_dict, options, run_metadata):
1106 """Perform either run or partial_run, depending the presence of `handle`."""
1108 def _feed_fn(feed, feed_val):
1109 for tensor_type, _, feed_fn, _ in _REGISTERED_EXPANSIONS:
1110 if isinstance(feed, tensor_type):
1111 return feed_fn(feed, feed_val)
1112 raise TypeError(f'{feed} in argument `feed_dict` has invalid type '
1113 f'"{type(feed).__name__}"')
1115 # Check session.
1116 if self._closed:
1117 raise RuntimeError('Attempted to use a closed Session.')
1118 if self.graph.version == 0:
1119 raise RuntimeError('The Session graph is empty. Add operations to the '
1120 'graph before calling run().')
1122 # Create request.
1123 feed_dict_tensor = {}
1124 feed_map = {}
1126 # Validate and process feed_dict.
1127 feed_handles = {}
1128 if feed_dict:
1129 feed_dict = nest.flatten_dict_items(feed_dict)
1130 for feed, feed_val in feed_dict.items():
1131 for subfeed, subfeed_val in _feed_fn(feed, feed_val):
1132 try:
1133 subfeed_t = self.graph.as_graph_element(
1134 subfeed, allow_tensor=True, allow_operation=False)
1135 except Exception as e:
1136 raise TypeError(
1137 f'Cannot interpret feed_dict key as Tensor: {e.args[0]}')
1139 if isinstance(subfeed_val, ops.Tensor):
1140 raise TypeError(
1141 'The value of a feed cannot be a tf.Tensor object. Acceptable '
1142 'feed values include Python scalars, strings, lists, numpy '
1143 'ndarrays, or TensorHandles. For reference, the tensor object '
1144 f'was {str(feed_val)} which was passed to the argument '
1145 f'`feed_dict` with key {str(feed)}.')
1147 subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
1148 if isinstance(subfeed_val, int) and _convert_to_numpy_obj(
1149 subfeed_dtype, subfeed_val) != subfeed_val:
1150 raise TypeError(
1151 f'Type of feed value {str(subfeed_val)} with type ' +
1152 f'{str(type(subfeed_val))} is not compatible with Tensor type '
1153 f'{str(subfeed_dtype)}. Try explicitly setting the type of the '
1154 'feed tensor to a larger type (e.g. int64).')
1156 is_tensor_handle_feed = isinstance(subfeed_val,
1157 session_ops.TensorHandle)
1158 if is_tensor_handle_feed:
1159 np_val = subfeed_val.to_numpy_array()
1160 feed_handles[subfeed_t.ref()] = subfeed_val
1161 else:
1162 np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
1164 if (not is_tensor_handle_feed and
1165 not subfeed_t.get_shape().is_compatible_with(np_val.shape)):
1166 raise ValueError(
1167 f'Cannot feed value of shape {str(np_val.shape)} for Tensor '
1168 f'{subfeed_t.name}, which has shape '
1169 f'{str(subfeed_t.get_shape())}')
1170 if not self.graph.is_feedable(subfeed_t):
1171 raise ValueError(f'Tensor {subfeed_t.name} may not be fed.')
1173 feed_dict_tensor[subfeed_t.ref()] = np_val
1174 feed_map[compat.as_bytes(subfeed_t.name)] = (subfeed_t, subfeed_val)
1176 # Create a fetch handler to take care of the structure of fetches.
1177 fetch_handler = _FetchHandler(
1178 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1180 # Run request and get response.
1181 # We need to keep the returned movers alive for the following _do_run().
1182 # These movers are no longer needed when _do_run() completes, and
1183 # are deleted when `movers` goes out of scope when this _run() ends.
1184 # TODO(yuanbyu, keveman): Revisit whether we should just treat feeding
1185 # of a handle from a different device as an error.
1186 _ = self._update_with_movers(feed_dict_tensor, feed_map)
1187 final_fetches = fetch_handler.fetches()
1188 final_targets = fetch_handler.targets()
1189 # We only want to really perform the run if fetches or targets are provided,
1190 # or if the call is a partial run that specifies feeds.
1191 if final_fetches or final_targets or (handle and feed_dict_tensor):
1192 results = self._do_run(handle, final_targets, final_fetches,
1193 feed_dict_tensor, options, run_metadata)
1194 else:
1195 results = []
1196 return fetch_handler.build_results(self, results)
1198 def make_callable(self, fetches, feed_list=None, accept_options=False):
1199 """Returns a Python callable that runs a particular step.
1201 The returned callable will take `len(feed_list)` arguments whose types
1202 must be compatible feed values for the respective elements of `feed_list`.
1203 For example, if element `i` of `feed_list` is a `tf.Tensor`, the `i`th
1204 argument to the returned callable must be a numpy ndarray (or something
1205 convertible to an ndarray) with matching element type and shape. See
1206 `tf.Session.run` for details of the allowable feed key and value types.
1208 The returned callable will have the same return type as
1209 `tf.Session.run(fetches, ...)`. For example, if `fetches` is a `tf.Tensor`,
1210 the callable will return a numpy ndarray; if `fetches` is a `tf.Operation`,
1211 it will return `None`.
1213 Args:
1214 fetches: A value or list of values to fetch. See `tf.Session.run` for
1215 details of the allowable fetch types.
1216 feed_list: (Optional.) A list of `feed_dict` keys. See `tf.Session.run`
1217 for details of the allowable feed key types.
1218 accept_options: (Optional.) If `True`, the returned `Callable` will be
1219 able to accept `tf.compat.v1.RunOptions` and `tf.compat.v1.RunMetadata`
1220 as optional keyword arguments `options` and `run_metadata`,
1221 respectively, with the same syntax and semantics as `tf.Session.run`,
1222 which is useful for certain use cases (profiling and debugging) but will
1223 result in measurable slowdown of the `Callable`'s
1224 performance. Default: `False`.
1226 Returns:
1227 A function that when called will execute the step defined by
1228 `feed_list` and `fetches` in this session.
1230 Raises:
1231 TypeError: If `fetches` or `feed_list` cannot be interpreted
1232 as arguments to `tf.Session.run`.
1233 """
1234 if feed_list is not None:
1235 if not isinstance(feed_list, (list, tuple)):
1236 raise TypeError('Argument `feed_list` must be a list or tuple. '
1237 f'Received: feed_list={feed_list}')
1238 # Delegate any non-empty feed lists to the existing `run()` logic.
1239 # TODO(mrry): Refactor the feed handling logic from
1240 # `Session._run()` so that we can convert the feeds to a list of
1241 # strings here.
1242 def _generic_run(*feed_args, **kwargs):
1243 feed_dict = {
1244 feed: feed_val for feed, feed_val in zip(feed_list, feed_args)
1245 }
1246 return self.run(fetches, feed_dict=feed_dict, **kwargs)
1248 return _generic_run
1250 # Ensure any changes to the graph are reflected in the runtime.
1251 # Note that we don't need to do this on subsequent calls to the
1252 # returned object, because the arguments to `fetches` must already be
1253 # in the graph.
1254 self._extend_graph()
1256 # Create a fetch handler to take care of the structure of fetches.
1257 fetch_handler = _FetchHandler(self._graph, fetches, {})
1258 # pylint: disable=protected-access
1259 fetch_list = [t._as_tf_output() for t in fetch_handler.fetches()]
1260 target_list = [op._c_op for op in fetch_handler.targets()]
1262 # pylint: enable=protected-access
1264 def _callable_template_with_options_and_metadata(fetch_list,
1265 target_list,
1266 fetch_handler,
1267 options=None,
1268 run_metadata=None):
1269 """Template callable that accepts RunOptions and RunMetadata."""
1270 options_ptr = tf_session.TF_NewBufferFromString(
1271 compat.as_bytes(options.SerializeToString())) if options else None
1272 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1273 try:
1274 results = self._call_tf_sessionrun(options_ptr, {}, fetch_list,
1275 target_list, run_metadata_ptr)
1276 if fetch_handler:
1277 results = fetch_handler.build_results(self, results)
1278 else:
1279 results = results[0] if results else None
1280 if run_metadata:
1281 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1282 run_metadata.ParseFromString(compat.as_bytes(proto_data))
1283 finally:
1284 if run_metadata_ptr:
1285 tf_session.TF_DeleteBuffer(run_metadata_ptr)
1286 if options:
1287 tf_session.TF_DeleteBuffer(options_ptr)
1288 return results
1290 if accept_options:
1291 return functools.partial(_callable_template_with_options_and_metadata,
1292 fetch_list, target_list, fetch_handler)
1293 elif isinstance(fetches, ops.Operation):
1294 # Special case for fetching a single operation, because the
1295 # function will have no return value.
1296 assert not fetch_list
1297 assert len(target_list) == 1
1299 def _single_operation_run():
1300 self._call_tf_sessionrun(None, {}, [], target_list, None)
1302 return _single_operation_run
1303 elif isinstance(fetches, ops.Tensor):
1304 # Special case for fetching a single tensor, because the
1305 # function can return the result of `TF_Run()` directly.
1306 assert len(fetch_list) == 1
1307 assert not target_list
1309 def _single_tensor_run():
1310 results = self._call_tf_sessionrun(None, {}, fetch_list, [], None)
1311 return results[0]
1313 return _single_tensor_run
1314 else:
1315 # In all other cases, we must use `fetch_handler` to build the
1316 # results for us.
1317 def _fetch_handler_run():
1318 results = self._call_tf_sessionrun(None, {}, fetch_list, target_list,
1319 None)
1320 return fetch_handler.build_results(self, results)
1322 return _fetch_handler_run
1324 # Captures the name of a node in an error status. The regex below matches
1325 # both the old and the new formats:
1326 # Old format: [[Node: <node_name> = ...]]
1327 # New format: [[{{node <node_name>}} = ...]]
1328 _NODEDEF_NAME_RE = re.compile(
1329 r'\[\[(Node: )?(\{\{node )?([^\} ]*)(\}\})?\s*=*')
1331 def _do_run(self, handle, target_list, fetch_list, feed_dict, options,
1332 run_metadata):
1333 """Runs a step based on the given fetches and feeds.
1335 Args:
1336 handle: a handle for partial_run. None if this is just a call to run().
1337 target_list: A list of operations to be run, but not fetched.
1338 fetch_list: A list of tensors to be fetched.
1339 feed_dict: A dictionary that maps tensors to numpy ndarrays.
1340 options: A (pointer to a) [`RunOptions`] protocol buffer, or None
1341 run_metadata: A (pointer to a) [`RunMetadata`] protocol buffer, or None
1343 Returns:
1344 A list of numpy ndarrays, corresponding to the elements of
1345 `fetch_list`. If the ith element of `fetch_list` contains the
1346 name of an operation, the first Tensor output of that operation
1347 will be returned for that element.
1349 Raises:
1350 tf.errors.OpError: Or one of its subclasses on error.
1351 """
1352 # pylint: disable=protected-access
1353 feeds = dict((t.deref()._as_tf_output(), v) for t, v in feed_dict.items())
1354 fetches = [t._as_tf_output() for t in fetch_list]
1355 targets = [op._c_op for op in target_list]
1357 # pylint: enable=protected-access
1359 def _run_fn(feed_dict, fetch_list, target_list, options, run_metadata):
1360 # Ensure any changes to the graph are reflected in the runtime.
1361 self._extend_graph()
1362 return self._call_tf_sessionrun(options, feed_dict, fetch_list,
1363 target_list, run_metadata)
1365 def _prun_fn(handle, feed_dict, fetch_list):
1366 if target_list:
1367 raise RuntimeError('partial_run() requires empty `target_list`. '
1368 f'Received: target_list={target_list} (non-empty)')
1369 return self._call_tf_sessionprun(handle, feed_dict, fetch_list)
1371 if handle is None:
1372 return self._do_call(_run_fn, feeds, fetches, targets, options,
1373 run_metadata)
1374 else:
1375 return self._do_call(_prun_fn, handle, feeds, fetches)
1377 def _do_call(self, fn, *args):
1378 try:
1379 return fn(*args)
1380 except errors.OpError as e:
1381 message = compat.as_text(e.message)
1382 m = BaseSession._NODEDEF_NAME_RE.search(message)
1383 node_def = None
1384 op = None
1385 if m is not None:
1386 node_name = m.group(3)
1387 try:
1388 op = self._graph.get_operation_by_name(node_name)
1389 node_def = op.node_def
1390 except KeyError:
1391 pass
1392 message = error_interpolation.interpolate(message, self._graph)
1393 if 'only supports NHWC tensor format' in message:
1394 message += ('\nA possible workaround: Try disabling Grappler optimizer'
1395 '\nby modifying the config for creating the session eg.'
1396 '\nsession_config.graph_options.rewrite_options.'
1397 'disable_meta_optimizer = True')
1398 raise type(e)(node_def, op, message) # pylint: disable=no-value-for-parameter
1400 def _extend_graph(self):
1401 with self._graph._session_run_lock(): # pylint: disable=protected-access
1402 tf_session.ExtendSession(self._session)
1404 # The threshold to run garbage collection to delete dead tensors.
1405 _DEAD_HANDLES_THRESHOLD = 10
1407 def _register_dead_handle(self, handle):
1408 # Register a dead handle in the session. Delete the dead tensors when
1409 # the number of dead tensors exceeds certain threshold.
1410 tensors_to_delete = None
1411 with self._delete_lock:
1412 self._dead_handles.append(handle)
1413 if len(self._dead_handles) == BaseSession._DEAD_HANDLES_THRESHOLD:
1414 tensors_to_delete = self._dead_handles
1415 self._dead_handles = []
1416 # Delete the dead tensors.
1417 if tensors_to_delete:
1418 feeds = {}
1419 fetches = []
1420 for deleter_key, tensor_handle in enumerate(tensors_to_delete):
1421 holder, deleter = session_ops._get_handle_deleter(
1422 self.graph, deleter_key, tensor_handle)
1423 feeds[holder] = tensor_handle
1424 fetches.append(deleter)
1425 self.run(fetches, feed_dict=feeds)
1427 def _update_with_movers(self, feed_dict, feed_map):
1428 # If a tensor handle that is fed to a device incompatible placeholder,
1429 # we move the tensor to the right device, generate a new tensor handle,
1430 # and update `feed_dict` to use the new handle.
1431 handle_movers = []
1432 for feed_name, val in feed_map.items():
1433 mover = session_ops._get_handle_mover(self.graph, *val)
1434 if mover:
1435 handle_movers.append((feed_name, val[1], mover))
1436 # Transfer a tensor to the right device if needed.
1437 if not handle_movers:
1438 return []
1439 else:
1440 feeds = {}
1441 fetches = []
1442 for _, handle, mover in handle_movers:
1443 feeds[mover[0]] = handle
1444 fetches.append(mover[1])
1445 handles = self.run(fetches, feed_dict=feeds)
1446 for handle_mover, handle in zip(handle_movers, handles):
1447 np_val = np.array(handle.handle, dtype=np.object_)
1448 feed_name = handle_mover[0]
1449 feed_tensor = feed_map[feed_name][0]
1450 feed_dict[feed_tensor.ref()] = np_val
1451 return handles
1453 def _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list,
1454 run_metadata):
1455 return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
1456 fetch_list, target_list,
1457 run_metadata)
1459 def _call_tf_sessionprun(self, handle, feed_dict, fetch_list):
1460 return tf_session.TF_SessionPRun_wrapper(self._session, handle, feed_dict,
1461 fetch_list)
1463 # pylint: disable=protected-access
1464 class _Callable(object):
1465 """Experimental wrapper for the C++ `Session::MakeCallable()` API."""
1467 def __init__(self, session, callable_options):
1468 self._session = session
1469 self._handle = None
1470 options_ptr = tf_session.TF_NewBufferFromString(
1471 compat.as_bytes(callable_options.SerializeToString()))
1472 try:
1473 self._handle = tf_session.TF_SessionMakeCallable(
1474 session._session, options_ptr)
1475 finally:
1476 tf_session.TF_DeleteBuffer(options_ptr)
1478 def __call__(self, *args, **kwargs):
1479 run_metadata = kwargs.get('run_metadata', None)
1480 try:
1481 run_metadata_ptr = tf_session.TF_NewBuffer() if run_metadata else None
1482 ret = tf_session.TF_SessionRunCallable(self._session._session,
1483 self._handle, args,
1484 run_metadata_ptr)
1485 if run_metadata:
1486 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
1487 run_metadata.ParseFromString(compat.as_bytes(proto_data))
1488 finally:
1489 if run_metadata_ptr:
1490 tf_session.TF_DeleteBuffer(run_metadata_ptr)
1491 return ret
1493 def __del__(self):
1494 # NOTE(mrry): It is possible that `self._session.__del__()` could be
1495 # called before this destructor, in which case `self._session._session`
1496 # will be `None`.
1497 if (self._handle is not None and self._session._session is not None and
1498 not self._session._closed):
1499 tf_session.TF_SessionReleaseCallable(self._session._session,
1500 self._handle)
1502 # pylint: enable=protected-access
1504 def _make_callable_from_options(self, callable_options):
1505 """Returns a handle to a "callable" with the given options.
1507 Args:
1508 callable_options: A `CallableOptions` protocol buffer message describing
1509 the computation that will be performed by the callable.
1511 Returns:
1512 A handle to the new callable.
1513 """
1514 self._extend_graph()
1515 return BaseSession._Callable(self, callable_options)
1518@tf_export(v1=['Session'])
1519class Session(BaseSession):
1520 """A class for running TensorFlow operations.
1522 A `Session` object encapsulates the environment in which `Operation`
1523 objects are executed, and `Tensor` objects are evaluated. For
1524 example:
1526 ```python
1527 tf.compat.v1.disable_eager_execution() # need to disable eager in TF2.x
1528 # Build a graph.
1529 a = tf.constant(5.0)
1530 b = tf.constant(6.0)
1531 c = a * b
1533 # Launch the graph in a session.
1534 sess = tf.compat.v1.Session()
1536 # Evaluate the tensor `c`.
1537 print(sess.run(c)) # prints 30.0
1538 ```
1540 A session may own resources, such as
1541 `tf.Variable`, `tf.queue.QueueBase`,
1542 and `tf.compat.v1.ReaderBase`. It is important to release
1543 these resources when they are no longer required. To do this, either
1544 invoke the `tf.Session.close` method on the session, or use
1545 the session as a context manager. The following two examples are
1546 equivalent:
1548 ```python
1549 # Using the `close()` method.
1550 sess = tf.compat.v1.Session()
1551 sess.run(...)
1552 sess.close()
1554 # Using the context manager.
1555 with tf.compat.v1.Session() as sess:
1556 sess.run(...)
1557 ```
1559 The
1560 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1561 protocol buffer exposes various configuration options for a
1562 session. For example, to create a session that uses soft constraints
1563 for device placement, and log the resulting placement decisions,
1564 create a session as follows:
1566 ```python
1567 # Launch the graph in a session that allows soft device placement and
1568 # logs the placement decisions.
1569 sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
1570 allow_soft_placement=True,
1571 log_device_placement=True))
1572 ```
1574 @compatibility(TF2)
1575 `Session` does not work with either eager execution or `tf.function`, and you
1576 should not invoke it directly. To migrate code that uses sessions to TF2,
1577 rewrite the code without it. See the
1578 [migration
1579 guide](https://www.tensorflow.org/guide/migrate#1_replace_v1sessionrun_calls)
1580 on replacing `Session.run` calls.
1581 @end_compatibility
1582 """
1584 def __init__(self, target='', graph=None, config=None):
1585 """Creates a new TensorFlow session.
1587 If no `graph` argument is specified when constructing the session,
1588 the default graph will be launched in the session. If you are
1589 using more than one graph (created with `tf.Graph()`) in the same
1590 process, you will have to use different sessions for each graph,
1591 but each graph can be used in multiple sessions. In this case, it
1592 is often clearer to pass the graph to be launched explicitly to
1593 the session constructor.
1595 Args:
1596 target: (Optional.) The execution engine to connect to. Defaults to using
1597 an in-process engine. See
1598 [Distributed TensorFlow](https://tensorflow.org/deploy/distributed) for
1599 more examples.
1600 graph: (Optional.) The `Graph` to be launched (described above).
1601 config: (Optional.) A
1602 [`ConfigProto`](https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto)
1603 protocol buffer with configuration options for the session.
1604 """
1605 super(Session, self).__init__(target, graph, config=config)
1606 # NOTE(mrry): Create these on first `__enter__` to avoid a reference cycle.
1607 self._default_graph_context_manager = None
1608 self._default_session_context_manager = None
1610 def __enter__(self):
1611 if self._default_graph_context_manager is None:
1612 self._default_graph_context_manager = self.graph.as_default()
1613 else:
1614 raise RuntimeError('Session context managers are not re-entrant. '
1615 'Use `Session.as_default()` if you want to enter '
1616 'a session multiple times.')
1617 if self._default_session_context_manager is None:
1618 self._default_session_context_manager = self.as_default()
1619 self._default_graph_context_manager.__enter__()
1620 return self._default_session_context_manager.__enter__()
1622 def __exit__(self, exec_type, exec_value, exec_tb):
1623 if exec_type is errors.OpError:
1624 logging.error('Session closing due to OpError: %s', (exec_value,))
1625 try:
1626 self._default_session_context_manager.__exit__(exec_type, exec_value,
1627 exec_tb)
1628 except RuntimeError as error:
1629 if error == exec_value:
1630 # NOTE(skyewm): for some reason, in Python3,
1631 # _default_session_context_manager.__exit__ will re-raise the "not
1632 # re-entrant" exception raised in __enter__ above (note that if we're
1633 # here, we're in the outer session context manager, since __exit__ is
1634 # not called when __enter__ raises an exception). We still want to
1635 # continue cleaning up this context manager before the exception is
1636 # further propagated, so we ignore it here (note that it'll continue
1637 # being propagated after this method completes).
1638 pass
1639 else:
1640 raise
1641 self._default_graph_context_manager.__exit__(exec_type, exec_value, exec_tb)
1643 self._default_session_context_manager = None
1644 self._default_graph_context_manager = None
1646 # If we are closing due to an exception, set a time limit on our Close() to
1647 # avoid blocking forever.
1648 # TODO(b/120204635) remove this when deadlock is fixed.
1649 if exec_type:
1650 close_thread = threading.Thread(
1651 name='SessionCloseThread', target=self.close)
1652 close_thread.daemon = True
1653 close_thread.start()
1654 close_thread.join(30.0)
1655 if close_thread.is_alive():
1656 logging.error(
1657 'Session failed to close after 30 seconds. Continuing after this '
1658 'point may leave your program in an undefined state.')
1659 else:
1660 self.close()
1662 @staticmethod
1663 def reset(target, containers=None, config=None):
1664 """Resets resource containers on `target`, and close all connected sessions.
1666 A resource container is distributed across all workers in the
1667 same cluster as `target`. When a resource container on `target`
1668 is reset, resources associated with that container will be cleared.
1669 In particular, all Variables in the container will become undefined:
1670 they lose their values and shapes.
1672 NOTE:
1673 (i) reset() is currently only implemented for distributed sessions.
1674 (ii) Any sessions on the master named by `target` will be closed.
1676 If no resource containers are provided, all containers are reset.
1678 Args:
1679 target: The execution engine to connect to.
1680 containers: A list of resource container name strings, or `None` if all of
1681 all the containers are to be reset.
1682 config: (Optional.) Protocol buffer with configuration options.
1684 Raises:
1685 tf.errors.OpError: Or one of its subclasses if an error occurs while
1686 resetting containers.
1687 """
1688 if target is not None:
1689 target = compat.as_bytes(target)
1690 if containers is not None:
1691 containers = [compat.as_bytes(c) for c in containers]
1692 else:
1693 containers = []
1694 tf_session.TF_Reset(target, containers, config)
1697@tf_export(v1=['InteractiveSession'])
1698class InteractiveSession(BaseSession):
1699 """A TensorFlow `Session` for use in interactive contexts, such as a shell.
1701 The only difference with a regular `Session` is that an `InteractiveSession`
1702 installs itself as the default session on construction.
1703 The methods `tf.Tensor.eval`
1704 and `tf.Operation.run`
1705 will use that session to run ops.
1707 This is convenient in interactive shells and [IPython
1708 notebooks](http://ipython.org), as it avoids having to pass an explicit
1709 `Session` object to run ops.
1711 For example:
1713 ```python
1714 sess = tf.compat.v1.InteractiveSession()
1715 a = tf.constant(5.0)
1716 b = tf.constant(6.0)
1717 c = a * b
1718 # We can just use 'c.eval()' without passing 'sess'
1719 print(c.eval())
1720 sess.close()
1721 ```
1723 Note that a regular session installs itself as the default session when it
1724 is created in a `with` statement. The common usage in non-interactive
1725 programs is to follow that pattern:
1727 ```python
1728 a = tf.constant(5.0)
1729 b = tf.constant(6.0)
1730 c = a * b
1731 with tf.compat.v1.Session():
1732 # We can also use 'c.eval()' here.
1733 print(c.eval())
1734 ```
1735 """
1737 _count_lock = threading.Lock()
1738 _active_session_count = 0 # GUARDED_BY(_count_lock)
1740 def __init__(self, target='', graph=None, config=None):
1741 """Creates a new interactive TensorFlow session.
1743 If no `graph` argument is specified when constructing the session,
1744 the default graph will be launched in the session. If you are
1745 using more than one graph (created with `tf.Graph()`) in the same
1746 process, you will have to use different sessions for each graph,
1747 but each graph can be used in multiple sessions. In this case, it
1748 is often clearer to pass the graph to be launched explicitly to
1749 the session constructor.
1751 Args:
1752 target: (Optional.) The execution engine to connect to. Defaults to using
1753 an in-process engine.
1754 graph: (Optional.) The `Graph` to be launched (described above).
1755 config: (Optional) `ConfigProto` proto used to configure the session.
1756 """
1757 if not config:
1758 # If config is not provided, choose some reasonable defaults for
1759 # interactive use:
1760 #
1761 # - Grow GPU memory as needed at the cost of fragmentation.
1762 gpu_options = config_pb2.GPUOptions(allow_growth=True)
1763 config = config_pb2.ConfigProto(gpu_options=gpu_options)
1764 # Interactive sessions always place pruned graphs.
1765 config.graph_options.place_pruned_graph = True
1767 super(InteractiveSession, self).__init__(target, graph, config)
1768 with InteractiveSession._count_lock:
1769 if InteractiveSession._active_session_count > 0:
1770 warnings.warn('An interactive session is already active. This can '
1771 'cause out-of-memory errors in some cases. You must '
1772 'explicitly call `InteractiveSession.close()` to release '
1773 'resources held by the other session(s).')
1774 InteractiveSession._active_session_count += 1
1775 # NOTE(mrry): We do not use `Session._closed` here because it has unhelpful
1776 # semantics (in particular, it is not set to true if `Session.close()` is
1777 # called on a session that has not been "opened" by running a step) and we
1778 # cannot change those semantics without breaking existing code.
1779 self._explicitly_closed = False
1781 self._default_session = self.as_default()
1782 self._default_session.enforce_nesting = False
1783 self._default_session.__enter__()
1784 self._explicit_graph = graph
1785 if self._explicit_graph is not None:
1786 self._default_graph = graph.as_default()
1787 self._default_graph.enforce_nesting = False
1788 self._default_graph.__enter__()
1790 def close(self):
1791 """Closes an `InteractiveSession`."""
1792 super(InteractiveSession, self).close()
1793 with InteractiveSession._count_lock:
1794 if not self._explicitly_closed:
1795 InteractiveSession._active_session_count -= 1
1796 self._explicitly_closed = True
1797 else:
1798 return
1799 if self._explicit_graph is not None:
1800 self._default_graph.__exit__(None, None, None)
1801 self._default_graph = None
1802 self._default_session.__exit__(None, None, None)
1803 self._default_session = None