Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/tfconfig_cluster_resolver.py: 44%

75 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables.""" 

16 

17 

18import json 

19import os 

20 

21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 

22from tensorflow.python.training.server_lib import ClusterSpec 

23from tensorflow.python.util.tf_export import tf_export 

24 

25_TF_CONFIG_ENV = 'TF_CONFIG' 

26_SESSION_MASTER_KEY = 'session_master' 

27_RPC_LAYER_KEY = 'rpc_layer' 

28_TASK_KEY = 'task' 

29 

30 

31def format_master_url(master, rpc_layer=None): 

32 if rpc_layer: 

33 return '%s://%s' % (rpc_layer, master) 

34 else: 

35 return master 

36 

37 

38def _load_tf_config(): 

39 return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}')) 

40 

41 

42def _get_value_in_tfconfig(key, default=None): 

43 tf_config = _load_tf_config() 

44 return tf_config[key] if key in tf_config else default 

45 

46 

47@tf_export('distribute.cluster_resolver.TFConfigClusterResolver') 

48class TFConfigClusterResolver(ClusterResolver): 

49 """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar. 

50 

51 This is an implementation of cluster resolvers when using TF_CONFIG to set 

52 information about the cluster. The cluster spec returned will be 

53 initialized from the TF_CONFIG environment variable. 

54 

55 An example to set TF_CONFIG is: 

56 

57 ```Python 

58 os.environ['TF_CONFIG'] = json.dumps({ 

59 'cluster': { 

60 'worker': ["localhost:12345", "localhost:23456"] 

61 }, 

62 'task': {'type': 'worker', 'index': 0} 

63 }) 

64 ``` 

65 

66 However, sometimes the container orchestration framework will set TF_CONFIG 

67 for you. In this case, you can just create an instance without passing in any 

68 arguments. You can find an example here to let Kuburnetes set TF_CONFIG for 

69 you: https://github.com/tensorflow/ecosystem/tree/master/kubernetes. Then you 

70 can use it with `tf.distribute.Strategy` as: 

71 

72 ```Python 

73 # `TFConfigClusterResolver` is already the default one in the following 

74 # strategy. 

75 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 

76 cluster_resolver=TFConfigClusterResolver()) 

77 ``` 

78 """ 

79 

80 def __init__(self, 

81 task_type=None, 

82 task_id=None, 

83 rpc_layer=None, 

84 environment=None): 

85 """Creates a new TFConfigClusterResolver. 

86 

87 Args: 

88 task_type: (String, optional) Overrides the task type specified in the 

89 TF_CONFIG environment variable. 

90 task_id: (Integer, optional) Overrides the task index specified in the 

91 TF_CONFIG environment variable. 

92 rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses. 

93 environment: (String, optional) Overrides the environment TensorFlow 

94 operates in. 

95 """ 

96 self._task_type = task_type 

97 self._task_id = task_id 

98 self._rpc_layer = rpc_layer 

99 self._environment = environment 

100 

101 @property 

102 def task_type(self): 

103 if self._task_type is None: 

104 task_info = _get_value_in_tfconfig(_TASK_KEY, {}) 

105 return str(task_info['type']) if 'type' in task_info else None 

106 else: 

107 return str(self._task_type) 

108 

109 @property 

110 def task_id(self): 

111 if self._task_id is None: 

112 task_info = _get_value_in_tfconfig(_TASK_KEY, {}) 

113 return int(task_info['index']) if 'index' in task_info else None 

114 else: 

115 return int(self._task_id) 

116 

117 @task_type.setter 

118 def task_type(self, task_type): 

119 self._task_type = task_type 

120 

121 @task_id.setter 

122 def task_id(self, task_id): 

123 self._task_id = task_id 

124 

125 @property 

126 def environment(self): 

127 return self._environment 

128 

129 @property 

130 def rpc_layer(self): 

131 if self._rpc_layer is None: 

132 return _get_value_in_tfconfig(_RPC_LAYER_KEY) 

133 else: 

134 return self._rpc_layer 

135 

136 @rpc_layer.setter 

137 def rpc_layer(self, rpc_layer): 

138 self._rpc_layer = rpc_layer 

139 

140 def num_accelerators(self, 

141 task_type=None, 

142 task_id=None, 

143 config_proto=None): 

144 task_type = self.task_type if task_type is None else task_type 

145 task_id = self.task_id if task_id is None else task_id 

146 return super(TFConfigClusterResolver, self).num_accelerators( 

147 task_type, task_id, config_proto) 

148 

149 def cluster_spec(self): 

150 """Returns a ClusterSpec based on the TF_CONFIG environment variable. 

151 

152 Returns: 

153 A ClusterSpec with information from the TF_CONFIG environment variable. 

154 """ 

155 tf_config = _load_tf_config() 

156 if 'cluster' not in tf_config: 

157 return ClusterSpec({}) 

158 return ClusterSpec(tf_config['cluster']) 

159 

160 def master(self, task_type=None, task_id=None, rpc_layer=None): 

161 """Returns the master address to use when creating a TensorFlow session. 

162 

163 Note: this is only useful for TensorFlow 1.x. 

164 

165 Args: 

166 task_type: (String, optional) Overrides and sets the task_type of the 

167 master. 

168 task_id: (Integer, optional) Overrides and sets the task id of the 

169 master. 

170 rpc_layer: (String, optional) Overrides and sets the protocol over which 

171 TensorFlow nodes communicate with each other. 

172 

173 Returns: 

174 The address of the master. 

175 

176 Raises: 

177 RuntimeError: If the task_type or task_id is not specified and the 

178 `TF_CONFIG` environment variable does not contain a task section. 

179 """ 

180 

181 # If `session_master` is set, just use that. 

182 session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY) 

183 if session_master is not None: 

184 return session_master 

185 

186 # Return an empty string if we are the only job in the ClusterSpec. 

187 cluster_spec = self.cluster_spec() 

188 if (not cluster_spec.jobs or 

189 (len(cluster_spec.jobs) == 1 and 

190 len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)): 

191 return '' 

192 

193 # We try to auto-detect the task type and id, but uses the user-supplied one 

194 # where available 

195 task_type = task_type if task_type is not None else self.task_type 

196 task_id = task_id if task_id is not None else self.task_id 

197 rpc_layer = rpc_layer if rpc_layer is not None else self.rpc_layer 

198 

199 return format_master_url(cluster_spec.task_address(task_type, task_id), 

200 rpc_layer)