Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_function.py: 56%
25 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
16"""Helper library for functions used during TPU compilation."""
18import contextlib
19import threading
22class TpuContext(threading.local):
23 """A context object holding state about the TPU computation being built."""
25 def __init__(self):
26 """Creates a new TpuContext."""
27 self._number_of_shards = None
29 @property
30 def number_of_shards(self):
31 return self._number_of_shards
33 def set_number_of_shards(self, number_of_shards):
34 self._number_of_shards = number_of_shards
37# The Tpu context holds the number of shards when a sharded computation is
38# being built, or None if no computation is being built.
39_current_tpu_context = TpuContext()
42@contextlib.contextmanager
43def tpu_shard_context(number_of_shards):
44 """A context manager setting current number of shards."""
45 if _current_tpu_context.number_of_shards is not None:
46 raise NotImplementedError(
47 "tpu_shard_context cannot be nested."
48 "If you're using TPUEstimator with inference_on_tpu, "
49 "make sure you have set "
50 "export_saved_model_api_version=ExportSavedModelApiVersion.V2 in "
51 "the creation of TPUEstimator.")
52 try:
53 _current_tpu_context.set_number_of_shards(number_of_shards)
54 yield
55 finally:
56 _current_tpu_context.set_number_of_shards(None)
59def get_tpu_context():
60 return _current_tpu_context
63# Decorator function for tpu computation func that was passed to tpu.rewrite()
64# if there is an embedded training loop in this func, trace tools will generate
65# step markers for each iteration.
66def on_device_training_loop(func):
67 # Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
68 setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
69 return func