Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/one_device_strategy.py: 46%

168 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A tf.distribute.Strategy for running on a single device.""" 

16 

17from tensorflow.python.distribute import device_util 

18from tensorflow.python.distribute import distribute_lib 

19from tensorflow.python.distribute import distribute_utils 

20from tensorflow.python.distribute import input_lib 

21from tensorflow.python.distribute import input_util 

22from tensorflow.python.distribute import numpy_dataset 

23from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 

24from tensorflow.python.framework import constant_op 

25from tensorflow.python.framework import ops 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import control_flow_ops 

28from tensorflow.python.ops import while_loop 

29from tensorflow.python.util import nest 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33# TODO(josh11b): Do we wrap values in types to generate errors if you are 

34# doing something that won't work with other DistributionStrategy 

35# implementations? 

36 

37 

38@tf_export("distribute.OneDeviceStrategy", v1=[]) 

39class OneDeviceStrategy(distribute_lib.Strategy): 

40 """A distribution strategy for running on a single device. 

41 

42 Using this strategy will place any variables created in its scope on the 

43 specified device. Input distributed through this strategy will be 

44 prefetched to the specified device. Moreover, any functions called via 

45 `strategy.run` will also be placed on the specified device 

46 as well. 

47 

48 Typical usage of this strategy could be testing your code with the 

49 tf.distribute.Strategy API before switching to other strategies which 

50 actually distribute to multiple devices/machines. 

51 

52 For example: 

53 ``` 

54 strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") 

55 

56 with strategy.scope(): 

57 v = tf.Variable(1.0) 

58 print(v.device) # /job:localhost/replica:0/task:0/device:GPU:0 

59 

60 def step_fn(x): 

61 return x * 2 

62 

63 result = 0 

64 for i in range(10): 

65 result += strategy.run(step_fn, args=(i,)) 

66 print(result) # 90 

67 ``` 

68 """ 

69 

70 def __init__(self, device): 

71 """Creates a `OneDeviceStrategy`. 

72 

73 Args: 

74 device: Device string identifier for the device on which the variables 

75 should be placed. See class docs for more details on how the device is 

76 used. Examples: "/cpu:0", "/gpu:0", "/device:CPU:0", "/device:GPU:0" 

77 """ 

78 super(OneDeviceStrategy, self).__init__(OneDeviceExtended(self, device)) 

79 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 

80 "OneDeviceStrategy") 

81 

82 def experimental_distribute_dataset(self, dataset, options=None): # pylint: disable=useless-super-delegation 

83 """Distributes a tf.data.Dataset instance provided via dataset. 

84 

85 In this case, there is only one device, so this is only a thin wrapper 

86 around the input dataset. It will, however, prefetch the input data to the 

87 specified device. The returned distributed dataset can be iterated over 

88 similar to how regular datasets can. 

89 

90 NOTE: Currently, the user cannot add any more transformations to a 

91 distributed dataset. 

92 

93 Example: 

94 ``` 

95 strategy = tf.distribute.OneDeviceStrategy() 

96 dataset = tf.data.Dataset.range(10).batch(2) 

97 dist_dataset = strategy.experimental_distribute_dataset(dataset) 

98 for x in dist_dataset: 

99 print(x) # [0, 1], [2, 3],... 

100 ``` 

101 Args: 

102 dataset: `tf.data.Dataset` to be prefetched to device. 

103 options: `tf.distribute.InputOptions` used to control options on how this 

104 dataset is distributed. 

105 Returns: 

106 A "distributed `Dataset`" that the caller can iterate over. 

107 """ 

108 return super(OneDeviceStrategy, self).experimental_distribute_dataset( 

109 dataset, options) 

110 

111 def distribute_datasets_from_function( 

112 self, 

113 dataset_fn, # pylint: disable=useless-super-delegation 

114 options=None): 

115 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 

116 

117 `dataset_fn` will be called once for each worker in the strategy. In this 

118 case, we only have one worker and one device so `dataset_fn` is called 

119 once. 

120 

121 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 

122 information about batching and input replication can be accessed: 

123 

124 ``` 

125 def dataset_fn(input_context): 

126 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 

127 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 

128 return d.shard( 

129 input_context.num_input_pipelines, input_context.input_pipeline_id) 

130 

131 inputs = strategy.distribute_datasets_from_function(dataset_fn) 

132 

133 for batch in inputs: 

134 replica_results = strategy.run(replica_fn, args=(batch,)) 

135 ``` 

136 

137 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 

138 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 

139 the global batch size. This may be computed using 

140 `input_context.get_per_replica_batch_size`. 

141 

142 Args: 

143 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 

144 returning a `tf.data.Dataset`. 

145 options: `tf.distribute.InputOptions` used to control options on how this 

146 dataset is distributed. 

147 

148 Returns: 

149 A "distributed `Dataset`", which the caller can iterate over like regular 

150 datasets. 

151 """ 

152 return super(OneDeviceStrategy, 

153 self).distribute_datasets_from_function(dataset_fn, options) 

154 

155 def experimental_local_results(self, value): # pylint: disable=useless-super-delegation 

156 """Returns the list of all local per-replica values contained in `value`. 

157 

158 In `OneDeviceStrategy`, the `value` is always expected to be a single 

159 value, so the result is just the value in a tuple. 

160 

161 Args: 

162 value: A value returned by `experimental_run()`, `run()`, 

163 `extended.call_for_each_replica()`, or a variable created in `scope`. 

164 

165 Returns: 

166 A tuple of values contained in `value`. If `value` represents a single 

167 value, this returns `(value,).` 

168 """ 

169 return super(OneDeviceStrategy, self).experimental_local_results(value) 

170 

171 def run(self, fn, args=(), kwargs=None, options=None): # pylint: disable=useless-super-delegation 

172 """Run `fn` on each replica, with the given arguments. 

173 

174 In `OneDeviceStrategy`, `fn` is simply called within a device scope for the 

175 given device, with the provided arguments. 

176 

177 Args: 

178 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 

179 args: (Optional) Positional arguments to `fn`. 

180 kwargs: (Optional) Keyword arguments to `fn`. 

181 options: (Optional) An instance of `tf.distribute.RunOptions` specifying 

182 the options to run `fn`. 

183 

184 Returns: 

185 Return value from running `fn`. 

186 """ 

187 return super(OneDeviceStrategy, self).run(fn, args, kwargs, options) 

188 

189 def reduce(self, reduce_op, value, axis): # pylint: disable=useless-super-delegation 

190 """Reduce `value` across replicas. 

191 

192 In `OneDeviceStrategy`, there is only one replica, so if axis=None, value 

193 is simply returned. If axis is specified as something other than None, 

194 such as axis=0, value is reduced along that axis and returned. 

195 

196 Example: 

197 ``` 

198 t = tf.range(10) 

199 

200 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=None).numpy() 

201 # result: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 

202 

203 result = strategy.reduce(tf.distribute.ReduceOp.SUM, t, axis=0).numpy() 

204 # result: 45 

205 ``` 

206 

207 Args: 

208 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 

209 be combined. 

210 value: A "per replica" value, e.g. returned by `run` to 

211 be combined into a single tensor. 

212 axis: Specifies the dimension to reduce along within each 

213 replica's tensor. Should typically be set to the batch dimension, or 

214 `None` to only reduce across replicas (e.g. if the tensor has no batch 

215 dimension). 

216 

217 Returns: 

218 A `Tensor`. 

219 """ 

220 return super(OneDeviceStrategy, self).reduce(reduce_op, value, axis) 

221 

222 def scope(self): # pylint: disable=useless-super-delegation 

223 """Returns a context manager selecting this Strategy as current. 

224 

225 Inside a `with strategy.scope():` code block, this thread 

226 will use a variable creator set by `strategy`, and will 

227 enter its "cross-replica context". 

228 

229 In `OneDeviceStrategy`, all variables created inside `strategy.scope()` 

230 will be on `device` specified at strategy construction time. 

231 See example in the docs for this class. 

232 

233 Returns: 

234 A context manager to use for creating variables with this strategy. 

235 """ 

236 return super(OneDeviceStrategy, self).scope() 

237 

238 

239@tf_export(v1=["distribute.OneDeviceStrategy"]) # pylint: disable=empty-docstring 

240class OneDeviceStrategyV1(distribute_lib.StrategyV1): 

241 

242 __doc__ = OneDeviceStrategy.__doc__.replace( 

243 "For example:\n ```", 

244 "For example:\n ```\n tf.enable_eager_execution()") 

245 

246 def __init__(self, device): 

247 super(OneDeviceStrategyV1, self).__init__(OneDeviceExtended(self, device)) 

248 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 

249 "OneDeviceStrategy") 

250 __init__.__doc__ = OneDeviceStrategy.__init__.__doc__ 

251 

252 

253# TODO(josh11b): Switch to V2 after callers have been updated to only V2 APIs. 

254class OneDeviceExtended(distribute_lib.StrategyExtendedV1): 

255 """Implementation of OneDeviceStrategy.""" 

256 

257 def __init__(self, container_strategy, device): 

258 super(OneDeviceExtended, self).__init__(container_strategy) 

259 self._device = device_util.resolve(device) 

260 self._input_device = device_util.get_host_for_device(self._device) 

261 

262 def _input_workers_with_options(self, options=None): 

263 if not options or options.experimental_fetch_to_device: 

264 return input_lib.InputWorkers([(self._input_device, (self._device,))]) 

265 else: 

266 return input_lib.InputWorkers([(self._input_device, 

267 (self._input_device,))]) 

268 

269 @property 

270 def _input_workers(self): 

271 return self._input_workers_with_options() 

272 

273 def _create_variable(self, next_creator, **kwargs): 

274 colocate_with = kwargs.pop("colocate_with", None) 

275 if colocate_with is None: 

276 with ops.device(self._device): 

277 return next_creator(**kwargs) 

278 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 

279 with ops.device(colocate_with.device): 

280 return next_creator(**kwargs) 

281 else: 

282 with ops.colocate_with(colocate_with): 

283 return next_creator(**kwargs) 

284 

285 def _validate_colocate_with_variable(self, colocate_with_variable): 

286 distribute_utils.validate_colocate(colocate_with_variable, self) 

287 

288 def _make_dataset_iterator(self, dataset): 

289 """Make iterator from dataset without splitting the batch.""" 

290 # Note that split_batch_by argument is not passed because it is always 1 in 

291 # this strategy, and adding it adds unnecessary overhead to the dataset. 

292 return input_lib_v1.DatasetIterator(dataset, self._input_workers, 

293 self._container_strategy()) 

294 

295 def _make_input_fn_iterator( 

296 self, 

297 input_fn, 

298 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 

299 return input_lib_v1.InputFunctionIterator(input_fn, self._input_workers, 

300 [distribute_lib.InputContext()], 

301 self._container_strategy()) 

302 

303 def _experimental_make_numpy_dataset(self, numpy_input, session): 

304 return numpy_dataset.one_host_numpy_dataset( 

305 numpy_input, numpy_dataset.SingleDevice(self._input_device), session) 

306 

307 def _broadcast_to(self, tensor, destinations): 

308 del destinations 

309 return tensor 

310 

311 def _experimental_distribute_dataset(self, dataset, options): 

312 # Note that split_batch_by argument is not passed because it is always 1 in 

313 # this strategy, and adding it adds unnecessary overhead to the dataset. 

314 if (options and options.experimental_replication_mode == 

315 distribute_lib.InputReplicationMode.PER_REPLICA): 

316 raise NotImplementedError( 

317 "InputReplicationMode.PER_REPLICA " 

318 "is only supported in " 

319 "`experimental_distribute_datasets_from_function`." 

320 ) 

321 return input_util.get_distributed_dataset( 

322 dataset, 

323 self._input_workers_with_options(options), 

324 self._container_strategy(), 

325 options=options) 

326 

327 def _distribute_datasets_from_function(self, dataset_fn, options): 

328 if (options and options.experimental_replication_mode == 

329 distribute_lib.InputReplicationMode.PER_REPLICA): 

330 raise NotImplementedError( 

331 "InputReplicationMode.PER_REPLICA " 

332 "is only supported in " 

333 "`experimental_distribute_datasets_from_function` " 

334 "of tf.distribute.MirroredStrategy") 

335 return input_util.get_distributed_datasets_from_function( 

336 dataset_fn, 

337 self._input_workers_with_options(options), 

338 [distribute_lib.InputContext()], 

339 self._container_strategy(), 

340 options=options) 

341 

342 def _experimental_distribute_values_from_function(self, value_fn): 

343 # TODO(b/137795644): This should return a PerReplica value but other 

344 # methods like run in OneDeviceStrategy need to be modified 

345 # to do the same. 

346 return value_fn(distribute_lib.ValueContext()) 

347 

348 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 

349 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 

350 initial_loop_values=None): 

351 if initial_loop_values is None: 

352 initial_loop_values = {} 

353 initial_loop_values = nest.flatten(initial_loop_values) 

354 

355 ctx = input_lib.MultiStepContext() 

356 def body(i, *args): 

357 """A wrapper around `fn` to create the while loop body.""" 

358 del args 

359 fn_result = fn(ctx, iterator.get_next()) 

360 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 

361 with ops.control_dependencies([fn_result]): 

362 return [i + 1] + flat_last_step_outputs 

363 

364 # We capture the control_flow_context at this point, before we run `fn` 

365 # inside a while_loop. This is useful in cases where we might need to exit 

366 # these contexts and get back to the outer context to do some things, for 

367 # e.g. create an op which should be evaluated only once at the end of the 

368 # loop on the host. One such usage is in creating metrics' value op. 

369 self._outer_control_flow_context = ( 

370 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 

371 

372 # TODO(priyag): Use max_iterations instead of an explicit counter. 

373 cond = lambda i, *args: i < iterations 

374 i = constant_op.constant(0) 

375 loop_result = while_loop.while_loop( 

376 cond, 

377 body, [i] + initial_loop_values, 

378 name="", 

379 parallel_iterations=1, 

380 back_prop=False, 

381 swap_memory=False, 

382 return_same_structure=True) 

383 del self._outer_control_flow_context 

384 

385 ctx.run_op = control_flow_ops.group(loop_result) 

386 

387 # Convert the last_step_outputs from a list to the original dict structure 

388 # of last_step_outputs. 

389 last_step_tensor_outputs = loop_result[1:] 

390 last_step_tensor_outputs_dict = nest.pack_sequence_as( 

391 ctx.last_step_outputs, last_step_tensor_outputs) 

392 

393 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 

394 return ctx 

395 

396 def _call_for_each_replica(self, fn, args, kwargs): 

397 strategy = self._container_strategy() 

398 with ops.device(self._device), _OneDeviceReplicaContext(strategy): 

399 return fn(*args, **kwargs) 

400 

401 def _reduce_to(self, reduce_op, value, destinations, options): 

402 del reduce_op, destinations, options 

403 return value 

404 

405 def _gather_to_implementation(self, value, destinations, axis, options): 

406 del destinations, axis, options 

407 return value 

408 

409 def _update(self, var, fn, args, kwargs, group): 

410 # The implementations of _update() and _update_non_slot() are identical 

411 # except _update() passes `var` as the first argument to `fn()`. 

412 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 

413 

414 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 

415 del colocate_with 

416 with ops.device(self._device), distribute_lib.UpdateContext(self._device): 

417 result = fn(*args, **kwargs) 

418 if group: 

419 return result 

420 else: 

421 return nest.map_structure(self._local_results, result) 

422 

423 def read_var(self, replica_local_var): 

424 """Read the aggregate value of a replica-local variable.""" 

425 return array_ops.identity(replica_local_var) 

426 

427 def _local_results(self, value): 

428 return (value,) 

429 

430 def value_container(self, value): 

431 return value 

432 

433 def _in_multi_worker_mode(self): 

434 """Whether this strategy indicates working in multi-worker settings.""" 

435 return False 

436 

437 @property 

438 def _num_replicas_in_sync(self): 

439 return 1 

440 

441 @property 

442 def worker_devices(self): 

443 return (self._device,) 

444 

445 @property 

446 def parameter_devices(self): 

447 return (self._device,) 

448 

449 def non_slot_devices(self, var_list): 

450 del var_list 

451 return (self._device,) 

452 

453 @property 

454 def experimental_should_init(self): 

455 return True 

456 

457 @property 

458 def experimental_between_graph(self): 

459 return False 

460 

461 @property 

462 def should_checkpoint(self): 

463 return True 

464 

465 @property 

466 def should_save_summary(self): 

467 return True 

468 

469 # TODO(priyag): Delete this once all strategies use global batch size. 

470 @property 

471 def _global_batch_size(self): 

472 """Global and per-replica batching are equivalent for OneDeviceStrategy.""" 

473 return True 

474 

475 @property 

476 def _support_per_replica_values(self): 

477 return False 

478 

479 def _get_local_replica_id(self, replica_id_in_sync_group): 

480 return replica_id_in_sync_group 

481 

482 

483class _OneDeviceReplicaContext(distribute_lib.ReplicaContext): 

484 """ReplicaContext for OneDeviceStrategy.""" 

485 

486 def __init__(self, strategy): 

487 distribute_lib.ReplicaContext.__init__( 

488 self, strategy, replica_id_in_sync_group=0) 

489 

490 @property 

491 def devices(self): 

492 return self._strategy.extended.worker_devices