Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/coordinator/coordinator_context.py: 43%

56 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""The execution context for ClusterCoordinator.""" 

16 

17import contextlib 

18import threading 

19 

20from tensorflow.core.framework import attr_value_pb2 

21from tensorflow.python.distribute.coordinator import remote_value 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import tensor 

25from tensorflow.python.util import compat 

26from tensorflow.python.util.tf_export import tf_export 

27 

28_dispatch_context = threading.local() 

29 

30 

31def get_current_dispatch_context(): 

32 try: 

33 return _dispatch_context.current 

34 except AttributeError: 

35 return None 

36 

37 

38@contextlib.contextmanager 

39def with_dispatch_context(worker_obj): 

40 previous_context = getattr(_dispatch_context, "current", None) 

41 _dispatch_context.current = DispatchContext(worker_obj) 

42 yield 

43 _dispatch_context.current = previous_context 

44 

45 

46class DispatchContext(object): 

47 """Context entered when executing a closure on a given worker.""" 

48 

49 def __init__(self, worker_obj): 

50 self._worker = worker_obj 

51 self._worker_index = worker_obj.worker_index 

52 

53 @property 

54 def worker(self): 

55 return self._worker 

56 

57 @property 

58 def worker_index(self): 

59 return self._worker_index 

60 

61 def maybe_get_remote_value(self, ret): 

62 return maybe_get_remote_value(ret) 

63 

64 

65def maybe_get_remote_value(val): 

66 """Gets the value of `val` if it is a `RemoteValue`.""" 

67 if isinstance(val, remote_value.RemoteValue): 

68 error = val._get_error() # pylint: disable=protected-access 

69 if error: 

70 raise AssertionError( 

71 "RemoteValue doesn't have a value because it has error %r:%s" % 

72 (error, error)) 

73 elif val._status is not remote_value.RemoteValueStatus.READY: # pylint: disable=protected-access 

74 raise AssertionError("The input RemoteValue has not been executed.") 

75 else: 

76 return val._get_values() # pylint: disable=protected-access 

77 else: 

78 return val 

79 

80 

81@tf_export("distribute.coordinator.experimental_get_current_worker_index", 

82 v1=[]) 

83def get_current_worker_index(): 

84 """Returns the current worker index, when called within a worker closure. 

85 

86 Some parameter server training workloads may require the worker to know its 

87 index, for example for data sharding for reduced-variance training. 

88 

89 This method may be used within a `tf.function` that is executed on a worker. 

90 That is, either a `dataset_fn` that runs via 

91 `ClusterCoordinator.create_per_worker_dataset`, or any other function 

92 scheduled via `ClusterCoordinator.schedule`. 

93 

94 Example (sharding data by worker): 

95 

96 ```python 

97 strategy = tf.distribute.ParameterServerStrategy( 

98 cluster_resolver=...) 

99 coordinator = ( 

100 tf.distribute.coordinator.ClusterCoordinator(strategy)) 

101 

102 def dataset_fn(context): 

103 dataset = tf.data.Dataset.range(10) 

104 worker_index = ( 

105 tf.distribute.coordinator.experimental_get_current_worker_index() 

106 ) 

107 dataset = dataset.shard( 

108 num_shards=num_workers, 

109 index=worker_index, 

110 ) 

111 return dataset 

112 

113 @tf.function 

114 def per_worker_dataset_fn(): 

115 return strategy.distribute_datasets_from_function(dataset_fn) 

116 

117 per_worker_dataset = coordinator.create_per_worker_dataset( 

118 per_worker_dataset_fn) 

119 ``` 

120 

121 Raises: 

122 RuntimeError: if called from outside a `tf.function` or outside of a remote 

123 closure execution context (that is, on a non-worker machine). 

124 """ 

125 

126 msg = ("Cannot retrieve the worker index. `get_worker_idx_and_num_workers` " 

127 "should be called from within a tf.function being executed on a " 

128 "worker. This method should only be called from either a dataset_fn " 

129 "that is passed into `ClusterCoordinator.create_per_worker_dataset`, " 

130 "or a tf.function that is passed into `ClusterCoordinator.schedule`.") 

131 if not ops.inside_function(): 

132 raise RuntimeError(msg) 

133 

134 def call_time_worker_index(): 

135 dispatch_context = get_current_dispatch_context() 

136 if not dispatch_context: 

137 raise RuntimeError(msg) 

138 return dispatch_context.worker_index 

139 

140 worker_index = ops.get_default_graph().capture_call_time_value( 

141 call_time_worker_index, tensor.TensorSpec([], dtype=dtypes.int64)) 

142 worker_index.op._set_attr( # pylint: disable=protected-access 

143 "_user_specified_name", 

144 attr_value_pb2.AttrValue(s=compat.as_bytes("worker_index"))) 

145 return worker_index