Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/compiler/xla/jit.py: 36%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Library for controlling the Tensorflow/XLA JIT compiler.""" 

16 

17import contextlib 

18 

19from tensorflow.core.framework import attr_value_pb2 

20from tensorflow.python.eager import context 

21from tensorflow.python.framework import ops 

22from tensorflow.python.util.tf_export import tf_export 

23 

24 

25_XLA_SCOPE_KEY = ("__xla_scope",) 

26 

27 

28class _XlaScope(object): 

29 """Keeps track of previous XLA scope calls, and depth of current call.""" 

30 

31 def __init__(self, count, depth): 

32 self.count = count 

33 self.depth = depth 

34 

35 

36@contextlib.contextmanager 

37@tf_export("xla.experimental.jit_scope") 

38def experimental_jit_scope(compile_ops=True, separate_compiled_gradients=False): 

39 """Enable or disable JIT compilation of operators within the scope. 

40 

41 NOTE: This is an experimental feature. 

42 

43 The compilation is a hint and only supported on a best-effort basis. 

44 

45 Example usage: 

46 

47 ```python 

48 with tf.xla.experimental.jit_scope(): 

49 c = tf.matmul(a, b) # compiled 

50 with tf.xla.experimental.jit_scope(compile_ops=False): 

51 d = tf.matmul(a, c) # not compiled 

52 with tf.xla.experimental.jit_scope( 

53 compile_ops=lambda node_def: 'matmul' in node_def.op.lower()): 

54 e = tf.matmul(a, b) + d # matmul is compiled, the addition is not. 

55 ``` 

56 

57 Example of `separate_compiled_gradients`: 

58 

59 ```python 

60 # In the example below, the computations for f, g and h will all be compiled 

61 # in separate scopes. 

62 with tf.xla.experimental.jit_scope( 

63 separate_compiled_gradients=True): 

64 f = tf.matmul(a, b) 

65 g = tf.gradients([f], [a, b], name='mygrads1') 

66 h = tf.gradients([f], [a, b], name='mygrads2') 

67 ``` 

68 

69 Ops that are not in the scope may be clustered and compiled with ops in 

70 the scope with `compile_ops=True`, while the ops in the scope with 

71 `compile_ops=False` will never be compiled. 

72 

73 For example: 

74 

75 ```python 

76 # In the example below, x and loss may be clustered and compiled together, 

77 # while y will not be compiled. 

78 with tf.xla.experimental.jit_scope(): 

79 x = tf.matmul(a, b) 

80 with tf.xla.experimental.jit_scope(compile_ops=False): 

81 y = tf.matmul(c, d) 

82 loss = x + y 

83 ``` 

84 

85 If you want to only compile the ops in the scope with `compile_ops=True`, 

86 consider adding an outer `jit_scope(compile_ops=False)`: 

87 

88 ```python 

89 # In the example below, only x will be compiled. 

90 with tf.xla.experimental.jit_scope(compile_ops=False): 

91 with tf.xla.experimental.jit_scope(): 

92 x = tf.matmul(a, b) 

93 y = tf.matmul(c, d) 

94 loss = x + y 

95 ``` 

96 

97 Args: 

98 compile_ops: Whether to enable or disable compilation in the scope. 

99 Either a Python bool, or a callable that accepts the parameter 

100 `node_def` and returns a python bool. 

101 separate_compiled_gradients: If true put each gradient subgraph into a 

102 separate compilation scope. This gives fine-grained control over which 

103 portions of the graph will be compiled as a single unit. Compiling 

104 gradients separately may yield better performance for some graphs. 

105 The scope is named based on the scope of the forward computation as well 

106 as the name of the gradients. As a result, the gradients will be compiled 

107 in a scope that is separate from both the forward computation, and from 

108 other gradients. 

109 Raises: 

110 RuntimeError: if called when eager execution is enabled. 

111 Yields: 

112 The current scope, enabling or disabling compilation. 

113 """ 

114 if context.executing_eagerly(): 

115 raise RuntimeError("xla.experimental.jit_scope is not supported when eager " 

116 "execution is enabled. Try use it inside tf.function.") 

117 

118 if callable(compile_ops): 

119 def xla_compile(node_def): 

120 return attr_value_pb2.AttrValue(b=compile_ops(node_def)) 

121 else: 

122 xla_compile = attr_value_pb2.AttrValue(b=compile_ops) 

123 

124 attrs = { 

125 "_XlaCompile": 

126 xla_compile, 

127 "_XlaSeparateCompiledGradients": 

128 attr_value_pb2.AttrValue(b=bool(separate_compiled_gradients)) 

129 } 

130 

131 # Find the singleton counter for the current scoped graph. If it 

132 # doesn't exist, create one. 

133 xla_scope_counter = ops.get_collection(_XLA_SCOPE_KEY) 

134 if not xla_scope_counter: 

135 xla_scope_counter = _XlaScope(0, 0) 

136 ops.add_to_collection(_XLA_SCOPE_KEY, xla_scope_counter) 

137 else: 

138 xla_scope_counter = xla_scope_counter[0] 

139 

140 if xla_scope_counter.depth == 0: 

141 # If we're at the root xla scope, we can increase the counter so 

142 # future calls to jit_scope use a different scope value. 

143 # If we're already within a scope, we'll be fusing using the scope 

144 # controlled by the parent. 

145 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 

146 s=("jit_scope_%d" % xla_scope_counter.count).encode()) 

147 xla_scope_counter.count += 1 

148 

149 xla_scope_counter.depth += 1 

150 

151 # pylint: disable=protected-access 

152 with ops.get_default_graph()._attr_scope(attrs): 

153 yield 

154 # pylint: enable=protected-access 

155 

156 xla_scope_counter.depth -= 1