Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/tpu_replicated_variable.py: 47%

150 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A Variable class that is replicated to logical cores for model parallelism.""" 

16from __future__ import absolute_import 

17from __future__ import division 

18from __future__ import print_function 

19 

20from collections import abc 

21import contextlib 

22 

23from tensorflow.python.compiler.xla.experimental import xla_sharding 

24from tensorflow.python.distribute import tpu_util 

25from tensorflow.python.eager import context 

26from tensorflow.python.framework import config 

27from tensorflow.python.framework import ops 

28from tensorflow.python.framework import tensor_conversion_registry 

29from tensorflow.python.ops import control_flow_ops 

30from tensorflow.python.ops import gen_resource_variable_ops 

31from tensorflow.python.ops import gen_tpu_partition_ops as tpu_partition_ops 

32from tensorflow.python.ops import variable_scope 

33from tensorflow.python.ops import variables as variables_lib 

34from tensorflow.python.saved_model import save_context 

35from tensorflow.python.trackable import base as trackable 

36 

37 

38def _on_device_update(update_fn, var, value, **kwargs): 

39 with ops.device(var.device): 

40 return update_fn(var, value, **kwargs) 

41 

42 

43class TPUReplicatedVariable(variables_lib.Variable): 

44 """Container for replicated `Variables` that are treated as a single variable. 

45 

46 This class maintains a list of replicated variables that are stored on 

47 separate logic TPU devices. TF2XLA bridge accesses these variables as 

48 if they were a single variable. 

49 """ 

50 

51 def __init__(self, variables, name='TPUReplicatedVariable'): 

52 """Treats `variables` as a replicated list of `tf.Variable`s. 

53 

54 Example: 

55 

56 ``` 

57 variables = [ 

58 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 

59 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 

60 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 

61 tf.Variable(..., shape=(10, 100), dtype=tf.float32), 

62 ] 

63 replicated_variable = TPUReplicatedVariable(variables) 

64 assert replicated_variable.shape.as_list() == [10, 100] 

65 ``` 

66 

67 Args: 

68 variables: A list of `ResourceVariable`s that comprise this replicated 

69 variable. Variables should not be shared between different 

70 `TPUReplicatedVariable` objects. 

71 name: String. Name of this container. Defaults to "TPUReplicatedVariable". 

72 """ 

73 if not isinstance(variables, abc.Sequence) or not variables or any( 

74 not isinstance(v, variables_lib.Variable) for v in variables): 

75 raise TypeError('Argument `variables` should be a non-empty list of ' 

76 f'`variables.Variable`s. Received {variables}') 

77 

78 if any(v.dtype != variables[0].dtype for v in variables): 

79 raise ValueError( 

80 'All elements in argument `variables` must have the same dtype. ' 

81 f'Received dtypes: {[v.dtype for v in variables]}') 

82 

83 if any(v.shape != variables[0].shape for v in variables): 

84 raise ValueError( 

85 'All elements in argument `variables` must have the same shape. ' 

86 f'Received shapes: {[v.shape for v in variables]}') 

87 

88 self._vars = variables 

89 self._name = name 

90 self._common_name = self._name.split(':')[0] 

91 self._cached_value = None 

92 

93 def __iter__(self): 

94 """Return an iterable for accessing the underlying sharded variables.""" 

95 return iter(self._vars) 

96 

97 @property 

98 def name(self): 

99 """The name of this object. Used for checkpointing.""" 

100 return self._name 

101 

102 @property 

103 def dtype(self): 

104 """The dtype of all `Variable`s in this object.""" 

105 return self._vars[0].dtype 

106 

107 @property 

108 def is_initialized(self): 

109 return self._vars[0].is_initialized 

110 

111 @property 

112 def trainable(self): 

113 return self._vars[0].trainable 

114 

115 @property 

116 def device(self): 

117 """The device this variable is on.""" 

118 return self._vars[0].device 

119 

120 @contextlib.contextmanager 

121 def _handle_graph(self): 

122 with self.handle.graph.as_default(): 

123 yield 

124 

125 @contextlib.contextmanager 

126 def _assign_dependencies(self): 

127 if self._cached_value is not None: 

128 with ops.control_dependencies([self._cached_value]): 

129 yield 

130 else: 

131 yield 

132 

133 @property 

134 def constraint(self): 

135 return self._vars[0].constraint 

136 

137 @property 

138 def _in_graph_mode(self): 

139 return self._vars[0]._in_graph_mode # pylint: disable=protected-access 

140 

141 @property 

142 def _unique_id(self): 

143 return self._vars[0]._unique_id # pylint: disable=protected-access 

144 

145 @property 

146 def graph(self): 

147 return self._vars[0].graph 

148 

149 @property 

150 def _shared_name(self): 

151 return self._common_name 

152 

153 @property 

154 def synchronization(self): 

155 return variable_scope.VariableSynchronization.NONE 

156 

157 @property 

158 def aggregation(self): 

159 return variable_scope.VariableAggregation.NONE 

160 

161 @property 

162 def variables(self): 

163 """The list of `Variables`.""" 

164 if save_context.in_save_context(): 

165 return [self._vars[0]] 

166 return self._vars 

167 

168 def _export_to_saved_model_graph(self, object_map, tensor_map, 

169 options, **kwargs): 

170 """For implementing `Trackable`.""" 

171 first_var = self._vars[0] 

172 resource_list = first_var._export_to_saved_model_graph( # pylint:disable=protected-access 

173 object_map, tensor_map, options, **kwargs) 

174 for v in self._vars[1:]: 

175 object_map[v] = object_map[first_var] 

176 tensor_map[v.handle] = tensor_map[first_var.handle] 

177 resource_list.append(v.handle) 

178 object_map[self] = object_map[first_var] 

179 tensor_map[self] = tensor_map[first_var.handle] 

180 resource_list.append(self) 

181 return resource_list 

182 

183 def _gather_saveables_for_saved_model(self): 

184 return {trackable.VARIABLE_VALUE_KEY: self._vars[0]} 

185 

186 @property 

187 def shape(self): 

188 return self._vars[0].shape 

189 

190 @property 

191 def handle(self): 

192 if save_context.in_save_context() or context.executing_eagerly(): 

193 return self._vars[0].handle 

194 

195 if tpu_util.enclosing_tpu_context() is None: 

196 raise NotImplementedError('TPUReplicatedVariable.handle is not available ' 

197 'outside tpu context or save context') 

198 else: 

199 with tpu_util.outside_or_skip_tpu_context(): 

200 packed_var = getattr(self, '_packed_var', None) 

201 

202 # TODO(b/202047549): Enable packed variables with soft device placement 

203 if packed_var is None or config.get_soft_device_placement(): 

204 tensor = tpu_partition_ops.tpu_partitioned_input_v2( 

205 [v.handle for v in self._vars], 

206 partition_dims=[], is_packed=False) 

207 else: 

208 tensor = tpu_partition_ops.tpu_partitioned_input_v2( 

209 [packed_var.packed_handle], partition_dims=[], is_packed=True) 

210 

211 return xla_sharding.replicate(tensor) 

212 

213 def _read_variable_op(self): 

214 return gen_resource_variable_ops.read_variable_op(self.handle, self.dtype) 

215 

216 def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): 

217 """Converts a variable to a tensor.""" 

218 # pylint: disable=protected-access 

219 if tpu_util.enclosing_tpu_context() is None: 

220 return self.read_value() 

221 else: 

222 return self._read_variable_op() 

223 

224 def read_value(self): 

225 return self._vars[0].read_value() 

226 

227 def _update(self, update_fn, value, **kwargs): 

228 """Converts the value to tensor and updates the variable list.""" 

229 input_tensor = ops.convert_to_tensor( 

230 value, name='value_in_tensor', dtype=self.dtype) 

231 

232 return control_flow_ops.group( 

233 *tuple( 

234 _on_device_update(update_fn, v, input_tensor, **kwargs) 

235 for v in self.variables)) 

236 

237 def assign(self, value, use_locking=False, name=None, read_value=True): 

238 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 

239 assign_fn = lambda var, *a, **ka: var.assign(*a, **ka) 

240 return self._update( 

241 assign_fn, 

242 value=value, 

243 use_locking=use_locking, 

244 name=name, 

245 read_value=read_value) 

246 else: 

247 return tpu_util.make_raw_assign_fn( 

248 gen_resource_variable_ops.assign_variable_op)( 

249 self, 

250 value=value, 

251 use_locking=use_locking, 

252 name=name, 

253 read_value=read_value) 

254 

255 def assign_sub(self, value, use_locking=False, name=None, read_value=True): 

256 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 

257 assign_sub_fn = lambda var, *a, **ka: var.assign_sub(*a, **ka) 

258 return self._update( 

259 assign_sub_fn, 

260 value=value, 

261 use_locking=use_locking, 

262 name=name, 

263 read_value=read_value) 

264 else: 

265 return tpu_util.make_raw_assign_fn( 

266 gen_resource_variable_ops.assign_sub_variable_op)( 

267 self, 

268 value=value, 

269 use_locking=use_locking, 

270 name=name, 

271 read_value=read_value) 

272 

273 def assign_add(self, value, use_locking=False, name=None, read_value=True): 

274 if tpu_util.enclosing_tpu_context() is None or context.executing_eagerly(): 

275 assign_add_fn = lambda var, *a, **ka: var.assign_add(*a, **ka) 

276 return self._update( 

277 assign_add_fn, 

278 value=value, 

279 use_locking=use_locking, 

280 name=name, 

281 read_value=read_value) 

282 else: 

283 return tpu_util.make_raw_assign_fn( 

284 gen_resource_variable_ops.assign_add_variable_op)( 

285 self, 

286 value=value, 

287 use_locking=use_locking, 

288 name=name, 

289 read_value=read_value) 

290 

291 def __str__(self): 

292 debug_str = ',\n'.join( 

293 ' %d: %s' % (i, v) for i, v in enumerate(self._vars)) 

294 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_str) 

295 

296 def __repr__(self): 

297 debug_repr = ',\n'.join( 

298 ' %d: %r' % (i, v) for i, v in enumerate(self._vars)) 

299 return '%s:{\n%s\n}' % (self.__class__.__name__, debug_repr) 

300 

301 

302# Register a conversion function which reads the value of the variable, 

303# allowing instances of the class to be used as tensors. 

304def _tensor_conversion_tpu_replicated_var(var, 

305 dtype=None, 

306 name=None, 

307 as_ref=False): 

308 return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access 

309 

310 

311tensor_conversion_registry.register_tensor_conversion_function( 

312 TPUReplicatedVariable, _tensor_conversion_tpu_replicated_var)