Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/init_ops.py: 32%

507 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operations often used for initializing tensors. 

16 

17All variable initializers returned by functions in this file should have the 

18following signature: 

19 

20def _initializer(shape, dtype=dtypes.float32, partition_info=None): 

21 Args: 

22 shape: List of `int` representing the shape of the output `Tensor`. Some 

23 initializers may also be able to accept a `Tensor`. 

24 dtype: (Optional) Type of the output `Tensor`. 

25 partition_info: (Optional) variable_scope._PartitionInfo object holding 

26 additional information about how the variable is partitioned. May be 

27 `None` if the variable is not partitioned. 

28 

29 Returns: 

30 A `Tensor` of type `dtype` and `shape`. 

31""" 

32import math 

33 

34import numpy as np 

35 

36from tensorflow.python.framework import constant_op 

37from tensorflow.python.framework import dtypes 

38from tensorflow.python.framework import tensor_shape 

39from tensorflow.python.ops import array_ops 

40from tensorflow.python.ops import array_ops_stack 

41from tensorflow.python.ops import gen_linalg_ops 

42from tensorflow.python.ops import linalg_ops_impl 

43from tensorflow.python.ops import math_ops 

44from tensorflow.python.ops import random_ops 

45from tensorflow.python.util import deprecation 

46from tensorflow.python.util.deprecation import deprecated 

47from tensorflow.python.util.deprecation import deprecated_arg_values 

48from tensorflow.python.util.deprecation import deprecated_args 

49from tensorflow.python.util.tf_export import tf_export 

50 

51 

52class Initializer: 

53 """Initializer base class: all initializers inherit from this class.""" 

54 

55 def __call__(self, shape, dtype=None, partition_info=None): 

56 """Returns a tensor object initialized as specified by the initializer. 

57 

58 Args: 

59 shape: Shape of the tensor. 

60 dtype: Optional dtype of the tensor. If not provided use the initializer 

61 dtype. 

62 partition_info: Optional information about the possible partitioning of a 

63 tensor. 

64 """ 

65 raise NotImplementedError 

66 

67 def get_config(self): 

68 """Returns the configuration of the initializer as a JSON-serializable dict. 

69 

70 Returns: 

71 A JSON-serializable Python dict. 

72 """ 

73 return {} 

74 

75 @classmethod 

76 def from_config(cls, config): 

77 """Instantiates an initializer from a configuration dictionary. 

78 

79 Example: 

80 

81 ```python 

82 initializer = RandomUniform(-1, 1) 

83 config = initializer.get_config() 

84 initializer = RandomUniform.from_config(config) 

85 ``` 

86 

87 Args: 

88 config: A Python dictionary. It will typically be the output of 

89 `get_config`. 

90 

91 Returns: 

92 An Initializer instance. 

93 """ 

94 return cls(**config) 

95 

96 

97@tf_export(v1=["initializers.zeros", "zeros_initializer"]) 

98@deprecation.deprecated_endpoints("initializers.zeros") 

99class Zeros(Initializer): 

100 """Initializer that generates tensors initialized to 0. 

101 

102 @compatibility(TF2) 

103 `tf.compat.v1.zeros_initializer` is compatible with eager execution 

104 and `tf.function`. 

105 

106 To migrate to TF2, please use `tf.zeros_initializer` instead. The `dtype` 

107 argument in `tf.compat.v1.zeros_initializer.__init__()` does not exist in 

108 `tf.zeros_initializer.__init__()`. However, you can specify the `dtype` in 

109 `__call__()` in both cases. 

110 

111 #### Structural Mapping to TF2 

112 

113 Before: 

114 

115 ```python 

116 initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32) 

117 variable = tf.Variable(initializer(shape=[3, 3])) 

118 ``` 

119 

120 After: 

121 

122 ```python 

123 initializer = tf.zeros_initializer() 

124 variable = tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)) 

125 ``` 

126 

127 #### How to Map Arguments 

128 

129 | TF1 Arg Name | TF2 Arg Name | Note | 

130 | :------------------- | :--------------- | :------------------------- | 

131 | `dtype` | `dtype` | In `__call__()` method | 

132 | `partition_info` | - | (`__call__` arg in TF1) Not supported | 

133 

134 

135 #### Before & After Usage Example 

136 

137 Before: 

138 

139 >>> initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32) 

140 >>> tf.Variable(initializer(shape=[3])).numpy() 

141 array([0., 0., 0.], dtype=float32) 

142 >>> tf.Variable(initializer(shape=[3, 3])).numpy() 

143 array([[0., 0., 0.], 

144 [0., 0., 0.], 

145 [0., 0., 0.]], dtype=float32) 

146 >>> initializer = tf.compat.v1.zeros_initializer() 

147 >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy() 

148 array([0., 0., 0.], dtype=float32) 

149 >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy() 

150 array([[0., 0., 0.], 

151 [0., 0., 0.], 

152 [0., 0., 0.]], dtype=float32) 

153 

154 After: 

155 

156 >>> initializer = tf.zeros_initializer() 

157 >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy() 

158 array([0., 0., 0.], dtype=float32) 

159 >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy() 

160 array([[0., 0., 0.], 

161 [0., 0., 0.], 

162 [0., 0., 0.]], dtype=float32) 

163 

164 @end_compatibility 

165 """ 

166 

167 @deprecated_args(None, 

168 "Call initializer instance with the dtype argument instead " 

169 "of passing it to the constructor", "dtype") 

170 def __init__(self, dtype=dtypes.float32): 

171 self.dtype = dtypes.as_dtype(dtype) 

172 

173 def __call__(self, shape, dtype=None, partition_info=None): 

174 if dtype is None: 

175 dtype = self.dtype 

176 return array_ops.zeros(shape, dtype) 

177 

178 def get_config(self): 

179 return {"dtype": self.dtype.name} 

180 

181 

182@tf_export(v1=["initializers.ones", "ones_initializer"]) 

183@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer") 

184class Ones(Initializer): 

185 """Initializer that generates tensors initialized to 1. 

186 

187 @compatibility(TF2) 

188 This API is compatible with TF2 behavior and `tf.function`, and can be 

189 migrated immediately with `tf.keras.initializers.ones`. 

190 

191 Before: 

192 >>> initializer = tf.compat.v1.keras.initializers.ones() 

193 >>> initializer((1, 1)) 

194 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)> 

195 

196 After: 

197 >>> initializer = tf.keras.initializers.ones() 

198 >>> initializer((1, 1)) 

199 <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)> 

200 

201 @end_compatibility 

202 """ 

203 

204 @deprecated_args(None, 

205 "Call initializer instance with the dtype argument instead " 

206 "of passing it to the constructor", "dtype") 

207 def __init__(self, dtype=dtypes.float32): 

208 self.dtype = dtypes.as_dtype(dtype) 

209 

210 def __call__(self, shape, dtype=None, partition_info=None): 

211 if dtype is None: 

212 dtype = self.dtype 

213 return array_ops.ones(shape, dtype) 

214 

215 def get_config(self): 

216 return {"dtype": self.dtype.name} 

217 

218 

219@tf_export(v1=["initializers.constant", "constant_initializer"]) 

220@deprecation.deprecated_endpoints("constant_initializer") 

221class Constant(Initializer): 

222 """Initializer that generates tensors with constant values. 

223 

224 The resulting tensor is populated with values of type `dtype`, as 

225 specified by arguments `value` following the desired `shape` of the 

226 new tensor (see examples below). 

227 

228 The argument `value` can be a constant value, or a list of values of type 

229 `dtype`. If `value` is a list, then the length of the list must be less 

230 than or equal to the number of elements implied by the desired shape of the 

231 tensor. In the case where the total number of elements in `value` is less 

232 than the number of elements required by the tensor shape, the last element 

233 in `value` will be used to fill the remaining entries. If the total number of 

234 elements in `value` is greater than the number of elements required by the 

235 tensor shape, the initializer will raise a `ValueError`. 

236 

237 Args: 

238 value: A Python scalar, list or tuple of values, or a N-dimensional numpy 

239 array. All elements of the initialized variable will be set to the 

240 corresponding value in the `value` argument. 

241 dtype: Default data type, used if no `dtype` argument is provided when 

242 calling the initializer. 

243 verify_shape: Boolean that enables verification of the shape of `value`. If 

244 `True`, the initializer will throw an error if the shape of `value` is not 

245 compatible with the shape of the initialized tensor. 

246 

247 Raises: 

248 TypeError: If the input `value` is not one of the expected types. 

249 

250 Examples: 

251 The following example can be rewritten using a numpy.ndarray instead 

252 of the `value` list, even reshaped, as shown in the two commented lines 

253 below the `value` list initialization. 

254 

255 >>> value = [0, 1, 2, 3, 4, 5, 6, 7] 

256 >>> init = tf.compat.v1.constant_initializer(value) 

257 >>> # fitting shape 

258 >>> with tf.compat.v1.Session(): 

259 ... x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init) 

260 ... x.initializer.run() 

261 ... print(x.eval()) 

262 [[0. 1. 2. 3.] 

263 [4. 5. 6. 7.]] 

264 >>> # Larger shape 

265 >>> with tf.compat.v1.Session(): 

266 ... y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init) 

267 ... y.initializer.run() 

268 ... print(y.eval()) 

269 [[0. 1. 2. 3.] 

270 [4. 5. 6. 7.] 

271 [7. 7. 7. 7.]] 

272 >>> # Smaller shape 

273 >>> with tf.compat.v1.Session(): 

274 ... z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init) 

275 Traceback (most recent call last): 

276 ... 

277 ValueError: Too many elements provided. Needed at most 6, but received 8 

278 >>> # Shape verification 

279 >>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True) 

280 >>> with tf.compat.v1.Session(): 

281 ... u = tf.compat.v1.get_variable('u', shape=[3, 4], 

282 ... initializer=init_verify) 

283 Traceback (most recent call last): 

284 ... 

285 TypeError: Expected Tensor's shape: (3, 4), got (8,). 

286 

287 @compatibility(TF2) 

288 Although it is a legacy API endpoint, `tf.compat.v1.constant_initializer` 

289 is compatible with eager execution and `tf.function`. 

290 

291 To migrate to a non-legacy TF2 API, please use `tf.constant_initializer` 

292 instead. The `dtype` 

293 argument in `tf.compat.v1.constant_initializer.__init__()` does not exist in 

294 `tf.constant_initializer.__init__()`. However, you can specify the `dtype` in 

295 `__call__()` in both cases. 

296 

297 In the `compat.v1` symbol, if `verify_shape` is set to `True`, an exception 

298 is raised when initializing a variable with a different shape from 

299 `value`. If set to `False`, `value` is reshaped to initialize the variable 

300 if necessary. An exception would only be raised when the number of 

301 elements are different. 

302 

303 The `verify_shape` argument is not supported in TF2. Using 

304 `tf.constant_initializer` is equivalent to setting `verify_shape` to `False`. 

305 

306 #### Structural Mapping to TF2 

307 

308 Before: 

309 

310 ```python 

311 value = [0, 1, 2, 3, 4, 5, 6, 7] 

312 initializer = tf.compat.v1.constant_initializer( 

313 value=value, 

314 dtype=tf.float32, 

315 verify_shape=False) 

316 variable = tf.Variable(initializer(shape=[2, 4])) 

317 ``` 

318 

319 After: 

320 

321 ```python 

322 value = [0, 1, 2, 3, 4, 5, 6, 7] 

323 initializer = tf.constant_initializer(value=value) 

324 tf.Variable(initializer(shape=[2, 4], dtype=tf.float32)) 

325 ``` 

326 

327 #### How to Map Arguments 

328 

329 | TF1 Arg Name | TF2 Arg Name | Note | 

330 | :-------------------- | :--------------- | :-------------------------- | 

331 | `value` | `value` | In constructor | 

332 | `dtype` | `dtype` | In `__call__()` method | 

333 | `verify_shape` | Not Supported | Equivalent to set to `False`| 

334 | `partition_info` | - | (`__call__` arg in TF1) Not supported | 

335 

336 

337 #### Before & After Usage Example 

338 

339 Before: 

340 

341 >>> value = [1., 2., 3., 4.] 

342 >>> initializer = tf.compat.v1.constant_initializer( 

343 ... value=value, dtype=tf.float32, verify_shape=True) 

344 >>> tf.Variable(initializer(shape=[2, 2])).numpy() 

345 Traceback (most recent call last): 

346 ... 

347 TypeError: Expected Tensor's shape: (2, 2), got (4,). 

348 >>> initializer = tf.compat.v1.constant_initializer( 

349 ... value=value, dtype=tf.float32, verify_shape=False) 

350 >>> tf.Variable(initializer(shape=[2, 2])).numpy() 

351 array([[1., 2.], 

352 [3., 4.]], dtype=float32) 

353 

354 After: 

355 

356 >>> value = [1., 2., 3., 4.] 

357 >>> initializer = tf.constant_initializer(value=value) 

358 >>> tf.Variable(initializer(shape=[2, 2], dtype=tf.float32)).numpy() 

359 array([[1., 2.], 

360 [3., 4.]], dtype=float32) 

361 

362 @end_compatibility 

363 """ 

364 

365 @deprecated_args(None, 

366 "Call initializer instance with the dtype argument instead " 

367 "of passing it to the constructor", "dtype") 

368 @deprecated_args(None, "Objects must now be the required shape or no shape " 

369 "can be specified", "verify_shape") 

370 def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False): 

371 if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))): 

372 raise TypeError( 

373 f"Invalid type for initial value={value} of type: " 

374 f"{type(value).__name__}. Expected Python scalar, list or tuple of " 

375 "values, or numpy.ndarray.") 

376 

377 self.value = value 

378 self.dtype = dtypes.as_dtype(dtype) 

379 self._verify_shape = verify_shape 

380 

381 def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None): 

382 if dtype is None: 

383 dtype = self.dtype 

384 if verify_shape is None: 

385 verify_shape = self._verify_shape 

386 return constant_op.constant_v1( 

387 self.value, dtype=dtype, shape=shape, verify_shape=verify_shape) 

388 

389 def get_config(self): 

390 # We don't include `verify_shape` for compatibility with Keras. 

391 # `verify_shape` should be passed as an argument to `__call__` rather 

392 # than as a constructor argument: conceptually it isn't a property 

393 # of the initializer. 

394 return {"value": self.value, "dtype": self.dtype.name} 

395 

396 

397@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"]) 

398@deprecation.deprecated_endpoints("initializers.random_uniform") 

399class RandomUniform(Initializer): 

400 """Initializer that generates tensors with a uniform distribution. 

401 

402 Args: 

403 minval: A python scalar or a scalar tensor. Lower bound of the range of 

404 random values to generate. 

405 maxval: A python scalar or a scalar tensor. Upper bound of the range of 

406 random values to generate. Defaults to 1 for float types. 

407 seed: A Python integer. Used to create random seeds. See 

408 `tf.compat.v1.set_random_seed` for behavior. 

409 dtype: Default data type, used if no `dtype` argument is provided when 

410 calling the initializer. 

411 

412 @compatibility(TF2) 

413 Although it is a legacy compat.v1 API, this symbol is compatible with eager 

414 execution and `tf.function`. 

415 

416 To switch to TF2, switch to using either 

417 `tf.initializers.RandomUniform` or `tf.keras.initializers.RandomUniform` 

418 (neither from `compat.v1`) and 

419 pass the dtype when calling the initializer. Keep in mind that 

420 the default minval, maxval and the behavior of fixed seeds have changed. 

421 

422 #### Structural Mapping to TF2 

423 

424 Before: 

425 

426 ```python 

427 initializer = tf.compat.v1.random_uniform_initializer( 

428 minval=minval, 

429 maxval=maxval, 

430 seed=seed, 

431 dtype=dtype) 

432 

433 weight_one = tf.Variable(initializer(shape_one)) 

434 weight_two = tf.Variable(initializer(shape_two)) 

435 ``` 

436 

437 After: 

438 

439 ```python 

440 initializer = tf.initializers.RandomUniform( 

441 minval=minval, 

442 maxval=maxval, 

443 seed=seed) 

444 

445 weight_one = tf.Variable(initializer(shape_one, dtype=dtype)) 

446 weight_two = tf.Variable(initializer(shape_two, dtype=dtype)) 

447 ``` 

448 

449 #### How to Map Arguments 

450 

451 | TF1 Arg Name | TF2 Arg Name | Note | 

452 | :-------------------- | :-------------- | :------------------------- | 

453 | `minval` | `minval` | Default changes from 0 to -0.05 | 

454 | `maxval` | `maxval` | Default changes from 1.0 to 0.05 | 

455 | `seed` | `seed` | | 

456 | `dtype` | `dtype` | The TF2 native api only takes it | 

457 : : : as a `__call__` arg, not a constructor arg. : 

458 | `partition_info` | - | (`__call__` arg in TF1) Not supported | 

459 

460 @end_compatibility 

461 """ 

462 

463 @deprecated_args(None, 

464 "Call initializer instance with the dtype argument instead " 

465 "of passing it to the constructor", "dtype") 

466 def __init__(self, minval=.0, maxval=None, seed=None, dtype=dtypes.float32): 

467 self.minval = minval 

468 self.maxval = maxval 

469 self.seed = seed 

470 self.dtype = dtypes.as_dtype(dtype) 

471 

472 def __call__(self, shape, dtype=None, partition_info=None): 

473 if dtype is None: 

474 dtype = self.dtype 

475 return random_ops.random_uniform( 

476 shape, self.minval, self.maxval, dtype, seed=self.seed) 

477 

478 def get_config(self): 

479 return { 

480 "minval": self.minval, 

481 "maxval": self.maxval, 

482 "seed": self.seed, 

483 "dtype": self.dtype.name 

484 } 

485 

486 

487@tf_export(v1=["initializers.random_normal", "random_normal_initializer"]) 

488@deprecation.deprecated_endpoints("initializers.random_normal") 

489class RandomNormal(Initializer): 

490 """Initializer that generates tensors with a normal distribution. 

491 

492 Args: 

493 mean: a python scalar or a scalar tensor. Mean of the random values to 

494 generate. 

495 stddev: a python scalar or a scalar tensor. Standard deviation of the random 

496 values to generate. 

497 seed: A Python integer. Used to create random seeds. See 

498 `tf.compat.v1.set_random_seed` for behavior. 

499 dtype: Default data type, used if no `dtype` argument is provided when 

500 calling the initializer. Only floating point types are supported. 

501 

502 @compatibility(TF2) 

503 Although it is a legacy `compat.v1` API, this symbol is compatible with eager 

504 execution and `tf.function`. 

505 

506 To switch to TF2, switch to using either 

507 `tf.initializers.RandomNormal` or `tf.keras.initializers.RandomNormal` 

508 (neither from `compat.v1`) and 

509 pass the dtype when calling the initializer. Keep in mind that 

510 the default stddev and the behavior of fixed seeds have changed. 

511 

512 #### Structural Mapping to TF2 

513 

514 Before: 

515 

516 ```python 

517 initializer = tf.compat.v1.random_normal_initializer( 

518 mean=mean, 

519 stddev=stddev, 

520 seed=seed, 

521 dtype=dtype) 

522 

523 weight_one = tf.Variable(initializer(shape_one)) 

524 weight_two = tf.Variable(initializer(shape_two)) 

525 ``` 

526 

527 After: 

528 

529 ```python 

530 initializer = tf.initializers.RandomNormal( 

531 mean=mean, 

532 seed=seed, 

533 stddev=stddev) 

534 

535 weight_one = tf.Variable(initializer(shape_one, dtype=dtype)) 

536 weight_two = tf.Variable(initializer(shape_two, dtype=dtype)) 

537 ``` 

538 

539 #### How to Map Arguments 

540 

541 | TF1 Arg Name | TF2 Arg Name | Note | 

542 | :----------------- | :-------------- | :------------------------- | 

543 | `mean` | `mean` | No change to defaults | 

544 | `stddev` | `stddev` | Default changes from 1.0 to 0.05 | 

545 | `seed` | `seed` | | 

546 | `dtype` | `dtype` | The TF2 native api only takes it as a | 

547 : : : `__call__` arg, not a constructor arg. : 

548 | `partition_info` | - | (`__call__` arg in TF1) Not supported. | 

549 

550 @end_compatibility 

551 """ 

552 

553 @deprecated_args(None, 

554 "Call initializer instance with the dtype argument instead " 

555 "of passing it to the constructor", "dtype") 

556 def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32): 

557 self.mean = mean 

558 self.stddev = stddev 

559 self.seed = seed 

560 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

561 

562 def __call__(self, shape, dtype=None, partition_info=None): 

563 if dtype is None: 

564 dtype = self.dtype 

565 return random_ops.random_normal( 

566 shape, self.mean, self.stddev, dtype, seed=self.seed) 

567 

568 def get_config(self): 

569 return { 

570 "mean": self.mean, 

571 "stddev": self.stddev, 

572 "seed": self.seed, 

573 "dtype": self.dtype.name 

574 } 

575 

576 

577@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"]) 

578@deprecation.deprecated_endpoints("initializers.truncated_normal", 

579 "truncated_normal_initializer") 

580class TruncatedNormal(Initializer): 

581 """Initializer that generates a truncated normal distribution. 

582 

583 These values are similar to values from a `random_normal_initializer` 

584 except that values more than two standard deviations from the mean 

585 are discarded and re-drawn. This is the recommended initializer for 

586 neural network weights and filters. 

587 

588 Args: 

589 mean: a python scalar or a scalar tensor. Mean of the random values to 

590 generate. 

591 stddev: a python scalar or a scalar tensor. Standard deviation of the random 

592 values to generate. 

593 seed: A Python integer. Used to create random seeds. See 

594 `tf.compat.v1.set_random_seed` for behavior. 

595 dtype: Default data type, used if no `dtype` argument is provided when 

596 calling the initializer. Only floating point types are supported. 

597 

598 @compatibility(TF2) 

599 Although it is a legacy `compat.v1` API, this symbol is compatible with eager 

600 execution and `tf.function`. 

601 

602 To switch to TF2, switch to using either 

603 `tf.initializers.truncated_normal` or `tf.keras.initializers.TruncatedNormal` 

604 (neither from `compat.v1`) and 

605 pass the dtype when calling the initializer. Keep in mind that 

606 the default stddev and the behavior of fixed seeds have changed. 

607 

608 #### Structural Mapping to TF2 

609 

610 Before: 

611 

612 ```python 

613 initializer = tf.compat.v1.truncated_normal_initializer( 

614 mean=mean, 

615 stddev=stddev, 

616 seed=seed, 

617 dtype=dtype) 

618 

619 weight_one = tf.Variable(initializer(shape_one)) 

620 weight_two = tf.Variable(initializer(shape_two)) 

621 ``` 

622 

623 After: 

624 

625 ```python 

626 initializer = tf.initializers.truncated_normal( 

627 mean=mean, 

628 seed=seed, 

629 stddev=stddev) 

630 

631 weight_one = tf.Variable(initializer(shape_one, dtype=dtype)) 

632 weight_two = tf.Variable(initializer(shape_two, dtype=dtype)) 

633 ``` 

634 

635 #### How to Map Arguments 

636 

637 | TF1 Arg Name | TF2 Arg Name | Note | 

638 | :-------------------- | :-------------- | :------------------------- | 

639 | `mean` | `mean` | No change to defaults | 

640 | `stddev` | `stddev` | Default changes from 1.0 to 0.05 | 

641 | `seed` | `seed` | | 

642 | `dtype` | `dtype` | The TF2 native api only takes it | 

643 : : : as a `__call__` arg, not a constructor arg. : 

644 | `partition_info` | - | (`__call__` arg in TF1) Not supported | 

645 

646 @end_compatibility 

647 """ 

648 

649 @deprecated_args(None, 

650 "Call initializer instance with the dtype argument instead " 

651 "of passing it to the constructor", "dtype") 

652 def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32): 

653 self.mean = mean 

654 self.stddev = stddev 

655 self.seed = seed 

656 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

657 

658 def __call__(self, shape, dtype=None, partition_info=None): 

659 if dtype is None: 

660 dtype = self.dtype 

661 return random_ops.truncated_normal( 

662 shape, self.mean, self.stddev, dtype, seed=self.seed) 

663 

664 def get_config(self): 

665 return { 

666 "mean": self.mean, 

667 "stddev": self.stddev, 

668 "seed": self.seed, 

669 "dtype": self.dtype.name 

670 } 

671 

672 

673@tf_export(v1=[ 

674 "initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer" 

675]) 

676@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer", 

677 "initializers.uniform_unit_scaling") 

678class UniformUnitScaling(Initializer): 

679 """Initializer that generates tensors without scaling variance. 

680 

681 When initializing a deep network, it is in principle advantageous to keep 

682 the scale of the input variance constant, so it does not explode or diminish 

683 by reaching the final layer. If the input is `x` and the operation `x * W`, 

684 and we want to initialize `W` uniformly at random, we need to pick `W` from 

685 

686 [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)] 

687 

688 to keep the scale intact, where `dim = W.shape[0]` (the size of the input). 

689 A similar calculation for convolutional networks gives an analogous result 

690 with `dim` equal to the product of the first 3 dimensions. When 

691 nonlinearities are present, we need to multiply this by a constant `factor`. 

692 See (Sussillo et al., 2014) for deeper motivation, experiments 

693 and the calculation of constants. In section 2.3 there, the constants were 

694 numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15. 

695 

696 Args: 

697 factor: Float. A multiplicative factor by which the values will be scaled. 

698 seed: A Python integer. Used to create random seeds. See 

699 `tf.compat.v1.set_random_seed` for behavior. 

700 dtype: Default data type, used if no `dtype` argument is provided when 

701 calling the initializer. Only floating point types are supported. 

702 References: 

703 [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558) 

704 ([pdf](http://arxiv.org/pdf/1412.6558.pdf)) 

705 """ 

706 

707 @deprecated_args(None, 

708 "Call initializer instance with the dtype argument instead " 

709 "of passing it to the constructor", "dtype") 

710 @deprecated(None, 

711 "Use tf.initializers.variance_scaling instead with distribution=" 

712 "uniform to get equivalent behavior.") 

713 def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32): 

714 self.factor = factor 

715 self.seed = seed 

716 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

717 

718 def __call__(self, shape, dtype=None, partition_info=None): 

719 if dtype is None: 

720 dtype = self.dtype 

721 scale_shape = shape 

722 if partition_info is not None: 

723 scale_shape = partition_info.full_shape 

724 

725 input_size = 1.0 

726 # Estimating input size is not possible to do perfectly, but we try. 

727 # The estimate, obtained by multiplying all dimensions but the last one, 

728 # is the right thing for matrix multiply and convolutions (see above). 

729 for dim in scale_shape[:-1]: 

730 input_size *= float(dim) 

731 # Avoid errors when initializing zero-size tensors. 

732 input_size = max(input_size, 1.0) 

733 max_val = math.sqrt(3 / input_size) * self.factor 

734 return random_ops.random_uniform( 

735 shape, -max_val, max_val, dtype, seed=self.seed) 

736 

737 def get_config(self): 

738 return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name} 

739 

740 

741@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"]) 

742@deprecation.deprecated_endpoints("initializers.variance_scaling", 

743 "variance_scaling_initializer") 

744class VarianceScaling(Initializer): 

745 """Initializer capable of adapting its scale to the shape of weights tensors. 

746 

747 @compatibility(TF2) 

748 Although it is a legacy `compat.v1` API, this symbol is compatible with eager 

749 execution and `tf.function`. 

750 

751 To switch to TF2 APIs, move to using either 

752 `tf.initializers.variance_scaling` or `tf.keras.initializers.VarianceScaling` 

753 (neither from `compat.v1`) and 

754 pass the dtype when calling the initializer. 

755 

756 #### Structural Mapping to TF2 

757 

758 Before: 

759 

760 ```python 

761 initializer = tf.compat.v1.variance_scaling_initializer( 

762 scale=scale, 

763 mode=mode, 

764 distribution=distribution 

765 seed=seed, 

766 dtype=dtype) 

767 

768 weight_one = tf.Variable(initializer(shape_one)) 

769 weight_two = tf.Variable(initializer(shape_two)) 

770 ``` 

771 

772 After: 

773 

774 ```python 

775 initializer = tf.keras.initializers.VarianceScaling( 

776 scale=scale, 

777 mode=mode, 

778 distribution=distribution 

779 seed=seed) 

780 

781 weight_one = tf.Variable(initializer(shape_one, dtype=dtype)) 

782 weight_two = tf.Variable(initializer(shape_two, dtype=dtype)) 

783 ``` 

784 

785 #### How to Map Arguments 

786 

787 | TF1 Arg Name | TF2 Arg Name | Note | 

788 | :----------------- | :-------------- | :------------------------- | 

789 | `scale` | `scale` | No change to defaults | 

790 | `mode` | `mode` | No change to defaults | 

791 | `distribution` | `distribution` | No change to defaults. | 

792 : : : 'normal' maps to 'truncated_normal' : 

793 | `seed` | `seed` | | 

794 | `dtype` | `dtype` | The TF2 api only takes it | 

795 : : : as a `__call__` arg, not a constructor arg. : 

796 | `partition_info` | - | (`__call__` arg in TF1) Not supported | 

797 

798 @end_compatibility 

799 

800 With `distribution="truncated_normal" or "untruncated_normal"`, 

801 samples are drawn from a truncated/untruncated normal 

802 distribution with a mean of zero and a standard deviation (after truncation, 

803 if used) `stddev = sqrt(scale / n)` 

804 where n is: 

805 - number of input units in the weight tensor, if mode = "fan_in" 

806 - number of output units, if mode = "fan_out" 

807 - average of the numbers of input and output units, if mode = "fan_avg" 

808 

809 With `distribution="uniform"`, samples are drawn from a uniform distribution 

810 within [-limit, limit], with `limit = sqrt(3 * scale / n)`. 

811 

812 Args: 

813 scale: Scaling factor (positive float). 

814 mode: One of "fan_in", "fan_out", "fan_avg". 

815 distribution: Random distribution to use. One of "normal", "uniform". 

816 seed: A Python integer. Used to create random seeds. See 

817 `tf.compat.v1.set_random_seed` for behavior. 

818 dtype: Default data type, used if no `dtype` argument is provided when 

819 calling the initializer. Only floating point types are supported. 

820 

821 Raises: 

822 ValueError: In case of an invalid value for the "scale", mode" or 

823 "distribution" arguments. 

824 """ 

825 

826 @deprecated_args(None, 

827 "Call initializer instance with the dtype argument instead " 

828 "of passing it to the constructor", "dtype") 

829 @deprecated_arg_values( 

830 None, 

831 "`normal` is a deprecated alias for `truncated_normal`", 

832 distribution="normal") 

833 def __init__(self, 

834 scale=1.0, 

835 mode="fan_in", 

836 distribution="truncated_normal", 

837 seed=None, 

838 dtype=dtypes.float32): 

839 if scale <= 0.: 

840 raise ValueError("Argument `scale` must be a positive float. Received: " 

841 f"{scale}") 

842 if mode not in {"fan_in", "fan_out", "fan_avg"}: 

843 raise ValueError("Argument `mode` should be one of ('fan_in', 'fan_out', " 

844 f"'fan_avg'). Received: {mode}") 

845 distribution = distribution.lower() 

846 if distribution not in { 

847 "normal", "uniform", "truncated_normal", "untruncated_normal" 

848 }: 

849 raise ValueError("Argument `distribution` should be one of ('normal', " 

850 "uniform', 'truncated_normal', 'untruncated_normal'). " 

851 f"Received: {distribution}") 

852 self.scale = scale 

853 self.mode = mode 

854 self.distribution = distribution 

855 self.seed = seed 

856 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

857 

858 def __call__(self, shape, dtype=None, partition_info=None): 

859 if dtype is None: 

860 dtype = self.dtype 

861 scale = self.scale 

862 scale_shape = shape 

863 if partition_info is not None: 

864 scale_shape = partition_info.full_shape 

865 fan_in, fan_out = _compute_fans(scale_shape) 

866 if self.mode == "fan_in": 

867 scale /= max(1., fan_in) 

868 elif self.mode == "fan_out": 

869 scale /= max(1., fan_out) 

870 else: 

871 scale /= max(1., (fan_in + fan_out) / 2.) 

872 if self.distribution == "normal" or self.distribution == "truncated_normal": 

873 # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) 

874 stddev = math.sqrt(scale) / .87962566103423978 

875 return random_ops.truncated_normal( 

876 shape, 0.0, stddev, dtype, seed=self.seed) 

877 elif self.distribution == "untruncated_normal": 

878 stddev = math.sqrt(scale) 

879 return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed) 

880 else: 

881 limit = math.sqrt(3.0 * scale) 

882 return random_ops.random_uniform( 

883 shape, -limit, limit, dtype, seed=self.seed) 

884 

885 def get_config(self): 

886 return { 

887 "scale": self.scale, 

888 "mode": self.mode, 

889 "distribution": self.distribution, 

890 "seed": self.seed, 

891 "dtype": self.dtype.name 

892 } 

893 

894 

895@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"]) 

896@deprecation.deprecated_endpoints("initializers.orthogonal", 

897 "orthogonal_initializer") 

898class Orthogonal(Initializer): 

899 """Initializer that generates an orthogonal matrix. 

900 

901 If the shape of the tensor to initialize is two-dimensional, it is initialized 

902 with an orthogonal matrix obtained from the QR decomposition of a matrix of 

903 random numbers drawn from a normal distribution. 

904 If the matrix has fewer rows than columns then the output will have orthogonal 

905 rows. Otherwise, the output will have orthogonal columns. 

906 

907 If the shape of the tensor to initialize is more than two-dimensional, 

908 a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])` 

909 is initialized, where `n` is the length of the shape vector. 

910 The matrix is subsequently reshaped to give a tensor of the desired shape. 

911 

912 Args: 

913 gain: multiplicative factor to apply to the orthogonal matrix 

914 seed: A Python integer. Used to create random seeds. See 

915 `tf.compat.v1.set_random_seed` for behavior. 

916 dtype: Default data type, used if no `dtype` argument is provided when 

917 calling the initializer. Only floating point types are supported. 

918 References: 

919 [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C) 

920 ([pdf](https://arxiv.org/pdf/1312.6120.pdf)) 

921 """ 

922 

923 @deprecated_args(None, 

924 "Call initializer instance with the dtype argument instead " 

925 "of passing it to the constructor", "dtype") 

926 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 

927 self.gain = gain 

928 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

929 self.seed = seed 

930 

931 def __call__(self, shape, dtype=None, partition_info=None): 

932 if dtype is None: 

933 dtype = self.dtype 

934 # Check the shape 

935 if len(shape) < 2: 

936 raise ValueError("The tensor to initialize, specified by argument `shape`" 

937 " must be at least two-dimensional. Received shape=" 

938 f"{shape}") 

939 # Flatten the input shape with the last dimension remaining 

940 # its original shape so it works for conv2d 

941 num_rows = 1 

942 for dim in shape[:-1]: 

943 num_rows *= dim 

944 num_rows = int(num_rows) 

945 num_cols = int(shape[-1]) 

946 if num_rows < num_cols: 

947 flat_shape = (num_cols, num_rows) 

948 else: 

949 flat_shape = (num_rows, num_cols) 

950 

951 # Generate a random matrix 

952 a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed) 

953 # Compute the qr factorization 

954 q, r = gen_linalg_ops.qr(a, full_matrices=False) 

955 # Make Q uniform 

956 d = array_ops.diag_part(r) 

957 q *= math_ops.sign(d) 

958 if num_rows < num_cols: 

959 q = array_ops.matrix_transpose(q) 

960 return self.gain * array_ops.reshape(q, shape) 

961 

962 def get_config(self): 

963 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 

964 

965 

966# Note these haven't been ported to TF2.0. They are not currently visible and 

967# the tests are non trivial to port 

968class ConvolutionDeltaOrthogonal(Initializer): 

969 """Initializer that generates a delta orthogonal kernel for ConvNets. 

970 

971 The shape of the tensor must have length 3, 4 or 5. The number of input 

972 filters must not exceed the number of output filters. The center pixels of the 

973 tensor form an orthogonal matrix. Other pixels are set to be zero. See 

974 algorithm 2 in (Xiao et al., 2018). 

975 

976 

977 Args: 

978 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 

979 The 2-norm of an input is multiplied by a factor of `gain` after applying 

980 this convolution. 

981 seed: A Python integer. Used to create random seeds. See 

982 `tf.compat.v1.set_random_seed` for behavior. 

983 dtype: Default data type, used if no `dtype` argument is provided when 

984 calling the initializer. Only floating point types are supported. 

985 References: 

986 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 

987 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 

988 """ 

989 

990 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 

991 self.gain = gain 

992 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

993 self.seed = seed 

994 

995 def __call__(self, shape, dtype=None, partition_info=None): 

996 if dtype is None: 

997 dtype = self.dtype 

998 # Check the shape 

999 if len(shape) < 3 or len(shape) > 5: 

1000 raise ValueError("The tensor to initialize, specified by argument `shape`" 

1001 " must be at least three-dimensional and at most " 

1002 f"five-dimensional. Received shape={shape}") 

1003 

1004 if shape[-2] > shape[-1]: 

1005 raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot " 

1006 "be greater than out_filters, specified by " 

1007 f"shape[-1]={shape[-1]}.") 

1008 

1009 # Generate a random matrix 

1010 a = random_ops.random_normal([shape[-1], shape[-1]], 

1011 dtype=dtype, 

1012 seed=self.seed) 

1013 # Compute the qr factorization 

1014 q, r = gen_linalg_ops.qr(a, full_matrices=False) 

1015 # Make Q uniform 

1016 d = array_ops.diag_part(r) 

1017 q *= math_ops.sign(d) 

1018 q = q[:shape[-2], :] 

1019 q *= math_ops.cast(self.gain, dtype=dtype) 

1020 if len(shape) == 3: 

1021 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2]], 

1022 array_ops.expand_dims(q, 0), shape) 

1023 elif len(shape) == 4: 

1024 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, 

1025 (shape[1] - 1) // 2]], 

1026 array_ops.expand_dims(q, 0), shape) 

1027 else: 

1028 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, (shape[1] - 1) // 2, 

1029 (shape[2] - 1) // 2]], 

1030 array_ops.expand_dims(q, 0), shape) 

1031 return weight 

1032 

1033 def get_config(self): 

1034 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 

1035 

1036 

1037class ConvolutionOrthogonal(Initializer): 

1038 """Initializer that generates orthogonal kernel for ConvNets. 

1039 

1040 Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution. 

1041 

1042 Args: 

1043 gain: multiplicative factor to apply to the orthogonal matrix. Default is 1. 

1044 The 2-norm of an input is multiplied by a factor of `gain` after applying 

1045 this convolution. 

1046 seed: A Python integer. Used to create random seeds. See 

1047 `tf.compat.v1.set_random_seed` for behavior. 

1048 dtype: Default data type, used if no `dtype` argument is provided when 

1049 calling the initializer. Only floating point types are supported. 

1050 References: 

1051 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 

1052 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 

1053 """ 

1054 

1055 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 

1056 self.gain = gain 

1057 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

1058 self.seed = seed 

1059 

1060 def __call__(self, shape, dtype=None, partition_info=None): 

1061 raise NotImplementedError 

1062 

1063 def get_config(self): 

1064 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 

1065 

1066 # Helper functions. 

1067 def _orthogonal_matrix(self, n): 

1068 """Construct an n x n orthogonal matrix. 

1069 

1070 Args: 

1071 n: Dimension. 

1072 

1073 Returns: 

1074 A n x n orthogonal matrix. 

1075 """ 

1076 a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed) 

1077 if self.seed: 

1078 self.seed += 1 

1079 q, r = gen_linalg_ops.qr(a) 

1080 d = array_ops.diag_part(r) 

1081 # make q uniform 

1082 q *= math_ops.sign(d) 

1083 return q 

1084 

1085 def _symmetric_projection(self, n): 

1086 """Compute a n x n symmetric projection matrix. 

1087 

1088 Args: 

1089 n: Dimension. 

1090 

1091 Returns: 

1092 A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T. 

1093 """ 

1094 q = self._orthogonal_matrix(n) 

1095 # randomly zeroing out some columns 

1096 mask = math_ops.cast( 

1097 random_ops.random_normal([n], seed=self.seed) > 0, self.dtype) 

1098 if self.seed: 

1099 self.seed += 1 

1100 c = math_ops.multiply(q, mask) 

1101 return math_ops.matmul(c, array_ops.matrix_transpose(c)) 

1102 

1103 

1104class ConvolutionOrthogonal2D(ConvolutionOrthogonal): 

1105 """Initializer that generates a 2D orthogonal kernel for ConvNets. 

1106 

1107 The shape of the tensor must have length 4. The number of input 

1108 filters must not exceed the number of output filters. 

1109 The orthogonality(==isometry) is exact when the inputs are circular padded. 

1110 There are finite-width effects with non-circular padding (e.g. zero padding). 

1111 See algorithm 1 in (Xiao et al., 2018). 

1112 

1113 Args: 

1114 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 

1115 This has the effect of scaling the output 2-norm by a factor of `gain`. 

1116 seed: A Python integer. Used to create random seeds. See 

1117 `tf.compat.v1.set_random_seed` for behavior. 

1118 dtype: Default data type, used if no `dtype` argument is provided when 

1119 calling the initializer. Only floating point types are supported. 

1120 References: 

1121 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 

1122 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 

1123 """ 

1124 

1125 def __call__(self, shape, dtype=None, partition_info=None): 

1126 if dtype is None: 

1127 dtype = self.dtype 

1128 if len(shape) != 4: 

1129 raise ValueError("The tensor to initialize, specified by argument `shape`" 

1130 f" must be four-dimensional. Received: {shape}") 

1131 

1132 if shape[-2] > shape[-1]: 

1133 raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot " 

1134 "be greater than out_filters, specified by " 

1135 f"shape[-1]={shape[-1]}.") 

1136 

1137 if shape[0] != shape[1]: 

1138 raise ValueError(f"Kernel sizes, specified by shape[0]={shape[0]} and " 

1139 f"shape[1]={shape[1]} must be equal.") 

1140 

1141 kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3]) 

1142 kernel *= math_ops.cast(self.gain, dtype=dtype) 

1143 return kernel 

1144 

1145 def _dict_to_tensor(self, x, k1, k2): 

1146 """Convert a dictionary to a tensor. 

1147 

1148 Args: 

1149 x: A k1 * k2 dictionary. 

1150 k1: First dimension of x. 

1151 k2: Second dimension of x. 

1152 

1153 Returns: 

1154 A k1 * k2 tensor. 

1155 """ 

1156 

1157 return array_ops_stack.stack([ 

1158 array_ops_stack.stack([x[i, j] for j in range(k2)]) for i in range(k1)]) 

1159 

1160 def _block_orth(self, p1, p2): 

1161 """Construct a 2 x 2 kernel. 

1162 

1163 Used to construct orthgonal kernel. 

1164 

1165 Args: 

1166 p1: A symmetric projection matrix. 

1167 p2: A symmetric projection matrix. 

1168 

1169 Returns: 

1170 A 2 x 2 kernel [[p1p2, p1(1-p2)], 

1171 [(1-p1)p2, (1-p1)(1-p2)]]. 

1172 Raises: 

1173 ValueError: If the dimensions of p1 and p2 are different. 

1174 """ 

1175 if p1.shape.as_list() != p2.shape.as_list(): 

1176 raise ValueError("The dimension of the matrices must be the same. " 

1177 f"Received p1.shape={p1.shape} and p2.shape={p2.shape}.") 

1178 n = p1.shape.as_list()[0] 

1179 kernel2x2 = {} 

1180 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 

1181 kernel2x2[0, 0] = math_ops.matmul(p1, p2) 

1182 kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2)) 

1183 kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2) 

1184 kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2)) 

1185 

1186 return kernel2x2 

1187 

1188 def _matrix_conv(self, m1, m2): 

1189 """Matrix convolution. 

1190 

1191 Args: 

1192 m1: A k x k dictionary, each element is a n x n matrix. 

1193 m2: A l x l dictionary, each element is a n x n matrix. 

1194 

1195 Returns: 

1196 (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix. 

1197 Raises: 

1198 ValueError: if the entries of m1 and m2 are of different dimensions. 

1199 """ 

1200 

1201 n = (m1[0, 0]).shape.as_list()[0] 

1202 if n != (m2[0, 0]).shape.as_list()[0]: 

1203 raise ValueError("The entries in matrices m1 and m2 must have the same " 

1204 f"dimensions. Received m1[0, 0].shape={m1[0, 0].shape} " 

1205 f"and m2[0, 0].shape={m2[0, 0].shape}.") 

1206 k = int(np.sqrt(len(m1))) 

1207 l = int(np.sqrt(len(m2))) 

1208 result = {} 

1209 size = k + l - 1 

1210 # Compute matrix convolution between m1 and m2. 

1211 for i in range(size): 

1212 for j in range(size): 

1213 result[i, j] = array_ops.zeros([n, n], self.dtype) 

1214 for index1 in range(min(k, i + 1)): 

1215 for index2 in range(min(k, j + 1)): 

1216 if (i - index1) < l and (j - index2) < l: 

1217 result[i, j] += math_ops.matmul(m1[index1, index2], 

1218 m2[i - index1, j - index2]) 

1219 return result 

1220 

1221 def _orthogonal_kernel(self, ksize, cin, cout): 

1222 """Construct orthogonal kernel for convolution. 

1223 

1224 Args: 

1225 ksize: Kernel size. 

1226 cin: Number of input channels. 

1227 cout: Number of output channels. 

1228 

1229 Returns: 

1230 An [ksize, ksize, cin, cout] orthogonal kernel. 

1231 Raises: 

1232 ValueError: If cin > cout. 

1233 """ 

1234 if cin > cout: 

1235 raise ValueError(f"The number of input channels (cin={cin}) cannot exceed" 

1236 f" the number of output channels (cout={cout}).") 

1237 orth = self._orthogonal_matrix(cout)[0:cin, :] 

1238 if ksize == 1: 

1239 return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0) 

1240 

1241 p = self._block_orth( 

1242 self._symmetric_projection(cout), self._symmetric_projection(cout)) 

1243 for _ in range(ksize - 2): 

1244 temp = self._block_orth( 

1245 self._symmetric_projection(cout), self._symmetric_projection(cout)) 

1246 p = self._matrix_conv(p, temp) 

1247 for i in range(ksize): 

1248 for j in range(ksize): 

1249 p[i, j] = math_ops.matmul(orth, p[i, j]) 

1250 

1251 return self._dict_to_tensor(p, ksize, ksize) 

1252 

1253 

1254class ConvolutionOrthogonal1D(ConvolutionOrthogonal): 

1255 """Initializer that generates a 1D orthogonal kernel for ConvNets. 

1256 

1257 The shape of the tensor must have length 3. The number of input 

1258 filters must not exceed the number of output filters. 

1259 The orthogonality(==isometry) is exact when the inputs are circular padded. 

1260 There are finite-width effects with non-circular padding (e.g. zero padding). 

1261 See algorithm 1 in (Xiao et al., 2018). 

1262 

1263 Args: 

1264 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 

1265 The 2-norm of an input is multiplied by a factor of `gain` after applying 

1266 this convolution. 

1267 seed: A Python integer. Used to create random seeds. See 

1268 `tf.compat.v1.set_random_seed` for behavior. 

1269 dtype: Default data type, used if no `dtype` argument is provided when 

1270 calling the initializer. Only floating point types are supported. 

1271 References: 

1272 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 

1273 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 

1274 """ 

1275 

1276 def __call__(self, shape, dtype=None, partition_info=None): 

1277 if dtype is None: 

1278 dtype = self.dtype 

1279 if len(shape) != 3: 

1280 raise ValueError("The tensor to initialize, specified by argument `shape`" 

1281 f" must be three-dimensional. Received shape={shape}") 

1282 

1283 if shape[-2] > shape[-1]: 

1284 raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot " 

1285 "be greater than out_filters, specified by " 

1286 f"shape[-1]={shape[-1]}.") 

1287 

1288 kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1]) 

1289 kernel *= math_ops.cast(self.gain, dtype=dtype) 

1290 return kernel 

1291 

1292 def _dict_to_tensor(self, x, k): 

1293 """Convert a dictionary to a tensor. 

1294 

1295 Args: 

1296 x: A dictionary of length k. 

1297 k: Dimension of x. 

1298 

1299 Returns: 

1300 A tensor with the same dimension. 

1301 """ 

1302 

1303 return array_ops_stack.stack([x[i] for i in range(k)]) 

1304 

1305 def _block_orth(self, projection_matrix): 

1306 """Construct a kernel. 

1307 

1308 Used to construct orthgonal kernel. 

1309 

1310 Args: 

1311 projection_matrix: A symmetric projection matrix of size n x n. 

1312 

1313 Returns: 

1314 [projection_matrix, (1 - projection_matrix)]. 

1315 """ 

1316 n = projection_matrix.shape.as_list()[0] 

1317 kernel = {} 

1318 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 

1319 kernel[0] = projection_matrix 

1320 kernel[1] = eye - projection_matrix 

1321 return kernel 

1322 

1323 def _matrix_conv(self, m1, m2): 

1324 """Matrix convolution. 

1325 

1326 Args: 

1327 m1: A dictionary of length k, each element is a n x n matrix. 

1328 m2: A dictionary of length l, each element is a n x n matrix. 

1329 

1330 Returns: 

1331 (k + l - 1) dictionary each element is a n x n matrix. 

1332 Raises: 

1333 ValueError: Ff the entries of m1 and m2 are of different dimensions. 

1334 """ 

1335 

1336 n = (m1[0]).shape.as_list()[0] 

1337 if n != (m2[0]).shape.as_list()[0]: 

1338 raise ValueError("The entries in matrices m1 and m2 must have the same " 

1339 f"dimensions. Received m1[0].shape={m1[0].shape} " 

1340 f"and m2[0].shape={m2[0].shape}.") 

1341 k = len(m1) 

1342 l = len(m2) 

1343 result = {} 

1344 size = k + l - 1 

1345 # Compute matrix convolution between m1 and m2. 

1346 for i in range(size): 

1347 result[i] = array_ops.zeros([n, n], self.dtype) 

1348 for index in range(min(k, i + 1)): 

1349 if (i - index) < l: 

1350 result[i] += math_ops.matmul(m1[index], m2[i - index]) 

1351 return result 

1352 

1353 def _orthogonal_kernel(self, ksize, cin, cout): 

1354 """Construct orthogonal kernel for convolution. 

1355 

1356 Args: 

1357 ksize: Kernel size. 

1358 cin: Number of input channels. 

1359 cout: Number of output channels. 

1360 

1361 Returns: 

1362 An [ksize, ksize, cin, cout] orthogonal kernel. 

1363 Raises: 

1364 ValueError: If cin > cout. 

1365 """ 

1366 if cin > cout: 

1367 raise ValueError(f"The number of input channels (cin={cin}) cannot exceed" 

1368 f" the number of output channels (cout={cout}).") 

1369 orth = self._orthogonal_matrix(cout)[0:cin, :] 

1370 if ksize == 1: 

1371 return array_ops.expand_dims(orth, 0) 

1372 

1373 p = self._block_orth(self._symmetric_projection(cout)) 

1374 for _ in range(ksize - 2): 

1375 temp = self._block_orth(self._symmetric_projection(cout)) 

1376 p = self._matrix_conv(p, temp) 

1377 for i in range(ksize): 

1378 p[i] = math_ops.matmul(orth, p[i]) 

1379 

1380 return self._dict_to_tensor(p, ksize) 

1381 

1382 

1383class ConvolutionOrthogonal3D(ConvolutionOrthogonal): 

1384 """Initializer that generates a 3D orthogonal kernel for ConvNets. 

1385 

1386 The shape of the tensor must have length 5. The number of input 

1387 filters must not exceed the number of output filters. 

1388 The orthogonality(==isometry) is exact when the inputs are circular padded. 

1389 There are finite-width effects with non-circular padding (e.g. zero padding). 

1390 See algorithm 1 (Xiao et al., 2018). 

1391 

1392 Args: 

1393 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 

1394 The 2-norm of an input is multiplied by a factor of `gain` after applying 

1395 this convolution. 

1396 seed: A Python integer. Used to create random seeds. See 

1397 `tf.compat.v1.set_random_seed` for behavior. 

1398 dtype: Default data type, used if no `dtype` argument is provided when 

1399 calling the initializer. Only floating point types are supported. 

1400 References: 

1401 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 

1402 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 

1403 """ 

1404 

1405 def __call__(self, shape, dtype=None, partition_info=None): 

1406 if dtype is None: 

1407 dtype = self.dtype 

1408 if len(shape) != 5: 

1409 raise ValueError("The tensor to initialize, specified by argument `shape`" 

1410 f" must be five-dimensional. Received shape={shape}") 

1411 

1412 if shape[-2] > shape[-1]: 

1413 raise ValueError(f"In_filters, specified by shape[-2]={shape[-2]} cannot " 

1414 "be greater than out_filters, specified by " 

1415 f"shape[-1]={shape[-1]}.") 

1416 

1417 if shape[0] != shape[1] or shape[0] != shape[2]: 

1418 raise ValueError(f"Kernel sizes, specified by shape[0]={shape[0]}, " 

1419 f"shape[1]={shape[1]} and shape[2]={shape[2]} must be " 

1420 "equal.") 

1421 

1422 kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1]) 

1423 kernel *= math_ops.cast(self.gain, dtype=dtype) 

1424 return kernel 

1425 

1426 def _dict_to_tensor(self, x, k1, k2, k3): 

1427 """Convert a dictionary to a tensor. 

1428 

1429 Args: 

1430 x: A k1 * k2 dictionary. 

1431 k1: First dimension of x. 

1432 k2: Second dimension of x. 

1433 k3: Third dimension of x. 

1434 

1435 Returns: 

1436 A k1 * k2 * k3 tensor. 

1437 """ 

1438 

1439 return array_ops_stack.stack([array_ops_stack.stack( 

1440 [array_ops_stack.stack([x[i, j, k] for k in range(k3)]) 

1441 for j in range(k2)]) for i in range(k1)]) 

1442 

1443 def _block_orth(self, p1, p2, p3): 

1444 """Construct a 3 x 3 kernel. 

1445 

1446 Used to construct orthgonal kernel. 

1447 

1448 Args: 

1449 p1: A symmetric projection matrix. 

1450 p2: A symmetric projection matrix. 

1451 p3: A symmetric projection matrix. 

1452 

1453 Returns: 

1454 A 2 x 2 x 2 kernel. 

1455 Raises: 

1456 ValueError: If the dimensions of p1, p2 and p3 are different. 

1457 """ 

1458 p1_shape = p1.shape.as_list() 

1459 if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list(): 

1460 raise ValueError("The dimension of the matrices must be the same. " 

1461 f"Received p1.shape={p1.shape}, p2.shape={p2.shape} and" 

1462 f" p3.shape={p3.shape}.") 

1463 n = p1_shape[0] 

1464 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 

1465 kernel2x2x2 = {} 

1466 

1467 def matmul(p1, p2, p3): 

1468 return math_ops.matmul(math_ops.matmul(p1, p2), p3) 

1469 

1470 def cast(i, p): 

1471 """Return p or (1-p).""" 

1472 return i * p + (1 - i) * (eye - p) 

1473 

1474 for i in [0, 1]: 

1475 for j in [0, 1]: 

1476 for k in [0, 1]: 

1477 kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3)) 

1478 return kernel2x2x2 

1479 

1480 def _matrix_conv(self, m1, m2): 

1481 """Matrix convolution. 

1482 

1483 Args: 

1484 m1: is a k x k x k dictionary, each element is a n x n matrix. 

1485 m2: is a l x l x l dictionary, each element is a n x n matrix. 

1486 

1487 Returns: 

1488 (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each 

1489 element is a n x n matrix. 

1490 Raises: 

1491 ValueError: if the entries of m1 and m2 are of different dimensions. 

1492 """ 

1493 

1494 n = (m1[0, 0, 0]).shape.as_list()[0] 

1495 if n != (m2[0, 0, 0]).shape.as_list()[0]: 

1496 raise ValueError("The entries in matrices m1 and m2 must have the same " 

1497 "dimensions. Received m1[0, 0, 0].shape=" 

1498 f"{m1[0, 0, 0].shape} and m2[0, 0, 0].shape=" 

1499 f"{m2[0, 0, 0].shape}.") 

1500 k = int(np.cbrt(len(m1))) 

1501 l = int(np.cbrt(len(m2))) 

1502 result = {} 

1503 size = k + l - 1 

1504 # Compute matrix convolution between m1 and m2. 

1505 for i in range(size): 

1506 for j in range(size): 

1507 for r in range(size): 

1508 result[i, j, r] = array_ops.zeros([n, n], self.dtype) 

1509 for index1 in range(min(k, i + 1)): 

1510 for index2 in range(min(k, j + 1)): 

1511 for index3 in range(min(k, r + 1)): 

1512 if (i - index1) < l and (j - index2) < l and (r - index3) < l: 

1513 result[i, j, r] += math_ops.matmul( 

1514 m1[index1, index2, index3], 

1515 m2[i - index1, j - index2, r - index3]) 

1516 return result 

1517 

1518 def _orthogonal_kernel(self, ksize, cin, cout): 

1519 """Construct orthogonal kernel for convolution. 

1520 

1521 Args: 

1522 ksize: Kernel size. 

1523 cin: Number of input channels. 

1524 cout: Number of output channels. 

1525 

1526 Returns: 

1527 An [ksize, ksize, ksize, cin, cout] orthogonal kernel. 

1528 Raises: 

1529 ValueError: If cin > cout. 

1530 """ 

1531 if cin > cout: 

1532 raise ValueError(f"The number of input channels (cin={cin}) cannot exceed" 

1533 f" the number of output channels (cout={cout}).") 

1534 orth = self._orthogonal_matrix(cout)[0:cin, :] 

1535 if ksize == 1: 

1536 return array_ops.expand_dims( 

1537 array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0), 0) 

1538 

1539 p = self._block_orth( 

1540 self._symmetric_projection(cout), self._symmetric_projection(cout), 

1541 self._symmetric_projection(cout)) 

1542 for _ in range(ksize - 2): 

1543 temp = self._block_orth( 

1544 self._symmetric_projection(cout), self._symmetric_projection(cout), 

1545 self._symmetric_projection(cout)) 

1546 p = self._matrix_conv(p, temp) 

1547 for i in range(ksize): 

1548 for j in range(ksize): 

1549 for k in range(ksize): 

1550 p[i, j, k] = math_ops.matmul(orth, p[i, j, k]) 

1551 

1552 return self._dict_to_tensor(p, ksize, ksize, ksize) 

1553 

1554 

1555@tf_export(v1=["initializers.identity"]) 

1556@deprecation.deprecated_endpoints("initializers.identity") 

1557class Identity(Initializer): 

1558 """Initializer that generates the identity matrix. 

1559 

1560 Only use for 2D matrices. 

1561 

1562 Args: 

1563 gain: Multiplicative factor to apply to the identity matrix. 

1564 dtype: Default data type, used if no `dtype` argument is provided when 

1565 calling the initializer. Only floating point types are supported. 

1566 """ 

1567 

1568 @deprecated_args(None, 

1569 "Call initializer instance with the dtype argument instead " 

1570 "of passing it to the constructor", "dtype") 

1571 def __init__(self, gain=1.0, dtype=dtypes.float32): 

1572 self.gain = gain 

1573 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 

1574 

1575 def __call__(self, shape, dtype=None, partition_info=None): 

1576 full_shape = shape if partition_info is None else partition_info.full_shape 

1577 if len(full_shape) != 2: 

1578 raise ValueError("The tensor to initialize, specified by argument `shape`" 

1579 " must be at least two-dimensional. Received shape=" 

1580 f"{shape}") 

1581 if dtype is None: 

1582 dtype = self.dtype 

1583 if isinstance(full_shape, tensor_shape.TensorShape): 

1584 full_shape = full_shape.as_list() 

1585 initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype) 

1586 if partition_info is not None: 

1587 initializer = array_ops.slice(initializer, partition_info.var_offset, 

1588 shape) 

1589 return self.gain * initializer 

1590 

1591 def get_config(self): 

1592 return {"gain": self.gain, "dtype": self.dtype.name} 

1593 

1594 

1595@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"]) 

1596@deprecation.deprecated_endpoints("glorot_uniform_initializer", 

1597 "initializers.glorot_uniform") 

1598class GlorotUniform(VarianceScaling): 

1599 """The Glorot uniform initializer, also called Xavier uniform initializer. 

1600 

1601 It draws samples from a uniform distribution within [-limit, limit] 

1602 where `limit` is `sqrt(6 / (fan_in + fan_out))` 

1603 where `fan_in` is the number of input units in the weight tensor 

1604 and `fan_out` is the number of output units in the weight tensor. 

1605 

1606 Args: 

1607 seed: A Python integer. Used to create random seeds. See 

1608 `tf.compat.v1.set_random_seed` for behavior. 

1609 dtype: Default data type, used if no `dtype` argument is provided when 

1610 calling the initializer. Only floating point types are supported. 

1611 References: 

1612 [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) 

1613 ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) 

1614 """ 

1615 

1616 @deprecated_args(None, 

1617 "Call initializer instance with the dtype argument instead " 

1618 "of passing it to the constructor", "dtype") 

1619 def __init__(self, seed=None, dtype=dtypes.float32): 

1620 super(GlorotUniform, self).__init__( 

1621 scale=1.0, mode="fan_avg", distribution="uniform", seed=seed) 

1622 

1623 def get_config(self): 

1624 return {"seed": self.seed, "dtype": self.dtype.name} 

1625 

1626 

1627@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"]) 

1628@deprecation.deprecated_endpoints("glorot_normal_initializer", 

1629 "initializers.glorot_normal") 

1630class GlorotNormal(VarianceScaling): 

1631 """The Glorot normal initializer, also called Xavier normal initializer. 

1632 

1633 It draws samples from a truncated normal distribution centered on 0 

1634 with standard deviation (after truncation) given by 

1635 `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number 

1636 of input units in the weight tensor and `fan_out` is the number of 

1637 output units in the weight tensor. 

1638 

1639 Args: 

1640 seed: A Python integer. Used to create random seeds. See 

1641 `tf.compat.v1.set_random_seed` for behavior. 

1642 dtype: Default data type, used if no `dtype` argument is provided when 

1643 calling the initializer. Only floating point types are supported. 

1644 References: 

1645 [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) 

1646 ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) 

1647 """ 

1648 

1649 @deprecated_args(None, 

1650 "Call initializer instance with the dtype argument instead " 

1651 "of passing it to the constructor", "dtype") 

1652 def __init__(self, seed=None, dtype=dtypes.float32): 

1653 super(GlorotNormal, self).__init__( 

1654 scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed) 

1655 

1656 def get_config(self): 

1657 return {"seed": self.seed, "dtype": self.dtype.name} 

1658 

1659 

1660# Aliases. 

1661 

1662# pylint: disable=invalid-name 

1663zeros_initializer = Zeros 

1664ones_initializer = Ones 

1665constant_initializer = Constant 

1666random_uniform_initializer = RandomUniform 

1667random_normal_initializer = RandomNormal 

1668truncated_normal_initializer = TruncatedNormal 

1669uniform_unit_scaling_initializer = UniformUnitScaling 

1670variance_scaling_initializer = VarianceScaling 

1671glorot_uniform_initializer = GlorotUniform 

1672glorot_normal_initializer = GlorotNormal 

1673orthogonal_initializer = Orthogonal 

1674identity_initializer = Identity 

1675convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal 

1676convolutional_orthogonal_1d = ConvolutionOrthogonal1D 

1677convolutional_orthogonal_2d = ConvolutionOrthogonal2D 

1678convolutional_orthogonal_3d = ConvolutionOrthogonal3D 

1679# pylint: enable=invalid-name 

1680 

1681 

1682@tf_export(v1=["initializers.lecun_normal"]) 

1683def lecun_normal(seed=None): 

1684 """LeCun normal initializer. 

1685 

1686 It draws samples from a truncated normal distribution centered on 0 

1687 with standard deviation (after truncation) given by 

1688 `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of 

1689 input units in the weight tensor. 

1690 

1691 Args: 

1692 seed: A Python integer. Used to seed the random generator. 

1693 

1694 Returns: 

1695 An initializer. 

1696 

1697 References: 

1698 - Self-Normalizing Neural Networks, 

1699 [Klambauer et al., 

1700 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) 

1701 # pylint: disable=line-too-long 

1702 ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) 

1703 - Efficient Backprop, 

1704 [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) 

1705 """ 

1706 return VarianceScaling( 

1707 scale=1., mode="fan_in", distribution="truncated_normal", seed=seed) 

1708 

1709 

1710@tf_export(v1=["initializers.lecun_uniform"]) 

1711def lecun_uniform(seed=None): 

1712 """LeCun uniform initializer. 

1713 

1714 It draws samples from a uniform distribution within [-limit, limit] 

1715 where `limit` is `sqrt(3 / fan_in)` 

1716 where `fan_in` is the number of input units in the weight tensor. 

1717 

1718 Args: 

1719 seed: A Python integer. Used to seed the random generator. 

1720 

1721 Returns: 

1722 An initializer. 

1723 

1724 References: 

1725 - Self-Normalizing Neural Networks, 

1726 [Klambauer et al., 

1727 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) 

1728 # pylint: disable=line-too-long 

1729 ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) 

1730 - Efficient Backprop, 

1731 [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) 

1732 """ 

1733 return VarianceScaling( 

1734 scale=1., mode="fan_in", distribution="uniform", seed=seed) 

1735 

1736 

1737@tf_export(v1=["initializers.he_normal"]) 

1738def he_normal(seed=None): 

1739 """He normal initializer. 

1740 

1741 It draws samples from a truncated normal distribution centered on 0 

1742 with standard deviation (after truncation) given by 

1743 `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of 

1744 input units in the weight tensor. 

1745 

1746 Args: 

1747 seed: A Python integer. Used to seed the random generator. 

1748 

1749 Returns: 

1750 An initializer. 

1751 

1752 References: 

1753 [He et al., 2015] 

1754 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) 

1755 # pylint: disable=line-too-long 

1756 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) 

1757 """ 

1758 return VarianceScaling( 

1759 scale=2., mode="fan_in", distribution="truncated_normal", seed=seed) 

1760 

1761 

1762@tf_export(v1=["initializers.he_uniform"]) 

1763def he_uniform(seed=None): 

1764 """He uniform variance scaling initializer. 

1765 

1766 It draws samples from a uniform distribution within [-limit, limit] 

1767 where `limit` is `sqrt(6 / fan_in)` 

1768 where `fan_in` is the number of input units in the weight tensor. 

1769 

1770 Args: 

1771 seed: A Python integer. Used to seed the random generator. 

1772 

1773 Returns: 

1774 An initializer. 

1775 

1776 References: 

1777 [He et al., 2015] 

1778 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) 

1779 # pylint: disable=line-too-long 

1780 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) 

1781 """ 

1782 return VarianceScaling( 

1783 scale=2., mode="fan_in", distribution="uniform", seed=seed) 

1784 

1785 

1786# Utility functions. 

1787 

1788 

1789def _compute_fans(shape): 

1790 """Computes the number of input and output units for a weight shape. 

1791 

1792 Args: 

1793 shape: Integer shape tuple or TF tensor shape. 

1794 

1795 Returns: 

1796 A tuple of integer scalars (fan_in, fan_out). 

1797 """ 

1798 if len(shape) < 1: # Just to avoid errors for constants. 

1799 fan_in = fan_out = 1 

1800 elif len(shape) == 1: 

1801 fan_in = fan_out = shape[0] 

1802 elif len(shape) == 2: 

1803 fan_in = shape[0] 

1804 fan_out = shape[1] 

1805 else: 

1806 # Assuming convolution kernels (2D, 3D, or more). 

1807 # kernel shape: (..., input_depth, depth) 

1808 receptive_field_size = 1 

1809 for dim in shape[:-2]: 

1810 receptive_field_size *= dim 

1811 fan_in = shape[-2] * receptive_field_size 

1812 fan_out = shape[-1] * receptive_field_size 

1813 return int(fan_in), int(fan_out) 

1814 

1815 

1816def _assert_float_dtype(dtype): 

1817 """Validate and return floating point type based on `dtype`. 

1818 

1819 `dtype` must be a floating point type. 

1820 

1821 Args: 

1822 dtype: The data type to validate. 

1823 

1824 Returns: 

1825 Validated type. 

1826 

1827 Raises: 

1828 ValueError: if `dtype` is not a floating point type. 

1829 """ 

1830 if not dtype.is_floating: 

1831 raise ValueError("Argument `dtype` is expected to be floating point. " 

1832 f"Received: {dtype}.") 

1833 return dtype