Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/tf_export.py: 45%

148 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for exporting TensorFlow symbols to the API. 

16 

17Exporting a function or a class: 

18 

19To export a function or a class use tf_export decorator. For e.g.: 

20```python 

21@tf_export('foo', 'bar.foo') 

22def foo(...): 

23 ... 

24``` 

25 

26If a function is assigned to a variable, you can export it by calling 

27tf_export explicitly. For e.g.: 

28```python 

29foo = get_foo(...) 

30tf_export('foo', 'bar.foo')(foo) 

31``` 

32 

33 

34Exporting a constant 

35```python 

36foo = 1 

37tf_export('consts.foo').export_constant(__name__, 'foo') 

38``` 

39""" 

40import collections 

41import functools 

42import sys 

43 

44from tensorflow.python.util import tf_decorator 

45from tensorflow.python.util import tf_inspect 

46 

47ESTIMATOR_API_NAME = 'estimator' 

48KERAS_API_NAME = 'keras' 

49TENSORFLOW_API_NAME = 'tensorflow' 

50 

51# List of subpackage names used by TensorFlow components. Have to check that 

52# TensorFlow core repo does not export any symbols under these names. 

53SUBPACKAGE_NAMESPACES = [ESTIMATOR_API_NAME] 

54 

55_Attributes = collections.namedtuple( 

56 'ExportedApiAttributes', ['names', 'constants']) 

57 

58# Attribute values must be unique to each API. 

59API_ATTRS = { 

60 TENSORFLOW_API_NAME: _Attributes( 

61 '_tf_api_names', 

62 '_tf_api_constants'), 

63 ESTIMATOR_API_NAME: _Attributes( 

64 '_estimator_api_names', 

65 '_estimator_api_constants'), 

66 KERAS_API_NAME: _Attributes( 

67 '_keras_api_names', 

68 '_keras_api_constants') 

69} 

70 

71API_ATTRS_V1 = { 

72 TENSORFLOW_API_NAME: _Attributes( 

73 '_tf_api_names_v1', 

74 '_tf_api_constants_v1'), 

75 ESTIMATOR_API_NAME: _Attributes( 

76 '_estimator_api_names_v1', 

77 '_estimator_api_constants_v1'), 

78 KERAS_API_NAME: _Attributes( 

79 '_keras_api_names_v1', 

80 '_keras_api_constants_v1') 

81} 

82 

83 

84class SymbolAlreadyExposedError(Exception): 

85 """Raised when adding API names to symbol that already has API names.""" 

86 pass 

87 

88 

89class InvalidSymbolNameError(Exception): 

90 """Raised when trying to export symbol as an invalid or unallowed name.""" 

91 pass 

92 

93_NAME_TO_SYMBOL_MAPPING = dict() 

94 

95 

96def get_symbol_from_name(name): 

97 return _NAME_TO_SYMBOL_MAPPING.get(name) 

98 

99 

100def get_canonical_name_for_symbol( 

101 symbol, api_name=TENSORFLOW_API_NAME, 

102 add_prefix_to_v1_names=False): 

103 """Get canonical name for the API symbol. 

104 

105 Example: 

106 ```python 

107 from tensorflow.python.util import tf_export 

108 cls = tf_export.get_symbol_from_name('keras.optimizers.Adam') 

109 

110 # Gives `<class 'keras.optimizer_v2.adam.Adam'>` 

111 print(cls) 

112 

113 # Gives `keras.optimizers.Adam` 

114 print(tf_export.get_canonical_name_for_symbol(cls, api_name='keras')) 

115 ``` 

116 

117 Args: 

118 symbol: API function or class. 

119 api_name: API name (tensorflow or estimator). 

120 add_prefix_to_v1_names: Specifies whether a name available only in V1 

121 should be prefixed with compat.v1. 

122 

123 Returns: 

124 Canonical name for the API symbol (for e.g. initializers.zeros) if 

125 canonical name could be determined. Otherwise, returns None. 

126 """ 

127 if not hasattr(symbol, '__dict__'): 

128 return None 

129 api_names_attr = API_ATTRS[api_name].names 

130 _, undecorated_symbol = tf_decorator.unwrap(symbol) 

131 if api_names_attr not in undecorated_symbol.__dict__: 

132 return None 

133 api_names = getattr(undecorated_symbol, api_names_attr) 

134 deprecated_api_names = undecorated_symbol.__dict__.get( 

135 '_tf_deprecated_api_names', []) 

136 

137 canonical_name = get_canonical_name(api_names, deprecated_api_names) 

138 if canonical_name: 

139 return canonical_name 

140 

141 # If there is no V2 canonical name, get V1 canonical name. 

142 api_names_attr = API_ATTRS_V1[api_name].names 

143 api_names = getattr(undecorated_symbol, api_names_attr) 

144 v1_canonical_name = get_canonical_name(api_names, deprecated_api_names) 

145 if add_prefix_to_v1_names: 

146 return 'compat.v1.%s' % v1_canonical_name 

147 return v1_canonical_name 

148 

149 

150def get_canonical_name(api_names, deprecated_api_names): 

151 """Get preferred endpoint name. 

152 

153 Args: 

154 api_names: API names iterable. 

155 deprecated_api_names: Deprecated API names iterable. 

156 Returns: 

157 Returns one of the following in decreasing preference: 

158 - first non-deprecated endpoint 

159 - first endpoint 

160 - None 

161 """ 

162 non_deprecated_name = next( 

163 (name for name in api_names if name not in deprecated_api_names), 

164 None) 

165 if non_deprecated_name: 

166 return non_deprecated_name 

167 if api_names: 

168 return api_names[0] 

169 return None 

170 

171 

172def get_v1_names(symbol): 

173 """Get a list of TF 1.* names for this symbol. 

174 

175 Args: 

176 symbol: symbol to get API names for. 

177 

178 Returns: 

179 List of all API names for this symbol including TensorFlow and 

180 Estimator names. 

181 """ 

182 names_v1 = [] 

183 tensorflow_api_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].names 

184 estimator_api_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].names 

185 keras_api_attr_v1 = API_ATTRS_V1[KERAS_API_NAME].names 

186 

187 if not hasattr(symbol, '__dict__'): 

188 return names_v1 

189 if tensorflow_api_attr_v1 in symbol.__dict__: 

190 names_v1.extend(getattr(symbol, tensorflow_api_attr_v1)) 

191 if estimator_api_attr_v1 in symbol.__dict__: 

192 names_v1.extend(getattr(symbol, estimator_api_attr_v1)) 

193 if keras_api_attr_v1 in symbol.__dict__: 

194 names_v1.extend(getattr(symbol, keras_api_attr_v1)) 

195 return names_v1 

196 

197 

198def get_v2_names(symbol): 

199 """Get a list of TF 2.0 names for this symbol. 

200 

201 Args: 

202 symbol: symbol to get API names for. 

203 

204 Returns: 

205 List of all API names for this symbol including TensorFlow and 

206 Estimator names. 

207 """ 

208 names_v2 = [] 

209 tensorflow_api_attr = API_ATTRS[TENSORFLOW_API_NAME].names 

210 estimator_api_attr = API_ATTRS[ESTIMATOR_API_NAME].names 

211 keras_api_attr = API_ATTRS[KERAS_API_NAME].names 

212 

213 if not hasattr(symbol, '__dict__'): 

214 return names_v2 

215 if tensorflow_api_attr in symbol.__dict__: 

216 names_v2.extend(getattr(symbol, tensorflow_api_attr)) 

217 if estimator_api_attr in symbol.__dict__: 

218 names_v2.extend(getattr(symbol, estimator_api_attr)) 

219 if keras_api_attr in symbol.__dict__: 

220 names_v2.extend(getattr(symbol, keras_api_attr)) 

221 return names_v2 

222 

223 

224def get_v1_constants(module): 

225 """Get a list of TF 1.* constants in this module. 

226 

227 Args: 

228 module: TensorFlow module. 

229 

230 Returns: 

231 List of all API constants under the given module including TensorFlow and 

232 Estimator constants. 

233 """ 

234 constants_v1 = [] 

235 tensorflow_constants_attr_v1 = API_ATTRS_V1[TENSORFLOW_API_NAME].constants 

236 estimator_constants_attr_v1 = API_ATTRS_V1[ESTIMATOR_API_NAME].constants 

237 

238 if hasattr(module, tensorflow_constants_attr_v1): 

239 constants_v1.extend(getattr(module, tensorflow_constants_attr_v1)) 

240 if hasattr(module, estimator_constants_attr_v1): 

241 constants_v1.extend(getattr(module, estimator_constants_attr_v1)) 

242 return constants_v1 

243 

244 

245def get_v2_constants(module): 

246 """Get a list of TF 2.0 constants in this module. 

247 

248 Args: 

249 module: TensorFlow module. 

250 

251 Returns: 

252 List of all API constants under the given module including TensorFlow and 

253 Estimator constants. 

254 """ 

255 constants_v2 = [] 

256 tensorflow_constants_attr = API_ATTRS[TENSORFLOW_API_NAME].constants 

257 estimator_constants_attr = API_ATTRS[ESTIMATOR_API_NAME].constants 

258 

259 if hasattr(module, tensorflow_constants_attr): 

260 constants_v2.extend(getattr(module, tensorflow_constants_attr)) 

261 if hasattr(module, estimator_constants_attr): 

262 constants_v2.extend(getattr(module, estimator_constants_attr)) 

263 return constants_v2 

264 

265 

266class api_export(object): # pylint: disable=invalid-name 

267 """Provides ways to export symbols to the TensorFlow API.""" 

268 

269 def __init__(self, *args, **kwargs): # pylint: disable=g-doc-args 

270 """Export under the names *args (first one is considered canonical). 

271 

272 Args: 

273 *args: API names in dot delimited format. 

274 **kwargs: Optional keyed arguments. 

275 v1: Names for the TensorFlow V1 API. If not set, we will use V2 API 

276 names both for TensorFlow V1 and V2 APIs. 

277 overrides: List of symbols that this is overriding 

278 (those overrided api exports will be removed). Note: passing overrides 

279 has no effect on exporting a constant. 

280 api_name: Name of the API you want to generate (e.g. `tensorflow` or 

281 `estimator`). Default is `tensorflow`. 

282 allow_multiple_exports: Allow symbol to be exported multiple time under 

283 different names. 

284 """ 

285 self._names = args 

286 self._names_v1 = kwargs.get('v1', args) 

287 if 'v2' in kwargs: 

288 raise ValueError('You passed a "v2" argument to tf_export. This is not ' 

289 'what you want. Pass v2 names directly as positional ' 

290 'arguments instead.') 

291 self._api_name = kwargs.get('api_name', TENSORFLOW_API_NAME) 

292 self._overrides = kwargs.get('overrides', []) 

293 self._allow_multiple_exports = kwargs.get('allow_multiple_exports', False) 

294 

295 self._validate_symbol_names() 

296 

297 def _validate_symbol_names(self): 

298 """Validate you are exporting symbols under an allowed package. 

299 

300 We need to ensure things exported by tf_export, estimator_export, etc. 

301 export symbols under disjoint top-level package names. 

302 

303 For TensorFlow, we check that it does not export anything under subpackage 

304 names used by components (estimator, keras, etc.). 

305 

306 For each component, we check that it exports everything under its own 

307 subpackage. 

308 

309 Raises: 

310 InvalidSymbolNameError: If you try to export symbol under disallowed name. 

311 """ 

312 all_symbol_names = set(self._names) | set(self._names_v1) 

313 if self._api_name == TENSORFLOW_API_NAME: 

314 for subpackage in SUBPACKAGE_NAMESPACES: 

315 if any(n.startswith(subpackage) for n in all_symbol_names): 

316 raise InvalidSymbolNameError( 

317 '@tf_export is not allowed to export symbols under %s.*' % ( 

318 subpackage)) 

319 else: 

320 if not all(n.startswith(self._api_name) for n in all_symbol_names): 

321 raise InvalidSymbolNameError( 

322 'Can only export symbols under package name of component. ' 

323 'e.g. tensorflow_estimator must export all symbols under ' 

324 'tf.estimator') 

325 

326 def __call__(self, func): 

327 """Calls this decorator. 

328 

329 Args: 

330 func: decorated symbol (function or class). 

331 

332 Returns: 

333 The input function with _tf_api_names attribute set. 

334 

335 Raises: 

336 SymbolAlreadyExposedError: Raised when a symbol already has API names 

337 and kwarg `allow_multiple_exports` not set. 

338 """ 

339 api_names_attr = API_ATTRS[self._api_name].names 

340 api_names_attr_v1 = API_ATTRS_V1[self._api_name].names 

341 # Undecorate overridden names 

342 for f in self._overrides: 

343 _, undecorated_f = tf_decorator.unwrap(f) 

344 delattr(undecorated_f, api_names_attr) 

345 delattr(undecorated_f, api_names_attr_v1) 

346 

347 _, undecorated_func = tf_decorator.unwrap(func) 

348 self.set_attr(undecorated_func, api_names_attr, self._names) 

349 self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1) 

350 

351 for name in self._names: 

352 _NAME_TO_SYMBOL_MAPPING[name] = func 

353 for name_v1 in self._names_v1: 

354 _NAME_TO_SYMBOL_MAPPING['compat.v1.%s' % name_v1] = func 

355 

356 return func 

357 

358 def set_attr(self, func, api_names_attr, names): 

359 # Check for an existing api. We check if attribute name is in 

360 # __dict__ instead of using hasattr to verify that subclasses have 

361 # their own _tf_api_names as opposed to just inheriting it. 

362 if api_names_attr in func.__dict__: 

363 if not self._allow_multiple_exports: 

364 raise SymbolAlreadyExposedError( 

365 'Symbol %s is already exposed as %s.' % 

366 (func.__name__, getattr(func, api_names_attr))) # pylint: disable=protected-access 

367 setattr(func, api_names_attr, names) 

368 

369 def export_constant(self, module_name, name): 

370 """Store export information for constants/string literals. 

371 

372 Export information is stored in the module where constants/string literals 

373 are defined. 

374 

375 e.g. 

376 ```python 

377 foo = 1 

378 bar = 2 

379 tf_export("consts.foo").export_constant(__name__, 'foo') 

380 tf_export("consts.bar").export_constant(__name__, 'bar') 

381 ``` 

382 

383 Args: 

384 module_name: (string) Name of the module to store constant at. 

385 name: (string) Current constant name. 

386 """ 

387 module = sys.modules[module_name] 

388 api_constants_attr = API_ATTRS[self._api_name].constants 

389 api_constants_attr_v1 = API_ATTRS_V1[self._api_name].constants 

390 

391 if not hasattr(module, api_constants_attr): 

392 setattr(module, api_constants_attr, []) 

393 # pylint: disable=protected-access 

394 getattr(module, api_constants_attr).append( 

395 (self._names, name)) 

396 

397 if not hasattr(module, api_constants_attr_v1): 

398 setattr(module, api_constants_attr_v1, []) 

399 getattr(module, api_constants_attr_v1).append( 

400 (self._names_v1, name)) 

401 

402 

403def kwarg_only(f): 

404 """A wrapper that throws away all non-kwarg arguments.""" 

405 f_argspec = tf_inspect.getfullargspec(f) 

406 

407 def wrapper(*args, **kwargs): 

408 if args: 

409 raise TypeError( 

410 '{f} only takes keyword args (possible keys: {kwargs}). ' 

411 'Please pass these args as kwargs instead.' 

412 .format(f=f.__name__, kwargs=f_argspec.args)) 

413 return f(**kwargs) 

414 

415 return tf_decorator.make_decorator( 

416 f, wrapper, decorator_argspec=f_argspec) 

417 

418 

419tf_export = functools.partial(api_export, api_name=TENSORFLOW_API_NAME) 

420keras_export = functools.partial(api_export, api_name=KERAS_API_NAME)