1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of Cluster Resolvers for TF_CONFIG Environment Variables."""
16
17
18import json
19import os
20
21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver
22from tensorflow.python.training.server_lib import ClusterSpec
23from tensorflow.python.util.tf_export import tf_export
24
25_TF_CONFIG_ENV = 'TF_CONFIG'
26_SESSION_MASTER_KEY = 'session_master'
27_RPC_LAYER_KEY = 'rpc_layer'
28_TASK_KEY = 'task'
29
30
31def format_master_url(master, rpc_layer=None):
32 if rpc_layer:
33 return '%s://%s' % (rpc_layer, master)
34 else:
35 return master
36
37
38def _load_tf_config():
39 return json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))
40
41
42def _get_value_in_tfconfig(key, default=None):
43 tf_config = _load_tf_config()
44 return tf_config[key] if key in tf_config else default
45
46
47@tf_export('distribute.cluster_resolver.TFConfigClusterResolver')
48class TFConfigClusterResolver(ClusterResolver):
49 """Implementation of a ClusterResolver which reads the TF_CONFIG EnvVar.
50
51 This is an implementation of cluster resolvers when using TF_CONFIG to set
52 information about the cluster. The cluster spec returned will be
53 initialized from the TF_CONFIG environment variable.
54
55 An example to set TF_CONFIG is:
56
57 ```Python
58 os.environ['TF_CONFIG'] = json.dumps({
59 'cluster': {
60 'worker': ["localhost:12345", "localhost:23456"]
61 },
62 'task': {'type': 'worker', 'index': 0}
63 })
64 ```
65
66 However, sometimes the container orchestration framework will set TF_CONFIG
67 for you. In this case, you can just create an instance without passing in any
68 arguments. You can find an example here to let Kuburnetes set TF_CONFIG for
69 you: https://github.com/tensorflow/ecosystem/tree/master/kubernetes. Then you
70 can use it with `tf.distribute.Strategy` as:
71
72 ```Python
73 # `TFConfigClusterResolver` is already the default one in the following
74 # strategy.
75 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy(
76 cluster_resolver=TFConfigClusterResolver())
77 ```
78 """
79
80 def __init__(self,
81 task_type=None,
82 task_id=None,
83 rpc_layer=None,
84 environment=None):
85 """Creates a new TFConfigClusterResolver.
86
87 Args:
88 task_type: (String, optional) Overrides the task type specified in the
89 TF_CONFIG environment variable.
90 task_id: (Integer, optional) Overrides the task index specified in the
91 TF_CONFIG environment variable.
92 rpc_layer: (String, optional) Overrides the rpc layer TensorFlow uses.
93 environment: (String, optional) Overrides the environment TensorFlow
94 operates in.
95 """
96 self._task_type = task_type
97 self._task_id = task_id
98 self._rpc_layer = rpc_layer
99 self._environment = environment
100
101 @property
102 def task_type(self):
103 if self._task_type is None:
104 task_info = _get_value_in_tfconfig(_TASK_KEY, {})
105 return str(task_info['type']) if 'type' in task_info else None
106 else:
107 return str(self._task_type)
108
109 @property
110 def task_id(self):
111 if self._task_id is None:
112 task_info = _get_value_in_tfconfig(_TASK_KEY, {})
113 return int(task_info['index']) if 'index' in task_info else None
114 else:
115 return int(self._task_id)
116
117 @task_type.setter
118 def task_type(self, task_type):
119 self._task_type = task_type
120
121 @task_id.setter
122 def task_id(self, task_id):
123 self._task_id = task_id
124
125 @property
126 def environment(self):
127 return self._environment
128
129 @property
130 def rpc_layer(self):
131 if self._rpc_layer is None:
132 return _get_value_in_tfconfig(_RPC_LAYER_KEY)
133 else:
134 return self._rpc_layer
135
136 @rpc_layer.setter
137 def rpc_layer(self, rpc_layer):
138 self._rpc_layer = rpc_layer
139
140 def num_accelerators(self,
141 task_type=None,
142 task_id=None,
143 config_proto=None):
144 task_type = self.task_type if task_type is None else task_type
145 task_id = self.task_id if task_id is None else task_id
146 return super(TFConfigClusterResolver, self).num_accelerators(
147 task_type, task_id, config_proto)
148
149 def cluster_spec(self):
150 """Returns a ClusterSpec based on the TF_CONFIG environment variable.
151
152 Returns:
153 A ClusterSpec with information from the TF_CONFIG environment variable.
154 """
155 tf_config = _load_tf_config()
156 if 'cluster' not in tf_config:
157 return ClusterSpec({})
158 return ClusterSpec(tf_config['cluster'])
159
160 def master(self, task_type=None, task_id=None, rpc_layer=None):
161 """Returns the master address to use when creating a TensorFlow session.
162
163 Note: this is only useful for TensorFlow 1.x.
164
165 Args:
166 task_type: (String, optional) Overrides and sets the task_type of the
167 master.
168 task_id: (Integer, optional) Overrides and sets the task id of the
169 master.
170 rpc_layer: (String, optional) Overrides and sets the protocol over which
171 TensorFlow nodes communicate with each other.
172
173 Returns:
174 The address of the master.
175
176 Raises:
177 RuntimeError: If the task_type or task_id is not specified and the
178 `TF_CONFIG` environment variable does not contain a task section.
179 """
180
181 # If `session_master` is set, just use that.
182 session_master = _get_value_in_tfconfig(_SESSION_MASTER_KEY)
183 if session_master is not None:
184 return session_master
185
186 # Return an empty string if we are the only job in the ClusterSpec.
187 cluster_spec = self.cluster_spec()
188 if (not cluster_spec.jobs or
189 (len(cluster_spec.jobs) == 1 and
190 len(cluster_spec.job_tasks(cluster_spec.jobs[0])) == 1)):
191 return ''
192
193 # We try to auto-detect the task type and id, but uses the user-supplied one
194 # where available
195 task_type = task_type if task_type is not None else self.task_type
196 task_id = task_id if task_id is not None else self.task_id
197 rpc_layer = rpc_layer if rpc_layer is not None else self.rpc_layer
198
199 return format_master_url(cluster_spec.task_address(task_type, task_id),
200 rpc_layer)