Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/packed_distributed_variable.py: 35%
183 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A variable which packs a list of variables distributed across devices."""
17from tensorflow.python.distribute import device_util
18from tensorflow.python.eager import context
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_conversion_registry
21from tensorflow.python.ops import math_ops
22from tensorflow.python.ops import resource_variable_ops
25class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
26 """A variable which packs multiple variables distributed across devices.
28 It's only supported when eager execution is enabled.
29 For op-by-op execution, use an unpacked handle on the current device; for
30 function execution, use the packed handle to reduce the overhead of function
31 calls.
32 """
34 def __init__(self, distributed_variables=None, name=None, **unused_kwargs):
35 """Packs a list of variables which are distributed across devices.
37 Args:
38 distributed_variables: A list of distributed Variables to pack.
39 name: Optional name for the variable. Defaults to `'Variable'` and gets
40 uniquified automatically.
41 """
42 if not ops.executing_eagerly_outside_functions():
43 raise ValueError(
44 "PackedDistributedVariable should be created in eager mode.")
45 if not distributed_variables:
46 raise ValueError("Expect a non-empty list of variables to pack.")
47 for i, var in enumerate(distributed_variables):
48 if not resource_variable_ops.is_resource_variable(var):
49 raise ValueError("Expect a list of ResourceVariables to pack, "
50 "but the %d-th variable is %s" % (i, type(var)))
52 self._distributed_variables = distributed_variables
53 self._devices = [v.device for v in distributed_variables]
54 with ops.init_scope():
55 with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
56 handle = ops.pack_eager_tensors(
57 [var.handle for var in distributed_variables])
58 handle_name = ops.name_from_scope_name(name)
59 unique_id = "%s_%d" % (handle_name, ops.uid())
60 super(PackedDistributedVariable, self).__init__(
61 trainable=distributed_variables[0].trainable,
62 shape=distributed_variables[0].shape,
63 dtype=distributed_variables[0].dtype,
64 handle=handle,
65 synchronization=distributed_variables[0].synchronization,
66 constraint=distributed_variables[0].constraint,
67 aggregation=distributed_variables[0].aggregation,
68 distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access
69 name=name,
70 unique_id=unique_id,
71 handle_name=handle_name,
72 graph_element=None,
73 initial_value=None,
74 initializer_op=None,
75 is_initialized_op=None,
76 cached_value=None,
77 caching_device=None,
78 is_distributed_variables=True)
80 @property
81 def devices(self):
82 return self._devices
84 def on_device(self, device):
85 return PackedVarAndDevice(self, device)
87 def get_var_on_device(self, device):
88 for i, d in enumerate(self._devices):
89 if d == device:
90 return self._distributed_variables[i]
91 raise ValueError("Device %s is not found" % device)
93 def get_var_on_current_device(self):
94 current_device = device_util.canonicalize(device_util.current())
95 return self.get_var_on_device(current_device)
97 def initial_value(self, device):
98 """Returns the Tensor used as the initial value for the variable."""
99 return self.get_var_on_device(device).initial_value
101 @property
102 def handle(self):
103 if context.executing_eagerly():
104 return self.get_var_on_current_device().handle
105 else:
106 return self._handle
108 @property
109 def packed_handle(self):
110 return self._handle
112 def _read_variable_op(self):
113 if context.executing_eagerly():
114 return self.get_var_on_current_device().value()
115 else:
116 return super(PackedDistributedVariable, self)._read_variable_op()
118 def value(self):
119 return self._read_variable_op()
121 def is_initialized(self, name=None):
122 if context.executing_eagerly():
123 result = self._distributed_variables[0].is_initialized()
124 for v in self._distributed_variables[1:-1]:
125 result = math_ops.logical_and(result, v.is_initialized())
126 result = math_ops.logical_and(
127 result, self._distributed_variables[-1].is_initialized(), name=name)
128 else:
129 with ops.device(self._devices[0]):
130 result = super(PackedDistributedVariable, self).is_initialized(name)
131 for d in self._devices[1:-1]:
132 with ops.device(d):
133 initialized = super(PackedDistributedVariable,
134 self).is_initialized(name)
135 result = math_ops.logical_and(result, initialized)
136 with ops.device(self._devices[-1]):
137 initialized = super(PackedDistributedVariable,
138 self).is_initialized(name)
139 result = math_ops.logical_and(result, initialized, name=name)
140 return result
142 def _update(self, update_fn, value, **kwargs):
143 if context.executing_eagerly():
144 return update_fn(self.get_var_on_current_device(), value, **kwargs)
145 else:
146 return update_fn(super(PackedDistributedVariable, self), value, **kwargs)
148 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
149 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
150 return self._update(
151 update_fn=assign_sub_fn,
152 value=delta,
153 use_locking=use_locking,
154 name=name,
155 read_value=read_value)
157 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
158 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
159 return self._update(
160 update_fn=assign_add_fn,
161 value=delta,
162 use_locking=use_locking,
163 name=name,
164 read_value=read_value)
166 def assign(self, value, use_locking=None, name=None, read_value=True):
167 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
168 return self._update(
169 update_fn=assign_fn,
170 value=value,
171 use_locking=use_locking,
172 name=name,
173 read_value=read_value)
175 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
176 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
177 return self._update(
178 update_fn=scatter_sub_fn,
179 value=sparse_delta,
180 use_locking=use_locking,
181 name=name)
183 def scatter_add(self, sparse_delta, use_locking=False, name=None):
184 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
185 return self._update(
186 update_fn=scatter_add_fn,
187 value=sparse_delta,
188 use_locking=use_locking,
189 name=name)
191 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
192 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
193 return self._update(
194 update_fn=scatter_mul_fn,
195 value=sparse_delta,
196 use_locking=use_locking,
197 name=name)
199 def scatter_div(self, sparse_delta, use_locking=False, name=None):
200 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
201 return self._update(
202 update_fn=scatter_div_fn,
203 value=sparse_delta,
204 use_locking=use_locking,
205 name=name)
207 def scatter_min(self, sparse_delta, use_locking=False, name=None):
208 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
209 return self._update(
210 update_fn=scatter_min_fn,
211 value=sparse_delta,
212 use_locking=use_locking,
213 name=name)
215 def scatter_max(self, sparse_delta, use_locking=False, name=None):
216 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
217 return self._update(
218 update_fn=scatter_max_fn,
219 value=sparse_delta,
220 use_locking=use_locking,
221 name=name)
223 def scatter_update(self, sparse_delta, use_locking=False, name=None):
224 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
225 return self._update(
226 update_fn=scatter_update_fn,
227 value=sparse_delta,
228 use_locking=use_locking,
229 name=name)
231 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
232 if context.executing_eagerly():
233 return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access
234 dtype=dtype,
235 name=name,
236 as_ref=as_ref)
237 else:
238 return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access
239 dtype=dtype,
240 name=name,
241 as_ref=as_ref)
244class PackedVarAndDevice(object):
245 """Holds a packed distributed variable and a device."""
247 def __init__(self, var, device):
248 self._var = var
249 self._device = device
251 def __getattr__(self, name):
252 # Exceptions raised inside the contextmanager can cause a reference
253 # cycle.[1] The cycle involves the current frame, which holds the reference
254 # to the outer frame. Tensorflow, e.g. iterators, relies on object
255 # finalizers to clean up resources. Such references prevents the resource
256 # from being deleted and can cause leaks and errors. One corner the case is
257 # that iterators are kept alive and the garbage collector happens to run
258 # after auto control dependencies; this causes the deletion to lose the
259 # control dependencies to operations that uses such resources.
260 #
261 # Catch and re-raise the exception seems to workaround the issue.
262 #
263 # [1] https://bugs.python.org/issue43533
264 try:
265 with ops.device(self._device):
266 return getattr(self._var, name)
267 except: # pylint: disable=try-except-raise
268 raise
270 def var(self):
271 return self._var
273 def value(self):
274 with ops.device(self._device):
275 return self._var.value()
277 def read_value(self):
278 with ops.device(self._device):
279 return self._var.read_value()
281 @property
282 def initial_value(self):
283 return self._var.initial_value(self._device)
285 def initialized_value(self):
286 with ops.device(self._device):
287 return self._var.initialized_value()
289 @property
290 def device(self):
291 return self._device
293 @property
294 def handle(self):
295 with ops.device(self._device):
296 return self._var.handle
298 def on_device_handle(self):
299 with ops.device(self._device):
300 return self._var.get_var_on_current_device().handle
302 @property
303 def op(self):
304 with ops.device(self._device):
305 return self._var.op
307 def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
308 with ops.device(self._device):
309 return self._var.assign_sub(delta, use_locking, name, read_value)
311 def assign_add(self, delta, use_locking=None, name=None, read_value=True):
312 with ops.device(self._device):
313 return self._var.assign_add(delta, use_locking, name, read_value)
315 def assign(self, value, use_locking=None, name=None, read_value=True):
316 with ops.device(self._device):
317 return self._var.assign(value, use_locking, name, read_value)
319 def scatter_sub(self, sparse_delta, use_locking=False, name=None):
320 with ops.device(self._device):
321 return self._var.scatter_sub(sparse_delta, use_locking, name)
323 def scatter_add(self, sparse_delta, use_locking=False, name=None):
324 with ops.device(self._device):
325 return self._var.scatter_add(sparse_delta, use_locking, name)
327 def scatter_mul(self, sparse_delta, use_locking=False, name=None):
328 with ops.device(self._device):
329 return self._var.scatter_mul(sparse_delta, use_locking, name)
331 def scatter_div(self, sparse_delta, use_locking=False, name=None):
332 with ops.device(self._device):
333 return self._var.scatter_div(sparse_delta, use_locking, name)
335 def scatter_min(self, sparse_delta, use_locking=False, name=None):
336 with ops.device(self._device):
337 return self._var.scatter_min(sparse_delta, use_locking, name)
339 def scatter_max(self, sparse_delta, use_locking=False, name=None):
340 with ops.device(self._device):
341 return self._var.scatter_max(sparse_delta, use_locking, name)
343 def scatter_update(self, sparse_delta, use_locking=False, name=None):
344 with ops.device(self._device):
345 return self._var.scatter_update(sparse_delta, use_locking, name)
347 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
348 with ops.device(self._device):
349 return self._var._dense_var_to_tensor( # pylint: disable=protected-access
350 dtype=dtype,
351 name=name,
352 as_ref=as_ref)
354 def _as_graph_element(self):
355 return self._var._as_graph_element() # pylint: disable=protected-access
358def _tensor_conversion_packed_var_and_device(var,
359 dtype=None,
360 name=None,
361 as_ref=False):
362 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
365tensor_conversion_registry.register_tensor_conversion_function(
366 PackedVarAndDevice, _tensor_conversion_packed_var_and_device)