Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/slot_creator.py: 22%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16"""Standard functions for creating slots. 

17 

18A slot is a `Variable` created with the same first m-dimension as a primary 

19variable or `Tensor`. A slot is always scoped in the namespace of the primary 

20object and typically has the same device and type. 

21 

22Slots are typically used as accumulators to track values associated with 

23the primary object: 

24 

25```python 

26# Optimizers can create a slot for each variable to track accumulators 

27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs} 

28for var in vs: 

29 apply_momentum(var, accumulators[var], lr, grad, momentum_tensor) 

30 

31# Slots can also be used for moving averages 

32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg") 

33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay)) 

34``` 

35""" 

36# pylint: disable=g-bad-name 

37 

38from tensorflow.python.compiler.xla.experimental import xla_sharding 

39from tensorflow.python.distribute import distribute_lib 

40from tensorflow.python.ops import array_ops 

41from tensorflow.python.ops import control_flow_ops 

42from tensorflow.python.ops import init_ops 

43from tensorflow.python.ops import ref_variable 

44from tensorflow.python.ops import resource_variable_ops 

45from tensorflow.python.ops import variable_scope 

46from tensorflow.python.ops import variable_v1 

47from tensorflow.python.ops import variables 

48 

49 

50def _create_slot_var(primary, 

51 val, 

52 scope, 

53 validate_shape, 

54 shape, 

55 dtype, 

56 *, 

57 copy_xla_sharding=False): 

58 """Helper function for creating a slot variable.""" 

59 

60 # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current 

61 # scope. 

62 current_partitioner = variable_scope.get_variable_scope().partitioner 

63 variable_scope.get_variable_scope().set_partitioner(None) 

64 # When init from val instead of callable initializer, the shape is expected to 

65 # be None, not <unknown> or any fully defined shape. 

66 shape = shape if callable(val) else None 

67 if resource_variable_ops.is_resource_variable(primary): 

68 use_resource = True 

69 elif isinstance(primary, ref_variable.RefVariable): 

70 use_resource = False 

71 else: 

72 use_resource = None 

73 slot = variable_scope.get_variable( 

74 scope, 

75 initializer=val, 

76 trainable=False, 

77 use_resource=use_resource, 

78 shape=shape, 

79 dtype=dtype, 

80 validate_shape=validate_shape) 

81 variable_scope.get_variable_scope().set_partitioner(current_partitioner) 

82 

83 # pylint: disable=protected-access 

84 if isinstance(primary, variables.Variable) and primary._save_slice_info: 

85 # Primary is a partitioned variable, so we need to also indicate that 

86 # the slot is a partitioned variable. Slots have the same partitioning 

87 # as their primaries. 

88 # For examples when using AdamOptimizer in linear model, slot.name 

89 # here can be "linear//weights/Adam:0", while primary.op.name is 

90 # "linear//weight". We want to get 'Adam' as real_slot_name, so we 

91 # remove "'linear//weight' + '/'" and ':0'. 

92 real_slot_name = slot.name[len(primary.op.name + "/"):-2] 

93 slice_info = primary._save_slice_info 

94 # support slot's shape not same as primary's shape 

95 # example: primary's shape = [10, 20, 30], slot's shape = 

96 # None, [], [10], [10, 20] or [10, 20, 30] is allowed 

97 # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary 

98 # slot's shape = [], don't set slot's slice_info 

99 # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims 

100 n = slot.shape.ndims 

101 if n is None or n > 0: 

102 slot._set_save_slice_info( 

103 variables.Variable.SaveSliceInfo( 

104 slice_info.full_name + "/" + real_slot_name, 

105 slice_info.full_shape[:n], slice_info.var_offset[:n], 

106 slice_info.var_shape[:n])) 

107 # pylint: enable=protected-access 

108 

109 # Copy XLA sharding attributes from the primary if the slot variable has the 

110 # same rank as the primary. 

111 def _has_same_rank(primary_shape, slot_shape): 

112 return (primary_shape.rank is not None and slot_shape.rank is not None and 

113 primary_shape.rank == slot_shape.rank) 

114 

115 if copy_xla_sharding and _has_same_rank(primary.shape, slot.shape): 

116 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False) 

117 return slot 

118 

119 

120def create_slot(primary, 

121 val, 

122 name, 

123 colocate_with_primary=True, 

124 *, 

125 copy_xla_sharding=False): 

126 """Create a slot initialized to the given value. 

127 

128 The type of the slot is determined by the given value. 

129 

130 Args: 

131 primary: The primary `Variable` or `Tensor`. 

132 val: A `Tensor` specifying the initial value of the slot. 

133 name: Name to use for the slot variable. 

134 colocate_with_primary: Boolean. If True the slot is located 

135 on the same device as `primary`. 

136 copy_xla_sharding: Boolean. If True also copies XLA sharding 

137 from primary. 

138 

139 Returns: 

140 A `Variable` object. 

141 """ 

142 # Scope the slot name in the namespace of the primary variable. 

143 # Set primary's name + '/' + name as default name, so the scope name of 

144 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 

145 # and the same name has been previously used, the scope name will add '_N' 

146 # as suffix for unique identifications. 

147 validate_shape = val.get_shape().is_fully_defined() 

148 if isinstance(primary, variables.Variable): 

149 prefix = primary._shared_name # pylint: disable=protected-access 

150 else: 

151 prefix = primary.op.name 

152 with variable_scope.variable_scope(None, prefix + "/" + name): 

153 if colocate_with_primary: 

154 distribution_strategy = distribute_lib.get_strategy() 

155 with distribution_strategy.extended.colocate_vars_with(primary): 

156 return _create_slot_var( 

157 primary, 

158 val, 

159 "", 

160 validate_shape, 

161 None, 

162 None, 

163 copy_xla_sharding=copy_xla_sharding) 

164 else: 

165 return _create_slot_var( 

166 primary, 

167 val, 

168 "", 

169 validate_shape, 

170 None, 

171 None, 

172 copy_xla_sharding=copy_xla_sharding) 

173 

174 

175def create_slot_with_initializer(primary, 

176 initializer, 

177 shape, 

178 dtype, 

179 name, 

180 colocate_with_primary=True, 

181 *, 

182 copy_xla_sharding=False): 

183 """Creates a slot initialized using an `Initializer`. 

184 

185 The type of the slot is determined by the given value. 

186 

187 Args: 

188 primary: The primary `Variable` or `Tensor`. 

189 initializer: An `Initializer`. The initial value of the slot. 

190 shape: Shape of the initial value of the slot. 

191 dtype: Type of the value of the slot. 

192 name: Name to use for the slot variable. 

193 colocate_with_primary: Boolean. If True the slot is located 

194 on the same device as `primary`. 

195 copy_xla_sharding: Boolean. If True also copies XLA sharding 

196 from primary. 

197 

198 Returns: 

199 A `Variable` object. 

200 """ 

201 # Scope the slot name in the namespace of the primary variable. 

202 # Set "primary.op.name + '/' + name" as default name, so the scope name of 

203 # optimizer can be shared when reuse is True. Meanwhile when reuse is False 

204 # and the same name has been previously used, the scope name will add '_N' 

205 # as suffix for unique identifications. 

206 validate_shape = shape.is_fully_defined() 

207 if isinstance(primary, variables.Variable): 

208 prefix = primary._shared_name # pylint: disable=protected-access 

209 else: 

210 prefix = primary.op.name 

211 with variable_scope.variable_scope(None, prefix + "/" + name): 

212 if colocate_with_primary: 

213 distribution_strategy = distribute_lib.get_strategy() 

214 with distribution_strategy.extended.colocate_vars_with(primary): 

215 return _create_slot_var( 

216 primary, 

217 initializer, 

218 "", 

219 validate_shape, 

220 shape, 

221 dtype, 

222 copy_xla_sharding=copy_xla_sharding) 

223 else: 

224 return _create_slot_var( 

225 primary, 

226 initializer, 

227 "", 

228 validate_shape, 

229 shape, 

230 dtype, 

231 copy_xla_sharding=copy_xla_sharding) 

232 

233 

234def create_zeros_slot(primary, 

235 name, 

236 dtype=None, 

237 colocate_with_primary=True, 

238 *, 

239 copy_xla_sharding=False): 

240 """Create a slot initialized to 0 with same shape as the primary object. 

241 

242 Args: 

243 primary: The primary `Variable` or `Tensor`. 

244 name: Name to use for the slot variable. 

245 dtype: Type of the slot variable. Defaults to the type of `primary`. 

246 colocate_with_primary: Boolean. If True the slot is located 

247 on the same device as `primary`. 

248 copy_xla_sharding: Boolean. If True also copies XLA sharding 

249 from primary. 

250 

251 Returns: 

252 A `Variable` object. 

253 """ 

254 if dtype is None: 

255 dtype = primary.dtype 

256 slot_shape = primary.get_shape() 

257 if slot_shape.is_fully_defined(): 

258 initializer = init_ops.zeros_initializer() 

259 return create_slot_with_initializer( 

260 primary, 

261 initializer, 

262 slot_shape, 

263 dtype, 

264 name, 

265 colocate_with_primary=colocate_with_primary, 

266 copy_xla_sharding=copy_xla_sharding) 

267 else: 

268 if isinstance(primary, variables.Variable): 

269 slot_shape = array_ops.shape( 

270 control_flow_ops.cond( 

271 variable_v1.is_variable_initialized(primary), primary.read_value, 

272 lambda: primary.initial_value)) 

273 else: 

274 slot_shape = array_ops.shape(primary) 

275 val = array_ops.zeros(slot_shape, dtype=dtype) 

276 return create_slot( 

277 primary, 

278 val, 

279 name, 

280 colocate_with_primary=colocate_with_primary, 

281 copy_xla_sharding=copy_xla_sharding)