Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/types/core.py: 81%

43 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Core TensorFlow types.""" 

16 

17import sys 

18import textwrap 

19 

20from typing import Union 

21 

22import numpy as np 

23 

24from tensorflow.python.types import doc_typealias 

25from tensorflow.python.util.tf_export import tf_export 

26 

27# pylint:disable=g-import-not-at-top 

28if sys.version_info >= (3, 8): 

29 from typing import Protocol 

30 from typing import runtime_checkable 

31else: 

32 from typing_extensions import Protocol 

33 from typing_extensions import runtime_checkable 

34# pylint:enable=g-import-not-at-top 

35 

36# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced. 

37# TODO(mdan): Add type annotations. 

38 

39 

40# TODO(b/178822082): Revisit this API when tf.types gets more resource. 

41@tf_export("__internal__.types.Tensor", v1=[]) 

42class Tensor(object): 

43 """The base class of all dense Tensor objects. 

44 

45 A dense tensor has a static data type (dtype), and may have a static rank and 

46 shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor 

47 which holds the unique handle that identifies the mutable object. 

48 """ 

49 

50 @property 

51 def dtype(self): 

52 pass 

53 

54 @property 

55 def shape(self): 

56 pass 

57 

58 

59# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `ops.Tensor`; 

60# care should be taken when performing `isinstance` checks on `Value`, e.g.: 

61# 

62# ``` 

63# if isinstance(core.Symbol) and not isinstance(core.Value): 

64# ... 

65# ``` 

66class Symbol(Tensor): 

67 """Symbolic "graph" Tensor. 

68 

69 These objects represent the output of an op definition and do not carry a 

70 value. 

71 """ 

72 pass 

73 

74 

75class Value(Tensor): 

76 """Tensor that can be associated with a value (aka "eager tensor"). 

77 

78 These objects represent the (usually future) output of executing an op 

79 immediately. 

80 """ 

81 

82 def numpy(self): 

83 pass 

84 

85 

86@tf_export("types.experimental.Callable", v1=[]) 

87class Callable: 

88 """Base class for TF callables like those created by tf.function. 

89 

90 Note: Callables are conceptually very similar to `tf.Operation`: a 

91 `tf.Operation` is a kind of callable. 

92 """ 

93 

94 def __call__(self, *args, **kwargs): 

95 """Executes this callable. 

96 

97 This behaves like a regular op - in eager mode, it immediately starts 

98 execution, returning results. In graph mode, it creates ops which return 

99 symbolic TensorFlow values (like `tf.Tensor`, `tf.data.Dataset`, 

100 etc.). For example, `tf.function` callables typically generate a 

101 `tf.raw_ops.PartitionedCall` op, but not always - the 

102 exact operations being generated are an internal implementation detail. 

103 

104 Args: 

105 *args: positional argument for this call 

106 **kwargs: keyword arguments for this call 

107 Returns: 

108 The execution results. 

109 """ 

110 

111 

112@tf_export("types.experimental.ConcreteFunction", v1=[]) 

113class ConcreteFunction(Callable): 

114 """Base class for graph functions. 

115 

116 A `ConcreteFunction` encapsulates a single graph function definition and 

117 is differentiable under `tf.GradientTape` contexts. 

118 """ 

119 

120 

121# TODO(mdan): Name just `types.Function`, for historic continuity? 

122@tf_export("types.experimental.GenericFunction", v1=[]) 

123class GenericFunction(Callable): 

124 """Base class for polymorphic graph functions. 

125 

126 Graph functions are Python callable objects that dispatch calls to a 

127 TensorFlow graph. Polymorphic graph functions can be backed by multiple TF 

128 graphs, and automatically select the appropriate specialization based on the 

129 type of input they were called with. They may also create specializations on 

130 the fly if necessary, for example by tracing. 

131 

132 Also see `tf.function`. 

133 """ 

134 

135 def get_concrete_function(self, *args, **kwargs) -> ConcreteFunction: 

136 """Returns a `ConcreteFunction` specialized to input types. 

137 

138 The arguments specified by `args` and `kwargs` follow normal function call 

139 rules. The returned `ConcreteFunction` has the same set of positional and 

140 keyword arguments as `self`, but their types are compatible to the types 

141 specified by `args` and `kwargs` (though not neccessarily equal). 

142 

143 >>> @tf.function 

144 ... def f(x): 

145 ... return x 

146 >>> f_concrete = f.get_concrete_function(tf.constant(1.0)) 

147 >>> f_concrete = f.get_concrete_function(x=tf.constant(1.0)) 

148 

149 Unlike normal calls, `get_concrete_function` allow type specifiers instead 

150 of TensorFlow objects, so for example `tf.Tensor`s may be replaced with 

151 `tf.TensorSpec`s. 

152 

153 >>> @tf.function 

154 ... def f(x): 

155 ... return x 

156 >>> f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64)) 

157 

158 If the function definition allows only one specialization, `args` and 

159 `kwargs` may be omitted altogether. 

160 

161 >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)]) 

162 ... def f(x): 

163 ... return x 

164 >>> f_concrete = f.get_concrete_function() 

165 

166 The returned `ConcreteFunction` can be called normally: 

167 

168 >>> f_concrete(tf.constant(1.0)) 

169 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 

170 >>> f_concrete(x=tf.constant(1.0)) 

171 <tf.Tensor: shape=(), dtype=float32, numpy=1.0> 

172 

173 Args: 

174 *args: inputs to specialize on. 

175 **kwargs: inputs to specialize on. 

176 

177 Returns: 

178 A `ConcreteFunction`. 

179 """ 

180 pass 

181 

182 def experimental_get_compiler_ir(self, *args, **kwargs): 

183 """Returns compiler IR for the compiled function. 

184 

185 This API is intended *only* for debugging as there are no guarantees on 

186 backwards compatibility of returned IR or the allowed values of `stage`. 

187 

188 Args: 

189 *args: compilation args supports inputs either: (1) all inputs are 

190 TensorSpec or (2) all inputs are tf.Tensor/Python variables. 

191 **kwargs: Keyword arguments used for compilation. Same requirement as 

192 compiliation args. 

193 

194 Returns: 

195 Function callable with the following kwargs: 

196 - `stage` at which the compiler IR should be serialized. Allowed values 

197 are: 

198 - `hlo`: HLO output after conversion from TF 

199 (https://www.tensorflow.org/xla/operation_semantics). 

200 - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized 

201 HLO module proto (a bytes object). 

202 - `optimized_hlo`: HLO after compiler optimizations. 

203 - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the 

204 output is a serialized HLO module proto (a bytes object). 

205 - `optimized_hlo_dot`: optimized HLO in DOT format suitable for 

206 Graphviz. 

207 - `device_name` can be either None, in which case the preferred device 

208 is used for compilation, or a device name. It can be a full device 

209 name, or a partial one, e.g., `/device:CPU:0`. 

210 

211 For example, for 

212 

213 ```python 

214 @tf.function(jit_compile=True) 

215 def f(x): 

216 return x + 1 

217 

218 f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo') 

219 ``` 

220 

221 the output is: 

222 

223 ``` 

224 HloModule a_inference_f_13__.9 

225 

226 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] { 

227 %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false} 

228 %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1) 

229 %constant.3 = f32[] constant(1) 

230 %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3) 

231 %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2, 

232 f32[10,10]{1,0} %broadcast.4) 

233 %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5) 

234 %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6) 

235 ROOT %get-tuple-element.8 = f32[10,10]{1,0} 

236 get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0 

237 } 

238 ``` 

239 

240 Here is another example using tf.TensorSpec inputs: 

241 

242 ```python 

243 y = tf.Variable(tf.zeros([10, 20], dtype=tf.float32)) 

244 

245 @tf.function(jit_compile=True) 

246 def f(x): 

247 return x + y 

248 

249 hlo_str = f.experimental_get_compiler_ir(tf.TensorSpec(shape=(10, 

250 20)))(stage='hlo') 

251 ``` 

252 

253 The output is: 

254 

255 ``` 

256 HloModule a_inference_f_120__.8, 

257 entry_computation_layout={(f32[10,20]{1,0},f32[10,20]{1,0})->f32[10,20]{1,0}} 

258 

259 ENTRY %a_inference_f_120__.8 (arg0.1: f32[10,20], arg1.2: f32[10,20]) -> 

260 f32[10,20] { 

261 %arg0.1 = f32[10,20]{1,0} parameter(0), parameter_replication={false}, 

262 metadata={op_name="XLA_Args"} 

263 %reshape.3 = f32[10,20]{1,0} reshape(f32[10,20]{1,0} %arg0.1) 

264 %arg1.2 = f32[10,20]{1,0} parameter(1), parameter_replication={false}, 

265 metadata={op_name="XLA_Args"} 

266 %add.4 = f32[10,20]{1,0} add(f32[10,20]{1,0} %reshape.3, f32[10,20]{1,0} 

267 %arg1.2), metadata={op_type="AddV2" op_name="add" 

268 source_file="<ipython-input-16-ea04879c1873>" source_line=4} 

269 %reshape.5 = f32[10,20]{1,0} reshape(f32[10,20]{1,0} %add.4), 

270 metadata={op_name="XLA_Retvals"} 

271 %tuple.6 = (f32[10,20]{1,0}) tuple(f32[10,20]{1,0} %reshape.5), 

272 metadata={op_name="XLA_Retvals"} 

273 ROOT %get-tuple-element.7 = f32[10,20]{1,0} 

274 get-tuple-element((f32[10,20]{1,0}) %tuple.6), index=0, 

275 metadata={op_name="XLA_Retvals"} 

276 } 

277 ``` 

278 

279 The HLO module accepts a flat list of inputs. To retrieve the order 

280 of these inputs signatures, users can call the 

281 `concrete_fn.structured_input_signature` and `concrete_fn.captured_inputs`: 

282 

283 ```python 

284 # Use concrete_fn to get the hlo_module flat_args. 

285 concrete_fn = f.get_concrete_function(tf.TensorSpec(shape=(10, 20))) 

286 flat_args = list( 

287 tf.nest.flatten(concrete_fn.structured_input_signature) 

288 ) + concrete_fn.captured_inputs 

289 ``` 

290 

291 Raises: 

292 ValueError: 

293 (1) If an invalid `stage` is selected 

294 (2) or if applied to a function which is not compiled 

295 (`jit_compile=True` is not set). 

296 (3) or if input shapes are not fully defined for tf.TensorSpec inputs 

297 TypeError: When called with input in graph mode. 

298 """ 

299 pass 

300 

301 

302@runtime_checkable 

303class TensorProtocol(Protocol): 

304 """Protocol type for objects that can be converted to Tensor.""" 

305 

306 def __tf_tensor__(self, dtype=None, name=None): 

307 """Converts this object to a Tensor. 

308 

309 Args: 

310 dtype: data type for the returned Tensor 

311 name: a name for the operations which create the Tensor 

312 Returns: 

313 A Tensor. 

314 """ 

315 pass 

316 

317 

318# TODO(rahulkamat): Add missing types that are convertible to Tensor. 

319TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes, 

320 complex, tuple, list, np.ndarray, np.generic] 

321doc_typealias.document( 

322 obj=TensorLike, 

323 doc=textwrap.dedent("""\ 

324 Union of all types that can be converted to a `tf.Tensor` by `tf.convert_to_tensor`. 

325 

326 This definition may be used in user code. Additional types may be added 

327 in the future as more input types are supported. 

328 

329 Example: 

330 

331 ``` 

332 def foo(x: TensorLike): 

333 pass 

334 ``` 

335 

336 This definition passes static type verification for: 

337 

338 ``` 

339 foo(tf.constant([1, 2, 3])) 

340 foo([1, 2, 3]) 

341 foo(np.array([1, 2, 3])) 

342 ``` 

343 """), 

344) 

345tf_export("types.experimental.TensorLike").export_constant( 

346 __name__, "TensorLike")