Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/template.py: 30%
247 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Provides templates which allow variable sharing."""
16import functools
17import traceback
18from tensorflow.python.checkpoint import checkpoint as trackable_util
19from tensorflow.python.eager import context
20from tensorflow.python.eager import def_function
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import variable_scope
23from tensorflow.python.platform import tf_logging as logging
24from tensorflow.python.trackable import base as trackable
25from tensorflow.python.util import object_identity
26from tensorflow.python.util import tf_contextlib
27from tensorflow.python.util.deprecation import deprecated
28from tensorflow.python.util.tf_export import tf_export
30__all__ = ["make_template"]
33@tf_export(v1=["make_template"])
34def make_template(name_,
35 func_,
36 create_scope_now_=False,
37 unique_name_=None,
38 custom_getter_=None,
39 **kwargs):
40 """Given an arbitrary function, wrap it so that it does variable sharing.
42 @compatibility(TF2)
43 `tf.compat.v1.make_template` is a legacy API that is only compatible
44 with eager execution enabled and `tf.function` if you combine it with
45 `tf.compat.v1.keras.utils.track_tf1_style_variables`. See the model mapping
46 migration guide section on `make_template` for more info:
48 https://www.tensorflow.org/guide/migrate/model_mapping#using_tfcompatv1make_template_in_the_decorated_method
50 Even if you use legacy apis for `variable_scope`-based variable reuse,
51 we recommend using
52 `tf.compat.v1.keras.utils.track_tf1_style_variables` directly and not using
53 `tf.compat.v1.make_template`, as it interoperates with eager execution in a
54 simpler and more predictable fashion than `make_template`.
56 The TF2 API approach would be tracking your variables using
57 `tf.Module`s or Keras layers and models rather than relying on
58 `make_template`.
59 @end_compatibility
61 This wraps `func_` in a Template and partially evaluates it. Templates are
62 functions that create variables the first time they are called and reuse them
63 thereafter. In order for `func_` to be compatible with a `Template` it must
64 have the following properties:
66 * The function should create all trainable variables and any variables that
67 should be reused by calling `tf.compat.v1.get_variable`. If a trainable
68 variable is
69 created using `tf.Variable`, then a ValueError will be thrown. Variables
70 that are intended to be locals can be created by specifying
71 `tf.Variable(..., trainable=false)`.
72 * The function may use variable scopes and other templates internally to
73 create and reuse variables, but it shouldn't use
74 `tf.compat.v1.global_variables` to
75 capture variables that are defined outside of the scope of the function.
76 * Internal scopes and variable names should not depend on any arguments that
77 are not supplied to `make_template`. In general you will get a ValueError
78 telling you that you are trying to reuse a variable that doesn't exist
79 if you make a mistake.
81 In the following example, both `z` and `w` will be scaled by the same `y`. It
82 is important to note that if we didn't assign `scalar_name` and used a
83 different name for z and w that a `ValueError` would be thrown because it
84 couldn't reuse the variable.
86 ```python
87 def my_op(x, scalar_name):
88 var1 = tf.compat.v1.get_variable(scalar_name,
89 shape=[],
90 initializer=tf.compat.v1.constant_initializer(1))
91 return x * var1
93 scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op, scalar_name='y')
95 z = scale_by_y(input1)
96 w = scale_by_y(input2)
97 ```
99 As a safe-guard, the returned function will raise a `ValueError` after the
100 first call if trainable variables are created by calling `tf.Variable`.
102 If all of these are true, then 2 properties are enforced by the template:
104 1. Calling the same template multiple times will share all non-local
105 variables.
106 2. Two different templates are guaranteed to be unique, unless you reenter the
107 same variable scope as the initial definition of a template and redefine
108 it. An examples of this exception:
110 ```python
111 def my_op(x, scalar_name):
112 var1 = tf.compat.v1.get_variable(scalar_name,
113 shape=[],
114 initializer=tf.compat.v1.constant_initializer(1))
115 return x * var1
117 with tf.compat.v1.variable_scope('scope') as vs:
118 scale_by_y = tf.compat.v1.make_template('scale_by_y', my_op,
119 scalar_name='y')
120 z = scale_by_y(input1)
121 w = scale_by_y(input2)
123 # Creates a template that reuses the variables above.
124 with tf.compat.v1.variable_scope(vs, reuse=True):
125 scale_by_y2 = tf.compat.v1.make_template('scale_by_y', my_op,
126 scalar_name='y')
127 z2 = scale_by_y2(input1)
128 w2 = scale_by_y2(input2)
129 ```
131 Depending on the value of `create_scope_now_`, the full variable scope may be
132 captured either at the time of first call or at the time of construction. If
133 this option is set to True, then all Tensors created by repeated calls to the
134 template will have an extra trailing _N+1 to their name, as the first time the
135 scope is entered in the Template constructor no Tensors are created.
137 Note: `name_`, `func_` and `create_scope_now_` have a trailing underscore to
138 reduce the likelihood of collisions with kwargs.
140 Args:
141 name_: A name for the scope created by this template. If necessary, the name
142 will be made unique by appending `_N` to the name.
143 func_: The function to wrap.
144 create_scope_now_: Boolean controlling whether the scope should be created
145 when the template is constructed or when the template is called. Default
146 is False, meaning the scope is created when the template is called.
147 unique_name_: When used, it overrides name_ and is not made unique. If a
148 template of the same scope/unique_name already exists and reuse is false,
149 an error is raised. Defaults to None.
150 custom_getter_: Optional custom getter for variables used in `func_`. See
151 the `tf.compat.v1.get_variable` `custom_getter` documentation for more
152 information.
153 **kwargs: Keyword arguments to apply to `func_`.
155 Returns:
156 A function to encapsulate a set of variables which should be created once
157 and reused. An enclosing scope will be created either when `make_template`
158 is called or when the result is called, depending on the value of
159 `create_scope_now_`. Regardless of the value, the first time the template
160 is called it will enter the scope with no reuse, and call `func_` to create
161 variables, which are guaranteed to be unique. All subsequent calls will
162 re-enter the scope and reuse those variables.
164 Raises:
165 ValueError: if `name_` is None.
166 """
167 return make_template_internal(
168 name_,
169 func_,
170 create_scope_now_,
171 unique_name_,
172 custom_getter_,
173 create_graph_function_=False,
174 **kwargs)
177def make_template_internal(name_,
178 func_,
179 create_scope_now_=False,
180 unique_name_=None,
181 custom_getter_=None,
182 create_graph_function_=False,
183 **kwargs):
184 """Make a template, optionally compiling func_ into a graph function.
186 See `make_template` for full documentation.
188 Args:
189 name_: A name for the scope created by this template. If necessary, the name
190 will be made unique by appending `_N` to the name.
191 func_: The function to wrap.
192 create_scope_now_: Boolean controlling whether the scope should be created
193 when the template is constructed or when the template is called. Default
194 is False, meaning the scope is created when the template is called.
195 unique_name_: When used, it overrides name_ and is not made unique. If a
196 template of the same scope/unique_name already exists and reuse is false,
197 an error is raised. Defaults to None. If executing eagerly, must be None.
198 custom_getter_: Optional custom getter for variables used in `func_`. See
199 the `tf.compat.v1.get_variable` `custom_getter` documentation for more
200 information.
201 create_graph_function_: When True, `func_` will be executed as a graph
202 function. This implies that `func_` must satisfy the properties that
203 `function.defun` requires of functions: See the documentation of
204 `function.defun` for details. When executing eagerly, setting this flag
205 to True can improve performance. Regardless of whether eager execution
206 is enabled, enabling this flag gives the caller access to graph-function
207 semantics, i.e., accesses to variables are totally ordered and
208 side-effecting ops are not pruned.
209 **kwargs: Keyword arguments to apply to `func_`.
211 Returns:
212 A function to encapsulate a set of variables which should be created once
213 and reused. An enclosing scope will be created either when `make_template`
214 is called or when the result is called, depending on the value of
215 `create_scope_now_`. Regardless of the value, the first time the template
216 is called it will enter the scope with no reuse, and call `func_` to create
217 variables, which are guaranteed to be unique. All subsequent calls will
218 re-enter the scope and reuse those variables.
220 Raises:
221 ValueError: if `name_` is None.
222 ValueError: if `unique_name_` is not None and eager execution is enabled.
223 """
225 if kwargs:
226 func_ = functools.partial(func_, **kwargs)
228 if context.executing_eagerly():
229 if unique_name_ is not None:
230 raise ValueError(
231 "unique_name_ cannot be used when eager execution is enabled.")
232 return EagerTemplate(
233 name_,
234 func_,
235 create_scope_now=create_scope_now_,
236 custom_getter=custom_getter_,
237 create_graph_function=create_graph_function_)
238 return Template(
239 name_,
240 func_,
241 create_scope_now=create_scope_now_,
242 unique_name=unique_name_,
243 custom_getter=custom_getter_,
244 create_graph_function=create_graph_function_)
247def _skip_common_stack_elements(stacktrace, base_case):
248 """Skips items that the target stacktrace shares with the base stacktrace."""
249 for i, (trace, base) in enumerate(zip(stacktrace, base_case)):
250 if trace != base:
251 return stacktrace[i:]
252 return stacktrace[-1:]
255class Template(trackable.Trackable):
256 """Wrap a function to aid in variable sharing.
258 Templates are functions that create variables the first time they are called
259 and reuse them thereafter. See `make_template` for full documentation.
261 Note: By default, the full variable scope is captured at the time of first
262 call. If `create_scope_now_` is passed as True to the constructor, the full
263 scope will be captured there, but no variables will created until the first
264 call.
265 """
267 def __init__(self,
268 name,
269 func,
270 create_scope_now=False,
271 unique_name=None,
272 custom_getter=None,
273 create_graph_function=False):
274 """Creates a template for the given function.
276 Args:
277 name: A name for the scope created by this template. The name will be made
278 unique by appending `_N` to the it (see how
279 `tf.compat.v1.variable_scope` treats the `default_name` for details).
280 func: The function to apply each time.
281 create_scope_now: Whether to create the scope at Template construction
282 time, rather than first call. Defaults to false. Creating the scope at
283 construction time may be more convenient if the template is to passed
284 through much lower level code, and you want to be sure of the scope name
285 without knowing exactly where it will be first called. If set to True,
286 the scope will be created in the constructor, and all subsequent times
287 in `__call__`, leading to a trailing numeral being added to the names of
288 all created Tensors. If set to False, the scope will be created at the
289 first call location.
290 unique_name: When used, it overrides `name` and is not made unique. If a
291 template of the same scope/unique_name already exists and reuse is
292 false, an error is raised. Defaults to None.
293 custom_getter: optional custom getter to pass to `variable_scope()`
294 create_graph_function: When True, `func` will be executed as a graph
295 function. Enabling this flag gives the caller access to graph-function
296 semantics, i.e., accesses to variables are totally ordered and
297 side-effecting ops are not pruned.
299 Raises:
300 ValueError: if `name` is None.
301 """
302 if create_graph_function:
303 self._func = def_function.function(func)
304 else:
305 self._func = func
306 self._stacktrace = traceback.format_stack()[:-2]
307 self._name = name
308 self._unique_name = unique_name
309 self._custom_getter = custom_getter
310 if name is None:
311 raise ValueError("name cannot be None.")
312 if create_scope_now:
313 with variable_scope._pure_variable_scope( # pylint:disable=protected-access
314 (self._unique_name or
315 variable_scope._get_unique_variable_scope(self._name)), # pylint:disable=protected-access
316 custom_getter=self._custom_getter) as vs:
317 self._variable_scope = vs
318 else:
319 self._variable_scope = None
320 # This variable keeps track of whether the template has been called to
321 # completion, which is not the same as whether the scope has been created.
322 self._variables_created = False
323 # `MirroredStrategy` builds the graph with multiple threads. If a
324 # `merge_call` happens within a template, multiple calls may be in progress
325 # simultaneously. This variable keeps track of whether any call of the
326 # template has started.
327 self._first_call = True
329 def _call_func(self, args, kwargs):
330 try:
331 if self._variables_created:
332 vars_at_start = len(
333 ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
334 trainable_at_start = len(
335 ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
337 result = self._func(*args, **kwargs)
339 # Variables were previously created, implying this is not the first
340 # time the template has been called. Check to make sure that no new
341 # trainable variables were created this time around.
342 trainable_variables = ops.get_collection_ref(
343 ops.GraphKeys.TRAINABLE_VARIABLES)
345 # If a variable that we intend to train is created as a side effect
346 # of creating a template, then that is almost certainly an error.
347 if trainable_at_start != len(trainable_variables):
348 raise ValueError("Trainable variable created when calling a template "
349 "after the first time, perhaps you used tf.Variable "
350 "when you meant tf.get_variable: %s" %
351 (trainable_variables[trainable_at_start:],))
353 # Non-trainable tracking variables are a legitimate reason why a new
354 # variable would be created, but it is a relatively advanced use-case,
355 # so log it.
356 variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
357 if vars_at_start != len(variables):
358 logging.info(
359 "New variables created when calling a template after "
360 "the first time, perhaps you used tf.Variable when you "
361 "meant tf.get_variable: %s", variables[vars_at_start:])
362 elif self._first_call:
363 self._first_call = False
364 try:
365 # The first time we run, restore variables if necessary (via
366 # Trackable).
367 with trackable_util.capture_dependencies(template=self):
368 result = self._func(*args, **kwargs)
369 except:
370 self._first_call = True
371 raise
372 self._variables_created = True
373 else: # We are calling the template in parallel from another thread.
374 result = self._func(*args, **kwargs)
375 return result
376 except Exception as exc:
377 # Reraise the exception, but append the original definition to the
378 # trace.
379 args = exc.args
380 if not args:
381 arg0 = ""
382 else:
383 arg0 = args[0]
384 trace = "".join(
385 _skip_common_stack_elements(self._stacktrace,
386 traceback.format_stack()))
387 arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
388 new_args = [arg0]
389 new_args.extend(args[1:])
390 exc.args = tuple(new_args)
391 raise
393 def __call__(self, *args, **kwargs):
394 if self._variable_scope:
395 # Only reuse variables if not on first call.
396 with variable_scope.variable_scope(
397 self._variable_scope, reuse=not self._first_call):
398 return self._call_func(args, kwargs)
399 else:
400 # The scope was not created at construction time, so create it here.
401 # Subsequent calls should reuse variables.
402 with variable_scope.variable_scope(
403 self._unique_name, self._name,
404 custom_getter=self._custom_getter) as vs:
405 self._variable_scope = vs
406 return self._call_func(args, kwargs)
408 @property
409 def name(self):
410 """Returns the name given to this Template."""
411 return self._name
413 @property
414 def func(self):
415 """Returns the func given to this Template."""
416 return self._func
418 @property
419 def variable_scope(self):
420 """Returns the variable scope object created by this Template."""
421 return self._variable_scope
423 @property
424 def variable_scope_name(self):
425 """Returns the variable scope name created by this Template."""
426 if self._variable_scope:
427 name = self._variable_scope.name
428 if not name or name[-1] == "/":
429 return name
430 else:
431 # To prevent partial matches on the scope_name, we add '/' at the end.
432 return name + "/"
434 @property
435 def variables(self):
436 """Returns the list of global and local variables created by the Template."""
437 return self.global_variables + self.local_variables
439 @property
440 def trainable_variables(self):
441 """Returns the list of trainable variables created by the Template."""
442 if self._variables_created:
443 return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES,
444 self.variable_scope_name)
445 else:
446 return []
448 @property
449 def non_trainable_variables(self):
450 """Returns the list of non-trainable variables created by the Template."""
451 # TODO(apassos) Make sure it matches Eager when using local variables.
452 global_variables = self.global_variables
453 trainable_variables = set(self.trainable_variables)
454 return [x for x in global_variables if x not in trainable_variables]
456 @property
457 def global_variables(self):
458 """Returns the list of global variables created by the Template."""
459 if self._variables_created:
460 return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
461 self.variable_scope_name)
462 else:
463 return []
465 @property
466 def local_variables(self):
467 """Returns the list of global variables created by the Template."""
468 if self._variables_created:
469 return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES,
470 self.variable_scope_name)
471 else:
472 return []
474 @property
475 def weights(self):
476 """List of weights/variables created by the Template."""
477 return self.variables
479 @property
480 def trainable_weights(self):
481 """List of trainable weights/variables created by the Template."""
482 return self.trainable_variables
484 @property
485 def non_trainable_weights(self):
486 """List of non-trainable weights/variables created by the Template."""
487 return self.non_trainable_variables
489 @property
490 @deprecated("2017-02-21",
491 "The .var_scope property is deprecated. Please change your "
492 "code to use the .variable_scope property")
493 def var_scope(self):
494 """Returns the variable scope object created by this Template."""
495 return self._variable_scope
498class _EagerTemplateVariableStore:
499 """Wrapper around EagerVariableStore to support nesting EagerTemplates."""
501 def __init__(self, variable_scope_name):
502 self._variable_scope_name = variable_scope_name
503 default = variable_scope._get_default_variable_store() # pylint: disable=protected-access
504 if default._store_eager_variables: # pylint: disable=protected-access
505 self._eager_variable_store = variable_scope.EagerVariableStore(default)
506 else:
507 # If no outer eager variable store has been made,
508 # the template needs to create one
509 self._eager_variable_store = variable_scope.EagerVariableStore()
510 self._used_once = False
512 def set_variable_scope_name(self, variable_scope_name):
513 self._variable_scope_name = variable_scope_name
515 @tf_contextlib.contextmanager
516 def as_default(self):
517 try:
518 if not self._used_once:
519 # If an outer eager VariableStore was explicitly created and set by
520 # the first time this template store was used (even if not at
521 # constructor time) then pick up the outer variable store.
522 default = variable_scope._get_default_variable_store() # pylint: disable=protected-access
523 if default._store_eager_variables: # pylint: disable=protected-access
524 self._eager_variable_store._store = default # pylint: disable=protected-access
525 self._used_once = True
526 with self._eager_variable_store.as_default(): # pylint: disable=protected-access
527 yield
528 finally:
529 # Each _EagerTemplateVariableStore object lives underneath a variable
530 # scope (see EagerTemplate.__call__). This variable scope's subscopes are
531 # closed when the EagerTemplate object returns from __call__. For
532 # top-level _EagerTemplateVariableStore objects, the variable store to
533 # which the variable scope is attached is different from the
534 # EagerVariableStore; as such it is necessary to close its subscopes
535 # here as well.
536 if self._variable_scope_name is None:
537 raise RuntimeError("A variable scope must be set before an "
538 "_EagerTemplateVariableStore object exits.")
539 variable_scope.get_variable_scope_store().close_variable_subscopes(
540 self._variable_scope_name)
542 def _variables_in_scope(self, variable_list):
543 if self._variable_scope_name is None:
544 raise RuntimeError(
545 "A variable scope must be set before variables can be accessed.")
546 return [
547 v for v in variable_list
548 if v.name.startswith(self._variable_scope_name + "/")
549 ]
551 def variables(self):
552 return self._variables_in_scope(self._eager_variable_store.variables())
554 def trainable_variables(self):
555 return self._variables_in_scope(
556 self._eager_variable_store.trainable_variables())
558 def non_trainable_variables(self):
559 return self._variables_in_scope(
560 self._eager_variable_store.non_trainable_variables())
563class EagerTemplate(Template):
564 """Wrap a function to aid in variable sharing in Eager mode.
566 Templates are functions that create variables the first time they are called
567 and reuse them thereafter. See `make_template` for full documentation.
569 Note: By default, the full variable scope is captured at the time of first
570 call. If `create_scope_now` is passed as True to the constructor, the full
571 scope will be captured there, but no variables will be created until the first
572 call.
573 """
575 def __init__(self,
576 name,
577 func,
578 create_scope_now=False,
579 custom_getter=None,
580 create_graph_function=False):
581 """Creates a template for the given function.
583 Args:
584 name: A name for the scope created by this template. The name will be made
585 unique by appending `_N` to the it (see how
586 `tf.compat.v1.variable_scope` treats the `default_name` for details).
587 func: The function to apply each time.
588 create_scope_now: Whether to create the scope at Template construction
589 time, rather than first call. Defaults to false. Creating the scope at
590 construction time may be more convenient if the template is passed
591 through much lower level code, and you want to be sure of the scope name
592 without knowing exactly where it will be first called. If set to True,
593 the scope will be created in the constructor, and all subsequent times
594 in `__call__`, leading to a trailing numeral being added to the names of
595 all created Tensors. If set to False, the scope will be created at the
596 first call location.
597 custom_getter: optional custom getter to pass to `variable_scope()`
598 create_graph_function: When True, `func` will be executed as a graph
599 function. Enabling this flag allows the caller to reap the performance
600 benefits associated with executing graphs, at the cost of sacrificing
601 debuggability; however, not all Python functions can be compiled into
602 graph functions. See the documentation for `function.defun` for details.
604 Raises:
605 RuntimeError: if eager execution is not enabled.
606 """
607 if not context.executing_eagerly():
608 raise RuntimeError(
609 "{} objects can only be used when eager execution is enabled, use "
610 "tf.Template for graph construction".format(type(self)))
611 super(EagerTemplate, self).__init__(name, func, create_scope_now, None,
612 custom_getter, create_graph_function)
613 if self._variable_scope is not None:
614 variable_scope_name = self._variable_scope.name
615 else:
616 # Defer setting the variable scope name until the variable scope
617 # is created in __call__.
618 variable_scope_name = None
619 self._template_store = _EagerTemplateVariableStore(variable_scope_name)
620 self._variable_scope_context_manager = None
622 def _call_func(self, args, kwargs):
623 try:
624 vars_at_start = self._template_store.variables()
625 trainable_at_start = self._template_store.trainable_variables()
626 if self._variables_created:
627 result = self._func(*args, **kwargs)
628 else:
629 # The first time we run, restore variables if necessary (via
630 # Trackable).
631 with trackable_util.capture_dependencies(template=self):
632 result = self._func(*args, **kwargs)
634 if self._variables_created:
635 # Variables were previously created, implying this is not the first
636 # time the template has been called. Check to make sure that no new
637 # trainable variables were created this time around.
638 trainable_variables = self._template_store.trainable_variables()
639 # If a variable that we intend to train is created as a side effect
640 # of creating a template, then that is almost certainly an error.
641 if len(trainable_at_start) != len(trainable_variables):
642 raise ValueError(
643 "Trainable variable created when calling a template "
644 "after the first time, perhaps you used tf.Variable "
645 "when you meant tf.get_variable: %s" % list(
646 object_identity.ObjectIdentitySet(trainable_variables) -
647 object_identity.ObjectIdentitySet(trainable_at_start)))
649 # Non-trainable tracking variables are a legitimate reason why a new
650 # variable would be created, but it is a relatively advanced use-case,
651 # so log it.
652 variables = self._template_store.variables()
653 if len(vars_at_start) != len(variables):
654 logging.info(
655 "New variables created when calling a template after "
656 "the first time, perhaps you used tf.Variable when you "
657 "meant tf.get_variable: %s",
658 list(
659 object_identity.ObjectIdentitySet(variables) -
660 object_identity.ObjectIdentitySet(vars_at_start)))
661 else:
662 self._variables_created = True
663 return result
664 except Exception as exc:
665 # Reraise the exception, but append the original definition to the
666 # trace.
667 args = exc.args
668 if not args:
669 arg0 = ""
670 else:
671 arg0 = args[0]
672 trace = "".join(
673 _skip_common_stack_elements(self._stacktrace,
674 traceback.format_stack()))
675 arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
676 new_args = [arg0]
677 new_args.extend(args[1:])
678 exc.args = tuple(new_args)
679 raise
681 def __call__(self, *args, **kwargs):
682 # In both branches below, the template store is installed as default after
683 # the variable scope is opened in order to ensure that templates nested at
684 # the same level correctly uniquify lower variable scope names.
685 if self._variable_scope:
686 # Create a cache for the variable scope context manager the first time
687 # around so that we don't have to keep recreating it.
688 if not self._variable_scope_context_manager:
689 self._variable_scope_context_manager = variable_scope.variable_scope(
690 self._variable_scope, reuse=variable_scope.AUTO_REUSE)
691 with self._variable_scope_context_manager:
692 with self._template_store.as_default():
693 return self._call_func(args, kwargs)
694 else:
695 # The scope was not created at construction time, so create it here.
696 # Subsequent calls should reuse variables.
697 with variable_scope.variable_scope(
698 self._unique_name, self._name,
699 custom_getter=self._custom_getter) as vs:
700 self._variable_scope = vs
701 # Because the scope was not created at construction time, the template
702 # store's variable scope name is unset; set it here.
703 self._template_store.set_variable_scope_name(vs.name)
704 with self._template_store.as_default():
705 return self._call_func(args, kwargs)
707 @property
708 def variables(self):
709 """Returns the list of variables created by the Template."""
710 # Currently there is no local variable in Eager mode.
711 if not self._variables_created:
712 return []
713 return self._template_store.variables()
715 @property
716 def trainable_variables(self):
717 """Returns the list of trainable variables created by the Template."""
718 # Currently there is no local variable in Eager mode.
719 if not self._variables_created:
720 return []
721 return self._template_store.trainable_variables()
723 @property
724 def non_trainable_variables(self):
725 """Returns the list of non-trainable variables created by the Template."""
726 # Currently there is no local variable in Eager mode.
727 if not self._variables_created:
728 return []
729 return self._template_store.non_trainable_variables()
731 @property
732 def global_variables(self):
733 """Returns the list of global variables created by the Template."""
734 # Currently there is no local variable in Eager mode.
735 if not self._variables_created:
736 return []
737 return self.variables
739 @property
740 def local_variables(self):
741 """Returns the list of global variables created by the Template."""
742 # Currently there is no local variable in Eager mode.
743 return []