Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/device_util.py: 33%
54 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Device-related support functions."""
18from tensorflow.python.eager import context
19from tensorflow.python.framework import config
20from tensorflow.python.framework import device as tf_device
21from tensorflow.python.framework import ops
24def canonicalize(d, default=None):
25 """Canonicalize device string.
27 If d has missing components, the rest would be deduced from the `default`
28 argument or from '/replica:0/task:0/device:CPU:0'. For example:
29 If d = '/cpu:0', default='/job:worker/task:1', it returns
30 '/job:worker/replica:0/task:1/device:CPU:0'.
31 If d = '/cpu:0', default='/job:worker', it returns
32 '/job:worker/replica:0/task:0/device:CPU:0'.
33 If d = '/gpu:0', default=None, it returns
34 '/replica:0/task:0/device:GPU:0'.
36 Note: This uses "job:localhost" as the default if executing eagerly.
38 Args:
39 d: a device string or tf.config.LogicalDevice
40 default: a string for default device if d doesn't have all components.
42 Returns:
43 a canonicalized device string.
44 """
45 if isinstance(d, context.LogicalDevice):
46 d = tf_device.DeviceSpec.from_string(d.name)
47 else:
48 d = tf_device.DeviceSpec.from_string(d)
50 assert d.device_type is None or d.device_type == d.device_type.upper(), (
51 "Device type '%s' must be all-caps." % (d.device_type,))
52 # Fill in missing device fields using defaults.
53 result = tf_device.DeviceSpec(
54 replica=0, task=0, device_type="CPU", device_index=0)
55 if ops.executing_eagerly_outside_functions():
56 # Try to deduce job, replica and task in case it's in a multi worker setup.
57 # TODO(b/151452748): Using list_logical_devices is not always safe since it
58 # may return remote devices as well, but we're already doing this elsewhere.
59 host_cpu = tf_device.DeviceSpec.from_string(
60 config.list_logical_devices("CPU")[0].name)
61 if host_cpu.job:
62 result = result.make_merged_spec(host_cpu)
63 else:
64 # The default job is localhost if eager execution is enabled
65 result = result.replace(job="localhost")
66 if default:
67 # Overrides any defaults with values from the default device if given.
68 result = result.make_merged_spec(
69 tf_device.DeviceSpec.from_string(default))
71 # Apply `d` last, so that it's values take precedence over the defaults.
72 result = result.make_merged_spec(d)
73 return result.to_string()
76def canonicalize_without_job_and_task(d):
77 """Partially canonicalize device string.
79 This returns device string from `d` without including job and task.
80 This is most useful for parameter server strategy where the device strings are
81 generated on the chief, but executed on workers.
83 For example:
84 If d = '/cpu:0', default='/job:worker/task:1', it returns
85 '/replica:0/device:CPU:0'.
86 If d = '/cpu:0', default='/job:worker', it returns
87 '/replica:0/device:CPU:0'.
88 If d = '/gpu:0', default=None, it returns
89 '/replica:0/device:GPU:0'.
91 Note: This uses "job:localhost" as the default if executing eagerly.
93 Args:
94 d: a device string or tf.config.LogicalDevice
96 Returns:
97 a partially canonicalized device string.
98 """
99 canonicalized_device = canonicalize(d)
100 spec = tf_device.DeviceSpec.from_string(canonicalized_device)
101 spec = spec.replace(job=None, task=None, replica=0)
102 return spec.to_string()
105def resolve(d):
106 """Canonicalize `d` with current device as default."""
107 return canonicalize(d, default=current())
110class _FakeNodeDef(object):
111 """A fake NodeDef for _FakeOperation."""
113 __slots__ = ["op", "name"]
115 def __init__(self):
116 self.op = ""
117 self.name = ""
120class _FakeOperation(object):
121 """A fake Operation object to pass to device functions."""
123 def __init__(self):
124 self.device = ""
125 self.type = ""
126 self.name = ""
127 self.node_def = _FakeNodeDef()
129 def _set_device(self, device):
130 self.device = ops._device_string(device) # pylint: disable=protected-access
132 def _set_device_from_string(self, device_str):
133 self.device = device_str
136def current():
137 """Return a string (not canonicalized) for the current device."""
138 # TODO(josh11b): Work out how this function interacts with ops.colocate_with.
139 if ops.executing_eagerly_outside_functions():
140 d = context.context().device_name
141 else:
142 op = _FakeOperation()
143 ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
144 d = op.device
145 return d
148def get_host_for_device(device):
149 """Returns the corresponding host device for the given device."""
150 spec = tf_device.DeviceSpec.from_string(device)
151 return tf_device.DeviceSpec(
152 job=spec.job, replica=spec.replica, task=spec.task,
153 device_type="CPU", device_index=0).to_string()
156def local_devices_from_num_gpus(num_gpus):
157 """Returns device strings for local GPUs or CPU."""
158 return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
159 ("/device:CPU:0",))