Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/packed_distributed_variable.py: 35%

183 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A variable which packs a list of variables distributed across devices.""" 

16 

17from tensorflow.python.distribute import device_util 

18from tensorflow.python.eager import context 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_conversion_registry 

21from tensorflow.python.ops import math_ops 

22from tensorflow.python.ops import resource_variable_ops 

23 

24 

25class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable): 

26 """A variable which packs multiple variables distributed across devices. 

27 

28 It's only supported when eager execution is enabled. 

29 For op-by-op execution, use an unpacked handle on the current device; for 

30 function execution, use the packed handle to reduce the overhead of function 

31 calls. 

32 """ 

33 

34 def __init__(self, distributed_variables=None, name=None, **unused_kwargs): 

35 """Packs a list of variables which are distributed across devices. 

36 

37 Args: 

38 distributed_variables: A list of distributed Variables to pack. 

39 name: Optional name for the variable. Defaults to `'Variable'` and gets 

40 uniquified automatically. 

41 """ 

42 if not ops.executing_eagerly_outside_functions(): 

43 raise ValueError( 

44 "PackedDistributedVariable should be created in eager mode.") 

45 if not distributed_variables: 

46 raise ValueError("Expect a non-empty list of variables to pack.") 

47 for i, var in enumerate(distributed_variables): 

48 if not resource_variable_ops.is_resource_variable(var): 

49 raise ValueError("Expect a list of ResourceVariables to pack, " 

50 "but the %d-th variable is %s" % (i, type(var))) 

51 

52 self._distributed_variables = distributed_variables 

53 self._devices = [v.device for v in distributed_variables] 

54 with ops.init_scope(): 

55 with ops.name_scope(name, "Variable", skip_on_eager=False) as name: 

56 handle = ops.pack_eager_tensors( 

57 [var.handle for var in distributed_variables]) 

58 handle_name = ops.name_from_scope_name(name) 

59 unique_id = "%s_%d" % (handle_name, ops.uid()) 

60 super(PackedDistributedVariable, self).__init__( 

61 trainable=distributed_variables[0].trainable, 

62 shape=distributed_variables[0].shape, 

63 dtype=distributed_variables[0].dtype, 

64 handle=handle, 

65 synchronization=distributed_variables[0].synchronization, 

66 constraint=distributed_variables[0].constraint, 

67 aggregation=distributed_variables[0].aggregation, 

68 distribute_strategy=distributed_variables[0]._distribute_strategy, # pylint: disable=protected-access 

69 name=name, 

70 unique_id=unique_id, 

71 handle_name=handle_name, 

72 graph_element=None, 

73 initial_value=None, 

74 initializer_op=None, 

75 is_initialized_op=None, 

76 cached_value=None, 

77 caching_device=None, 

78 is_distributed_variables=True) 

79 

80 @property 

81 def devices(self): 

82 return self._devices 

83 

84 def on_device(self, device): 

85 return PackedVarAndDevice(self, device) 

86 

87 def get_var_on_device(self, device): 

88 for i, d in enumerate(self._devices): 

89 if d == device: 

90 return self._distributed_variables[i] 

91 raise ValueError("Device %s is not found" % device) 

92 

93 def get_var_on_current_device(self): 

94 current_device = device_util.canonicalize(device_util.current()) 

95 return self.get_var_on_device(current_device) 

96 

97 def initial_value(self, device): 

98 """Returns the Tensor used as the initial value for the variable.""" 

99 return self.get_var_on_device(device).initial_value 

100 

101 @property 

102 def handle(self): 

103 if context.executing_eagerly(): 

104 return self.get_var_on_current_device().handle 

105 else: 

106 return self._handle 

107 

108 @property 

109 def packed_handle(self): 

110 return self._handle 

111 

112 def _read_variable_op(self): 

113 if context.executing_eagerly(): 

114 return self.get_var_on_current_device().value() 

115 else: 

116 return super(PackedDistributedVariable, self)._read_variable_op() 

117 

118 def value(self): 

119 return self._read_variable_op() 

120 

121 def is_initialized(self, name=None): 

122 if context.executing_eagerly(): 

123 result = self._distributed_variables[0].is_initialized() 

124 for v in self._distributed_variables[1:-1]: 

125 result = math_ops.logical_and(result, v.is_initialized()) 

126 result = math_ops.logical_and( 

127 result, self._distributed_variables[-1].is_initialized(), name=name) 

128 else: 

129 with ops.device(self._devices[0]): 

130 result = super(PackedDistributedVariable, self).is_initialized(name) 

131 for d in self._devices[1:-1]: 

132 with ops.device(d): 

133 initialized = super(PackedDistributedVariable, 

134 self).is_initialized(name) 

135 result = math_ops.logical_and(result, initialized) 

136 with ops.device(self._devices[-1]): 

137 initialized = super(PackedDistributedVariable, 

138 self).is_initialized(name) 

139 result = math_ops.logical_and(result, initialized, name=name) 

140 return result 

141 

142 def _update(self, update_fn, value, **kwargs): 

143 if context.executing_eagerly(): 

144 return update_fn(self.get_var_on_current_device(), value, **kwargs) 

145 else: 

146 return update_fn(super(PackedDistributedVariable, self), value, **kwargs) 

147 

148 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

149 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 

150 return self._update( 

151 update_fn=assign_sub_fn, 

152 value=delta, 

153 use_locking=use_locking, 

154 name=name, 

155 read_value=read_value) 

156 

157 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

158 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 

159 return self._update( 

160 update_fn=assign_add_fn, 

161 value=delta, 

162 use_locking=use_locking, 

163 name=name, 

164 read_value=read_value) 

165 

166 def assign(self, value, use_locking=None, name=None, read_value=True): 

167 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 

168 return self._update( 

169 update_fn=assign_fn, 

170 value=value, 

171 use_locking=use_locking, 

172 name=name, 

173 read_value=read_value) 

174 

175 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

176 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 

177 return self._update( 

178 update_fn=scatter_sub_fn, 

179 value=sparse_delta, 

180 use_locking=use_locking, 

181 name=name) 

182 

183 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

184 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 

185 return self._update( 

186 update_fn=scatter_add_fn, 

187 value=sparse_delta, 

188 use_locking=use_locking, 

189 name=name) 

190 

191 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

192 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 

193 return self._update( 

194 update_fn=scatter_mul_fn, 

195 value=sparse_delta, 

196 use_locking=use_locking, 

197 name=name) 

198 

199 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

200 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 

201 return self._update( 

202 update_fn=scatter_div_fn, 

203 value=sparse_delta, 

204 use_locking=use_locking, 

205 name=name) 

206 

207 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

208 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 

209 return self._update( 

210 update_fn=scatter_min_fn, 

211 value=sparse_delta, 

212 use_locking=use_locking, 

213 name=name) 

214 

215 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

216 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 

217 return self._update( 

218 update_fn=scatter_max_fn, 

219 value=sparse_delta, 

220 use_locking=use_locking, 

221 name=name) 

222 

223 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

224 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 

225 return self._update( 

226 update_fn=scatter_update_fn, 

227 value=sparse_delta, 

228 use_locking=use_locking, 

229 name=name) 

230 

231 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

232 if context.executing_eagerly(): 

233 return self.get_var_on_current_device()._dense_var_to_tensor( # pylint: disable=protected-access 

234 dtype=dtype, 

235 name=name, 

236 as_ref=as_ref) 

237 else: 

238 return super(PackedDistributedVariable, self)._dense_var_to_tensor( # pylint: disable=protected-access 

239 dtype=dtype, 

240 name=name, 

241 as_ref=as_ref) 

242 

243 

244class PackedVarAndDevice(object): 

245 """Holds a packed distributed variable and a device.""" 

246 

247 def __init__(self, var, device): 

248 self._var = var 

249 self._device = device 

250 

251 def __getattr__(self, name): 

252 # Exceptions raised inside the contextmanager can cause a reference 

253 # cycle.[1] The cycle involves the current frame, which holds the reference 

254 # to the outer frame. Tensorflow, e.g. iterators, relies on object 

255 # finalizers to clean up resources. Such references prevents the resource 

256 # from being deleted and can cause leaks and errors. One corner the case is 

257 # that iterators are kept alive and the garbage collector happens to run 

258 # after auto control dependencies; this causes the deletion to lose the 

259 # control dependencies to operations that uses such resources. 

260 # 

261 # Catch and re-raise the exception seems to workaround the issue. 

262 # 

263 # [1] https://bugs.python.org/issue43533 

264 try: 

265 with ops.device(self._device): 

266 return getattr(self._var, name) 

267 except: # pylint: disable=try-except-raise 

268 raise 

269 

270 def var(self): 

271 return self._var 

272 

273 def value(self): 

274 with ops.device(self._device): 

275 return self._var.value() 

276 

277 def read_value(self): 

278 with ops.device(self._device): 

279 return self._var.read_value() 

280 

281 @property 

282 def initial_value(self): 

283 return self._var.initial_value(self._device) 

284 

285 def initialized_value(self): 

286 with ops.device(self._device): 

287 return self._var.initialized_value() 

288 

289 @property 

290 def device(self): 

291 return self._device 

292 

293 @property 

294 def handle(self): 

295 with ops.device(self._device): 

296 return self._var.handle 

297 

298 def on_device_handle(self): 

299 with ops.device(self._device): 

300 return self._var.get_var_on_current_device().handle 

301 

302 @property 

303 def op(self): 

304 with ops.device(self._device): 

305 return self._var.op 

306 

307 def assign_sub(self, delta, use_locking=None, name=None, read_value=True): 

308 with ops.device(self._device): 

309 return self._var.assign_sub(delta, use_locking, name, read_value) 

310 

311 def assign_add(self, delta, use_locking=None, name=None, read_value=True): 

312 with ops.device(self._device): 

313 return self._var.assign_add(delta, use_locking, name, read_value) 

314 

315 def assign(self, value, use_locking=None, name=None, read_value=True): 

316 with ops.device(self._device): 

317 return self._var.assign(value, use_locking, name, read_value) 

318 

319 def scatter_sub(self, sparse_delta, use_locking=False, name=None): 

320 with ops.device(self._device): 

321 return self._var.scatter_sub(sparse_delta, use_locking, name) 

322 

323 def scatter_add(self, sparse_delta, use_locking=False, name=None): 

324 with ops.device(self._device): 

325 return self._var.scatter_add(sparse_delta, use_locking, name) 

326 

327 def scatter_mul(self, sparse_delta, use_locking=False, name=None): 

328 with ops.device(self._device): 

329 return self._var.scatter_mul(sparse_delta, use_locking, name) 

330 

331 def scatter_div(self, sparse_delta, use_locking=False, name=None): 

332 with ops.device(self._device): 

333 return self._var.scatter_div(sparse_delta, use_locking, name) 

334 

335 def scatter_min(self, sparse_delta, use_locking=False, name=None): 

336 with ops.device(self._device): 

337 return self._var.scatter_min(sparse_delta, use_locking, name) 

338 

339 def scatter_max(self, sparse_delta, use_locking=False, name=None): 

340 with ops.device(self._device): 

341 return self._var.scatter_max(sparse_delta, use_locking, name) 

342 

343 def scatter_update(self, sparse_delta, use_locking=False, name=None): 

344 with ops.device(self._device): 

345 return self._var.scatter_update(sparse_delta, use_locking, name) 

346 

347 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

348 with ops.device(self._device): 

349 return self._var._dense_var_to_tensor( # pylint: disable=protected-access 

350 dtype=dtype, 

351 name=name, 

352 as_ref=as_ref) 

353 

354 def _as_graph_element(self): 

355 return self._var._as_graph_element() # pylint: disable=protected-access 

356 

357 

358def _tensor_conversion_packed_var_and_device(var, 

359 dtype=None, 

360 name=None, 

361 as_ref=False): 

362 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

363 

364 

365tensor_conversion_registry.register_tensor_conversion_function( 

366 PackedVarAndDevice, _tensor_conversion_packed_var_and_device)