Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/input_util.py: 38%

16 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utils to create distributed datasets based on TF version.""" 

16 

17from tensorflow.python import tf2 

18from tensorflow.python.distribute import input_lib 

19from tensorflow.python.distribute.v1 import input_lib as input_lib_v1 

20 

21 

22def get_distributed_dataset( 

23 dataset, 

24 input_workers, 

25 strategy, 

26 num_replicas_in_sync=None, 

27 input_context=None, 

28 options=None, 

29 build=True, 

30 replica_order=None, 

31): 

32 """Returns a distributed dataset from the given tf.data.Dataset instance. 

33 

34 This is a common function that is used by all strategies to return a 

35 distributed dataset. The distributed dataset instance returned is different 

36 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 

37 instances returned differ from each other in the APIs supported by each of 

38 them. 

39 

40 Args: 

41 dataset: a tf.data.Dataset instance. 

42 input_workers: an InputWorkers object which specifies devices on which 

43 iterators should be created. 

44 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

45 handle last partial batch. 

46 num_replicas_in_sync: Optional integer. If this is not None, the value is 

47 used to decide how to rebatch datasets into smaller batches so that the 

48 total batch size for each step (across all workers and replicas) adds up 

49 to `dataset`'s batch size. 

50 input_context: `InputContext` for sharding. Only pass this in for between 

51 graph multi-worker cases where there is only one `input_worker`. In these 

52 cases, we will shard based on the `input_pipeline_id` and 

53 `num_input_pipelines` in the `InputContext`. 

54 options: Default is None. `tf.distribute.InputOptions` used to control 

55 options on how this dataset is distributed. 

56 build: whether to build underlying datasets when a DistributedDataset is 

57 created. This is only useful for `ParameterServerStrategy` now. 

58 replica_order: the order of the replicas, which will be used to reorder the 

59 iterators to match the device order. 

60 

61 Returns: 

62 A distributed dataset instance. 

63 """ 

64 if tf2.enabled(): 

65 return input_lib.DistributedDataset( 

66 input_workers, 

67 strategy, 

68 dataset, 

69 num_replicas_in_sync=num_replicas_in_sync, 

70 input_context=input_context, 

71 build=build, 

72 options=options, 

73 replica_order=replica_order, 

74 ) 

75 else: 

76 return input_lib_v1.DistributedDatasetV1( 

77 dataset, 

78 input_workers, 

79 strategy, 

80 num_replicas_in_sync=num_replicas_in_sync, 

81 input_context=input_context, 

82 options=options) 

83 

84 

85def get_distributed_datasets_from_function( 

86 dataset_fn, 

87 input_workers, 

88 input_contexts, 

89 strategy, 

90 options=None, 

91 build=True, 

92 replica_order=None, 

93): 

94 """Returns a distributed dataset from the given input function. 

95 

96 This is a common function that is used by all strategies to return a 

97 distributed dataset. The distributed dataset instance returned is different 

98 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 

99 instances returned differ from each other in the APIs supported by each of 

100 them. 

101 

102 Args: 

103 dataset_fn: a function that returns a tf.data.Dataset instance. 

104 input_workers: an InputWorkers object which specifies devices on which 

105 iterators should be created. 

106 input_contexts: A list of `InputContext` instances to be passed to call(s) 

107 to `dataset_fn`. Length and order should match worker order in 

108 `worker_device_pairs`. 

109 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 

110 handle last partial batch. 

111 options: Default is None. `tf.distribute.InputOptions` used to control 

112 options on how this dataset is distributed. 

113 build: whether to build underlying datasets when a 

114 `DistributedDatasetFromFunction` is created. This is only useful for 

115 `ParameterServerStrategy` now. 

116 replica_order: the order of the replicas, which will be used to reorder the 

117 iterators to match the device order. 

118 

119 Returns: 

120 A distributed dataset instance. 

121 

122 Raises: 

123 ValueError: if `options.experimental_replication_mode` and 

124 `options.experimental_place_dataset_on_device` are not consistent 

125 """ 

126 if (options is not None and options.experimental_replication_mode != 

127 input_lib.InputReplicationMode.PER_REPLICA and 

128 options.experimental_place_dataset_on_device): 

129 raise ValueError( 

130 "When `experimental_place_dataset_on_device` is set for dataset " 

131 "placement, you must also specify `PER_REPLICA` for the " 

132 "replication mode") 

133 

134 if (options is not None and options.experimental_replication_mode 

135 == input_lib.InputReplicationMode.PER_REPLICA and 

136 options.experimental_fetch_to_device and 

137 options.experimental_place_dataset_on_device): 

138 raise ValueError( 

139 "`experimental_place_dataset_on_device` can not be set to True " 

140 "when experimental_fetch_to_device is True and " 

141 "replication mode is set to `PER_REPLICA`") 

142 

143 if tf2.enabled(): 

144 return input_lib.DistributedDatasetsFromFunction( 

145 input_workers, 

146 strategy, 

147 input_contexts=input_contexts, 

148 dataset_fn=dataset_fn, 

149 options=options, 

150 build=build, 

151 replica_order=replica_order, 

152 ) 

153 else: 

154 return input_lib_v1.DistributedDatasetsFromFunctionV1( 

155 input_workers, strategy, input_contexts, dataset_fn, options)