Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/template.py: 30%

247 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Provides templates which allow variable sharing.""" 

16import functools 

17import traceback 

18from tensorflow.python.checkpoint import checkpoint as trackable_util 

19from tensorflow.python.eager import context 

20from tensorflow.python.eager import def_function 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import variable_scope 

23from tensorflow.python.platform import tf_logging as logging 

24from tensorflow.python.trackable import base as trackable 

25from tensorflow.python.util import object_identity 

26from tensorflow.python.util import tf_contextlib 

27from tensorflow.python.util.deprecation import deprecated 

28from tensorflow.python.util.tf_export import tf_export 

29 

30__all__ = ["make_template"] 

31 

32 

33@tf_export(v1=["make_template"]) 

34def make_template(name_, 

35 func_, 

36 create_scope_now_=False, 

37 unique_name_=None, 

38 custom_getter_=None, 

39 **kwargs): 

40 """Given an arbitrary function, wrap it so that it does variable sharing. 

41 

42 @compatibility(TF2) 

43 `tf.compat.v1.make_template` is a legacy API that is only compatible 

44 with eager execution enabled and `tf.function` if you combine it with 

45 `tf.compat.v1.keras.utils.track_tf1_style_variables`. See the model mapping 

46 migration guide section on `make_template` for more info: 

47 

48 https://www.tensorflow.org/guide/migrate/model_mapping#using_tfcompatv1make_template_in_the_decorated_method 

49 

50 Even if you use legacy apis for `variable_scope`-based variable reuse, 

51 we recommend using 

52 `tf.compat.v1.keras.utils.track_tf1_style_variables` directly and not using 

53 `tf.compat.v1.make_template`, as it interoperates with eager execution in a 

54 simpler and more predictable fashion than `make_template`. 

55 

56 The TF2 API approach would be tracking your variables using 

57 `tf.Module`s or Keras layers and models rather than relying on 

58 `make_template`. 

59 @end_compatibility 

60 

61 This wraps `func_` in a Template and partially evaluates it. Templates are 

62 functions that create variables the first time they are called and reuse them 

63 thereafter. In order for `func_` to be compatible with a `Template` it must 

64 have the following properties: 

65 

66 * The function should create all trainable variables and any variables that 

67 should be reused by calling `tf.compat.v1.get_variable`. If a trainable 

68 variable is 

69 created using `tf.Variable`, then a ValueError will be thrown. Variables 

70 that are intended to be locals can be created by specifying 

71 `tf.Variable(..., trainable=false)`. 

72 * The function may use variable scopes and other templates internally to 

73 create and reuse variables, but it shouldn't use 

74 `tf.compat.v1.global_variables` to 

75 capture variables that are defined outside of the scope of the function. 

76 * Internal scopes and variable names should not depend on any arguments that 

77 are not supplied to `make_template`. In general you will get a ValueError 

78 telling you that you are trying to reuse a variable that doesn't exist 

79 if you make a mistake. 

80 

81 In the following example, both `z` and `w` will be scaled by the same `y`. It 

82 is important to note that if we didn't assign `scalar_name` and used a 

83 different name for z and w that a `ValueError` would be thrown because it 

84 couldn't reuse the variable. 

85 

86 ```python 

87 def my_op(x, scalar_name): 

88 var1 = tf.compat.v1.get_variable(scalar_name, 

89 shape=[], 

90 initializer=tf.compat.v1.constant_initializer(1)) 

91 return x * var1 

92 

93 scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y') 

94 

95 z = scale_by_y(input1) 

96 w = scale_by_y(input2) 

97 ``` 

98 

99 As a safe-guard, the returned function will raise a `ValueError` after the 

100 first call if trainable variables are created by calling `tf.Variable`. 

101 

102 If all of these are true, then 2 properties are enforced by the template: 

103 

104 1. Calling the same template multiple times will share all non-local 

105 variables. 

106 2. Two different templates are guaranteed to be unique, unless you reenter the 

107 same variable scope as the initial definition of a template and redefine 

108 it. An examples of this exception: 

109 

110 ```python 

111 def my_op(x, scalar_name): 

112 var1 = tf.compat.v1.get_variable(scalar_name, 

113 shape=[], 

114 initializer=tf.compat.v1.constant_initializer(1)) 

115 return x * var1 

116 

117 with tf.compat.v1.variable_scope('scope') as vs: 

118 scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, 

119 scalar_name='y') 

120 z = scale_by_y(input1) 

121 w = scale_by_y(input2) 

122 

123 # Creates a template that reuses the variables above. 

124 with tf.compat.v1.variable_scope(vs, reuse=True): 

125 scale_by_y2 = tf.compat.v1.make_template('scale_by_y', my_op, 

126 scalar_name='y') 

127 z2 = scale_by_y2(input1) 

128 w2 = scale_by_y2(input2) 

129 ``` 

130 

131 Depending on the value of `create_scope_now_`, the full variable scope may be 

132 captured either at the time of first call or at the time of construction. If 

133 this option is set to True, then all Tensors created by repeated calls to the 

134 template will have an extra trailing _N+1 to their name, as the first time the 

135 scope is entered in the Template constructor no Tensors are created. 

136 

137 Note: `name_`, `func_` and `create_scope_now_` have a trailing underscore to 

138 reduce the likelihood of collisions with kwargs. 

139 

140 Args: 

141 name_: A name for the scope created by this template. If necessary, the name 

142 will be made unique by appending `_N` to the name. 

143 func_: The function to wrap. 

144 create_scope_now_: Boolean controlling whether the scope should be created 

145 when the template is constructed or when the template is called. Default 

146 is False, meaning the scope is created when the template is called. 

147 unique_name_: When used, it overrides name_ and is not made unique. If a 

148 template of the same scope/unique_name already exists and reuse is false, 

149 an error is raised. Defaults to None. 

150 custom_getter_: Optional custom getter for variables used in `func_`. See 

151 the `tf.compat.v1.get_variable` `custom_getter` documentation for more 

152 information. 

153 **kwargs: Keyword arguments to apply to `func_`. 

154 

155 Returns: 

156 A function to encapsulate a set of variables which should be created once 

157 and reused. An enclosing scope will be created either when `make_template` 

158 is called or when the result is called, depending on the value of 

159 `create_scope_now_`. Regardless of the value, the first time the template 

160 is called it will enter the scope with no reuse, and call `func_` to create 

161 variables, which are guaranteed to be unique. All subsequent calls will 

162 re-enter the scope and reuse those variables. 

163 

164 Raises: 

165 ValueError: if `name_` is None. 

166 """ 

167 return make_template_internal( 

168 name_, 

169 func_, 

170 create_scope_now_, 

171 unique_name_, 

172 custom_getter_, 

173 create_graph_function_=False, 

174 **kwargs) 

175 

176 

177def make_template_internal(name_, 

178 func_, 

179 create_scope_now_=False, 

180 unique_name_=None, 

181 custom_getter_=None, 

182 create_graph_function_=False, 

183 **kwargs): 

184 """Make a template, optionally compiling func_ into a graph function. 

185 

186 See `make_template` for full documentation. 

187 

188 Args: 

189 name_: A name for the scope created by this template. If necessary, the name 

190 will be made unique by appending `_N` to the name. 

191 func_: The function to wrap. 

192 create_scope_now_: Boolean controlling whether the scope should be created 

193 when the template is constructed or when the template is called. Default 

194 is False, meaning the scope is created when the template is called. 

195 unique_name_: When used, it overrides name_ and is not made unique. If a 

196 template of the same scope/unique_name already exists and reuse is false, 

197 an error is raised. Defaults to None. If executing eagerly, must be None. 

198 custom_getter_: Optional custom getter for variables used in `func_`. See 

199 the `tf.compat.v1.get_variable` `custom_getter` documentation for more 

200 information. 

201 create_graph_function_: When True, `func_` will be executed as a graph 

202 function. This implies that `func_` must satisfy the properties that 

203 `function.defun` requires of functions: See the documentation of 

204 `function.defun` for details. When executing eagerly, setting this flag 

205 to True can improve performance. Regardless of whether eager execution 

206 is enabled, enabling this flag gives the caller access to graph-function 

207 semantics, i.e., accesses to variables are totally ordered and 

208 side-effecting ops are not pruned. 

209 **kwargs: Keyword arguments to apply to `func_`. 

210 

211 Returns: 

212 A function to encapsulate a set of variables which should be created once 

213 and reused. An enclosing scope will be created either when `make_template` 

214 is called or when the result is called, depending on the value of 

215 `create_scope_now_`. Regardless of the value, the first time the template 

216 is called it will enter the scope with no reuse, and call `func_` to create 

217 variables, which are guaranteed to be unique. All subsequent calls will 

218 re-enter the scope and reuse those variables. 

219 

220 Raises: 

221 ValueError: if `name_` is None. 

222 ValueError: if `unique_name_` is not None and eager execution is enabled. 

223 """ 

224 

225 if kwargs: 

226 func_ = functools.partial(func_, **kwargs) 

227 

228 if context.executing_eagerly(): 

229 if unique_name_ is not None: 

230 raise ValueError( 

231 "unique_name_ cannot be used when eager execution is enabled.") 

232 return EagerTemplate( 

233 name_, 

234 func_, 

235 create_scope_now=create_scope_now_, 

236 custom_getter=custom_getter_, 

237 create_graph_function=create_graph_function_) 

238 return Template( 

239 name_, 

240 func_, 

241 create_scope_now=create_scope_now_, 

242 unique_name=unique_name_, 

243 custom_getter=custom_getter_, 

244 create_graph_function=create_graph_function_) 

245 

246 

247def _skip_common_stack_elements(stacktrace, base_case): 

248 """Skips items that the target stacktrace shares with the base stacktrace.""" 

249 for i, (trace, base) in enumerate(zip(stacktrace, base_case)): 

250 if trace != base: 

251 return stacktrace[i:] 

252 return stacktrace[-1:] 

253 

254 

255class Template(trackable.Trackable): 

256 """Wrap a function to aid in variable sharing. 

257 

258 Templates are functions that create variables the first time they are called 

259 and reuse them thereafter. See `make_template` for full documentation. 

260 

261 Note: By default, the full variable scope is captured at the time of first 

262 call. If `create_scope_now_` is passed as True to the constructor, the full 

263 scope will be captured there, but no variables will created until the first 

264 call. 

265 """ 

266 

267 def __init__(self, 

268 name, 

269 func, 

270 create_scope_now=False, 

271 unique_name=None, 

272 custom_getter=None, 

273 create_graph_function=False): 

274 """Creates a template for the given function. 

275 

276 Args: 

277 name: A name for the scope created by this template. The name will be made 

278 unique by appending `_N` to the it (see how 

279 `tf.compat.v1.variable_scope` treats the `default_name` for details). 

280 func: The function to apply each time. 

281 create_scope_now: Whether to create the scope at Template construction 

282 time, rather than first call. Defaults to false. Creating the scope at 

283 construction time may be more convenient if the template is to passed 

284 through much lower level code, and you want to be sure of the scope name 

285 without knowing exactly where it will be first called. If set to True, 

286 the scope will be created in the constructor, and all subsequent times 

287 in `__call__`, leading to a trailing numeral being added to the names of 

288 all created Tensors. If set to False, the scope will be created at the 

289 first call location. 

290 unique_name: When used, it overrides `name` and is not made unique. If a 

291 template of the same scope/unique_name already exists and reuse is 

292 false, an error is raised. Defaults to None. 

293 custom_getter: optional custom getter to pass to `variable_scope()` 

294 create_graph_function: When True, `func` will be executed as a graph 

295 function. Enabling this flag gives the caller access to graph-function 

296 semantics, i.e., accesses to variables are totally ordered and 

297 side-effecting ops are not pruned. 

298 

299 Raises: 

300 ValueError: if `name` is None. 

301 """ 

302 if create_graph_function: 

303 self._func = def_function.function(func) 

304 else: 

305 self._func = func 

306 self._stacktrace = traceback.format_stack()[:-2] 

307 self._name = name 

308 self._unique_name = unique_name 

309 self._custom_getter = custom_getter 

310 if name is None: 

311 raise ValueError("name cannot be None.") 

312 if create_scope_now: 

313 with variable_scope._pure_variable_scope( # pylint:disable=protected-access 

314 (self._unique_name or 

315 variable_scope._get_unique_variable_scope(self._name)), # pylint:disable=protected-access 

316 custom_getter=self._custom_getter) as vs: 

317 self._variable_scope = vs 

318 else: 

319 self._variable_scope = None 

320 # This variable keeps track of whether the template has been called to 

321 # completion, which is not the same as whether the scope has been created. 

322 self._variables_created = False 

323 # `MirroredStrategy` builds the graph with multiple threads. If a 

324 # `merge_call` happens within a template, multiple calls may be in progress 

325 # simultaneously. This variable keeps track of whether any call of the 

326 # template has started. 

327 self._first_call = True 

328 

329 def _call_func(self, args, kwargs): 

330 try: 

331 if self._variables_created: 

332 vars_at_start = len( 

333 ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) 

334 trainable_at_start = len( 

335 ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) 

336 

337 result = self._func(*args, **kwargs) 

338 

339 # Variables were previously created, implying this is not the first 

340 # time the template has been called. Check to make sure that no new 

341 # trainable variables were created this time around. 

342 trainable_variables = ops.get_collection_ref( 

343 ops.GraphKeys.TRAINABLE_VARIABLES) 

344 

345 # If a variable that we intend to train is created as a side effect 

346 # of creating a template, then that is almost certainly an error. 

347 if trainable_at_start != len(trainable_variables): 

348 raise ValueError("Trainable variable created when calling a template " 

349 "after the first time, perhaps you used tf.Variable " 

350 "when you meant tf.get_variable: %s" % 

351 (trainable_variables[trainable_at_start:],)) 

352 

353 # Non-trainable tracking variables are a legitimate reason why a new 

354 # variable would be created, but it is a relatively advanced use-case, 

355 # so log it. 

356 variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES) 

357 if vars_at_start != len(variables): 

358 logging.info( 

359 "New variables created when calling a template after " 

360 "the first time, perhaps you used tf.Variable when you " 

361 "meant tf.get_variable: %s", variables[vars_at_start:]) 

362 elif self._first_call: 

363 self._first_call = False 

364 try: 

365 # The first time we run, restore variables if necessary (via 

366 # Trackable). 

367 with trackable_util.capture_dependencies(template=self): 

368 result = self._func(*args, **kwargs) 

369 except: 

370 self._first_call = True 

371 raise 

372 self._variables_created = True 

373 else: # We are calling the template in parallel from another thread. 

374 result = self._func(*args, **kwargs) 

375 return result 

376 except Exception as exc: 

377 # Reraise the exception, but append the original definition to the 

378 # trace. 

379 args = exc.args 

380 if not args: 

381 arg0 = "" 

382 else: 

383 arg0 = args[0] 

384 trace = "".join( 

385 _skip_common_stack_elements(self._stacktrace, 

386 traceback.format_stack())) 

387 arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) 

388 new_args = [arg0] 

389 new_args.extend(args[1:]) 

390 exc.args = tuple(new_args) 

391 raise 

392 

393 def __call__(self, *args, **kwargs): 

394 if self._variable_scope: 

395 # Only reuse variables if not on first call. 

396 with variable_scope.variable_scope( 

397 self._variable_scope, reuse=not self._first_call): 

398 return self._call_func(args, kwargs) 

399 else: 

400 # The scope was not created at construction time, so create it here. 

401 # Subsequent calls should reuse variables. 

402 with variable_scope.variable_scope( 

403 self._unique_name, self._name, 

404 custom_getter=self._custom_getter) as vs: 

405 self._variable_scope = vs 

406 return self._call_func(args, kwargs) 

407 

408 @property 

409 def name(self): 

410 """Returns the name given to this Template.""" 

411 return self._name 

412 

413 @property 

414 def func(self): 

415 """Returns the func given to this Template.""" 

416 return self._func 

417 

418 @property 

419 def variable_scope(self): 

420 """Returns the variable scope object created by this Template.""" 

421 return self._variable_scope 

422 

423 @property 

424 def variable_scope_name(self): 

425 """Returns the variable scope name created by this Template.""" 

426 if self._variable_scope: 

427 name = self._variable_scope.name 

428 if not name or name[-1] == "/": 

429 return name 

430 else: 

431 # To prevent partial matches on the scope_name, we add '/' at the end. 

432 return name + "/" 

433 

434 @property 

435 def variables(self): 

436 """Returns the list of global and local variables created by the Template.""" 

437 return self.global_variables + self.local_variables 

438 

439 @property 

440 def trainable_variables(self): 

441 """Returns the list of trainable variables created by the Template.""" 

442 if self._variables_created: 

443 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, 

444 self.variable_scope_name) 

445 else: 

446 return [] 

447 

448 @property 

449 def non_trainable_variables(self): 

450 """Returns the list of non-trainable variables created by the Template.""" 

451 # TODO(apassos) Make sure it matches Eager when using local variables. 

452 global_variables = self.global_variables 

453 trainable_variables = set(self.trainable_variables) 

454 return [x for x in global_variables if x not in trainable_variables] 

455 

456 @property 

457 def global_variables(self): 

458 """Returns the list of global variables created by the Template.""" 

459 if self._variables_created: 

460 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 

461 self.variable_scope_name) 

462 else: 

463 return [] 

464 

465 @property 

466 def local_variables(self): 

467 """Returns the list of global variables created by the Template.""" 

468 if self._variables_created: 

469 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, 

470 self.variable_scope_name) 

471 else: 

472 return [] 

473 

474 @property 

475 def weights(self): 

476 """List of weights/variables created by the Template.""" 

477 return self.variables 

478 

479 @property 

480 def trainable_weights(self): 

481 """List of trainable weights/variables created by the Template.""" 

482 return self.trainable_variables 

483 

484 @property 

485 def non_trainable_weights(self): 

486 """List of non-trainable weights/variables created by the Template.""" 

487 return self.non_trainable_variables 

488 

489 @property 

490 @deprecated("2017-02-21", 

491 "The .var_scope property is deprecated. Please change your " 

492 "code to use the .variable_scope property") 

493 def var_scope(self): 

494 """Returns the variable scope object created by this Template.""" 

495 return self._variable_scope 

496 

497 

498class _EagerTemplateVariableStore: 

499 """Wrapper around EagerVariableStore to support nesting EagerTemplates.""" 

500 

501 def __init__(self, variable_scope_name): 

502 self._variable_scope_name = variable_scope_name 

503 default = variable_scope._get_default_variable_store() # pylint: disable=protected-access 

504 if default._store_eager_variables: # pylint: disable=protected-access 

505 self._eager_variable_store = variable_scope.EagerVariableStore(default) 

506 else: 

507 # If no outer eager variable store has been made, 

508 # the template needs to create one 

509 self._eager_variable_store = variable_scope.EagerVariableStore() 

510 self._used_once = False 

511 

512 def set_variable_scope_name(self, variable_scope_name): 

513 self._variable_scope_name = variable_scope_name 

514 

515 @tf_contextlib.contextmanager 

516 def as_default(self): 

517 try: 

518 if not self._used_once: 

519 # If an outer eager VariableStore was explicitly created and set by 

520 # the first time this template store was used (even if not at 

521 # constructor time) then pick up the outer variable store. 

522 default = variable_scope._get_default_variable_store() # pylint: disable=protected-access 

523 if default._store_eager_variables: # pylint: disable=protected-access 

524 self._eager_variable_store._store = default # pylint: disable=protected-access 

525 self._used_once = True 

526 with self._eager_variable_store.as_default(): # pylint: disable=protected-access 

527 yield 

528 finally: 

529 # Each _EagerTemplateVariableStore object lives underneath a variable 

530 # scope (see EagerTemplate.__call__). This variable scope's subscopes are 

531 # closed when the EagerTemplate object returns from __call__. For 

532 # top-level _EagerTemplateVariableStore objects, the variable store to 

533 # which the variable scope is attached is different from the 

534 # EagerVariableStore; as such it is necessary to close its subscopes 

535 # here as well. 

536 if self._variable_scope_name is None: 

537 raise RuntimeError("A variable scope must be set before an " 

538 "_EagerTemplateVariableStore object exits.") 

539 variable_scope.get_variable_scope_store().close_variable_subscopes( 

540 self._variable_scope_name) 

541 

542 def _variables_in_scope(self, variable_list): 

543 if self._variable_scope_name is None: 

544 raise RuntimeError( 

545 "A variable scope must be set before variables can be accessed.") 

546 return [ 

547 v for v in variable_list 

548 if v.name.startswith(self._variable_scope_name + "/") 

549 ] 

550 

551 def variables(self): 

552 return self._variables_in_scope(self._eager_variable_store.variables()) 

553 

554 def trainable_variables(self): 

555 return self._variables_in_scope( 

556 self._eager_variable_store.trainable_variables()) 

557 

558 def non_trainable_variables(self): 

559 return self._variables_in_scope( 

560 self._eager_variable_store.non_trainable_variables()) 

561 

562 

563class EagerTemplate(Template): 

564 """Wrap a function to aid in variable sharing in Eager mode. 

565 

566 Templates are functions that create variables the first time they are called 

567 and reuse them thereafter. See `make_template` for full documentation. 

568 

569 Note: By default, the full variable scope is captured at the time of first 

570 call. If `create_scope_now` is passed as True to the constructor, the full 

571 scope will be captured there, but no variables will be created until the first 

572 call. 

573 """ 

574 

575 def __init__(self, 

576 name, 

577 func, 

578 create_scope_now=False, 

579 custom_getter=None, 

580 create_graph_function=False): 

581 """Creates a template for the given function. 

582 

583 Args: 

584 name: A name for the scope created by this template. The name will be made 

585 unique by appending `_N` to the it (see how 

586 `tf.compat.v1.variable_scope` treats the `default_name` for details). 

587 func: The function to apply each time. 

588 create_scope_now: Whether to create the scope at Template construction 

589 time, rather than first call. Defaults to false. Creating the scope at 

590 construction time may be more convenient if the template is passed 

591 through much lower level code, and you want to be sure of the scope name 

592 without knowing exactly where it will be first called. If set to True, 

593 the scope will be created in the constructor, and all subsequent times 

594 in `__call__`, leading to a trailing numeral being added to the names of 

595 all created Tensors. If set to False, the scope will be created at the 

596 first call location. 

597 custom_getter: optional custom getter to pass to `variable_scope()` 

598 create_graph_function: When True, `func` will be executed as a graph 

599 function. Enabling this flag allows the caller to reap the performance 

600 benefits associated with executing graphs, at the cost of sacrificing 

601 debuggability; however, not all Python functions can be compiled into 

602 graph functions. See the documentation for `function.defun` for details. 

603 

604 Raises: 

605 RuntimeError: if eager execution is not enabled. 

606 """ 

607 if not context.executing_eagerly(): 

608 raise RuntimeError( 

609 "{} objects can only be used when eager execution is enabled, use " 

610 "tf.Template for graph construction".format(type(self))) 

611 super(EagerTemplate, self).__init__(name, func, create_scope_now, None, 

612 custom_getter, create_graph_function) 

613 if self._variable_scope is not None: 

614 variable_scope_name = self._variable_scope.name 

615 else: 

616 # Defer setting the variable scope name until the variable scope 

617 # is created in __call__. 

618 variable_scope_name = None 

619 self._template_store = _EagerTemplateVariableStore(variable_scope_name) 

620 self._variable_scope_context_manager = None 

621 

622 def _call_func(self, args, kwargs): 

623 try: 

624 vars_at_start = self._template_store.variables() 

625 trainable_at_start = self._template_store.trainable_variables() 

626 if self._variables_created: 

627 result = self._func(*args, **kwargs) 

628 else: 

629 # The first time we run, restore variables if necessary (via 

630 # Trackable). 

631 with trackable_util.capture_dependencies(template=self): 

632 result = self._func(*args, **kwargs) 

633 

634 if self._variables_created: 

635 # Variables were previously created, implying this is not the first 

636 # time the template has been called. Check to make sure that no new 

637 # trainable variables were created this time around. 

638 trainable_variables = self._template_store.trainable_variables() 

639 # If a variable that we intend to train is created as a side effect 

640 # of creating a template, then that is almost certainly an error. 

641 if len(trainable_at_start) != len(trainable_variables): 

642 raise ValueError( 

643 "Trainable variable created when calling a template " 

644 "after the first time, perhaps you used tf.Variable " 

645 "when you meant tf.get_variable: %s" % list( 

646 object_identity.ObjectIdentitySet(trainable_variables) - 

647 object_identity.ObjectIdentitySet(trainable_at_start))) 

648 

649 # Non-trainable tracking variables are a legitimate reason why a new 

650 # variable would be created, but it is a relatively advanced use-case, 

651 # so log it. 

652 variables = self._template_store.variables() 

653 if len(vars_at_start) != len(variables): 

654 logging.info( 

655 "New variables created when calling a template after " 

656 "the first time, perhaps you used tf.Variable when you " 

657 "meant tf.get_variable: %s", 

658 list( 

659 object_identity.ObjectIdentitySet(variables) - 

660 object_identity.ObjectIdentitySet(vars_at_start))) 

661 else: 

662 self._variables_created = True 

663 return result 

664 except Exception as exc: 

665 # Reraise the exception, but append the original definition to the 

666 # trace. 

667 args = exc.args 

668 if not args: 

669 arg0 = "" 

670 else: 

671 arg0 = args[0] 

672 trace = "".join( 

673 _skip_common_stack_elements(self._stacktrace, 

674 traceback.format_stack())) 

675 arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) 

676 new_args = [arg0] 

677 new_args.extend(args[1:]) 

678 exc.args = tuple(new_args) 

679 raise 

680 

681 def __call__(self, *args, **kwargs): 

682 # In both branches below, the template store is installed as default after 

683 # the variable scope is opened in order to ensure that templates nested at 

684 # the same level correctly uniquify lower variable scope names. 

685 if self._variable_scope: 

686 # Create a cache for the variable scope context manager the first time 

687 # around so that we don't have to keep recreating it. 

688 if not self._variable_scope_context_manager: 

689 self._variable_scope_context_manager = variable_scope.variable_scope( 

690 self._variable_scope, reuse=variable_scope.AUTO_REUSE) 

691 with self._variable_scope_context_manager: 

692 with self._template_store.as_default(): 

693 return self._call_func(args, kwargs) 

694 else: 

695 # The scope was not created at construction time, so create it here. 

696 # Subsequent calls should reuse variables. 

697 with variable_scope.variable_scope( 

698 self._unique_name, self._name, 

699 custom_getter=self._custom_getter) as vs: 

700 self._variable_scope = vs 

701 # Because the scope was not created at construction time, the template 

702 # store's variable scope name is unset; set it here. 

703 self._template_store.set_variable_scope_name(vs.name) 

704 with self._template_store.as_default(): 

705 return self._call_func(args, kwargs) 

706 

707 @property 

708 def variables(self): 

709 """Returns the list of variables created by the Template.""" 

710 # Currently there is no local variable in Eager mode. 

711 if not self._variables_created: 

712 return [] 

713 return self._template_store.variables() 

714 

715 @property 

716 def trainable_variables(self): 

717 """Returns the list of trainable variables created by the Template.""" 

718 # Currently there is no local variable in Eager mode. 

719 if not self._variables_created: 

720 return [] 

721 return self._template_store.trainable_variables() 

722 

723 @property 

724 def non_trainable_variables(self): 

725 """Returns the list of non-trainable variables created by the Template.""" 

726 # Currently there is no local variable in Eager mode. 

727 if not self._variables_created: 

728 return [] 

729 return self._template_store.non_trainable_variables() 

730 

731 @property 

732 def global_variables(self): 

733 """Returns the list of global variables created by the Template.""" 

734 # Currently there is no local variable in Eager mode. 

735 if not self._variables_created: 

736 return [] 

737 return self.variables 

738 

739 @property 

740 def local_variables(self): 

741 """Returns the list of global variables created by the Template.""" 

742 # Currently there is no local variable in Eager mode. 

743 return []