Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/kubernetes_cluster_resolver.py: 21%

48 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Implementation of Cluster Resolvers for Kubernetes.""" 

16 

17from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 

18from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url 

19from tensorflow.python.training import server_lib 

20from tensorflow.python.util.tf_export import tf_export 

21 

22 

23@tf_export('distribute.cluster_resolver.KubernetesClusterResolver') 

24class KubernetesClusterResolver(ClusterResolver): 

25 """ClusterResolver for Kubernetes. 

26 

27 This is an implementation of cluster resolvers for Kubernetes. When given the 

28 the Kubernetes namespace and label selector for pods, we will retrieve the 

29 pod IP addresses of all running pods matching the selector, and return a 

30 ClusterSpec based on that information. 

31 

32 Note: it cannot retrieve `task_type`, `task_id` or `rpc_layer`. To use it 

33 with some distribution strategies like 

34 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, you will need to 

35 specify `task_type` and `task_id` by setting these attributes. 

36 

37 Usage example with tf.distribute.Strategy: 

38 

39 ```Python 

40 # On worker 0 

41 cluster_resolver = KubernetesClusterResolver( 

42 {"worker": ["job-name=worker-cluster-a", "job-name=worker-cluster-b"]}) 

43 cluster_resolver.task_type = "worker" 

44 cluster_resolver.task_id = 0 

45 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 

46 cluster_resolver=cluster_resolver) 

47 

48 # On worker 1 

49 cluster_resolver = KubernetesClusterResolver( 

50 {"worker": ["job-name=worker-cluster-a", "job-name=worker-cluster-b"]}) 

51 cluster_resolver.task_type = "worker" 

52 cluster_resolver.task_id = 1 

53 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 

54 cluster_resolver=cluster_resolver) 

55 ``` 

56 """ 

57 

58 def __init__(self, 

59 job_to_label_mapping=None, 

60 tf_server_port=8470, 

61 rpc_layer='grpc', 

62 override_client=None): 

63 """Initializes a new KubernetesClusterResolver. 

64 

65 This initializes a new Kubernetes ClusterResolver. The ClusterResolver 

66 will attempt to talk to the Kubernetes master to retrieve all the instances 

67 of pods matching a label selector. 

68 

69 Args: 

70 job_to_label_mapping: A mapping of TensorFlow jobs to label selectors. 

71 This allows users to specify many TensorFlow jobs in one Cluster 

72 Resolver, and each job can have pods belong with different label 

73 selectors. For example, a sample mapping might be 

74 ``` 

75 {'worker': ['job-name=worker-cluster-a', 'job-name=worker-cluster-b'], 

76 'ps': ['job-name=ps-1', 'job-name=ps-2']} 

77 ``` 

78 tf_server_port: The port the TensorFlow server is listening on. 

79 rpc_layer: (Optional) The RPC layer TensorFlow should use to communicate 

80 between tasks in Kubernetes. Defaults to 'grpc'. 

81 override_client: The Kubernetes client (usually automatically retrieved 

82 using `from kubernetes import client as k8sclient`). If you pass this 

83 in, you are responsible for setting Kubernetes credentials manually. 

84 

85 Raises: 

86 ImportError: If the Kubernetes Python client is not installed and no 

87 `override_client` is passed in. 

88 RuntimeError: If autoresolve_task is not a boolean or a callable. 

89 """ 

90 try: 

91 from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top 

92 

93 k8sconfig.load_kube_config() 

94 except ImportError: 

95 if not override_client: 

96 raise ImportError('The Kubernetes Python client must be installed ' 

97 'before using the Kubernetes Cluster Resolver. ' 

98 'To install the Kubernetes Python client, run ' 

99 '`pip install kubernetes` on your command line.') 

100 

101 if not job_to_label_mapping: 

102 job_to_label_mapping = {'worker': ['job-name=tensorflow']} 

103 

104 self._job_to_label_mapping = job_to_label_mapping 

105 self._tf_server_port = tf_server_port 

106 self._override_client = override_client 

107 

108 self.task_type = None 

109 self.task_id = None 

110 self.rpc_layer = rpc_layer 

111 

112 def master(self, task_type=None, task_id=None, rpc_layer=None): 

113 """Returns the master address to use when creating a session. 

114 

115 You must have set the task_type and task_id object properties before 

116 calling this function, or pass in the `task_type` and `task_id` 

117 parameters when using this function. If you do both, the function parameters 

118 will override the object properties. 

119 

120 Note: this is only useful for TensorFlow 1.x. 

121 

122 Args: 

123 task_type: (Optional) The type of the TensorFlow task of the master. 

124 task_id: (Optional) The index of the TensorFlow task of the master. 

125 rpc_layer: (Optional) The RPC protocol for the given cluster. 

126 

127 Returns: 

128 The name or URL of the session master. 

129 """ 

130 task_type = task_type if task_type is not None else self.task_type 

131 task_id = task_id if task_id is not None else self.task_id 

132 

133 if task_type is not None and task_id is not None: 

134 return format_master_url( 

135 self.cluster_spec().task_address(task_type, task_id), 

136 rpc_layer or self.rpc_layer) 

137 

138 return '' 

139 

140 def cluster_spec(self): 

141 """Returns a ClusterSpec object based on the latest info from Kubernetes. 

142 

143 We retrieve the information from the Kubernetes master every time this 

144 method is called. 

145 

146 Returns: 

147 A ClusterSpec containing host information returned from Kubernetes. 

148 

149 Raises: 

150 RuntimeError: If any of the pods returned by the master is not in the 

151 `Running` phase. 

152 """ 

153 if self._override_client: 

154 client = self._override_client 

155 else: 

156 from kubernetes import config as k8sconfig # pylint: disable=g-import-not-at-top 

157 from kubernetes import client as k8sclient # pylint: disable=g-import-not-at-top 

158 

159 k8sconfig.load_kube_config() 

160 client = k8sclient.CoreV1Api() 

161 

162 cluster_map = {} 

163 

164 for tf_job in self._job_to_label_mapping: 

165 all_pods = [] 

166 for selector in self._job_to_label_mapping[tf_job]: 

167 ret = client.list_pod_for_all_namespaces(label_selector=selector) 

168 selected_pods = [] 

169 

170 # Sort the list by the name to make sure it doesn't change call to call. 

171 for pod in sorted(ret.items, key=lambda x: x.metadata.name): 

172 if pod.status.phase == 'Running': 

173 selected_pods.append( 

174 '%s:%s' % (pod.status.host_ip, self._tf_server_port)) 

175 else: 

176 raise RuntimeError('Pod "%s" is not running; phase: "%s"' % 

177 (pod.metadata.name, pod.status.phase)) 

178 all_pods.extend(selected_pods) 

179 cluster_map[tf_job] = all_pods 

180 

181 return server_lib.ClusterSpec(cluster_map)