Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_strategy_util.py: 21%
124 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
17import gc
19from tensorflow.core.protobuf import config_pb2
20from tensorflow.python.client import session as session_lib
21from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
22from tensorflow.python.eager import context
23from tensorflow.python.eager import monitoring
24from tensorflow.python.eager.def_function import function
25from tensorflow.python.eager.def_function import functions_run_eagerly
26from tensorflow.python.eager.def_function import run_functions_eagerly
27from tensorflow.python.framework import device
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.tpu import topology
32from tensorflow.python.tpu import tpu
33from tensorflow.python.util import compat
34from tensorflow.python.util.tf_export import tf_export
37_INITIALIZED_TPU_SYSTEMS = {}
38_LOCAL_MASTERS = ("", "local")
41_tpu_worker_address = monitoring.StringGauge(
42 "/tensorflow/tpu/worker_address",
43 "The worker address that the coordinator/client connects to.", "address")
46@tf_export("tpu.experimental.initialize_tpu_system")
47def initialize_tpu_system(cluster_resolver=None):
48 """Initialize the TPU devices.
50 Args:
51 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
52 which provides information about the TPU cluster.
53 Returns:
54 The tf.tpu.Topology object for the topology of the TPU cluster. If called
55 inside tf.function, it returns the serialized topology object instead.
57 Raises:
58 RuntimeError: If running inside a tf.function.
59 NotFoundError: If no TPU devices found in eager mode.
60 """
62 # Deallocate all TPU buffers by clearing out eager context caches and
63 # triggering garbage collection to avoid keeping invalid tpu buffer around
64 # after reinitialized tpu system.
65 logging.info("Deallocate tpu buffers before initializing tpu system.")
66 context.context()._clear_caches() # pylint: disable=protected-access
67 context.context().clear_kernel_cache()
68 gc.collect()
70 job = None
71 if cluster_resolver is None:
72 # If no cluster resolver is specified, and running eagerly, execute the init
73 # ops in the current device scope.
74 if context.executing_eagerly():
75 curr_device = device.DeviceSpec.from_string(context.context().device_name)
76 if curr_device.job is not None:
77 job = "{}/replica:0/task:0".format(curr_device.job)
79 cluster_resolver = TPUClusterResolver("")
80 assert isinstance(cluster_resolver, TPUClusterResolver)
82 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
83 if tpu_name in _INITIALIZED_TPU_SYSTEMS:
84 logging.warning(
85 "TPU system %s has already been initialized. "
86 "Reinitializing the TPU can cause previously created "
87 "variables on TPU to be lost.", tpu_name)
89 logging.info("Initializing the TPU system: %s", tpu_name)
91 # This function looks as it is for the following non-intuitive reasons.
92 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
93 # DistributedTPURewritePass. This pass actually adds real ops that
94 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
95 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
96 if tpu_name not in _LOCAL_MASTERS:
97 # Explicitly place the tpu.initialize_system in the first worker to
98 # avoid the output node match multiple devices error.
99 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
101 if context.executing_eagerly():
102 @function(autograph=False)
103 def _tpu_init_fn():
104 # In TF1, we usually close chips when compilation fails to clear the data
105 # in infeed. In TF2, we don't need to do this because infeed is no longer
106 # used, so user can recover from TPU compilation failures more smoothly.
107 # Same for the cancellation of a TPU excution.
108 return tpu.initialize_system(
109 job=job,
110 compilation_failure_closes_chips=False,
111 tpu_cancellation_closes_chips=False)
113 # The TPU_SYSTEM device must match the device used in tpu.initialize_system
114 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
115 # devices available.
116 run_eagerly = functions_run_eagerly()
117 if run_eagerly:
118 logging.warning(
119 "It looks like tf.function behavior was disabled, perhaps using"
120 " tf.config.run_functions_eagerly."
121 " tf.tpu.experimental.initialize_tpu_system requires tf.function to"
122 " work. This primitive will override the disable."
123 )
124 run_functions_eagerly(False)
125 try:
126 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
127 output = _tpu_init_fn()
128 context.async_wait()
129 except errors.InvalidArgumentError as e:
130 raise errors.NotFoundError(
131 None, None,
132 "TPUs not found in the cluster. Failed in initialization: "
133 + str(e))
134 finally:
135 if run_eagerly is not None:
136 run_functions_eagerly(run_eagerly)
137 # Clear out the eager context caches since the memory is invalid now.
138 context.context()._initialize_logical_devices() # pylint: disable=protected-access
140 serialized_topology = output.numpy()
141 elif not ops.executing_eagerly_outside_functions():
142 master = cluster_resolver.master()
143 cluster_spec = cluster_resolver.cluster_spec()
145 session_config = config_pb2.ConfigProto(allow_soft_placement=True)
146 if cluster_spec:
147 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
149 with ops.Graph().as_default():
150 with session_lib.Session(config=session_config, target=master) as sess:
151 serialized_topology = sess.run(tpu.initialize_system())
152 else:
153 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
154 serialized_topology = tpu.initialize_system(
155 job=job, compilation_failure_closes_chips=False)
156 # If initialize_tpu_system is called inside tf.function, we only return
157 # the serialized topology object as the tf.tpu.Topology object has to be
158 # constructed in eager mode.
159 return serialized_topology
161 logging.info("Finished initializing TPU system.")
162 tpu_topology = topology.Topology(serialized=serialized_topology)
163 cluster_resolver.set_tpu_topology(serialized_topology)
164 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
166 # Record the address of the TPU worker-0 that the coordinator connects to.
167 # This can be used to associate the TPU worker with the right coordinator when
168 # aggregating the metrics for the application. An example of the address:
169 # /bns/mb/borg/mb/bns/chienchunh/chienchunh_group_49640234.1.tfm_train_tpu_worker/0
170 _tpu_worker_address.get_cell("address").set(cluster_resolver.get_master())
172 return tpu_topology
175def get_initialized_tpu_systems():
176 """Returns all currently initialized tpu systems.
178 Returns:
179 A dictionary, with tpu name as the key and the tpu topology as the value.
180 """
181 return _INITIALIZED_TPU_SYSTEMS.copy()
184@tf_export("tpu.experimental.shutdown_tpu_system")
185def shutdown_tpu_system(cluster_resolver=None):
186 """Shuts down the TPU devices.
188 This will clear all caches, even those that are maintained through sequential
189 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
190 cache.
192 Args:
193 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
194 which provides information about the TPU cluster.
196 Raises:
197 RuntimeError: If no TPU devices found for eager execution or if run in a
198 tf.function.
199 """
200 job = None
201 if cluster_resolver is None:
202 # If no cluster resolver is specified, and running eagerly, execute the init
203 # ops in the current device scope.
204 if context.executing_eagerly():
205 curr_device = device.DeviceSpec.from_string(context.context().device_name)
206 if curr_device.job is not None:
207 job = "{}/replica:0/task:0".format(curr_device.job)
209 cluster_resolver = TPUClusterResolver("")
210 assert isinstance(cluster_resolver, TPUClusterResolver)
212 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access
213 if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
214 logging.warning("You are shutting down a TPU system %s that has not been "
215 "initialized." % tpu_name)
217 logging.info("Shutting down the TPU system: %s", tpu_name)
219 if context.executing_eagerly():
220 # This function looks as it is for the following non-intuitive reasons.
221 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
222 # DistributedTPURewritePass. This pass actually adds real ops that
223 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
224 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
225 if tpu_name not in _LOCAL_MASTERS:
226 # Explicitly place the tpu.shutdown_system in the first worker to
227 # avoid the output node match multiple devices error.
228 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
230 @function(autograph=False)
231 def _tpu_shutdown_fn():
232 tpu.shutdown_system(job=job)
234 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
235 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
236 # devices available.
237 run_eagerly = functions_run_eagerly()
238 if run_eagerly:
239 logging.warning(
240 "It looks like tf.function behavior was disabled, perhaps using"
241 " tf.config.run_functions_eagerly."
242 " tf.tpu.experimental.shutdown_tpu_system requires tf.function to"
243 " work. This primitive will override the disable."
244 )
245 run_functions_eagerly(False)
246 try:
247 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access
248 _tpu_shutdown_fn()
249 finally:
250 if run_eagerly is not None:
251 run_functions_eagerly(run_eagerly)
253 # Clear out the eager context caches since the memory is invalid now.
254 logging.info("Clearing out eager caches")
255 context.context()._clear_caches() # pylint: disable=protected-access
256 context.context().clear_kernel_cache()
257 elif not ops.executing_eagerly_outside_functions():
258 master = cluster_resolver.master()
259 cluster_spec = cluster_resolver.cluster_spec()
261 session_config = config_pb2.ConfigProto(allow_soft_placement=True)
262 if cluster_spec:
263 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
265 with ops.Graph().as_default():
266 with session_lib.Session(config=session_config, target=master) as sess:
267 sess.run(tpu.shutdown_system())
268 else:
269 raise RuntimeError(
270 "initialize_tpu_system is not supported within "
271 "tf.functions. You should call initialize_tpu_system outside of your tf.function. "
272 )
274 logging.info("Finished shutting down TPU system.")
275 if tpu_name in _INITIALIZED_TPU_SYSTEMS:
276 del _INITIALIZED_TPU_SYSTEMS[tpu_name]