Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/moving_averages.py: 22%

140 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Maintain moving averages of parameters.""" 

16from tensorflow.python.distribute import distribute_lib 

17from tensorflow.python.distribute import reduce_util as ds_reduce_util 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.ops import control_flow_ops 

21from tensorflow.python.ops import init_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.ops import state_ops 

24from tensorflow.python.ops import variable_scope 

25from tensorflow.python.ops import variable_v1 

26from tensorflow.python.ops import variables 

27from tensorflow.python.training import slot_creator 

28from tensorflow.python.util.tf_export import tf_export 

29from tensorflow.tools.docs import doc_controls 

30 

31 

32@tf_export("__internal__.train.assign_moving_average", v1=[]) 

33def assign_moving_average(variable, value, decay, zero_debias=True, name=None): 

34 """Compute the moving average of a variable. 

35 

36 The moving average of 'variable' updated with 'value' is: 

37 variable * decay + value * (1 - decay) 

38 

39 The returned Operation sets 'variable' to the newly computed moving average, 

40 by performing this subtraction: 

41 variable -= (1 - decay) * (variable - value) 

42 

43 Since variables that are initialized to a `0` value will be `0` biased, 

44 `zero_debias` optionally enables scaling by the mathematically correct 

45 debiasing factor of 

46 1 - decay ** num_updates 

47 See Section 3 of (Kingma et al., 2015) for more details. 

48 

49 The names of the debias shadow variables, by default, include both the scope 

50 they were created in and the scope of the variables they debias. They are also 

51 given a uniquifying-suffix. 

52 

53 E.g.: 

54 

55 ``` 

56 with tf.compat.v1.variable_scope('scope1'): 

57 with tf.compat.v1.variable_scope('scope2'): 

58 var = tf.compat.v1.get_variable('foo') 

59 update_1 = tf.assign_moving_average(var, 0.0, 1.0) 

60 update_2 = tf.assign_moving_average(var, 0.0, 0.9) 

61 

62 # var.name: 'scope1/scope2/foo' 

63 # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' 

64 # 'scope1/scope2/scope1/scope2/foo/biased_1' 

65 ``` 

66 

67 Args: 

68 variable: A Variable. 

69 value: A tensor with the same shape as 'variable'. 

70 decay: A float `Tensor` or float value. The moving average decay. 

71 zero_debias: A python bool. If true, assume the variable is 0-initialized 

72 and unbias it, as in (Kingma et al., 2015). See docstring in 

73 `_zero_debias` for more details. 

74 name: Optional name of the returned operation. 

75 

76 Returns: 

77 A tensor which if evaluated will compute and return the new moving average. 

78 

79 References: 

80 Adam - A Method for Stochastic Optimization: 

81 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 

82 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 

83 """ 

84 with ops.name_scope(name, "AssignMovingAvg", 

85 [variable, value, decay]) as scope: 

86 decay = ops.convert_to_tensor(1.0 - decay, name="decay") 

87 if decay.dtype != variable.dtype.base_dtype: 

88 decay = math_ops.cast(decay, variable.dtype.base_dtype) 

89 

90 def update_fn(v, value): 

91 return state_ops.assign_sub(v, (v - value) * decay, name=scope) 

92 

93 def update(strategy, v, value): 

94 if zero_debias: 

95 return _zero_debias(strategy, v, value, decay) 

96 else: 

97 return _update(strategy, v, update_fn, args=(value,)) 

98 

99 replica_context = distribute_lib.get_replica_context() 

100 if replica_context: 

101 # In a replica context, we update variable using the mean of value across 

102 # replicas. 

103 def merge_fn(strategy, v, value): 

104 value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value, 

105 v) 

106 return update(strategy, v, value) 

107 

108 return replica_context.merge_call(merge_fn, args=(variable, value)) 

109 else: 

110 strategy = distribute_lib.get_cross_replica_context() 

111 return update(strategy, variable, value) 

112 

113 

114def weighted_moving_average(value, 

115 decay, 

116 weight, 

117 truediv=True, 

118 collections=None, 

119 name=None): 

120 """Compute the weighted moving average of `value`. 

121 

122 Conceptually, the weighted moving average is: 

123 `moving_average(value * weight) / moving_average(weight)`, 

124 where a moving average updates by the rule 

125 `new_value = decay * old_value + (1 - decay) * update` 

126 Internally, this Op keeps moving average variables of both `value * weight` 

127 and `weight`. 

128 

129 Args: 

130 value: A numeric `Tensor`. 

131 decay: A float `Tensor` or float value. The moving average decay. 

132 weight: `Tensor` that keeps the current value of a weight. Shape should be 

133 able to multiply `value`. 

134 truediv: Boolean, if `True`, dividing by `moving_average(weight)` is 

135 floating point division. If `False`, use division implied by dtypes. 

136 collections: List of graph collections keys to add the internal variables 

137 `value * weight` and `weight` to. Defaults to 

138 `[GraphKeys.GLOBAL_VARIABLES]`. 

139 name: Optional name of the returned operation. Defaults to 

140 "WeightedMovingAvg". 

141 

142 Returns: 

143 An Operation that updates and returns the weighted moving average. 

144 """ 

145 # Unlike assign_moving_average, the weighted moving average doesn't modify 

146 # user-visible variables. It is the ratio of two internal variables, which are 

147 # moving averages of the updates. Thus, the signature of this function is 

148 # quite different than assign_moving_average. 

149 if collections is None: 

150 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 

151 with variable_scope.variable_scope(name, "WeightedMovingAvg", 

152 [value, weight, decay]) as scope: 

153 value_x_weight_var = variable_scope.get_variable( 

154 "value_x_weight", 

155 shape=value.get_shape(), 

156 dtype=value.dtype, 

157 initializer=init_ops.zeros_initializer(), 

158 trainable=False, 

159 collections=collections) 

160 weight_var = variable_scope.get_variable( 

161 "weight", 

162 shape=weight.get_shape(), 

163 dtype=weight.dtype, 

164 initializer=init_ops.zeros_initializer(), 

165 trainable=False, 

166 collections=collections) 

167 numerator = assign_moving_average( 

168 value_x_weight_var, value * weight, decay, zero_debias=False) 

169 denominator = assign_moving_average( 

170 weight_var, weight, decay, zero_debias=False) 

171 

172 if truediv: 

173 return math_ops.truediv(numerator, denominator, name=scope.name) 

174 else: 

175 return math_ops.divide(numerator, denominator, name=scope.name) 

176 

177 

178def _update(strategy, var, update_fn, args): 

179 """Applies updates depending on the context.""" 

180 assert distribute_lib.in_cross_replica_context(), ( 

181 "_update can only be called in cross-replica context") 

182 if distribute_lib.get_update_replica_id() is not None: 

183 # Call update_fn on var to delegate the implementation. We expect `var` will 

184 # do the right thing in update context, e.g, if `var` is a MirroredVariable, 

185 # it should pick its component variable based on `update_replica_id` and 

186 # only update that. 

187 return update_fn(var, *args) 

188 else: 

189 return strategy.extended.update(var, update_fn, args) 

190 

191 

192def _zero_debias(strategy, unbiased_var, value, decay): 

193 """Compute the delta required for a debiased Variable. 

194 

195 All exponential moving averages initialized with Tensors are initialized to 0, 

196 and therefore are biased to 0. Variables initialized to 0 and used as EMAs are 

197 similarly biased. This function creates the debias updated amount according to 

198 a scale factor, as in (Kingma et al., 2015). 

199 

200 To demonstrate the bias the results from 0-initialization, take an EMA that 

201 was initialized to `0` with decay `b`. After `t` timesteps of seeing the 

202 constant `c`, the variable have the following value: 

203 

204 ``` 

205 EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ... 

206 = c*(1 - b^t) 

207 ``` 

208 

209 To have the true value `c`, we would divide by the scale factor `1 - b^t`. 

210 

211 In order to perform debiasing, we use two shadow variables. One keeps track of 

212 the biased estimate, and the other keeps track of the number of updates that 

213 have occurred. 

214 

215 Args: 

216 strategy: `Strategy` used to create and update variables. 

217 unbiased_var: A Variable representing the current value of the unbiased EMA. 

218 value: A Tensor representing the most recent value. 

219 decay: A Tensor representing `1-decay` for the EMA. 

220 

221 Returns: 

222 The amount that the unbiased variable should be updated. Computing this 

223 tensor will also update the shadow variables appropriately. 

224 

225 References: 

226 Adam - A Method for Stochastic Optimization: 

227 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 

228 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 

229 

230 """ 

231 with variable_scope.variable_scope( 

232 unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]): 

233 with ops.init_scope(): 

234 biased_initializer = init_ops.zeros_initializer() 

235 local_step_initializer = init_ops.zeros_initializer() 

236 

237 def _maybe_get_unique(name): 

238 """Get name for a unique variable, if not `reuse=True`.""" 

239 if variable_scope.get_variable_scope().reuse: 

240 return name 

241 vs_vars = [ 

242 x.op.name 

243 for x in variable_scope.get_variable_scope().global_variables() 

244 ] 

245 full_name = variable_scope.get_variable_scope().name + "/" + name 

246 if full_name not in vs_vars: 

247 return name 

248 idx = 1 

249 while full_name + ("_%d" % idx) in vs_vars: 

250 idx += 1 

251 return name + ("_%d" % idx) 

252 

253 with strategy.extended.colocate_vars_with(unbiased_var): 

254 biased_var = variable_scope.get_variable( 

255 _maybe_get_unique("biased"), 

256 initializer=biased_initializer, 

257 shape=unbiased_var.get_shape(), 

258 dtype=unbiased_var.dtype, 

259 trainable=False) 

260 local_step = variable_scope.get_variable( 

261 _maybe_get_unique("local_step"), 

262 shape=[], 

263 dtype=unbiased_var.dtype, 

264 initializer=local_step_initializer, 

265 trainable=False) 

266 

267 def update_fn(v, value, biased_var, local_step): 

268 update_biased = state_ops.assign_sub(biased_var, 

269 (biased_var - value) * decay) 

270 update_local_step = local_step.assign_add(1) 

271 

272 # This function gets `1 - decay`, so use `1.0 - decay` in the exponent. 

273 bias_factor = 1 - math_ops.pow(1.0 - decay, update_local_step) 

274 return state_ops.assign( 

275 v, update_biased / bias_factor, name=ops.get_name_scope() + "/") 

276 

277 return _update( 

278 strategy, unbiased_var, update_fn, args=(value, biased_var, local_step)) 

279 

280 

281@tf_export("train.ExponentialMovingAverage") 

282class ExponentialMovingAverage: 

283 """Maintains moving averages of variables by employing an exponential decay. 

284 

285 When training a model, it is often beneficial to maintain moving averages of 

286 the trained parameters. Evaluations that use averaged parameters sometimes 

287 produce significantly better results than the final trained values. 

288 

289 The `apply()` method adds shadow copies of trained variables the first time 

290 it is called, and maintains a moving average of the trained variables in 

291 their shadow copies at every additional invocation. 

292 It should generally be called immediately after creating the model weights, 

293 and then after each training step. 

294 

295 The `average()` method gives access to the shadow variables. 

296 It allows you to use the moving averages in place of the last trained values 

297 for evaluations, by loading the moving averages into your model via 

298 `var.assign(ema.average(var))`. 

299 Additionally, although `ExponentialMovingAverage` 

300 objects are not directly trackable by checkpoints, 

301 `average()` returns the moving average variables for your model weights, 

302 which you can then checkpoint. (There is an example 

303 of this near the bottom of this docstring). 

304 So, `average()` is useful when 

305 building an evaluation model, or when restoring a model from a checkpoint 

306 file. 

307 

308 The moving averages are computed using exponential decay. You specify the 

309 decay value (as a scalar float value, `Tensor`, or `Variable`) when creating 

310 the `ExponentialMovingAverage` object. The shadow variables are initialized 

311 with the same initial values as the trained variables. When you run `apply` 

312 to update the moving averages, each shadow variable is updated with the 

313 formula: 

314 

315 `shadow_variable -= (1 - decay) * (shadow_variable - variable)` 

316 

317 This is mathematically equivalent to the classic formula below, but the use 

318 of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless 

319 updates to the variables: 

320 

321 `shadow_variable = decay * shadow_variable + (1 - decay) * variable` 

322 

323 Reasonable values for `decay` are close to 1.0, typically in the 

324 multiple-nines range: 0.999, 0.9999, etc. 

325 

326 To have fine-grained control over the value of the decay parameter during 

327 training, pass a scalar `tf.Variable` as the `decay` value to the constructor, 

328 and update the variable as needed. 

329 

330 Example usage when creating a training model: 

331 

332 ```python 

333 # Create variables. 

334 var0 = tf.Variable(...) 

335 var1 = tf.Variable(...) 

336 # ... use the variables to build a training model... 

337 

338 # Create an ExponentialMovingAverage object 

339 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 

340 

341 # The first `apply` creates the shadow variables that hold the moving averages 

342 ema.apply([var0, var1]) 

343 

344 # grab the moving averages for checkpointing purposes or to be able to 

345 # load the moving averages into the model weights 

346 averages = [ema.average(var0), ema.average(var1)] 

347 

348 ... 

349 def train_step(...): 

350 ... 

351 # Apply the optimizer. 

352 opt.minimize(my_loss, [var0, var1]) 

353 

354 # Update the moving averages 

355 # of var0 and var1 with additional calls to `apply` 

356 ema.apply([var0, var1]) 

357 

358 ...train the model by running train_step multiple times... 

359 ``` 

360 

361 There are several ways to use the moving averages for evaluations: 

362 

363 1. Assign the values of the shadow variables to your model variables with 

364 `Variable.assign(...)` before evaluating your 

365 model. You can use the `average()` 

366 method to get the shadow variable for a given variable. To continue 

367 training after using this approach, make sure to record the unaveraged 

368 weights and restore them before continuing to train. You can see the 

369 tensorflow-addons' MovingAverage optimizer's `swap_weights` method for 

370 one example of how to swap variables efficiently in distributed settings: 

371 https://github.com/tensorflow/addons/blob/v0.13.0/tensorflow_addons/optimizers/moving_average.py#L151 

372 2. Make sure to checkpoint out your moving average variables in your 

373 `tf.train.Checkpoint`. At evaluation time, create your shadow variables and 

374 use `tf.train.Checkpoint` to restore the moving averages into the shadow 

375 variables. Then, load the moving averages into the actual model weights via 

376 `var.assign(moving_avg)`. 

377 3. Checkpoint out your moving average variables in your `tf.train.Checkpoint`. 

378 For evaluation, restore your model weights directly from the moving 

379 averages instead of from the non-averaged weights. 

380 Caution: If you choose this approach, include only the object-graph paths 

381 to the averaged path in your checkpoint restore. 

382 If you point both the unaveraged and averaged paths in a checkpoint 

383 restore to the same variables, it is hard to reason about whether your 

384 model will restore the averaged or non-averaged variables. 

385 

386 Example of saving out then restoring the shadow variable values: 

387 

388 ```python 

389 # Create variables. 

390 var0 = tf.Variable(...) 

391 var1 = tf.Variable(...) 

392 # ... use the variables to build a training model... 

393 

394 # Create an ExponentialMovingAverage object, create the shadow variables, 

395 # and grab the moving averages for checkpointing purposes. 

396 # (The ExponentialMovingAverage object itself is not checkpointable) 

397 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 

398 ema.apply([var0, var1]) 

399 avg_var0 = ema.average(var0) 

400 avg_var1 = ema.average(var1) 

401 

402 # Create a Checkpoint that will manage the model weights and the averages, 

403 checkpoint = tf.train.Checkpoint(model_weights=[var0, var1], 

404 averaged_weights=[avg_var0, avg_var1]) 

405 ... # Do training 

406 

407 # Save out the checkpoint including the model weights and the moving averages 

408 checkpoint.save(...) 

409 ``` 

410 

411 Restore option: restore all averaged & non-averaged weights, then load 

412 moving averages into the model via `var.assign()` 

413 ```python 

414 # Create variables. 

415 var0 = tf.Variable(...) 

416 var1 = tf.Variable(...) 

417 # ... use the variables to build a training model... 

418 

419 # Create an ExponentialMovingAverage object, create the shadow variables, 

420 # and grab the moving averages for checkpoint restore purposes. 

421 # (The ExponentialMovingAverage object itself is not checkpointable) 

422 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 

423 ema.apply([var0, var1]) 

424 avg_var0 = ema.average(var0) 

425 avg_var1 = ema.average(var1) 

426 

427 # Create a Checkpoint that will manage the model weights and the averages, 

428 checkpoint = tf.train.Checkpoint(model_weights=[var0, var1], 

429 averaged_weights=[avg_var0, avg_var1]) 

430 checkpoint.restore(...) 

431 var0.assign(avg_var0) 

432 var1.assign(avg_var1) 

433 # var0 and var1 now hold the moving average values 

434 ``` 

435 

436 Restore option: Directly restore the moving averages into the model weights. 

437 ```python 

438 # Create variables. 

439 var0 = tf.Variable(...) 

440 var1 = tf.Variable(...) 

441 # ... use the variables to build a training model... 

442 

443 # Create a Checkpoint that will manage two objects with trackable state, 

444 checkpoint = tf.train.Checkpoint(averaged_weights=[var0, var1]) 

445 checkpoint.restore(...) 

446 # var0 and var1 now hold the moving average values 

447 ``` 

448 """ 

449 

450 def __init__(self, 

451 decay, 

452 num_updates=None, 

453 zero_debias=False, 

454 name="ExponentialMovingAverage"): 

455 """Creates a new ExponentialMovingAverage object. 

456 

457 The `apply()` method has to be called to create shadow variables. 

458 Follow-on calls to the `apply()` method will update the moving averages 

459 in the shadow variables. 

460 (In TF 1.x graphs `apply()` will return an update op to update 

461 the moving averages which must be explicitly run). 

462 

463 The optional `num_updates` parameter allows one to tweak the decay rate 

464 dynamically. It is typical to pass the count of training steps, usually 

465 kept in a variable that is incremented at each step, in which case the 

466 decay rate is lower at the start of training. This makes moving averages 

467 move faster. If passed, the actual decay rate used is: 

468 

469 `min(decay, (1 + num_updates) / (10 + num_updates))` 

470 

471 Args: 

472 decay: A scalar float value, `Tensor`, or `Variable`. The decay parameter. 

473 num_updates: Optional count of number of updates applied to variables. 

474 zero_debias: If `True`, zero debias moving-averages that are initialized 

475 with tensors. (Note: moving averages may not be initialized with 

476 non-variable tensors when eager execution is enabled). 

477 name: String. Optional prefix name to use for the name of ops added in 

478 `apply()`. 

479 """ 

480 self._decay = decay 

481 self._num_updates = num_updates 

482 self._zero_debias = zero_debias 

483 self._name = name 

484 self._averages = {} 

485 

486 @property 

487 def name(self): 

488 """The name of this ExponentialMovingAverage object.""" 

489 return self._name 

490 

491 def apply(self, var_list=None): 

492 """Maintains moving averages of variables. 

493 

494 `var_list` must be a list of `Variable` objects. This method 

495 creates shadow variables (holding the moving averages) 

496 for all elements of `var_list`, and 

497 updates the moving averages using the current `var_list` values. Shadow 

498 variables for `Variable` objects are initialized to the variable's initial 

499 value. 

500 

501 Shadow variables are created with `trainable=False`. To access them you 

502 can use the EMA object's `average` method. Note that `EMA` objects are 

503 not trackable by checkpoints, so if you want to checkpoint or restore the 

504 moving variables you will need to manually grab the shadow 

505 variables via `average()` and assign them as `tf.Module` properties or 

506 directly pass them to your `tf.train.Checkpoint`. 

507 

508 Note that `apply()` can be called multiple times. When eager execution is 

509 enabled each call to apply will update the variables once, so this needs to 

510 be called in a loop. 

511 

512 In legacy TF 1.x graphs, this method returns an op that updates all 

513 shadow variables from the current value of their associated variables. In 

514 TF 1.x graphs without automatically control dependencies this op needs to be 

515 manually run. 

516 

517 Args: 

518 var_list: A list of Variable objects. The variables 

519 must be of types bfloat16, float16, float32, or float64. 

520 (In legacy TF 1.x graphs these may be tensors, but this is unsupported 

521 when eager execution is enabled.) 

522 

523 Returns: 

524 An Operation that updates the moving averages. 

525 

526 Raises: 

527 TypeError: If the arguments are not an allowed type. 

528 """ 

529 # TODO(touts): op_scope 

530 if var_list is None: 

531 var_list = variables.trainable_variables() 

532 for v in var_list: 

533 if (isinstance(v, ops.Tensor) 

534 and ops.executing_eagerly_outside_functions()): 

535 raise TypeError( 

536 "tf.train.ExponentialMovingAverage does not support non-Variable" 

537 " tensors when eager execution is enabled.") 

538 zero_debias_true = set() # set of vars to set `zero_debias=True` 

539 for var in var_list: 

540 if var.dtype.base_dtype not in [ 

541 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 

542 ]: 

543 raise TypeError("The variables must be half, float, or double: %s" % 

544 var.name) 

545 

546 if var.ref() not in self._averages: 

547 # For variables: to lower communication bandwidth across devices we keep 

548 # the moving averages on the same device as the variables. For other 

549 # tensors, we rely on the existing device allocation mechanism. 

550 with ops.init_scope(): 

551 if isinstance(var, variables.Variable): 

552 with ops.device(var.device): 

553 initialized_value = control_flow_ops.cond( 

554 variable_v1.is_variable_initialized(var), var.read_value, 

555 lambda: var.initial_value) # pylint: disable=cell-var-from-loop 

556 avg = slot_creator.create_slot( 

557 var, 

558 initialized_value, 

559 self.name, 

560 colocate_with_primary=True, 

561 copy_xla_sharding=True) 

562 # NOTE(mrry): We only add `tf.Variable` objects to the 

563 # `MOVING_AVERAGE_VARIABLES` collection. 

564 ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 

565 else: 

566 avg = slot_creator.create_zeros_slot( 

567 var, 

568 self.name, 

569 colocate_with_primary=(var.op.type in [ 

570 "Variable", "VariableV2", "VarHandleOp" 

571 ]), 

572 copy_xla_sharding=True) 

573 if self._zero_debias: 

574 zero_debias_true.add(avg.ref()) 

575 self._averages[var.ref()] = avg 

576 

577 with ops.name_scope(self.name) as scope: 

578 decay = ops.convert_to_tensor( 

579 self._decay, dtype=dtypes.float32, name="decay") 

580 if self._num_updates is not None: 

581 num_updates = math_ops.cast( 

582 self._num_updates, dtypes.float32, name="num_updates") 

583 decay = math_ops.minimum(decay, 

584 (1.0 + num_updates) / (10.0 + num_updates)) 

585 updates = [] 

586 for var in var_list: 

587 avg = self._averages[var.ref()] 

588 zero_debias = avg.ref() in zero_debias_true 

589 updates.append(assign_moving_average(avg, var, decay, zero_debias)) 

590 return control_flow_ops.group(*updates, name=scope) 

591 

592 def average(self, var): 

593 """Returns the `Variable` holding the average of `var`. 

594 

595 Args: 

596 var: A `Variable` object. 

597 

598 Returns: 

599 A `Variable` object or `None` if the moving average of `var` 

600 is not maintained. 

601 """ 

602 return self._averages.get(var.ref(), None) 

603 

604 @doc_controls.do_not_generate_docs 

605 def average_name(self, var): 

606 """[Meant for TF1] Returns name of `Variable` holding the average for `var`. 

607 

608 (Designed to work with legacy `tf.compat.v1.train.Saver`, it is sensitive to 

609 specific variable names and not recommended for TF2) 

610 

611 The typical scenario for `ExponentialMovingAverage` is to compute moving 

612 averages of variables during training, and restore the variables from the 

613 computed moving averages during evaluations. 

614 

615 To restore variables, you have to know the name of the shadow variables. 

616 That name and the original variable can then be passed to a `Saver()` object 

617 to restore the variable from the moving average value with: 

618 `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})` 

619 

620 `average_name()` can be called whether or not `apply()` has been called. 

621 

622 Args: 

623 var: A `Variable` object. 

624 

625 Returns: 

626 A string: The name of the variable that will be used or was used 

627 by the `ExponentialMovingAverage class` to hold the moving average of 

628 `var`. 

629 """ 

630 if var.ref() in self._averages: 

631 return self._averages[var.ref()].name[:-len(":0")] 

632 return ops.get_default_graph().unique_name( 

633 var.name[:-len(":0")] + "/" + self.name, mark_as_used=False) 

634 

635 @doc_controls.do_not_generate_docs 

636 def variables_to_restore(self, moving_avg_variables=None): 

637 """[Designed for TF 1.x] Returns a map of names to `Variables` to restore. 

638 

639 (Designed to work with legacy `tf.compat.v1.train.Saver`, sensitive to 

640 specific variable names and not recommended for TF2) 

641 

642 If a variable has a moving average, use the moving average variable name as 

643 the restore name; otherwise, use the variable name. 

644 

645 For example, 

646 

647 ```python 

648 variables_to_restore = ema.variables_to_restore() 

649 saver = tf.compat.v1.train.Saver(variables_to_restore) 

650 ``` 

651 

652 Below is an example of such mapping: 

653 

654 ``` 

655 conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, 

656 conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, 

657 global_step: global_step 

658 ``` 

659 

660 Args: 

661 moving_avg_variables: a list of variables that require to use of the 

662 moving average variable name to be restored. If None, it will default to 

663 variables.moving_average_variables() + variables.trainable_variables() 

664 

665 Returns: 

666 A map from restore_names to variables. The restore_name is either the 

667 original or the moving average version of the variable name, depending 

668 on whether the variable name is in the `moving_avg_variables`. 

669 """ 

670 name_map = {} 

671 if moving_avg_variables is None: 

672 # Include trainable variables and variables which have been explicitly 

673 # added to the moving_average_variables collection. 

674 moving_avg_variables = variables.trainable_variables() 

675 moving_avg_variables += variables.moving_average_variables() 

676 # Remove duplicates 

677 moving_avg_variables = set(v.ref() for v in moving_avg_variables) 

678 # Collect all the variables with moving average, 

679 for v in moving_avg_variables: 

680 name_map[self.average_name(v.deref())] = v.deref() 

681 # Make sure we restore variables without moving averages as well. 

682 moving_avg_variable_names = set( 

683 v.deref().name for v in moving_avg_variables) 

684 for v in list(set(variables.global_variables())): 

685 if v.name not in moving_avg_variable_names and v.op.name not in name_map: 

686 name_map[v.op.name] = v 

687 return name_map