Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/stack.py: 58%

38 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Classes used to handle thread-local stacks.""" 

16 

17import threading 

18 

19from tensorflow.python.util import tf_contextlib 

20from tensorflow.python.util.tf_export import tf_export 

21 

22 

23class DefaultStack(threading.local): 

24 """A thread-local stack of objects for providing implicit defaults.""" 

25 

26 def __init__(self): 

27 super().__init__() 

28 self._enforce_nesting = True 

29 self.stack = [] 

30 

31 def get_default(self): 

32 return self.stack[-1] if self.stack else None 

33 

34 def reset(self): 

35 self.stack = [] 

36 

37 def is_cleared(self): 

38 return not self.stack 

39 

40 @property 

41 def enforce_nesting(self): 

42 return self._enforce_nesting 

43 

44 @enforce_nesting.setter 

45 def enforce_nesting(self, value): 

46 self._enforce_nesting = value 

47 

48 @tf_contextlib.contextmanager 

49 def get_controller(self, default): 

50 """A context manager for manipulating a default stack.""" 

51 self.stack.append(default) 

52 try: 

53 yield default 

54 finally: 

55 # stack may be empty if reset() was called 

56 if self.stack: 

57 if self._enforce_nesting: 

58 if self.stack[-1] is not default: 

59 raise AssertionError( 

60 "Nesting violated for default stack of %s objects" % 

61 type(default)) 

62 self.stack.pop() 

63 else: 

64 self.stack.remove(default) 

65 

66 

67_default_session_stack = DefaultStack() 

68 

69 

70def default_session(session): 

71 """Python "with" handler for defining a default session. 

72 

73 This function provides a means of registering a session for handling 

74 Tensor.eval() and Operation.run() calls. It is primarily intended for use 

75 by session.Session, but can be used with any object that implements 

76 the Session.run() interface. 

77 

78 Use with the "with" keyword to specify that Tensor.eval() and Operation.run() 

79 invocations within the scope of a block should be executed by a particular 

80 session. 

81 

82 The default session applies to the current thread only, so it is always 

83 possible to inspect the call stack and determine the scope of a default 

84 session. If you create a new thread, and wish to use the default session 

85 in that thread, you must explicitly add a "with ops.default_session(sess):" 

86 block in that thread's function. 

87 

88 Example: 

89 The following code examples are equivalent: 

90 

91 # 1. Using the Session object directly: 

92 sess = ... 

93 c = tf.constant(5.0) 

94 sess.run(c) 

95 

96 # 2. Using default_session(): 

97 sess = ... 

98 with ops.default_session(sess): 

99 c = tf.constant(5.0) 

100 result = c.eval() 

101 

102 # 3. Overriding default_session(): 

103 sess = ... 

104 with ops.default_session(sess): 

105 c = tf.constant(5.0) 

106 with ops.default_session(...): 

107 c.eval(session=sess) 

108 

109 Args: 

110 session: The session to be installed as the default session. 

111 

112 Returns: 

113 A context manager for the default session. 

114 """ 

115 return _default_session_stack.get_controller(session) 

116 

117 

118@tf_export(v1=["get_default_session"]) 

119def get_default_session(): 

120 """Returns the default session for the current thread. 

121 

122 The returned `Session` will be the innermost session on which a 

123 `Session` or `Session.as_default()` context has been entered. 

124 

125 NOTE: The default session is a property of the current thread. If you 

126 create a new thread, and wish to use the default session in that 

127 thread, you must explicitly add a `with sess.as_default():` in that 

128 thread's function. 

129 

130 Returns: 

131 The default `Session` being used in the current thread. 

132 """ 

133 return _default_session_stack.get_default()