Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/checkpoint_utils.py: 21%

164 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tools to work with name-based checkpoints. 

16 

17While some of these symbols also work with the TF2 object-based checkpoints, 

18they are not recommended for TF2. Please check `tensorflow/python/checkpoint` 

19for newer utilities built to work with TF2 checkpoints. 

20""" 

21 

22from collections import abc 

23import os 

24import time 

25 

26from tensorflow.python.checkpoint import checkpoint_management 

27from tensorflow.python.distribute import distribute_lib 

28from tensorflow.python.framework import ops 

29from tensorflow.python.ops import io_ops 

30from tensorflow.python.ops import resource_variable_ops 

31from tensorflow.python.ops import variable_scope as vs 

32from tensorflow.python.ops import variables 

33from tensorflow.python.platform import gfile 

34from tensorflow.python.platform import tf_logging as logging 

35from tensorflow.python.training import py_checkpoint_reader 

36from tensorflow.python.training.saving import saveable_object_util 

37from tensorflow.python.util.tf_export import tf_export 

38 

39 

40__all__ = [ 

41 "load_checkpoint", "load_variable", "list_variables", 

42 "checkpoints_iterator", "init_from_checkpoint" 

43] 

44 

45 

46@tf_export("train.load_checkpoint") 

47def load_checkpoint(ckpt_dir_or_file): 

48 """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`. 

49 

50 If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints, 

51 reader for the latest checkpoint is returned. 

52 

53 Example usage: 

54 

55 ```python 

56 import tensorflow as tf 

57 a = tf.Variable(1.0) 

58 b = tf.Variable(2.0) 

59 ckpt = tf.train.Checkpoint(var_list={'a': a, 'b': b}) 

60 ckpt_path = ckpt.save('tmp-ckpt') 

61 reader= tf.train.load_checkpoint(ckpt_path) 

62 print(reader.get_tensor('var_list/a/.ATTRIBUTES/VARIABLE_VALUE')) # 1.0 

63 ``` 

64 

65 Args: 

66 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint 

67 file. 

68 

69 Returns: 

70 `CheckpointReader` object. 

71 

72 Raises: 

73 ValueError: If `ckpt_dir_or_file` resolves to a directory with no 

74 checkpoints. 

75 """ 

76 filename = _get_checkpoint_filename(ckpt_dir_or_file) 

77 if filename is None: 

78 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 

79 "given directory %s" % ckpt_dir_or_file) 

80 return py_checkpoint_reader.NewCheckpointReader(filename) 

81 

82 

83@tf_export("train.load_variable") 

84def load_variable(ckpt_dir_or_file, name): 

85 """Returns the tensor value of the given variable in the checkpoint. 

86 

87 When the variable name is unknown, you can use `tf.train.list_variables` to 

88 inspect all the variable names. 

89 

90 Example usage: 

91 

92 ```python 

93 import tensorflow as tf 

94 a = tf.Variable(1.0) 

95 b = tf.Variable(2.0) 

96 ckpt = tf.train.Checkpoint(var_list={'a': a, 'b': b}) 

97 ckpt_path = ckpt.save('tmp-ckpt') 

98 var= tf.train.load_variable( 

99 ckpt_path, 'var_list/a/.ATTRIBUTES/VARIABLE_VALUE') 

100 print(var) # 1.0 

101 ``` 

102 

103 Args: 

104 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 

105 name: Name of the variable to return. 

106 

107 Returns: 

108 A numpy `ndarray` with a copy of the value of this variable. 

109 """ 

110 # TODO(b/29227106): Fix this in the right place and remove this. 

111 if name.endswith(":0"): 

112 name = name[:-2] 

113 reader = load_checkpoint(ckpt_dir_or_file) 

114 return reader.get_tensor(name) 

115 

116 

117@tf_export("train.list_variables") 

118def list_variables(ckpt_dir_or_file): 

119 """Lists the checkpoint keys and shapes of variables in a checkpoint. 

120 

121 Checkpoint keys are paths in a checkpoint graph. 

122 

123 Example usage: 

124 

125 ```python 

126 import tensorflow as tf 

127 import os 

128 ckpt_directory = "/tmp/training_checkpoints/ckpt" 

129 ckpt = tf.train.Checkpoint(optimizer=optimizer, model=model) 

130 manager = tf.train.CheckpointManager(ckpt, ckpt_directory, max_to_keep=3) 

131 train_and_checkpoint(model, manager) 

132 tf.train.list_variables(manager.latest_checkpoint) 

133 ``` 

134 

135 Args: 

136 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 

137 

138 Returns: 

139 List of tuples `(key, shape)`. 

140 """ 

141 reader = load_checkpoint(ckpt_dir_or_file) 

142 variable_map = reader.get_variable_to_shape_map() 

143 names = sorted(variable_map.keys()) 

144 result = [] 

145 for name in names: 

146 result.append((name, variable_map[name])) 

147 return result 

148 

149 

150def wait_for_new_checkpoint(checkpoint_dir, 

151 last_checkpoint=None, 

152 seconds_to_sleep=1, 

153 timeout=None): 

154 """Waits until a new checkpoint file is found. 

155 

156 Args: 

157 checkpoint_dir: The directory in which checkpoints are saved. 

158 last_checkpoint: The last checkpoint path used or `None` if we're expecting 

159 a checkpoint for the first time. 

160 seconds_to_sleep: The number of seconds to sleep for before looking for a 

161 new checkpoint. 

162 timeout: The maximum number of seconds to wait. If left as `None`, then the 

163 process will wait indefinitely. 

164 

165 Returns: 

166 a new checkpoint path, or None if the timeout was reached. 

167 """ 

168 logging.info("Waiting for new checkpoint at %s", checkpoint_dir) 

169 stop_time = time.time() + timeout if timeout is not None else None 

170 while True: 

171 checkpoint_path = checkpoint_management.latest_checkpoint(checkpoint_dir) 

172 if checkpoint_path is None or checkpoint_path == last_checkpoint: 

173 if stop_time is not None and time.time() + seconds_to_sleep > stop_time: 

174 return None 

175 time.sleep(seconds_to_sleep) 

176 else: 

177 logging.info("Found new checkpoint at %s", checkpoint_path) 

178 return checkpoint_path 

179 

180 

181@tf_export("train.checkpoints_iterator") 

182def checkpoints_iterator(checkpoint_dir, 

183 min_interval_secs=0, 

184 timeout=None, 

185 timeout_fn=None): 

186 """Continuously yield new checkpoint files as they appear. 

187 

188 The iterator only checks for new checkpoints when control flow has been 

189 reverted to it. This means it can miss checkpoints if your code takes longer 

190 to run between iterations than `min_interval_secs` or the interval at which 

191 new checkpoints are written. 

192 

193 The `timeout` argument is the maximum number of seconds to block waiting for 

194 a new checkpoint. It is used in combination with the `timeout_fn` as 

195 follows: 

196 

197 * If the timeout expires and no `timeout_fn` was specified, the iterator 

198 stops yielding. 

199 * If a `timeout_fn` was specified, that function is called and if it returns 

200 a true boolean value the iterator stops yielding. 

201 * If the function returns a false boolean value then the iterator resumes the 

202 wait for new checkpoints. At this point the timeout logic applies again. 

203 

204 This behavior gives control to callers on what to do if checkpoints do not 

205 come fast enough or stop being generated. For example, if callers have a way 

206 to detect that the training has stopped and know that no new checkpoints 

207 will be generated, they can provide a `timeout_fn` that returns `True` when 

208 the training has stopped. If they know that the training is still going on 

209 they return `False` instead. 

210 

211 Args: 

212 checkpoint_dir: The directory in which checkpoints are saved. 

213 min_interval_secs: The minimum number of seconds between yielding 

214 checkpoints. 

215 timeout: The maximum number of seconds to wait between checkpoints. If left 

216 as `None`, then the process will wait indefinitely. 

217 timeout_fn: Optional function to call after a timeout. If the function 

218 returns True, then it means that no new checkpoints will be generated and 

219 the iterator will exit. The function is called with no arguments. 

220 

221 Yields: 

222 String paths to latest checkpoint files as they arrive. 

223 """ 

224 checkpoint_path = None 

225 while True: 

226 new_checkpoint_path = wait_for_new_checkpoint( 

227 checkpoint_dir, checkpoint_path, timeout=timeout) 

228 if new_checkpoint_path is None: 

229 if not timeout_fn: 

230 # timed out 

231 logging.info("Timed-out waiting for a checkpoint.") 

232 return 

233 if timeout_fn(): 

234 # The timeout_fn indicated that we are truly done. 

235 return 

236 else: 

237 # The timeout_fn indicated that more checkpoints may come. 

238 continue 

239 start = time.time() 

240 checkpoint_path = new_checkpoint_path 

241 yield checkpoint_path 

242 time_to_next_eval = start + min_interval_secs - time.time() 

243 if time_to_next_eval > 0: 

244 time.sleep(time_to_next_eval) 

245 

246 

247@tf_export(v1=["train.init_from_checkpoint"]) 

248def init_from_checkpoint(ckpt_dir_or_file, assignment_map): 

249 """Replaces `tf.Variable` initializers so they load from a checkpoint file. 

250 

251 @compatibility(TF2) 

252 `tf.compat.v1.train.init_from_checkpoint` is not recommended for restoring 

253 variable values in TF2. 

254 

255 To restore checkpoints in TF2, please use 

256 `tf.keras.Model.load_weights` or `tf.train.Checkpoint.restore`. These APIs use 

257 use an [object-based method of checkpointing] 

258 (https://www.tensorflow.org/guide/checkpoint#loading_mechanics), while 

259 `tf.compat.v1.init_from_checkpoint` relies on a more-fragile variable-name 

260 based method of checkpointing. There is no object-based equivalent of 

261 `init_from_checkpoint` in TF2. 

262 

263 Please re-write your checkpoints immediately using the object-based APIs, 

264 see [migration guide] 

265 (https://www.tensorflow.org/guide/migrate#checkpoint_compatibility) for more 

266 details. 

267 

268 You can load a name-based checkpoint written by `tf.compat.v1.train.Saver` 

269 using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However, 

270 you may have to change the names of the variables in your model to match the 

271 variable names in the name-based checkpoint, which can be viewed with 

272 `tf.train.list_variables(path)`. 

273 

274 Another option is to create an `assignment_map` that maps the name of the 

275 variables in the name-based checkpoint to the variables in your model, eg: 

276 ``` 

277 { 

278 'sequential/dense/bias': model.variables[0], 

279 'sequential/dense/kernel': model.variables[1] 

280 } 

281 ``` 

282 and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to 

283 restore the name-based checkpoint. 

284 

285 After restoring, re-encode your checkpoint using `tf.train.Checkpoint.save` 

286 or `tf.keras.Model.save_weights`. 

287 

288 @end_compatibility 

289 

290 Values are not loaded immediately, but when the initializer is run 

291 (typically by running a `tf.compat.v1.global_variables_initializer` op). 

292 

293 Note: This overrides default initialization ops of specified variables and 

294 redefines dtype. 

295 

296 Assignment map supports following syntax: 

297 

298 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 

299 current `scope_name` from `checkpoint_scope_name` with matching tensor 

300 names. 

301 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 

302 will initialize `scope_name/variable_name` variable 

303 from `checkpoint_scope_name/some_other_variable`. 

304 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 

305 object with tensor 'scope_variable_name' from the checkpoint. 

306 * `'scope_variable_name': list(variable)` - will initialize list of 

307 partitioned variables with tensor 'scope_variable_name' from the checkpoint. 

308 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 

309 checkpoint's root (e.g. no scope). 

310 

311 Supports loading into partitioned variables, which are represented as 

312 `'<variable>/part_<part #>'`. 

313 

314 Assignment map can be a dict, or a list of pairs. The latter is 

315 necessary to initialize multiple variables in the current graph from 

316 the same variable in the checkpoint. 

317 

318 Example: 

319 

320 ```python 

321 

322 # Say, '/tmp/model.ckpt' has the following tensors: 

323 # -- name='old_scope_1/var1', shape=[20, 2] 

324 # -- name='old_scope_1/var2', shape=[50, 4] 

325 # -- name='old_scope_2/var3', shape=[100, 100] 

326 

327 # Create new model's variables 

328 with tf.compat.v1.variable_scope('new_scope_1'): 

329 var1 = tf.compat.v1.get_variable('var1', shape=[20, 2], 

330 initializer=tf.compat.v1.zeros_initializer()) 

331 with tf.compat.v1.variable_scope('new_scope_2'): 

332 var2 = tf.compat.v1.get_variable('var2', shape=[50, 4], 

333 initializer=tf.compat.v1.zeros_initializer()) 

334 # Partition into 5 variables along the first axis. 

335 var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100], 

336 initializer=tf.compat.v1.zeros_initializer(), 

337 partitioner=lambda shape, dtype: [5, 1]) 

338 

339 # Initialize all variables in `new_scope_1` from `old_scope_1`. 

340 init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'}) 

341 

342 # Use names to specify which variables to initialize from checkpoint. 

343 init_from_checkpoint('/tmp/model.ckpt', 

344 {'old_scope_1/var1': 'new_scope_1/var1', 

345 'old_scope_1/var2': 'new_scope_2/var2'}) 

346 

347 # Or use tf.Variable objects to identify what to initialize. 

348 init_from_checkpoint('/tmp/model.ckpt', 

349 {'old_scope_1/var1': var1, 

350 'old_scope_1/var2': var2}) 

351 

352 # Initialize partitioned variables using variable's name 

353 init_from_checkpoint('/tmp/model.ckpt', 

354 {'old_scope_2/var3': 'new_scope_2/var3'}) 

355 

356 # Or specify the list of tf.Variable objects. 

357 init_from_checkpoint('/tmp/model.ckpt', 

358 {'old_scope_2/var3': var3._get_variable_list()}) 

359 

360 ``` 

361 

362 Args: 

363 ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint. 

364 assignment_map: Dict, or a list of key-value pairs, where keys are names 

365 of the variables in the checkpoint and values are current variables or 

366 names of current variables (in default graph). 

367 

368 Raises: 

369 ValueError: If missing variables in current graph, or if missing 

370 checkpoints or tensors in checkpoints. 

371 

372 """ 

373 init_from_checkpoint_fn = lambda _: _init_from_checkpoint( 

374 ckpt_dir_or_file, assignment_map) 

375 if distribute_lib.get_cross_replica_context(): 

376 init_from_checkpoint_fn(None) 

377 else: 

378 distribute_lib.get_replica_context().merge_call( 

379 init_from_checkpoint_fn) 

380 

381 

382def _init_from_checkpoint(ckpt_dir_or_file, assignment_map): 

383 """See `init_from_checkpoint` for documentation.""" 

384 ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file) 

385 reader = load_checkpoint(ckpt_dir_or_file) 

386 variable_map = reader.get_variable_to_shape_map() 

387 if isinstance(assignment_map, abc.Mapping): 

388 assignment_map = assignment_map.items() 

389 

390 # We only want to sort by tensor names. 

391 sort_key = lambda pair: pair[0] 

392 

393 for tensor_name_in_ckpt, current_var_or_name in sorted( 

394 assignment_map, key=sort_key): 

395 var = None 

396 # Check if this is Variable object or list of Variable objects (in case of 

397 # partitioned variables). 

398 if _is_variable(current_var_or_name) or ( 

399 isinstance(current_var_or_name, list) 

400 and all(_is_variable(v) for v in current_var_or_name)): 

401 var = current_var_or_name 

402 else: 

403 store_vars = vs._get_default_variable_store()._vars # pylint:disable=protected-access 

404 # Check if this variable is in var_store. 

405 var = store_vars.get(current_var_or_name, None) 

406 # Also check if variable is partitioned as list. 

407 if var is None: 

408 var = _collect_partitioned_variable(current_var_or_name, store_vars) 

409 if var is not None: 

410 # If 1 to 1 mapping was provided, find variable in the checkpoint. 

411 if tensor_name_in_ckpt not in variable_map: 

412 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 

413 tensor_name_in_ckpt, ckpt_dir_or_file, variable_map 

414 )) 

415 if _is_variable(var): 

416 # Additional at-call-time checks. 

417 if not var.get_shape().is_compatible_with( 

418 variable_map[tensor_name_in_ckpt]): 

419 raise ValueError( 

420 "Shape of variable %s (%s) doesn't match with shape of " 

421 "tensor %s (%s) from checkpoint reader." % ( 

422 var.name, str(var.get_shape()), 

423 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 

424 )) 

425 var_name = var.name 

426 else: 

427 var_name = ",".join(v.name for v in var) 

428 _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt) 

429 logging.debug("Initialize variable %s from checkpoint %s with %s", 

430 var_name, ckpt_dir_or_file, tensor_name_in_ckpt) 

431 else: 

432 scopes = "" 

433 # TODO(vihanjain): Support list of 'current_var_or_name' here. 

434 if "/" in current_var_or_name: 

435 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 

436 if not tensor_name_in_ckpt.endswith("/"): 

437 raise ValueError( 

438 "Assignment map with scope only name {} should map to scope only " 

439 "{}. Should be 'scope/': 'other_scope/'.".format( 

440 scopes, tensor_name_in_ckpt)) 

441 # If scope to scope mapping was provided, find all variables in the scope 

442 # and create variable to variable mapping. 

443 scope_variables = set() 

444 for var_name in store_vars: 

445 if not scopes or var_name.startswith(scopes + "/"): 

446 # Consume /part_ if partitioned variable. 

447 if "/part_" in var_name: 

448 var_name = var_name[:var_name.index("/part_")] 

449 scope_variables.add(var_name) 

450 for var_name in sorted(scope_variables): 

451 # Lookup name with specified prefix and suffix from current variable. 

452 # If tensor_name given is '/' (root), don't use it for full name. 

453 full_tensor_name = var_name[len(scopes):] 

454 if current_var_or_name != "/": 

455 full_tensor_name = full_tensor_name[1:] 

456 if tensor_name_in_ckpt != "/": 

457 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 

458 # Remove trailing '/', if any, in the full_tensor_name 

459 if full_tensor_name.endswith("/"): 

460 full_tensor_name = full_tensor_name[:-1] 

461 if full_tensor_name not in variable_map: 

462 raise ValueError( 

463 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 

464 full_tensor_name, var_name[len(scopes) + 1:], 

465 tensor_name_in_ckpt, ckpt_dir_or_file 

466 )) 

467 var = store_vars.get(var_name, None) 

468 if var is None: 

469 var = _collect_partitioned_variable(var_name, store_vars) 

470 _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name) 

471 logging.debug("Initialize variable %s from checkpoint %s with %s", 

472 var_name, ckpt_dir_or_file, full_tensor_name) 

473 

474 

475def _get_checkpoint_filename(ckpt_dir_or_file): 

476 """Returns checkpoint filename given directory or specific checkpoint file.""" 

477 if isinstance(ckpt_dir_or_file, os.PathLike): 

478 ckpt_dir_or_file = os.fspath(ckpt_dir_or_file) 

479 if gfile.IsDirectory(ckpt_dir_or_file): 

480 return checkpoint_management.latest_checkpoint(ckpt_dir_or_file) 

481 return ckpt_dir_or_file 

482 

483 

484def _set_checkpoint_initializer(variable, 

485 ckpt_file, 

486 tensor_name, 

487 slice_spec, 

488 name="checkpoint_initializer"): 

489 """Overrides given variable's initialization op. 

490 

491 Sets variable initializer to assign op that initializes variable from tensor's 

492 value in the checkpoint. 

493 

494 Args: 

495 variable: `tf.Variable` object. 

496 ckpt_file: string, full path of the checkpoint. 

497 tensor_name: Name of the tensor to load from the checkpoint. 

498 slice_spec: Slice specification for loading partitioned tensors. 

499 name: Name of the operation. 

500 """ 

501 base_type = variable.dtype.base_dtype 

502 # Do not colocate with variable since RestoreV2 op only runs on CPU and 

503 # colocation will force variable (and other ops that colocate with variable) 

504 # to be on CPU as well. It is okay to place the variable's initializer op on 

505 # CPU since it will only be run once at the start. 

506 with ops.device(variable.device), ops.device("/cpu:0"): 

507 restore_op = io_ops.restore_v2( 

508 ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] 

509 

510 names_to_saveables = saveable_object_util.op_list_to_dict([variable]) 

511 saveable_objects = [] 

512 for name, op in names_to_saveables.items(): 

513 for s in saveable_object_util.saveable_objects_for_op(op, name): 

514 saveable_objects.append(s) 

515 

516 assert len(saveable_objects) == 1 # Should be only one variable. 

517 init_op = saveable_objects[0].restore([restore_op], restored_shapes=None) 

518 

519 # pylint:disable=protected-access 

520 variable._initializer_op = init_op 

521 restore_op.set_shape(variable.shape) 

522 variable._initial_value = restore_op 

523 # pylint:enable=protected-access 

524 

525 

526def _set_variable_or_list_initializer(variable_or_list, ckpt_file, 

527 tensor_name): 

528 """Overrides initialization op of given variable or list of variables. 

529 

530 Calls `_set_checkpoint_initializer` for each variable in the given list of 

531 variables. 

532 

533 Args: 

534 variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects. 

535 ckpt_file: string, full path of the checkpoint. 

536 tensor_name: Name of the tensor to load from the checkpoint. 

537 

538 Raises: 

539 ValueError: if all objects in `variable_or_list` are not partitions of the 

540 same large variable. 

541 """ 

542 if isinstance(variable_or_list, (list, tuple)): 

543 # A set of slices. 

544 slice_name = None 

545 for v in variable_or_list: 

546 slice_info = v._save_slice_info # pylint:disable=protected-access 

547 if slice_name is None: 

548 slice_name = slice_info.full_name 

549 elif slice_name != slice_info.full_name: 

550 raise ValueError("Slices must all be from the same tensor: %s != %s" % 

551 (slice_name, slice_info.full_name)) 

552 _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec) 

553 else: 

554 _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "") 

555 

556 

557def _is_variable(x): 

558 return (isinstance(x, variables.Variable) or 

559 resource_variable_ops.is_resource_variable(x)) 

560 

561 

562def _collect_partitioned_variable(name, all_vars): 

563 """Returns list of `tf.Variable` that comprise the partitioned variable.""" 

564 if name + "/part_0" in all_vars: 

565 var = [] 

566 i = 0 

567 while name + "/part_%d" % i in all_vars: 

568 var.append(all_vars[name + "/part_%d" % i]) 

569 i += 1 

570 return var 

571 return None