Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/tpu/tpu_function.py: 56%

25 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15 

16"""Helper library for functions used during TPU compilation.""" 

17 

18import contextlib 

19import threading 

20 

21 

22class TpuContext(threading.local): 

23 """A context object holding state about the TPU computation being built.""" 

24 

25 def __init__(self): 

26 """Creates a new TpuContext.""" 

27 self._number_of_shards = None 

28 

29 @property 

30 def number_of_shards(self): 

31 return self._number_of_shards 

32 

33 def set_number_of_shards(self, number_of_shards): 

34 self._number_of_shards = number_of_shards 

35 

36 

37# The Tpu context holds the number of shards when a sharded computation is 

38# being built, or None if no computation is being built. 

39_current_tpu_context = TpuContext() 

40 

41 

42@contextlib.contextmanager 

43def tpu_shard_context(number_of_shards): 

44 """A context manager setting current number of shards.""" 

45 if _current_tpu_context.number_of_shards is not None: 

46 raise NotImplementedError( 

47 "tpu_shard_context cannot be nested." 

48 "If you're using TPUEstimator with inference_on_tpu, " 

49 "make sure you have set " 

50 "export_saved_model_api_version=ExportSavedModelApiVersion.V2 in " 

51 "the creation of TPUEstimator.") 

52 try: 

53 _current_tpu_context.set_number_of_shards(number_of_shards) 

54 yield 

55 finally: 

56 _current_tpu_context.set_number_of_shards(None) 

57 

58 

59def get_tpu_context(): 

60 return _current_tpu_context 

61 

62 

63# Decorator function for tpu computation func that was passed to tpu.rewrite() 

64# if there is an embedded training loop in this func, trace tools will generate 

65# step markers for each iteration. 

66def on_device_training_loop(func): 

67 # Value for this attribute is from xla.DebugOptions.StepMarkerLocation. 

68 setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP") 

69 return func