Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/device_spec.py: 80%
165 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class to represent a device."""
17from tensorflow.python.util.tf_export import tf_export
18from tensorflow.python import pywrap_tfe
20# EPU represents for TPU embedding for now. Subject to change in future.
21_VALID_DEVICE_TYPES = frozenset({"CPU", "GPU", "TPU", "CUSTOM", "EPU"})
23# ==============================================================================
24# == Global Implementation Details =============================================
25# ==============================================================================
26_STRING_TO_COMPONENTS_CACHE = {}
27_COMPONENTS_TO_STRING_CACHE = {}
30def _as_str_or_none(inp):
31 return None if inp is None else str(inp)
34def _as_int_or_none(inp):
35 return None if inp is None else int(inp)
38def _as_device_str_or_none(device_type):
39 # For backwards compatibility only, we support lowercase variants of
40 # cpu and gpu but turn them into uppercase here.
41 if device_type in ("cpu", "gpu"):
42 return device_type.upper()
43 return _as_str_or_none(device_type)
46@tf_export("DeviceSpec", v1=[])
47class DeviceSpecV2(object):
48 """Represents a (possibly partial) specification for a TensorFlow device.
50 `DeviceSpec`s are used throughout TensorFlow to describe where state is stored
51 and computations occur. Using `DeviceSpec` allows you to parse device spec
52 strings to verify their validity, merge them or compose them programmatically.
54 Example:
56 ```python
57 # Place the operations on device "GPU:0" in the "ps" job.
58 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
59 with tf.device(device_spec.to_string()):
60 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0.
61 my_var = tf.Variable(..., name="my_variable")
62 squared_var = tf.square(my_var)
63 ```
65 With eager execution disabled (by default in TensorFlow 1.x and by calling
66 disable_eager_execution() in TensorFlow 2.x), the following syntax
67 can be used:
69 ```python
70 tf.compat.v1.disable_eager_execution()
72 # Same as previous
73 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
74 # No need of .to_string() method.
75 with tf.device(device_spec):
76 my_var = tf.Variable(..., name="my_variable")
77 squared_var = tf.square(my_var)
78 ```
80 If a `DeviceSpec` is partially specified, it will be merged with other
81 `DeviceSpec`s according to the scope in which it is defined. `DeviceSpec`
82 components defined in inner scopes take precedence over those defined in
83 outer scopes.
85 ```python
86 gpu0_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0)
87 with tf.device(DeviceSpec(job="train").to_string()):
88 with tf.device(gpu0_spec.to_string()):
89 # Nodes created here will be assigned to /job:ps/device:GPU:0.
90 with tf.device(DeviceSpec(device_type="GPU", device_index=1).to_string()):
91 # Nodes created here will be assigned to /job:train/device:GPU:1.
92 ```
94 A `DeviceSpec` consists of 5 components -- each of
95 which is optionally specified:
97 * Job: The job name.
98 * Replica: The replica index.
99 * Task: The task index.
100 * Device type: The device type string (e.g. "CPU" or "GPU").
101 * Device index: The device index.
102 """
104 __slots__ = ("_job", "_replica", "_task", "_device_type", "_device_index",
105 "_as_string", "_hash")
107 def __init__(self,
108 job=None,
109 replica=None,
110 task=None,
111 device_type=None,
112 device_index=None):
113 """Create a new `DeviceSpec` object.
115 Args:
116 job: string. Optional job name.
117 replica: int. Optional replica index.
118 task: int. Optional task index.
119 device_type: Optional device type string (e.g. "CPU" or "GPU")
120 device_index: int. Optional device index. If left unspecified, device
121 represents 'any' device_index.
122 """
123 self._job = _as_str_or_none(job)
124 self._replica = _as_int_or_none(replica)
125 self._task = _as_int_or_none(task)
126 self._device_type = _as_device_str_or_none(device_type)
127 self._device_index = _as_int_or_none(device_index)
128 self._as_string = self._components_to_string(
129 job=self._job,
130 replica=self._replica,
131 task=self._task,
132 device_type=self._device_type,
133 device_index=self._device_index)
134 self._hash = hash(self.to_string())
136 def to_string(self):
137 """Return a string representation of this `DeviceSpec`.
139 Returns:
140 a string of the form
141 /job:<name>/replica:<id>/task:<id>/device:<device_type>:<id>.
142 """
143 return self._as_string
145 @classmethod
146 def from_string(cls, spec):
147 """Construct a `DeviceSpec` from a string.
149 Args:
150 spec: a string of the form
151 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or
152 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are
153 mutually exclusive. All entries are optional.
155 Returns:
156 A DeviceSpec.
157 """
158 return cls(*cls._string_to_components(spec))
160 def parse_from_string(self, spec):
161 """Parse a `DeviceSpec` name into its components.
163 **2.x behavior change**:
165 In TensorFlow 1.x, this function mutates its own state and returns itself.
166 In 2.x, DeviceSpecs are immutable, and this function will return a
167 DeviceSpec which contains the spec.
169 * Recommended:
171 ```
172 # my_spec and my_updated_spec are unrelated.
173 my_spec = tf.DeviceSpec.from_string("/CPU:0")
174 my_updated_spec = tf.DeviceSpec.from_string("/GPU:0")
175 with tf.device(my_updated_spec):
176 ...
177 ```
179 * Will work in 1.x and 2.x (though deprecated in 2.x):
181 ```
182 my_spec = tf.DeviceSpec.from_string("/CPU:0")
183 my_updated_spec = my_spec.parse_from_string("/GPU:0")
184 with tf.device(my_updated_spec):
185 ...
186 ```
188 * Will NOT work in 2.x:
190 ```
191 my_spec = tf.DeviceSpec.from_string("/CPU:0")
192 my_spec.parse_from_string("/GPU:0") # <== Will not update my_spec
193 with tf.device(my_spec):
194 ...
195 ```
197 In general, `DeviceSpec.from_string` should completely replace
198 `DeviceSpec.parse_from_string`, and `DeviceSpec.replace` should
199 completely replace setting attributes directly.
201 Args:
202 spec: an optional string of the form
203 /job:<name>/replica:<id>/task:<id>/device:CPU:<id> or
204 /job:<name>/replica:<id>/task:<id>/device:GPU:<id> as cpu and gpu are
205 mutually exclusive. All entries are optional.
207 Returns:
208 The `DeviceSpec`.
210 Raises:
211 ValueError: if the spec was not valid.
212 """
213 return self.from_string(spec)
215 def make_merged_spec(self, dev):
216 """Returns a new DeviceSpec which incorporates `dev`.
218 When combining specs, `dev` will take precedence over the current spec.
219 So for instance:
220 ```
221 first_spec = tf.DeviceSpec(job=0, device_type="CPU")
222 second_spec = tf.DeviceSpec(device_type="GPU")
223 combined_spec = first_spec.make_merged_spec(second_spec)
224 ```
226 is equivalent to:
227 ```
228 combined_spec = tf.DeviceSpec(job=0, device_type="GPU")
229 ```
231 Args:
232 dev: a `DeviceSpec`
234 Returns:
235 A new `DeviceSpec` which combines `self` and `dev`
236 """
237 return self.__class__(*self._get_combined_properties(dev))
239 def replace(self, **kwargs):
240 """Convenience method for making a new DeviceSpec by overriding fields.
242 For instance:
243 ```
244 my_spec = DeviceSpec=(job="my_job", device="CPU")
245 my_updated_spec = my_spec.replace(device="GPU")
246 my_other_spec = my_spec.replace(device=None)
247 ```
249 Args:
250 **kwargs: This method takes the same args as the DeviceSpec constructor
252 Returns:
253 A DeviceSpec with the fields specified in kwargs overridden.
254 """
255 init_kwargs = dict(
256 job=self.job,
257 replica=self.replica,
258 task=self.task,
259 device_type=self.device_type,
260 device_index=self.device_index)
262 # Explicitly provided kwargs take precedence.
263 init_kwargs.update(kwargs)
264 return self.__class__(**init_kwargs)
266 @property
267 def job(self):
268 return self._job
270 @property
271 def replica(self):
272 return self._replica
274 @property
275 def task(self):
276 return self._task
278 @property
279 def device_type(self):
280 return self._device_type
282 @property
283 def device_index(self):
284 return self._device_index
286 def _get_combined_properties(self, dev):
287 """Combine the current DeviceSpec with another DeviceSpec.
289 The combination of DeviceSpecs is will give priority to dev.
291 Args:
292 dev: a `DeviceSpec`
294 Returns:
295 A tuple of (job, replica, task, device_type, device_index) which
296 represents the combination of self and dev.
297 """
298 return (
299 dev.job if dev.job is not None else self.job,
300 dev.replica if dev.replica is not None else self.replica,
301 dev.task if dev.task is not None else self.task,
302 dev.device_type if dev.device_type is not None else self.device_type,
303 dev.device_index if dev.device_index is not None else self.device_index,
304 )
306 @staticmethod
307 def _get_valid_device_types():
308 valid_device_types = set({})
309 physical_devices = pywrap_tfe.TF_ListPluggablePhysicalDevices()
310 for device in physical_devices:
311 valid_device_types.add(device.decode().split(":")[1])
312 valid_device_types = valid_device_types | _VALID_DEVICE_TYPES
313 return valid_device_types
315 @staticmethod
316 def _string_to_components(spec=None):
317 """Stateless portion of device spec string parsing.
319 Args:
320 spec: An optional string specifying a device specification.
322 Returns:
323 The parsed components of `spec`. Note that the result of this function
324 must go through attribute setters of DeviceSpec, and should therefore NOT
325 be used directly.
326 """
327 cached_result = _STRING_TO_COMPONENTS_CACHE.get(spec)
328 if cached_result is not None:
329 return cached_result
331 raw_spec = spec # keep a copy of the original to update the cache
332 job, replica, task, device_type, device_index = None, None, None, None, None
334 spec = spec or ""
335 splits = [x.split(":") for x in spec.split("/")]
336 valid_device_types = DeviceSpecV2._get_valid_device_types()
337 for y in splits:
338 ly = len(y)
339 if y:
340 # NOTE(taylorrobie): these will go through setters later.
341 if ly == 2 and y[0] == "job":
342 job = y[1]
343 elif ly == 2 and y[0] == "replica":
344 replica = y[1]
345 elif ly == 2 and y[0] == "task":
346 task = y[1]
347 elif ((ly == 1 or ly == 2) and (y[0].upper() in valid_device_types)):
348 if device_type is not None:
349 raise ValueError(f"Multiple device types are not allowed "
350 f"while parsing the device spec: {spec}.")
351 device_type = y[0].upper()
352 if ly == 2 and y[1] != "*":
353 device_index = int(y[1])
354 elif ly == 3 and y[0] == "device":
355 if device_type is not None:
356 raise ValueError(f"Multiple device types are not allowed "
357 f"while parsing the device spec: {spec}.")
358 device_type = y[1]
359 if y[2] != "*":
360 device_index = int(y[2])
361 elif ly and y[0] != "": # pylint: disable=g-explicit-bool-comparison
362 raise ValueError(f"Unknown attribute '{y[0]}' is encountered "
363 f"while parsing the device spec: '{spec}'.")
365 output = (job, replica, task, device_type, device_index)
366 _STRING_TO_COMPONENTS_CACHE[raw_spec] = output
367 return output
369 @staticmethod
370 def _components_to_string(job, replica, task, device_type, device_index):
371 """Stateless portion of `to_string` (separated to allow caching)."""
372 key = (job, replica, task, device_type, device_index)
373 cached_result = _COMPONENTS_TO_STRING_CACHE.get(key)
374 if cached_result is not None:
375 return cached_result
377 output = []
378 if job is not None:
379 output.append("/job:" + job)
380 if replica is not None:
381 output.append("/replica:" + str(replica))
382 if task is not None:
383 output.append("/task:" + str(task))
384 if device_type is not None:
385 device_index_string = "*"
386 if device_index is not None:
387 # Unlike the others, device_index is stored as an int.
388 device_index_string = str(device_index)
389 output.append("/device:%s:%s" % (device_type, device_index_string))
391 output = "".join(output)
392 _COMPONENTS_TO_STRING_CACHE[key] = output
393 return output
395 def __eq__(self, other):
396 """Checks if the `other` DeviceSpec is same as the current instance, eg have
398 same value for all the internal fields.
400 Args:
401 other: Another DeviceSpec
403 Returns:
404 Return `True` if `other` is also a DeviceSpec instance and has same value
405 as the current instance.
406 Return `False` otherwise.
407 """
408 return (isinstance(other, self.__class__) and
409 self.to_string() == other.to_string())
411 def __hash__(self):
412 return self._hash
414 def __repr__(self):
415 return (
416 f"<DeviceSpec(job={self.job}, replica={self.replica}, task={self.task}, "
417 f"device_type={self.device_type}, device_index={self.device_index})>")
420@tf_export(v1=["DeviceSpec"]) # pylint: disable=missing-docstring
421class DeviceSpecV1(DeviceSpecV2):
422 __doc__ = DeviceSpecV2.__doc__
423 __slots__ = DeviceSpecV2.__slots__
425 @DeviceSpecV2.job.setter
426 def job(self, job):
427 self._job = _as_str_or_none(job)
428 self._as_string, self._hash = None, None
430 @DeviceSpecV2.replica.setter
431 def replica(self, replica):
432 self._replica = _as_int_or_none(replica)
433 self._as_string, self._hash = None, None
435 @DeviceSpecV2.task.setter
436 def task(self, task):
437 self._task = _as_int_or_none(task)
438 self._as_string, self._hash = None, None
440 @DeviceSpecV2.device_type.setter
441 def device_type(self, device_type):
442 self._device_type = _as_device_str_or_none(device_type)
443 self._as_string, self._hash = None, None
445 @DeviceSpecV2.device_index.setter
446 def device_index(self, device_index):
447 self._device_index = _as_int_or_none(device_index)
448 self._as_string, self._hash = None, None
450 def __hash__(self):
451 if self._hash is None:
452 self._hash = hash(self.to_string())
453 return self._hash
455 def to_string(self):
456 if self._as_string is None:
457 self._as_string = self._components_to_string(
458 job=self.job,
459 replica=self.replica,
460 task=self.task,
461 device_type=self.device_type,
462 device_index=self.device_index)
463 return self._as_string
465 def parse_from_string(self, spec):
466 (self.job, self.replica, self.task, self.device_type,
467 self.device_index) = self._string_to_components(spec)
469 return self
471 def merge_from(self, dev):
472 """Merge the properties of "dev" into this `DeviceSpec`.
474 Note: Will be removed in TensorFlow 2.x since DeviceSpecs will become
475 immutable.
477 Args:
478 dev: a `DeviceSpec`.
479 """
480 (self.job, self.replica, self.task, self.device_type,
481 self.device_index) = self._get_combined_properties(dev)
483 # Use parent class docstrings for public methods.
484 to_string.__doc__ = DeviceSpecV2.to_string.__doc__
485 parse_from_string.__doc__ = DeviceSpecV2.parse_from_string.__doc__