Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/core/function_wrappers.py: 25%
60 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Support for wrapping converted functions bodies with auxiliary logic."""
17from tensorflow.python.autograph.core import ag_ctx
18from tensorflow.python.autograph.core import converter
19from tensorflow.python.autograph.operators import variables
20from tensorflow.python.framework import auto_control_deps
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.util import nest
26# TODO(mdan): Move this into operators - it represents a function definition.
29class FunctionScope(object):
30 """Context manager that wraps the body of a converted function.
32 This context manager handles various operations related to the scope of a
33 function:
34 * optional TF name scopes - these name scopes match the name of the
35 function, for easy visualization in tensorBoard;
36 * optional automatic control dependencies - this adds the same mechanism
37 for control dependencies that is used by `@tf.function`; it can be
38 optionally enabled when using `tf.autograph.to_graph`;
39 * tracking of autograph conversion state (whether it's enabled by the user,
40 conversion options;
41 """
43 def __init__(self, function_name, scope_name, options):
44 self.name = scope_name
45 self.options = options
47 if options.user_requested:
48 self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED,
49 options)
50 self.callopts = options.call_options()
52 use_name_scope = options.uses(converter.Feature.NAME_SCOPES)
53 self.use_name_scope = use_name_scope
54 if use_name_scope:
55 self.name_scope = ops.name_scope(self._sanitize(function_name))
57 use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS)
58 self.use_auto_deps = use_auto_deps
59 if use_auto_deps:
60 self.autodeps_scope = auto_control_deps.AutomaticControlDependencies()
61 self._return_value_marked = False
63 def _sanitize(self, name):
64 """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope."""
65 # TensorFlow doesn't like leading underscores at the top level.
66 if name and name.startswith('_'):
67 name = 'fn' + name
68 return name
70 def __enter__(self):
71 if self.options.user_requested:
72 self.autograph_ctx.__enter__()
73 if self.use_name_scope:
74 self.name_scope.__enter__()
75 if self.use_auto_deps:
76 self.autodeps_scope.__enter__()
77 return self
79 def __exit__(self, exc_type, exc_val, exc_tb):
80 if self.options.user_requested:
81 self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb)
82 if self.use_name_scope:
83 self.name_scope.__exit__(exc_type, exc_val, exc_tb)
84 if self.use_auto_deps:
85 self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb)
87 def ret(self, value, did_return):
88 """Marks a value as returned from the function guarded by the scope."""
89 del did_return
91 if isinstance(value, variables.UndefinedReturnValue):
92 return None
94 if self.use_auto_deps:
95 self._return_value_marked = True
96 if value is None:
97 # We don't create dummy returns, to preserve Python semantics. The user
98 # is responsible for adding a return value to the top-level function.
99 return None
101 def _mark_return_if_tensor(t):
102 if tensor_util.is_tf_type(t):
103 return self.autodeps_scope.mark_as_return(t)
104 return t
106 value = nest.map_structure(_mark_return_if_tensor, value)
107 return value
110def with_function_scope(thunk, scope_name, options):
111 """Inline version of the FunctionScope context manager."""
112 with FunctionScope('lambda_', scope_name, options) as scope:
113 return thunk(scope)