Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/stack.py: 58%
38 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes used to handle thread-local stacks."""
17import threading
19from tensorflow.python.util import tf_contextlib
20from tensorflow.python.util.tf_export import tf_export
23class DefaultStack(threading.local):
24 """A thread-local stack of objects for providing implicit defaults."""
26 def __init__(self):
27 super().__init__()
28 self._enforce_nesting = True
29 self.stack = []
31 def get_default(self):
32 return self.stack[-1] if self.stack else None
34 def reset(self):
35 self.stack = []
37 def is_cleared(self):
38 return not self.stack
40 @property
41 def enforce_nesting(self):
42 return self._enforce_nesting
44 @enforce_nesting.setter
45 def enforce_nesting(self, value):
46 self._enforce_nesting = value
48 @tf_contextlib.contextmanager
49 def get_controller(self, default):
50 """A context manager for manipulating a default stack."""
51 self.stack.append(default)
52 try:
53 yield default
54 finally:
55 # stack may be empty if reset() was called
56 if self.stack:
57 if self._enforce_nesting:
58 if self.stack[-1] is not default:
59 raise AssertionError(
60 "Nesting violated for default stack of %s objects" %
61 type(default))
62 self.stack.pop()
63 else:
64 self.stack.remove(default)
67_default_session_stack = DefaultStack()
70def default_session(session):
71 """Python "with" handler for defining a default session.
73 This function provides a means of registering a session for handling
74 Tensor.eval() and Operation.run() calls. It is primarily intended for use
75 by session.Session, but can be used with any object that implements
76 the Session.run() interface.
78 Use with the "with" keyword to specify that Tensor.eval() and Operation.run()
79 invocations within the scope of a block should be executed by a particular
80 session.
82 The default session applies to the current thread only, so it is always
83 possible to inspect the call stack and determine the scope of a default
84 session. If you create a new thread, and wish to use the default session
85 in that thread, you must explicitly add a "with ops.default_session(sess):"
86 block in that thread's function.
88 Example:
89 The following code examples are equivalent:
91 # 1. Using the Session object directly:
92 sess = ...
93 c = tf.constant(5.0)
94 sess.run(c)
96 # 2. Using default_session():
97 sess = ...
98 with ops.default_session(sess):
99 c = tf.constant(5.0)
100 result = c.eval()
102 # 3. Overriding default_session():
103 sess = ...
104 with ops.default_session(sess):
105 c = tf.constant(5.0)
106 with ops.default_session(...):
107 c.eval(session=sess)
109 Args:
110 session: The session to be installed as the default session.
112 Returns:
113 A context manager for the default session.
114 """
115 return _default_session_stack.get_controller(session)
118@tf_export(v1=["get_default_session"])
119def get_default_session():
120 """Returns the default session for the current thread.
122 The returned `Session` will be the innermost session on which a
123 `Session` or `Session.as_default()` context has been entered.
125 NOTE: The default session is a property of the current thread. If you
126 create a new thread, and wish to use the default session in that
127 thread, you must explicitly add a `with sess.as_default():` in that
128 thread's function.
130 Returns:
131 The default `Session` being used in the current thread.
132 """
133 return _default_session_stack.get_default()