Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/types/trace.py: 81%

47 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""tf.function tracing types. 

16 

17See `core.GenericFunction` and `core.ConcreteFunction`. 

18 

19`GenericFunction` assigns types to call arguments, forming a signature. 

20Function signatures are used to match arguments to `ConcreteFunction`s. 

21For example, when a new `ConcreteFunction` is traced, it is assigned a 

22the signature of the arguments it was traced with. Subsequent call arguments 

23which match its signature will be dispatched to the same `ConcreteFunction`. 

24If no `ConcreteFunction` with a matching signature is found, a new one may be 

25traced (a process known as retracing). 

26""" 

27 

28import abc 

29from typing import Any, List, Optional, Sequence, Iterator 

30 

31from typing_extensions import Protocol 

32from typing_extensions import runtime_checkable 

33 

34from tensorflow.python.types import core 

35from tensorflow.python.util.tf_export import tf_export 

36from tensorflow.tools.docs import doc_controls 

37 

38 

39@tf_export("types.experimental.TraceType", v1=[]) 

40class TraceType(metaclass=abc.ABCMeta): 

41 """Represents the type of object(s) for tf.function tracing purposes. 

42 

43 `TraceType` is an abstract class that other classes might inherit from to 

44 provide information regarding associated class(es) for the purposes of 

45 tf.function tracing. The typing logic provided through this mechanism will be 

46 used to make decisions regarding usage of cached concrete functions and 

47 retracing. 

48 

49 For example, if we have the following tf.function and classes: 

50 ```python 

51 @tf.function 

52 def get_mixed_flavor(fruit_a, fruit_b): 

53 return fruit_a.flavor + fruit_b.flavor 

54 

55 class Fruit: 

56 flavor = tf.constant([0, 0]) 

57 

58 class Apple(Fruit): 

59 flavor = tf.constant([1, 2]) 

60 

61 class Mango(Fruit): 

62 flavor = tf.constant([3, 4]) 

63 ``` 

64 

65 tf.function does not know when to re-use an existing concrete function in 

66 regards to the `Fruit` class so naively it retraces for every new instance. 

67 ```python 

68 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function 

69 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again 

70 ``` 

71 

72 However, we, as the designers of the `Fruit` class, know that each subclass 

73 has a fixed flavor and we can reuse an existing traced concrete function if 

74 it was the same subclass. Avoiding such unnecessary tracing of concrete 

75 functions can have significant performance benefits. 

76 

77 ```python 

78 class FruitTraceType(tf.types.experimental.TraceType): 

79 def __init__(self, fruit): 

80 self.fruit_type = type(fruit) 

81 self.fruit_value = fruit 

82 

83 def is_subtype_of(self, other): 

84 return (type(other) is FruitTraceType and 

85 self.fruit_type is other.fruit_type) 

86 

87 def most_specific_common_supertype(self, others): 

88 return self if all(self == other for other in others) else None 

89 

90 def placeholder_value(self, placeholder_context=None): 

91 return self.fruit_value 

92 

93 class Fruit: 

94 

95 def __tf_tracing_type__(self, context): 

96 return FruitTraceType(self) 

97 ``` 

98 

99 Now if we try calling it again: 

100 ```python 

101 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function 

102 get_mixed_flavor(Apple(), Mango()) # Re-uses the traced concrete function 

103 ``` 

104 """ 

105 

106 @abc.abstractmethod 

107 def is_subtype_of(self, other: "TraceType") -> bool: 

108 """Returns True if `self` is a subtype of `other`. 

109 

110 For example, `tf.function` uses subtyping for dispatch: 

111 if `a.is_subtype_of(b)` is True, then an argument of `TraceType` 

112 `a` can be used as argument to a `ConcreteFunction` traced with an 

113 a `TraceType` `b`. 

114 

115 Args: 

116 other: A TraceType object to be compared against. 

117 

118 Example: 

119 

120 ```python 

121 class Dimension(TraceType): 

122 def __init__(self, value: Optional[int]): 

123 self.value = value 

124 

125 def is_subtype_of(self, other): 

126 # Either the value is the same or other has a generalized value that 

127 # can represent any specific ones. 

128 return (self.value == other.value) or (other.value is None) 

129 ``` 

130 """ 

131 

132 @abc.abstractmethod 

133 def most_specific_common_supertype( 

134 self, others: Sequence["TraceType"]) -> Optional["TraceType"]: 

135 """Returns the most specific supertype of `self` and `others`, if exists. 

136 

137 The returned `TraceType` is a supertype of `self` and `others`, that is, 

138 they are all subtypes (see `is_subtype_of`) of it. 

139 It is also most specific, that is, there it has no subtype that is also 

140 a common supertype of `self` and `others`. 

141 

142 If `self` and `others` have no common supertype, this returns `None`. 

143 

144 Args: 

145 others: A sequence of TraceTypes. 

146 

147 Example: 

148 ```python 

149 class Dimension(TraceType): 

150 def __init__(self, value: Optional[int]): 

151 self.value = value 

152 

153 def most_specific_common_supertype(self, other): 

154 # Either the value is the same or other has a generalized value that 

155 # can represent any specific ones. 

156 if self.value == other.value: 

157 return self.value 

158 else: 

159 return Dimension(None) 

160 ``` 

161 """ 

162 

163 @abc.abstractmethod 

164 def placeholder_value(self, placeholder_context) -> Any: 

165 """Creates a placeholder for tracing. 

166 

167 tf.funcion traces with the placeholder value rather than the actual value. 

168 For example, a placeholder value can represent multiple different 

169 actual values. This means that the trace generated with that placeholder 

170 value is more general and reusable which saves expensive retracing. 

171 

172 Args: 

173 placeholder_context: A `PlaceholderContext` container for context 

174 information when creating a placeholder value. 

175 

176 For the `Fruit` example shared above, implementing: 

177 

178 ```python 

179 class FruitTraceType: 

180 def placeholder_value(self, placeholder_context): 

181 return Fruit() 

182 ``` 

183 instructs tf.function to trace with the `Fruit()` objects 

184 instead of the actual `Apple()` and `Mango()` objects when it receives a 

185 call to `get_mixed_flavor(Apple(), Mango())`. For example, Tensor arguments 

186 are replaced with Tensors of similar shape and dtype, output from 

187 a tf.Placeholder op. 

188 

189 More generally, placeholder values are the arguments of a tf.function, 

190 as seen from the function's body: 

191 ```python 

192 @tf.function 

193 def foo(x): 

194 # Here `x` is be the placeholder value 

195 ... 

196 

197 foo(x) # Here `x` is the actual value 

198 ``` 

199 """ 

200 

201 @doc_controls.do_not_doc_inheritable 

202 def _to_tensors(self, value: Any) -> List[core.Tensor]: 

203 """Breaks down a value of this type into Tensors. 

204 

205 Args: 

206 value: An input value belonging to this TraceType 

207 

208 Returns: 

209 List of Tensors. 

210 """ 

211 del value 

212 return [] 

213 

214 @doc_controls.do_not_doc_inheritable 

215 def _from_tensors(self, tensors: Iterator[core.Tensor]) -> Any: 

216 """Regenerates a value of this type from Tensors. 

217 

218 Must use the same fixed amount of tensors as `_to_tensors`. 

219 

220 Args: 

221 tensors: An iterator from which the tensors can be pulled. 

222 

223 Returns: 

224 A value of this type. 

225 """ 

226 del tensors 

227 return self.placeholder_value(PlaceholderContext()) 

228 

229 @doc_controls.do_not_doc_inheritable 

230 def _flatten(self) -> List["TraceType"]: 

231 """Returns a list of TensorSpecs corresponding to `_to_tensors` values.""" 

232 return [] 

233 

234 @doc_controls.do_not_doc_inheritable 

235 def _cast(self, value, casting_context) -> Any: # pylint:disable=unused-argument 

236 """Cast value to this type. 

237 

238 Args: 

239 value: An input value belonging to this TraceType. 

240 casting_context: A context reserved for future usage such as to determine 

241 casting rules. 

242 

243 Returns: 

244 The value casted to this TraceType. 

245 

246 Raises: 

247 AssertionError: When _cast is not overloaded in subclass, 

248 the value is returned directly, and it should be the same to 

249 self.placeholder_value(). 

250 """ 

251 assert value == self.placeholder_value( 

252 PlaceholderContext()), f"Can not cast {value!r} to type {self!r}" 

253 return value 

254 

255 @abc.abstractmethod 

256 def __hash__(self) -> int: 

257 pass 

258 

259 @abc.abstractmethod 

260 def __eq__(self, other) -> bool: 

261 pass 

262 

263 

264@tf_export("types.experimental.TracingContext", v1=[]) 

265class TracingContext(metaclass=abc.ABCMeta): 

266 """Contains information scoped to the tracing of multiple objects. 

267 

268 `TracingContext` is a container class for flags and variables that have 

269 any kind of influence on the tracing behaviour of the class implementing 

270 the __tf_tracing_type__. This context will be shared across all 

271 __tf_tracing_type__ calls while constructing the TraceType for a particular 

272 set of objects. 

273 """ 

274 

275 

276class PlaceholderContext(): 

277 """Contains context information for generating placeholders within a scope.""" 

278 

279 

280class CastContext(): 

281 """Contains context info and rules for casting values to a TypeSpec.""" 

282 

283 

284@runtime_checkable 

285class SupportsTracingProtocol(Protocol): 

286 """A protocol allowing custom classes to control tf.function retracing.""" 

287 

288 @doc_controls.doc_private 

289 @abc.abstractmethod 

290 def __tf_tracing_type__(self, context: TracingContext) -> TraceType: 

291 """Returns the tracing type of this object. 

292 

293 The tracing type is used to build the signature of a tf.function 

294 when traced, and to match arguments with existing signatures. 

295 When a Function object is called, tf.function looks at the tracing type 

296 of the call arguments. If an existing signature of matching type exists, 

297 it will be used. Otherwise, a new function is traced, and its signature 

298 will use the tracing type of the call arguments. 

299 

300 Args: 

301 context: a context object created for each function call for tracking 

302 information about the call arguments as a whole 

303 Returns: 

304 The tracing type of this object. 

305 """ 

306 

307# TODO(b/219556836): Direct tf_export decorator adds non-method members to the 

308# Protocol which breaks @runtime_checkable since it does not support them. 

309tf_export( 

310 "types.experimental.SupportsTracingProtocol", 

311 v1=[]).export_constant(__name__, "SupportsTracingProtocol")