Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/device_setter.py: 25%
60 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device function for replicated training."""
16from tensorflow.core.framework import node_def_pb2
17from tensorflow.python.framework import device as pydev
18from tensorflow.python.platform import tf_logging as logging
19from tensorflow.python.training import server_lib
20from tensorflow.python.util.tf_export import tf_export
22# This is a tuple of PS ops used by tf.estimator.Estimator which should work in
23# almost all of cases.
24STANDARD_PS_OPS = ("Variable", "VariableV2", "AutoReloadVariable",
25 "MutableHashTable", "MutableHashTableV2",
26 "MutableHashTableOfTensors", "MutableHashTableOfTensorsV2",
27 "MutableDenseHashTable", "MutableDenseHashTableV2",
28 "VarHandleOp", "BoostedTreesEnsembleResourceHandleOp",
29 "BoostedTreesQuantileStreamResourceHandleOp",
30 "ResourceConditionalAccumulator",
31 "DecisionTreeResource")
34class _RoundRobinStrategy:
35 """Returns the next ps task index for placement in round-robin order.
37 This class is not to be used directly by users. See instead
38 `replica_device_setter()` below.
39 """
41 def __init__(self, num_tasks):
42 """Create a new `_RoundRobinStrategy`.
44 Args:
45 num_tasks: Number of ps tasks to cycle among.
46 """
47 self._num_tasks = num_tasks
48 self._next_task = 0
50 def __call__(self, unused_op):
51 """Choose a ps task index for the given `Operation`.
53 Args:
54 unused_op: An `Operation` to be placed on ps.
56 Returns:
57 The next ps task index to use for the `Operation`. Returns the next
58 index, in the range `[offset, offset + num_tasks)`.
59 """
60 task = self._next_task
61 self._next_task = (self._next_task + 1) % self._num_tasks
62 return task
65class _ReplicaDeviceChooser:
66 """Class to choose devices for Ops in a replicated training setup.
68 This class is not to be used directly by users. See instead
69 `replica_device_setter()` below.
70 """
72 def __init__(self, ps_tasks, ps_device, worker_device, merge_devices, ps_ops,
73 ps_strategy):
74 """Create a new `_ReplicaDeviceChooser`.
76 Args:
77 ps_tasks: Number of tasks in the `ps` job.
78 ps_device: String. Name of the `ps` job.
79 worker_device: String. Name of the `worker` job.
80 merge_devices: Boolean. Set to True to allow merging of device specs.
81 ps_ops: List of strings representing `Operation` types that need to be
82 placed on `ps` devices.
83 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
84 `ps_ops`), that takes the `Operation` and returns the ps task index to
85 use.
86 """
87 self._ps_tasks = ps_tasks
88 self._ps_device = ps_device
89 self._worker_device = worker_device
90 self._merge_devices = merge_devices
91 self._ps_ops = ps_ops
92 self._ps_strategy = ps_strategy
94 def device_function(self, op):
95 """Choose a device for `op`.
97 Args:
98 op: an `Operation`.
100 Returns:
101 The device to use for the `Operation`.
102 """
103 # If we don't return early here, either merge_devices is True, or op.device
104 # is empty (in which case merging is a no-op). So we can always merge below.
105 if not self._merge_devices and op.device:
106 return op.device
108 current_device = pydev.DeviceSpec.from_string(op.device or "")
110 # The ps_device will be used for specified ops (ps_ops) whenever it is
111 # present and ps_tasks is non-zero. However, its task number will only be
112 # set (using ps_strategy) if there is a job field in ps_device that won't be
113 # changed by the job field (if present) in current_device.
114 node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
115 if self._ps_tasks and self._ps_device and node_def.op in self._ps_ops:
116 ps_device = pydev.DeviceSpec.from_string(self._ps_device)
118 current_job, ps_job = current_device.job, ps_device.job
119 if ps_job and (not current_job or current_job == ps_job):
120 ps_device = ps_device.replace(task=self._ps_strategy(op))
122 ps_device = ps_device.make_merged_spec(current_device)
123 return ps_device.to_string()
125 worker_device = pydev.DeviceSpec.from_string(self._worker_device or "")
126 worker_device = worker_device.make_merged_spec(current_device)
127 return worker_device.to_string()
130@tf_export(v1=["train.replica_device_setter"])
131def replica_device_setter(ps_tasks=0,
132 ps_device="/job:ps",
133 worker_device="/job:worker",
134 merge_devices=True,
135 cluster=None,
136 ps_ops=None,
137 ps_strategy=None):
138 """Return a `device function` to use when building a Graph for replicas.
140 Device Functions are used in `with tf.device(device_function):` statement to
141 automatically assign devices to `Operation` objects as they are constructed,
142 Device constraints are added from the inner-most context first, working
143 outwards. The merging behavior adds constraints to fields that are yet unset
144 by a more inner context. Currently the fields are (job, task, cpu/gpu).
146 If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.
147 Otherwise, the value of `ps_tasks` is derived from `cluster`.
149 By default, only Variable ops are placed on ps tasks, and the placement
150 strategy is round-robin over all ps tasks. A custom `ps_strategy` may be used
151 to do more intelligent placement, such as
152 `tf.contrib.training.GreedyLoadBalancingStrategy`.
154 For example,
156 ```python
157 # To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
158 # jobs on hosts worker0, worker1 and worker2.
159 cluster_spec = {
160 "ps": ["ps0:2222", "ps1:2222"],
161 "worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
162 with
163 tf.compat.v1.device(tf.compat.v1.train.replica_device_setter(cluster=cluster_spec)):
164 # Build your graph
165 v1 = tf.Variable(...) # assigned to /job:ps/task:0
166 v2 = tf.Variable(...) # assigned to /job:ps/task:1
167 v3 = tf.Variable(...) # assigned to /job:ps/task:0
168 # Run compute
169 ```
171 Args:
172 ps_tasks: Number of tasks in the `ps` job. Ignored if `cluster` is
173 provided.
174 ps_device: String. Device of the `ps` job. If empty no `ps` job is used.
175 Defaults to `ps`.
176 worker_device: String. Device of the `worker` job. If empty no `worker`
177 job is used.
178 merge_devices: `Boolean`. If `True`, merges or only sets a device if the
179 device constraint is completely unset. merges device specification rather
180 than overriding them.
181 cluster: `ClusterDef` proto or `ClusterSpec`.
182 ps_ops: List of strings representing `Operation` types that need to be
183 placed on `ps` devices. If `None`, defaults to `STANDARD_PS_OPS`.
184 ps_strategy: A callable invoked for every ps `Operation` (i.e. matched by
185 `ps_ops`), that takes the `Operation` and returns the ps task index to
186 use. If `None`, defaults to a round-robin strategy across all `ps`
187 devices.
189 Returns:
190 A function to pass to `tf.device()`.
192 Raises:
193 TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer,
194 or if `ps_strategy` is provided but not a callable.
195 """
196 if cluster is not None:
197 if isinstance(cluster, server_lib.ClusterSpec):
198 cluster_spec = cluster.as_dict()
199 else:
200 cluster_spec = server_lib.ClusterSpec(cluster).as_dict()
201 # Get ps_job_name from ps_device by stripping "/job:".
202 ps_job_name = pydev.DeviceSpec.from_string(ps_device).job
203 if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
204 return None
205 ps_tasks = len(cluster_spec[ps_job_name])
207 if ps_tasks == 0:
208 return None
210 if ps_ops is None:
211 # TODO(sherrym): Variables in the LOCAL_VARIABLES collection should not be
212 # placed in the parameter server.
213 ps_ops = list(STANDARD_PS_OPS)
215 if not merge_devices:
216 logging.warning(
217 "DEPRECATION: It is recommended to set merge_devices=true in "
218 "replica_device_setter")
219 if ps_strategy is None:
220 ps_strategy = _RoundRobinStrategy(ps_tasks)
221 if not callable(ps_strategy):
222 raise TypeError("ps_strategy must be callable")
223 chooser = _ReplicaDeviceChooser(ps_tasks, ps_device, worker_device,
224 merge_devices, ps_ops, ps_strategy)
225 return chooser.device_function