Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/values_util.py: 30%

139 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utility functions used by values.py and ps_values.py.""" 

16 

17from tensorflow.python.distribute import distribute_lib 

18from tensorflow.python.distribute import reduce_util 

19from tensorflow.python.eager import context 

20from tensorflow.python.framework import ops 

21from tensorflow.python.framework import tensor_util 

22from tensorflow.python.ops import control_flow_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.ops import variable_scope as vs 

25from tensorflow.python.saved_model import save_context 

26from tensorflow.python.saved_model import save_options 

27from tensorflow.python.training.saving import saveable_object 

28 

29 

30def write_object_proto(var, proto, options): 

31 """Update a SavedObject proto for the caller. 

32 

33 If a DistributedVariable object supports this method, it will be called when 

34 saving with a pre-built `SavedObject` proto representing the object, plus an 

35 instance of `SaveOptions`. This method is then free to modify that proto 

36 instance. 

37 

38 `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally 

39 write out information about their components to the 

40 `experimental_distributed_variable_components` field of a 

41 `SavedVariable` (depending on the `SaveOptions` variable policy). 

42 

43 Args: 

44 var: The DistributedVariable object. 

45 proto: A pre-built `SavedObject` proto for this object. It is assumed this 

46 will be a `SavedVariable` instance. 

47 options: A `SaveOptions` instance. 

48 """ 

49 if options.experimental_variable_policy._expand_distributed_variables( # pylint: disable=protected-access 

50 ): 

51 for var in var.values: 

52 var_proto = ( 

53 proto.variable.experimental_distributed_variable_components.add()) 

54 var_proto.name = var.name.split(":")[0] 

55 var_proto.device = var.device 

56 

57 

58def get_on_write_saveable(var, primary_var, name): 

59 """Return saveable spec for AUTO and ON_WRITE variables.""" 

60 # We use a callable so that we don't have to evaluate this expression 

61 # in the case where we are trying to restore instead of save. 

62 def tensor(): 

63 if context.executing_eagerly() and not primary_var.is_initialized(): 

64 # A SaveSpec tensor value of `None` indicates that the variable is 

65 # uninitialized. 

66 return None 

67 strategy = var.distribute_strategy 

68 return strategy.extended.read_var(var) 

69 

70 spec = saveable_object.SaveSpec( 

71 tensor=tensor, 

72 slice_spec="", 

73 name=name, 

74 dtype=var.dtype, 

75 device=primary_var.device) 

76 

77 return tensor, [spec] 

78 

79 

80def get_on_write_restore_ops(var, tensor): 

81 """Return restore ops for AUTO and ON_WRITE variables.""" 

82 packed_var = var._packed_variable # pylint: disable=protected-access 

83 if packed_var is not None: 

84 return control_flow_ops.group( 

85 tuple( 

86 assign_on_device(d, packed_var, tensor) 

87 for d in packed_var.devices)) 

88 return control_flow_ops.group( 

89 tuple( 

90 assign_on_device(v.device, v, tensor) 

91 for v in var.values)) 

92 

93 

94def get_on_read_saveable(var, primary_var, name): 

95 """Return saveables for ON_READ variable.""" 

96 

97 # We use a callable so that we don't have to evaluate this expression 

98 # in the case where we are trying to restore instead of save. 

99 def tensor(): 

100 return var._get_cross_replica() # pylint: disable=protected-access 

101 

102 spec = saveable_object.SaveSpec( 

103 tensor=tensor, 

104 slice_spec="", 

105 name=name, 

106 dtype=var.dtype, 

107 device=primary_var.device) 

108 

109 return tensor, [spec] 

110 

111 

112def get_on_read_restore_ops(var, tensor, aggregation): 

113 """Return restore ops for ON_READ variables.""" 

114 # To preserve the sum across save and restore, we have to divide the 

115 # total across all devices when restoring a variable that was summed 

116 # when saving. 

117 if aggregation == vs.VariableAggregation.SUM: 

118 strategy = var.distribute_strategy 

119 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 

120 var.dtype) 

121 return control_flow_ops.group( 

122 tuple( 

123 assign_on_device(v.device, v, tensor) 

124 for v in var.values)) 

125 

126 

127# Utility function that indicates if you are in an UpdateContext when running 

128# in a replica fn. 

129def in_replica_update_context(): 

130 return distribute_lib.get_update_replica_id() is not None 

131 

132 

133def on_write_assign(var, value, use_locking=False, name=None, read_value=True): 

134 assign_fn = lambda var, *a, **kw: var.assign(*a, **kw) 

135 return var._update( # pylint: disable=protected-access 

136 update_fn=assign_fn, 

137 value=value, 

138 use_locking=use_locking, 

139 name=name, 

140 read_value=read_value) 

141 

142 

143def on_write_assign_add(var, value, use_locking=False, name=None, 

144 read_value=True): 

145 assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw) 

146 return var._update( # pylint: disable=protected-access 

147 update_fn=assign_add_fn, 

148 value=value, 

149 use_locking=use_locking, 

150 name=name, 

151 read_value=read_value) 

152 

153 

154def on_write_assign_sub(var, value, use_locking=False, name=None, 

155 read_value=True): 

156 assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw) 

157 return var._update( # pylint: disable=protected-access 

158 update_fn=assign_sub_fn, 

159 value=value, 

160 use_locking=use_locking, 

161 name=name, 

162 read_value=read_value) 

163 

164 

165def assign_on_each_device(var, assign_func, value, read_value): 

166 """Update the variable on each replica with the given assign_func and value.""" 

167 if var._packed_variable is not None: # pylint: disable=protected-access 

168 update = control_flow_ops.group( 

169 tuple( 

170 assign_func(d, var._packed_variable, value) for d in var._devices)) # pylint: disable=protected-access 

171 else: 

172 update = control_flow_ops.group( 

173 tuple(assign_func(v.device, v, value) for v in var._values)) # pylint: disable=protected-access 

174 if not read_value: 

175 return update 

176 with ops.control_dependencies([update] if update else []): 

177 return var.read_value() 

178 

179 

180def on_read_assign_sub_cross_replica(var, value, read_value=True): 

181 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

182 if distribute_lib.in_cross_replica_context(): 

183 if var.aggregation == vs.VariableAggregation.SUM: 

184 raise ValueError( 

185 "SyncOnReadVariable does not support `assign_sub` in " 

186 "cross-replica context when aggregation is set to " 

187 "`tf.VariableAggregation.SUM`.") 

188 return assign_on_each_device(var, assign_sub_on_device, 

189 value, read_value) 

190 

191 

192def on_read_assign_add_cross_replica(var, value, read_value=True): 

193 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

194 if distribute_lib.in_cross_replica_context(): 

195 if var.aggregation == vs.VariableAggregation.SUM: 

196 raise ValueError( 

197 "SyncOnReadVariable does not support `assign_add` in " 

198 "cross-replica context when aggregation is set to " 

199 "`tf.VariableAggregation.SUM`.") 

200 return assign_on_each_device(var, assign_add_on_device, 

201 value, read_value) 

202 

203 

204def on_read_assign_cross_replica(var, value, read_value=True): 

205 """Return the value of the variable in cross replica context.""" 

206 with distribute_lib.enter_or_assert_strategy(var.distribute_strategy): 

207 if distribute_lib.in_cross_replica_context(): 

208 # To preserve the sum across save and restore, we have to divide the 

209 # total across all devices when restoring a variable that was summed 

210 # when saving. 

211 tensor = value 

212 if var.aggregation == vs.VariableAggregation.SUM: 

213 strategy = var._distribute_strategy # pylint: disable=protected-access 

214 tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync, 

215 var.dtype) 

216 return assign_on_each_device(var, assign_on_device, tensor, 

217 read_value) 

218 

219 

220def scatter_sub(var, sparse_delta, use_locking=False, name=None): 

221 scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw) 

222 return var._update( # pylint: disable=protected-access 

223 update_fn=scatter_sub_fn, 

224 value=sparse_delta, 

225 use_locking=use_locking, 

226 name=name) 

227 

228 

229def scatter_add(var, sparse_delta, use_locking=False, name=None): 

230 scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw) 

231 return var._update( # pylint: disable=protected-access 

232 update_fn=scatter_add_fn, 

233 value=sparse_delta, 

234 use_locking=use_locking, 

235 name=name) 

236 

237 

238def scatter_mul(var, sparse_delta, use_locking=False, name=None): 

239 scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw) 

240 return var._update( # pylint: disable=protected-access 

241 update_fn=scatter_mul_fn, 

242 value=sparse_delta, 

243 use_locking=use_locking, 

244 name=name) 

245 

246 

247def scatter_div(var, sparse_delta, use_locking=False, name=None): 

248 scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw) 

249 return var._update( # pylint: disable=protected-access 

250 update_fn=scatter_div_fn, 

251 value=sparse_delta, 

252 use_locking=use_locking, 

253 name=name) 

254 

255 

256def scatter_min(var, sparse_delta, use_locking=False, name=None): 

257 scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw) 

258 return var._update( # pylint: disable=protected-access 

259 update_fn=scatter_min_fn, 

260 value=sparse_delta, 

261 use_locking=use_locking, 

262 name=name) 

263 

264 

265def scatter_max(var, sparse_delta, use_locking=False, name=None): 

266 scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw) 

267 return var._update( # pylint: disable=protected-access 

268 update_fn=scatter_max_fn, 

269 value=sparse_delta, 

270 use_locking=use_locking, 

271 name=name) 

272 

273 

274def scatter_update(var, sparse_delta, use_locking=False, name=None): 

275 scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw) 

276 return var._update( # pylint: disable=protected-access 

277 update_fn=scatter_update_fn, 

278 value=sparse_delta, 

279 use_locking=use_locking, 

280 name=name) 

281 

282 

283def get_current_replica_id_as_int(): 

284 """Returns the current replica ID as an integer, or `None`.""" 

285 replica_context = distribute_lib.get_replica_context() 

286 if replica_context: 

287 replica_id = replica_context._replica_id # pylint: disable=protected-access 

288 if not isinstance(replica_id, int): 

289 replica_id = tensor_util.constant_value(replica_id) 

290 else: 

291 replica_id = distribute_lib.get_update_replica_id() 

292 return replica_id 

293 

294 

295def assign_on_device(device, variable, tensor): 

296 with ops.device(device): 

297 return variable.assign(tensor) 

298 

299 

300def assign_add_on_device(device, variable, tensor): 

301 with ops.device(device): 

302 return variable.assign_add(tensor) 

303 

304 

305def assign_sub_on_device(device, variable, tensor): 

306 with ops.device(device): 

307 return variable.assign_sub(tensor) 

308 

309 

310def assert_replica_context(strategy): 

311 replica_context = distribute_lib.get_replica_context() 

312 if not replica_context: 

313 raise RuntimeError( 

314 "Replica-local variables may only be assigned in a replica context.") 

315 if replica_context.strategy is not strategy: 

316 raise RuntimeError( 

317 "Replica-local variables may only be assigned in a replica context.") 

318 

319 

320def apply_aggregation(strategy, value, aggregation, destinations): 

321 if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA: 

322 return strategy.extended.broadcast_to( 

323 strategy.experimental_local_results(value)[0], 

324 destinations=destinations) 

325 reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation) 

326 return strategy.extended.reduce_to(reduce_op, value, destinations) 

327 

328 

329aggregation_error_msg = ( 

330 "You must specify an aggregation method to update a " 

331 "{variable_type} in Replica Context. You can do so by passing " 

332 "an explicit value for argument `aggregation` to tf.Variable(..)." 

333 "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`" 

334 "`tf.VariableAggregation` lists the possible aggregation methods." 

335 "This is required because {variable_type} should always be " 

336 "kept in sync. When updating them or assigning to them in a " 

337 "replica context, we automatically try to aggregate the values " 

338 "before updating the variable. For this aggregation, we need to " 

339 "know the aggregation method. " 

340 "Another alternative is to not try to update such " 

341 "{variable_type} in replica context, but in cross replica " 

342 "context. You can enter cross replica context by calling " 

343 "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`." 

344 "Inside `merge_fn`, you can then update the {variable_type} " 

345 "using `tf.distribute.StrategyExtended.update()`.") 

346 

347 

348scatter_error_msg = ("{op_name} is only supported for mirrored " 

349 "variable (variable created within certain " 

350 "`tf.distribute.Strategy` scope) with NONE or " 

351 "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.") 

352 

353 

354def is_saving_non_distributed(): 

355 """Returns whether we're saving a non-distributed version of the model. 

356 

357 It returns True iff we are in saving context and are saving a non-distributed 

358 version of the model. That is, SaveOptions.experimental_variable_policy is 

359 NONE. 

360 

361 Returns: 

362 A boolean. 

363 """ 

364 if not save_context.in_save_context(): 

365 return False 

366 options = save_context.get_save_options() 

367 return (options.experimental_variable_policy != 

368 save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES) 

369 

370 

371def mark_as_unsaveable(): 

372 """Marks the function as unsaveable if not inside save context.""" 

373 if ops.inside_function() and not save_context.in_save_context(): 

374 ops.get_default_graph().mark_as_unsaveable(""" 

375ConcreteFunction that uses distributed variables in certain way cannot be saved. 

376If you're saving with 

377 

378tf.saved_model.save(..., signatures=f.get_concrete_function()) 

379 

380do 

381 

382@tf.function(input_signature=...) 

383def f_with_input_signature(): 

384 ... 

385 

386tf.saved_model.save(..., signatures=f_with_input_signature)` 

387 

388instead.""")