Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/test_util.py: 19%

1564 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15 

16# pylint: disable=invalid-name 

17"""Test utils for tensorflow.""" 

18import collections 

19from collections import OrderedDict 

20import contextlib 

21import functools 

22import gc 

23import itertools 

24import math 

25import os 

26import random 

27import re 

28import tempfile 

29import threading 

30import time 

31import unittest 

32 

33from absl.testing import parameterized 

34import numpy as np 

35 

36from google.protobuf import descriptor_pool 

37from google.protobuf import text_format 

38 

39from tensorflow.core.config import flags 

40from tensorflow.core.framework import graph_pb2 

41from tensorflow.core.protobuf import rewriter_config_pb2 

42from tensorflow.python import pywrap_sanitizers 

43from tensorflow.python import tf2 

44from tensorflow.python.client import device_lib 

45from tensorflow.python.client import pywrap_tf_session 

46from tensorflow.python.client import session 

47from tensorflow.python.compat.compat import forward_compatibility_horizon 

48from tensorflow.python.eager import backprop 

49from tensorflow.python.eager import context 

50from tensorflow.python.eager import def_function 

51from tensorflow.python.framework import _test_metrics_util 

52from tensorflow.python.framework import config 

53from tensorflow.python.framework import device as pydev 

54from tensorflow.python.framework import dtypes 

55from tensorflow.python.framework import errors 

56from tensorflow.python.framework import errors_impl 

57from tensorflow.python.framework import gpu_util 

58from tensorflow.python.framework import importer 

59from tensorflow.python.framework import indexed_slices 

60from tensorflow.python.framework import ops 

61from tensorflow.python.framework import random_seed 

62from tensorflow.python.framework import sparse_tensor 

63from tensorflow.python.framework import tensor_shape 

64from tensorflow.python.framework import tensor_util 

65from tensorflow.python.framework import tfrt_utils 

66from tensorflow.python.framework import versions 

67from tensorflow.python.ops import array_ops 

68from tensorflow.python.ops import control_flow_util 

69from tensorflow.python.ops import control_flow_util_v2 

70from tensorflow.python.ops import gen_sync_ops 

71from tensorflow.python.ops import gradients_impl 

72from tensorflow.python.ops import math_ops 

73from tensorflow.python.ops import script_ops 

74from tensorflow.python.ops import summary_ops_v2 

75from tensorflow.python.ops import variables 

76from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import 

77from tensorflow.python.ops.ragged import ragged_tensor 

78from tensorflow.python.ops.ragged import ragged_tensor_value 

79from tensorflow.python.platform import _pywrap_stacktrace_handler 

80from tensorflow.python.platform import googletest 

81from tensorflow.python.platform import tf_logging as logging 

82from tensorflow.python.training import server_lib 

83from tensorflow.python.util import _pywrap_util_port 

84from tensorflow.python.util import compat 

85from tensorflow.python.util import deprecation 

86from tensorflow.python.util import nest 

87from tensorflow.python.util import tf_decorator 

88from tensorflow.python.util import tf_inspect 

89from tensorflow.python.util import traceback_utils 

90from tensorflow.python.util.compat import collections_abc 

91from tensorflow.python.util.protobuf import compare 

92from tensorflow.python.util.tf_export import tf_export 

93 

94 

95# If the below import is made available through the BUILD rule, then this 

96# function is overridden and will instead return True and cause Tensorflow 

97# graphs to be compiled with XLA. 

98def is_xla_enabled(): 

99 return False 

100 

101 

102try: 

103 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import 

104except Exception: # pylint: disable=broad-except 

105 pass 

106 

107 

108# Uses the same mechanism as above to selectively enable/disable MLIR 

109# compilation. 

110def is_mlir_bridge_enabled(): 

111 return None 

112 

113 

114try: 

115 from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 

116except ImportError: 

117 try: 

118 from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import 

119 except ImportError: 

120 pass 

121 

122 

123def is_asan_enabled(): 

124 """Check if ASAN is enabled.""" 

125 return pywrap_sanitizers.is_asan_enabled() 

126 

127 

128def is_msan_enabled(): 

129 """Check if MSAN is enabled.""" 

130 return pywrap_sanitizers.is_msan_enabled() 

131 

132 

133def is_tsan_enabled(): 

134 """Check if TSAN is enabled.""" 

135 return pywrap_sanitizers.is_tsan_enabled() 

136 

137 

138def is_ubsan_enabled(): 

139 """Check if UBSAN is enabled.""" 

140 return pywrap_sanitizers.is_ubsan_enabled() 

141 

142 

143def _get_object_count_by_type(exclude=()): 

144 return ( 

145 collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) - 

146 collections.Counter([type(obj).__name__ for obj in exclude])) 

147 

148 

149@tf_export("test.gpu_device_name") 

150def gpu_device_name(): 

151 """Returns the name of a GPU device if available or a empty string. 

152 

153 This method should only be used in tests written with `tf.test.TestCase`. 

154 

155 >>> class MyTest(tf.test.TestCase): 

156 ... 

157 ... def test_add_on_gpu(self): 

158 ... if not tf.test.is_built_with_gpu_support(): 

159 ... self.skipTest("test is only applicable on GPU") 

160 ... 

161 ... with tf.device(tf.test.gpu_device_name()): 

162 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 

163 

164 """ 

165 for x in device_lib.list_local_devices(): 

166 if x.device_type == "GPU": 

167 return compat.as_str(x.name) 

168 return "" 

169 

170 

171def assert_ops_in_graph(expected_ops, graph): 

172 """Assert all expected operations are found. 

173 

174 Args: 

175 expected_ops: `dict<string, string>` of op name to op type. 

176 graph: Graph to check. 

177 

178 Returns: 

179 `dict<string, node>` of node name to node. 

180 

181 Raises: 

182 ValueError: If the expected ops are not present in the graph. 

183 """ 

184 actual_ops = {} 

185 gd = graph.as_graph_def() 

186 for node in gd.node: 

187 if node.name in expected_ops: 

188 if expected_ops[node.name] != node.op: 

189 raise ValueError("Expected op for node %s is different. %s vs %s" % 

190 (node.name, expected_ops[node.name], node.op)) 

191 actual_ops[node.name] = node 

192 if set(expected_ops.keys()) != set(actual_ops.keys()): 

193 raise ValueError("Not all expected ops are present. Expected %s, found %s" % 

194 (expected_ops.keys(), actual_ops.keys())) 

195 return actual_ops 

196 

197 

198@tf_export("test.assert_equal_graph_def", v1=[]) 

199def assert_equal_graph_def_v2(expected, actual): 

200 """Asserts that two `GraphDef`s are (mostly) the same. 

201 

202 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 

203 nodes, attrs, and control inputs. Node names are used to match up nodes 

204 between the graphs, so the naming of nodes must be consistent. This function 

205 ignores randomized attribute values that may appear in V2 checkpoints. 

206 

207 Args: 

208 expected: The `GraphDef` we expected. 

209 actual: The `GraphDef` we have. 

210 

211 Raises: 

212 AssertionError: If the `GraphDef`s do not match. 

213 TypeError: If either argument is not a `GraphDef`. 

214 """ 

215 assert_equal_graph_def(actual, expected, checkpoint_v2=True, 

216 hash_table_shared_name=True) 

217 

218 

219@tf_export(v1=["test.assert_equal_graph_def"]) 

220def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False, 

221 hash_table_shared_name=False): 

222 """Asserts that two `GraphDef`s are (mostly) the same. 

223 

224 Compares two `GraphDef` protos for equality, ignoring versions and ordering of 

225 nodes, attrs, and control inputs. Node names are used to match up nodes 

226 between the graphs, so the naming of nodes must be consistent. 

227 

228 Args: 

229 actual: The `GraphDef` we have. 

230 expected: The `GraphDef` we expected. 

231 checkpoint_v2: boolean determining whether to ignore randomized attribute 

232 values that appear in V2 checkpoints. 

233 hash_table_shared_name: boolean determining whether to ignore randomized 

234 shared_names that appear in HashTableV2 op defs. 

235 

236 Raises: 

237 AssertionError: If the `GraphDef`s do not match. 

238 TypeError: If either argument is not a `GraphDef`. 

239 """ 

240 assert_equal_graph_def(actual, expected, checkpoint_v2, 

241 hash_table_shared_name) 

242 

243 

244def assert_equal_graph_def(actual, expected, checkpoint_v2=False, 

245 hash_table_shared_name=False): 

246 if not isinstance(actual, graph_pb2.GraphDef): 

247 raise TypeError("Expected tf.GraphDef for actual, got %s" % 

248 type(actual).__name__) 

249 if not isinstance(expected, graph_pb2.GraphDef): 

250 raise TypeError("Expected tf.GraphDef for expected, got %s" % 

251 type(expected).__name__) 

252 

253 if checkpoint_v2: 

254 _strip_checkpoint_v2_randomized(actual) 

255 _strip_checkpoint_v2_randomized(expected) 

256 

257 if hash_table_shared_name: 

258 _strip_hash_table_shared_name(actual) 

259 _strip_hash_table_shared_name(expected) 

260 

261 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(), 

262 expected.SerializeToString()) 

263 if diff: 

264 raise AssertionError(compat.as_str(diff)) 

265 

266 

267def assert_meta_graph_protos_equal(tester, a, b): 

268 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`.""" 

269 # Carefully check the collection_defs 

270 tester.assertEqual(set(a.collection_def), set(b.collection_def)) 

271 collection_keys = a.collection_def.keys() 

272 for k in collection_keys: 

273 a_value = a.collection_def[k] 

274 b_value = b.collection_def[k] 

275 proto_type = ops.get_collection_proto_type(k) 

276 if proto_type: 

277 a_proto = proto_type() 

278 b_proto = proto_type() 

279 # Number of entries in the collections is the same 

280 tester.assertEqual( 

281 len(a_value.bytes_list.value), len(b_value.bytes_list.value)) 

282 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value, 

283 b_value.bytes_list.value): 

284 a_proto.ParseFromString(a_value_item) 

285 b_proto.ParseFromString(b_value_item) 

286 tester.assertProtoEquals(a_proto, b_proto) 

287 else: 

288 tester.assertEquals(a_value, b_value) 

289 # Compared the fields directly, remove their raw values from the 

290 # proto comparison below. 

291 a.ClearField("collection_def") 

292 b.ClearField("collection_def") 

293 

294 # Check the graph_defs. 

295 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True) 

296 # Check graph_def versions (ignored by assert_equal_graph_def). 

297 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions) 

298 # Compared the fields directly, remove their raw values from the 

299 # proto comparison below. 

300 a.ClearField("graph_def") 

301 b.ClearField("graph_def") 

302 

303 tester.assertProtoEquals(a, b) 

304 

305 

306# Matches attributes named via _SHARDED_SUFFIX in 

307# tensorflow/python/training/saver.py 

308_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part" 

309 

310 

311def _strip_checkpoint_v2_randomized(graph_def): 

312 for node in graph_def.node: 

313 delete_keys = [] 

314 for attr_key in node.attr: 

315 attr_tensor_value = node.attr[attr_key].tensor 

316 if attr_tensor_value and len(attr_tensor_value.string_val) == 1: 

317 attr_tensor_string_value = attr_tensor_value.string_val[0] 

318 if (attr_tensor_string_value and 

319 re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN), 

320 attr_tensor_string_value)): 

321 delete_keys.append(attr_key) 

322 for attr_key in delete_keys: 

323 del node.attr[attr_key] 

324 

325 

326_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+" 

327 

328 

329def _strip_hash_table_shared_name(graph_def): 

330 for node in graph_def.node: 

331 delete_keys = [] 

332 if node.op == "HashTableV2" and "shared_name" in node.attr: 

333 if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN), 

334 node.attr["shared_name"].s): 

335 delete_keys.append("shared_name") 

336 for attr_key in delete_keys: 

337 del node.attr[attr_key] 

338 

339 

340def IsGoogleCudaEnabled(): 

341 return _pywrap_util_port.IsGoogleCudaEnabled() 

342 

343 

344def IsBuiltWithROCm(): 

345 return _pywrap_util_port.IsBuiltWithROCm() 

346 

347 

348def IsBuiltWithXLA(): 

349 return _pywrap_util_port.IsBuiltWithXLA() 

350 

351 

352def IsBuiltWithNvcc(): 

353 return _pywrap_util_port.IsBuiltWithNvcc() 

354 

355 

356def GpuSupportsHalfMatMulAndConv(): 

357 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv() 

358 

359 

360def IsMklEnabled(): 

361 return _pywrap_util_port.IsMklEnabled() 

362 

363 

364def InstallStackTraceHandler(): 

365 _pywrap_stacktrace_handler.InstallStacktraceHandler() 

366 

367 

368def NHWCToNCHW(input_tensor): 

369 """Converts the input from the NHWC format to NCHW. 

370 

371 Args: 

372 input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape 

373 

374 Returns: 

375 converted tensor or shape array 

376 """ 

377 # tensor dim -> new axis order 

378 new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]} 

379 if isinstance(input_tensor, ops.Tensor): 

380 ndims = input_tensor.shape.ndims 

381 return array_ops.transpose(input_tensor, new_axes[ndims]) 

382 else: 

383 ndims = len(input_tensor) 

384 return [input_tensor[a] for a in new_axes[ndims]] 

385 

386 

387def NHWCToNCHW_VECT_C(input_shape_or_tensor): 

388 """Transforms the input from the NHWC layout to NCHW_VECT_C layout. 

389 

390 Note: Does not include quantization or type conversion steps, which should 

391 be applied afterwards. 

392 

393 Args: 

394 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape 

395 

396 Returns: 

397 tensor or shape array transformed into NCHW_VECT_C 

398 

399 Raises: 

400 ValueError: if last dimension of `input_shape_or_tensor` is not evenly 

401 divisible by 4. 

402 """ 

403 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} 

404 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 

405 temp_shape = ( 

406 input_shape_or_tensor.shape.as_list() 

407 if is_tensor else input_shape_or_tensor) 

408 if temp_shape[-1] % 4 != 0: 

409 raise ValueError( 

410 "Last dimension of input must be evenly divisible by 4 to convert to " 

411 "NCHW_VECT_C.") 

412 temp_shape[-1] //= 4 

413 temp_shape.append(4) 

414 permutation = permutations[len(temp_shape)] 

415 if is_tensor: 

416 t = array_ops.reshape(input_shape_or_tensor, temp_shape) 

417 return array_ops.transpose(t, permutation) 

418 else: 

419 return [temp_shape[a] for a in permutation] 

420 

421 

422def NCHW_VECT_CToNHWC(input_shape_or_tensor): 

423 """Transforms the input from the NCHW_VECT_C layout to NHWC layout. 

424 

425 Note: Does not include de-quantization or type conversion steps, which should 

426 be applied beforehand. 

427 

428 Args: 

429 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape 

430 

431 Returns: 

432 tensor or shape array transformed into NHWC 

433 

434 Raises: 

435 ValueError: if last dimension of `input_shape_or_tensor` is not 4. 

436 """ 

437 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} 

438 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) 

439 input_shape = ( 

440 input_shape_or_tensor.shape.as_list() 

441 if is_tensor else input_shape_or_tensor) 

442 if input_shape[-1] != 4: 

443 raise ValueError("Last dimension of NCHW_VECT_C must be 4.") 

444 permutation = permutations[len(input_shape)] 

445 nhwc_shape = [input_shape[a] for a in permutation[:-1]] 

446 nhwc_shape[-1] *= input_shape[-1] 

447 if is_tensor: 

448 t = array_ops.transpose(input_shape_or_tensor, permutation) 

449 return array_ops.reshape(t, nhwc_shape) 

450 else: 

451 return nhwc_shape 

452 

453 

454def NCHWToNHWC(input_tensor): 

455 """Converts the input from the NCHW format to NHWC. 

456 

457 Args: 

458 input_tensor: a 4- or 5-D tensor, or an array representing shape 

459 

460 Returns: 

461 converted tensor or shape array 

462 """ 

463 # tensor dim -> new axis order 

464 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]} 

465 if isinstance(input_tensor, ops.Tensor): 

466 ndims = input_tensor.shape.ndims 

467 return array_ops.transpose(input_tensor, new_axes[ndims]) 

468 else: 

469 ndims = len(input_tensor) 

470 return [input_tensor[a] for a in new_axes[ndims]] 

471 

472 

473def skip_if(condition): 

474 """Skips the decorated function if condition is or evaluates to True. 

475 

476 Args: 

477 condition: Either an expression that can be used in "if not condition" 

478 statement, or a callable whose result should be a boolean. 

479 

480 Returns: 

481 The wrapped function 

482 """ 

483 

484 def real_skip_if(fn): 

485 

486 def wrapper(*args, **kwargs): 

487 if callable(condition): 

488 skip = condition() 

489 else: 

490 skip = condition 

491 if not skip: 

492 return fn(*args, **kwargs) 

493 

494 return wrapper 

495 

496 return real_skip_if 

497 

498 

499@contextlib.contextmanager 

500def skip_if_error(test_obj, error_type, messages=None): 

501 """Context manager to skip cases not considered failures by the tests. 

502 

503 Note that this does not work if used in setUpClass/tearDownClass. 

504 Usage in setUp/tearDown works fine just like regular test methods. 

505 

506 Args: 

507 test_obj: A test object provided as `self` in the test methods; this object 

508 is usually an instance of `unittest.TestCase`'s subclass and should have 

509 `skipTest` method. 

510 error_type: The error type to skip. Note that if `messages` are given, both 

511 `error_type` and `messages` need to match for the test to be skipped. 

512 messages: Optional, a string or list of strings. If `None`, the test will be 

513 skipped if `error_type` matches what is raised; otherwise, the test is 

514 skipped if any of the `messages` is contained in the message of the error 

515 raised, and `error_type` matches the error raised. 

516 

517 Yields: 

518 Nothing. 

519 """ 

520 if messages: 

521 messages = nest.flatten(messages) 

522 try: 

523 yield 

524 except error_type as e: 

525 if not messages or any(message in str(e) for message in messages): 

526 test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e))) 

527 else: 

528 raise 

529 

530 

531def enable_c_shapes(fn): 

532 """No-op. TODO(b/74620627): Remove this.""" 

533 return fn 

534 

535 

536def with_c_shapes(cls): 

537 """No-op. TODO(b/74620627): Remove this.""" 

538 return cls 

539 

540 

541def enable_control_flow_v2(fn): 

542 """Decorator for enabling CondV2 and WhileV2 on a test. 

543 

544 Note this enables using CondV2 and WhileV2 after running the test class's 

545 setup/teardown methods. 

546 

547 In addition to this, callers must import the while_v2 module in order to set 

548 the _while_v2 module in control_flow_ops. 

549 

550 Args: 

551 fn: the function to be wrapped 

552 

553 Returns: 

554 The wrapped function 

555 """ 

556 

557 def wrapper(*args, **kwargs): 

558 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 

559 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True 

560 try: 

561 return fn(*args, **kwargs) 

562 finally: 

563 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old 

564 

565 return wrapper 

566 

567 

568def with_control_flow_v2(cls): 

569 """Adds methods that call original methods with WhileV2 and CondV2 enabled. 

570 

571 Note this enables CondV2 and WhileV2 in new methods after running the test 

572 class's setup method. 

573 

574 In addition to this, callers must import the while_v2 module in order to set 

575 the _while_v2 module in control_flow_ops. 

576 

577 If a test function has _disable_control_flow_v2 attr set to True (using the 

578 @disable_control_flow_v2 decorator), the v2 function is not generated for it. 

579 

580 Example: 

581 

582 @test_util.with_control_flow_v2 

583 class ControlFlowTest(test.TestCase): 

584 

585 def testEnabledForV2(self): 

586 ... 

587 

588 @test_util.disable_control_flow_v2("b/xyzabc") 

589 def testDisabledForV2(self): 

590 ... 

591 

592 Generated class: 

593 class ControlFlowTest(test.TestCase): 

594 

595 def testEnabledForV2(self): 

596 ... 

597 

598 def testEnabledForV2WithControlFlowV2(self): 

599 // Enable V2 flags. 

600 testEnabledForV2(self) 

601 // Restore V2 flags. 

602 

603 def testDisabledForV2(self): 

604 ... 

605 

606 Args: 

607 cls: class to decorate 

608 

609 Returns: 

610 cls with new test methods added 

611 """ 

612 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 

613 return cls 

614 

615 for name, value in cls.__dict__.copy().items(): 

616 if (callable(value) and 

617 name.startswith(unittest.TestLoader.testMethodPrefix) and 

618 not getattr(value, "_disable_control_flow_v2", False)): 

619 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value)) 

620 return cls 

621 

622 

623def disable_control_flow_v2(unused_msg): 

624 """Decorator for a function in a with_control_flow_v2 enabled test class. 

625 

626 Blocks the function from being run with v2 control flow ops. 

627 

628 Args: 

629 unused_msg: Reason for disabling. 

630 

631 Returns: 

632 The wrapped function with _disable_control_flow_v2 attr set to True. 

633 """ 

634 

635 def wrapper(func): 

636 func._disable_control_flow_v2 = True 

637 return func 

638 

639 return wrapper 

640 

641 

642def enable_output_all_intermediates(fn): 

643 """Force-enable outputing all intermediates from functional control flow ops. 

644 

645 Args: 

646 fn: the function to be wrapped 

647 

648 Returns: 

649 The wrapped function 

650 """ 

651 

652 def wrapper(*args, **kwargs): 

653 output_all_intermediates_old = \ 

654 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE 

655 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True 

656 try: 

657 return fn(*args, **kwargs) 

658 finally: 

659 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \ 

660 output_all_intermediates_old 

661 

662 return wrapper 

663 

664 

665def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2): 

666 """Decorator for asserting that no new Python objects persist after a test. 

667 

668 Runs the test multiple times executing eagerly, first as a warmup and then to 

669 let objects accumulate. The warmup helps ignore caches which do not grow as 

670 the test is run repeatedly. 

671 

672 Useful for checking that there are no missing Py_DECREFs in the C exercised by 

673 a bit of Python. 

674 

675 Args: 

676 func: The function to test. 

677 warmup_iters: The numer of warmup iterations, excluded from measuring. 

678 

679 Returns: 

680 The wrapped function performing the test. 

681 """ 

682 

683 def wrap_f(f): 

684 def decorator(self, *args, **kwargs): 

685 """Warms up, gets object counts, runs the test, checks for new objects.""" 

686 with context.eager_mode(): 

687 gc.disable() 

688 # Python 3.11 removed "errors" and "skipped" as members of 

689 # unittest.case._Outcome so get them from the test result object 

690 # instead. 

691 test_errors = None 

692 test_skipped = None 

693 if hasattr(self._outcome, "errors"): 

694 test_errors = self._outcome.errors 

695 test_skipped = self._outcome.skipped 

696 else: 

697 test_errors = self._outcome.result.errors 

698 test_skipped = self._outcome.result.skipped 

699 # Run the test 2 times as warmup, in an attempt to fill up caches, which 

700 # should not grow as the test is run repeatedly below. 

701 # 

702 # TODO(b/117156879): Running warmup twice is black magic; we have seen 

703 # tests that fail with 1 warmup run, and pass with 2, on various 

704 # versions of python2.7.x. 

705 for _ in range(warmup_iters): 

706 f(self, *args, **kwargs) 

707 # Since we aren't in the normal test lifecycle, we need to manually run 

708 # cleanups to clear out their object references. 

709 self.doCleanups() 

710 

711 # Some objects are newly created by _get_object_count_by_type(). So 

712 # create and save as a dummy variable to include it as a baseline. 

713 obj_count_by_type = _get_object_count_by_type() 

714 gc.collect() 

715 

716 # Make sure any registered functions are cleaned up in the C++ runtime. 

717 registered_function_names = context.context().list_function_names() 

718 

719 # unittest.doCleanups adds to self._outcome with each unwound call. 

720 # These objects are retained across gc collections so we exclude them 

721 # from the object count calculation. 

722 obj_count_by_type = _get_object_count_by_type( 

723 exclude=gc.get_referents(test_errors, test_skipped)) 

724 

725 if ops.has_default_graph(): 

726 collection_sizes_before = { 

727 collection: len(ops.get_collection(collection)) 

728 for collection in ops.get_default_graph().collections 

729 } 

730 for _ in range(3): 

731 f(self, *args, **kwargs) 

732 # Since we aren't in the normal test lifecycle, we need to manually run 

733 # cleanups to clear out their object references. 

734 self.doCleanups() 

735 # Note that gc.get_objects misses anything that isn't subject to garbage 

736 # collection (C types). Collections are a common source of leaks, so we 

737 # test for collection sizes explicitly. 

738 if ops.has_default_graph(): 

739 for collection_key in ops.get_default_graph().collections: 

740 collection = ops.get_collection(collection_key) 

741 size_before = collection_sizes_before.get(collection_key, 0) 

742 if len(collection) > size_before: 

743 raise AssertionError( 

744 ("Collection %s increased in size from " 

745 "%d to %d (current items %s).") % 

746 (collection_key, size_before, len(collection), collection)) 

747 # Make sure our collection checks don't show up as leaked memory by 

748 # removing references to temporary variables. 

749 del collection 

750 del collection_key 

751 del size_before 

752 del collection_sizes_before 

753 gc.collect() 

754 

755 # There should be no new Python objects hanging around. 

756 obj_count_by_type = ( 

757 _get_object_count_by_type( 

758 exclude=gc.get_referents(test_errors, test_skipped)) - 

759 obj_count_by_type) 

760 

761 # There should be no newly registered functions hanging around. 

762 leftover_functions = ( 

763 context.context().list_function_names() - registered_function_names) 

764 assert not leftover_functions, ( 

765 "The following functions were newly created: %s" % 

766 leftover_functions) 

767 

768 # In some cases (specifically on MacOS), new_count is somehow 

769 # smaller than previous_count. 

770 # Using plain assert because not all classes using this decorator 

771 # have assertLessEqual 

772 assert not obj_count_by_type, ( 

773 "The following objects were newly created: %s" % 

774 str(obj_count_by_type)) 

775 gc.enable() 

776 return decorator 

777 

778 if func is None: 

779 return wrap_f 

780 else: 

781 return wrap_f(func) 

782 

783 

784def assert_no_new_tensors(f): 

785 """Decorator for asserting that no new Tensors persist after a test. 

786 

787 Mainly useful for checking that code using the Python C API has correctly 

788 manipulated reference counts. 

789 

790 Clears the caches that it knows about, runs the garbage collector, then checks 

791 that there are no Tensor or Tensor-like objects still around. This includes 

792 Tensors to which something still has a reference (e.g. from missing 

793 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one 

794 of the objects has __del__ defined). 

795 

796 Args: 

797 f: The test case to run. 

798 

799 Returns: 

800 The decorated test case. 

801 """ 

802 

803 def decorator(self, **kwargs): 

804 """Finds existing Tensors, runs the test, checks for new Tensors.""" 

805 

806 def _is_tensorflow_object(obj): 

807 try: 

808 return isinstance(obj, 

809 (ops.Tensor, variables.Variable, 

810 tensor_shape.Dimension, tensor_shape.TensorShape)) 

811 except (ReferenceError, AttributeError): 

812 # If the object no longer exists, we don't care about it. 

813 return False 

814 

815 tensors_before = set( 

816 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj)) 

817 outside_executed_eagerly = context.executing_eagerly() 

818 # Run the test in a new graph so that collections get cleared when it's 

819 # done, but inherit the graph key so optimizers behave. 

820 outside_graph_key = ops.get_default_graph()._graph_key 

821 with ops.Graph().as_default(): 

822 ops.get_default_graph()._graph_key = outside_graph_key 

823 if outside_executed_eagerly: 

824 with context.eager_mode(): 

825 result = f(self, **kwargs) 

826 else: 

827 result = f(self, **kwargs) 

828 # Make an effort to clear caches, which would otherwise look like leaked 

829 # Tensors. 

830 context.context()._clear_caches() # pylint: disable=protected-access 

831 gc.collect() 

832 tensors_after = [ 

833 obj for obj in gc.get_objects() 

834 if _is_tensorflow_object(obj) and id(obj) not in tensors_before 

835 ] 

836 if tensors_after: 

837 raise AssertionError(("%d Tensors not deallocated after test: %s" % ( 

838 len(tensors_after), 

839 str(tensors_after), 

840 ))) 

841 return result 

842 

843 return decorator 

844 

845 

846def _find_reference_cycle(objects, idx): 

847 

848 def get_ignore_reason(obj, denylist): 

849 """Tests whether an object should be omitted from the dependency graph.""" 

850 if len(denylist) > 100: 

851 return "<depth limit>" 

852 if tf_inspect.isframe(obj): 

853 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]: 

854 return "<test code>" 

855 for b in denylist: 

856 if b is obj: 

857 return "<test code>" 

858 if obj is denylist: 

859 return "<test code>" 

860 return None 

861 

862 # Note: this function is meant to help with diagnostics. Its output is purely 

863 # a human-readable representation, so you may freely modify it to suit your 

864 # needs. 

865 def describe(obj, denylist, leaves_only=False): 

866 """Returns a custom human-readable summary of obj. 

867 

868 Args: 

869 obj: the value to describe. 

870 denylist: same as denylist in get_ignore_reason. 

871 leaves_only: boolean flag used when calling describe recursively. Useful 

872 for summarizing collections. 

873 """ 

874 if get_ignore_reason(obj, denylist): 

875 return "{}{}".format(get_ignore_reason(obj, denylist), type(obj)) 

876 if tf_inspect.isframe(obj): 

877 return "frame: {}".format(tf_inspect.getframeinfo(obj)) 

878 elif tf_inspect.ismodule(obj): 

879 return "module: {}".format(obj.__name__) 

880 else: 

881 if leaves_only: 

882 return "{}, {}".format(type(obj), id(obj)) 

883 elif isinstance(obj, list): 

884 return "list({}): {}".format( 

885 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 

886 elif isinstance(obj, tuple): 

887 return "tuple({}): {}".format( 

888 id(obj), [describe(e, denylist, leaves_only=True) for e in obj]) 

889 elif isinstance(obj, dict): 

890 return "dict({}): {} keys".format(id(obj), len(obj.keys())) 

891 elif tf_inspect.isfunction(obj): 

892 return "function({}) {}; globals ID: {}".format( 

893 id(obj), obj.__name__, id(obj.__globals__)) 

894 else: 

895 return "{}, {}".format(type(obj), id(obj)) 

896 

897 def build_ref_graph(obj, graph, reprs, denylist): 

898 """Builds a reference graph as <referrer> -> <list of referents>. 

899 

900 Args: 

901 obj: The object to start from. The graph will be built by recursively 

902 adding its referrers. 

903 graph: Dict holding the graph to be built. To avoid creating extra 

904 references, the graph holds object IDs rather than actual objects. 

905 reprs: Auxiliary structure that maps object IDs to their human-readable 

906 description. 

907 denylist: List of objects to ignore. 

908 """ 

909 referrers = gc.get_referrers(obj) 

910 denylist = denylist + (referrers,) 

911 

912 obj_id = id(obj) 

913 for r in referrers: 

914 if get_ignore_reason(r, denylist) is None: 

915 r_id = id(r) 

916 if r_id not in graph: 

917 graph[r_id] = [] 

918 if obj_id not in graph[r_id]: 

919 graph[r_id].append(obj_id) 

920 build_ref_graph(r, graph, reprs, denylist) 

921 reprs[r_id] = describe(r, denylist) 

922 

923 def find_cycle(el, graph, reprs, path): 

924 """Finds and prints a single cycle in the dependency graph.""" 

925 if el not in graph: 

926 return 

927 for r in graph[el]: 

928 if r in path: 

929 logging.error("Reference cycle sample:") 

930 for p in path + (r,): 

931 logging.error(reprs.get(p, "unknown object " + str(p))) 

932 return True 

933 else: 

934 if find_cycle(r, graph, reprs, path + (r,)): 

935 return True 

936 return False 

937 

938 obj = objects[idx] 

939 graph = {} # referrer ID -> object ID 

940 reprs = {} # object ID -> description 

941 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason, 

942 describe, build_ref_graph, find_cycle)) 

943 for k in graph: 

944 if find_cycle(k, graph, reprs, ()): 

945 return True 

946 return False 

947 

948 

949def assert_no_garbage_created(f): 

950 """Test method decorator to assert that no garbage has been created. 

951 

952 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters 

953 cannot be un-set (i.e. will disable garbage collection for any other unit 

954 tests in the same file/shard). 

955 

956 Args: 

957 f: The function to decorate. 

958 

959 Returns: 

960 The decorated function. 

961 """ 

962 

963 def decorator(self, **kwargs): 

964 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" 

965 gc.disable() 

966 previous_debug_flags = gc.get_debug() 

967 gc.set_debug(gc.DEBUG_SAVEALL) 

968 gc.collect() 

969 previous_garbage = len(gc.garbage) 

970 result = f(self, **kwargs) 

971 gc.collect() 

972 new_garbage = len(gc.garbage) 

973 if new_garbage > previous_garbage: 

974 

975 for i, obj in enumerate(gc.garbage[previous_garbage:]): 

976 # Known false positive for ast.fix_missing_locations. 

977 if getattr(obj, "__module__", "") == "ast": 

978 new_garbage -= 3 

979 

980 if new_garbage > previous_garbage: 

981 logging.error( 

982 "The decorated test created work for Python's garbage collector, " 

983 "likely due to a reference cycle. New objects in cycle(s):") 

984 for i, obj in enumerate(gc.garbage[previous_garbage:]): 

985 try: 

986 logging.error("Object %d of %d", i, 

987 len(gc.garbage) - previous_garbage) 

988 

989 def _safe_object_str(obj): 

990 return "<%s %d>" % (obj.__class__.__name__, id(obj)) 

991 

992 logging.error(" Object type: %s", _safe_object_str(obj)) 

993 logging.error( 

994 " Referrer types: %s", ", ".join( 

995 [_safe_object_str(ref) for ref in gc.get_referrers(obj)])) 

996 logging.error( 

997 " Referent types: %s", ", ".join( 

998 [_safe_object_str(ref) for ref in gc.get_referents(obj)])) 

999 logging.error(" Object attribute names: %s", dir(obj)) 

1000 logging.error(" Object __str__:") 

1001 logging.error(obj) 

1002 logging.error(" Object __repr__:") 

1003 logging.error(repr(obj)) 

1004 except Exception: # pylint: disable=broad-except 

1005 logging.error("(Exception while printing object)") 

1006 

1007 # When garbage is created, this call can help identify reference cycles, 

1008 # which are typically the cause of such garbage. 

1009 if new_garbage > previous_garbage: 

1010 for i in range(previous_garbage, new_garbage): 

1011 if _find_reference_cycle(gc.garbage, i): 

1012 break 

1013 

1014 # This will fail if any garbage has been created, typically because of a 

1015 # reference cycle. 

1016 self.assertEqual(previous_garbage, new_garbage) 

1017 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would 

1018 # be nice to be able to decorate arbitrary tests in a large test suite and 

1019 # not hold on to every object in other tests. 

1020 gc.set_debug(previous_debug_flags) 

1021 gc.enable() 

1022 return result 

1023 

1024 return decorator 

1025 

1026 

1027def _combine_named_parameters(**kwargs): 

1028 """Generate combinations based on its keyword arguments. 

1029 

1030 Two sets of returned combinations can be concatenated using +. Their product 

1031 can be computed using `times()`. 

1032 

1033 Args: 

1034 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 

1035 `option=the_only_possibility`. 

1036 

1037 Returns: 

1038 a list of dictionaries for each combination. Keys in the dictionaries are 

1039 the keyword argument names. Each key has one value - one of the 

1040 corresponding keyword argument values. 

1041 """ 

1042 sort_by_key = lambda k: k[0] 

1043 combinations = [] 

1044 for key, values in sorted(kwargs.items(), key=sort_by_key): 

1045 if not isinstance(values, list): 

1046 values = [values] 

1047 combinations.append([(key, value) for value in values]) 

1048 

1049 return [OrderedDict(result) for result in itertools.product(*combinations)] 

1050 

1051 

1052def generate_combinations_with_testcase_name(**kwargs): 

1053 """Generate combinations based on its keyword arguments using combine(). 

1054 

1055 This function calls combine() and appends a testcase name to the list of 

1056 dictionaries returned. The 'testcase_name' key is a required for named 

1057 parameterized tests. 

1058 

1059 Args: 

1060 **kwargs: keyword arguments of form `option=[possibilities, ...]` or 

1061 `option=the_only_possibility`. 

1062 

1063 Returns: 

1064 a list of dictionaries for each combination. Keys in the dictionaries are 

1065 the keyword argument names. Each key has one value - one of the 

1066 corresponding keyword argument values. 

1067 """ 

1068 combinations = _combine_named_parameters(**kwargs) 

1069 named_combinations = [] 

1070 for combination in combinations: 

1071 assert isinstance(combination, OrderedDict) 

1072 name = "".join([ 

1073 "_{}_{}".format("".join(filter(str.isalnum, key)), 

1074 "".join(filter(str.isalnum, str(value)))) 

1075 for key, value in combination.items() 

1076 ]) 

1077 named_combinations.append( 

1078 OrderedDict( 

1079 list(combination.items()) + 

1080 [("testcase_name", "_test{}".format(name))])) 

1081 

1082 return named_combinations 

1083 

1084 

1085def run_all_in_graph_and_eager_modes(cls): 

1086 """Execute all test methods in the given class with and without eager.""" 

1087 base_decorator = run_in_graph_and_eager_modes 

1088 for name in dir(cls): 

1089 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 

1090 name.startswith("testSkipEager") or 

1091 name.startswith("test_skip_eager") or 

1092 name == "test_session"): 

1093 continue 

1094 value = getattr(cls, name, None) 

1095 if callable(value): 

1096 setattr(cls, name, base_decorator(value)) 

1097 return cls 

1098 

1099 

1100def enable_nested_function_shape_inference(fn): 

1101 """Decorator for enabling nested_function_shape_inference on a test. 

1102 

1103 This function returns a decorator intended to be applied to test methods in 

1104 a `tf.test.TestCase` class. Doing so will set nested_function_shape_inference, 

1105 reset the context, execute the test, then reset the context to the state 

1106 it was in prior to this test. 

1107 

1108 Example: 

1109 

1110 class MyTest(test.TestCase): 

1111 

1112 @enable_nested_function_shape_inference 

1113 def testFoo(self): 

1114 ... 

1115 

1116 Args: 

1117 fn: the function to be wrapped. 

1118 

1119 Returns: 

1120 The wrapped function. 

1121 """ 

1122 

1123 def wrapper(*args, **kwargs): 

1124 # If `nested_function_shape_inference` is already enabled do nothing. 

1125 if flags.config().enable_nested_function_shape_inference.value(): 

1126 return fn(*args, **kwargs) 

1127 

1128 flags.config().enable_nested_function_shape_inference.reset(True) 

1129 try: 

1130 return fn(*args, **kwargs) 

1131 finally: 

1132 flags.config().enable_nested_function_shape_inference.reset(False) 

1133 

1134 return wrapper 

1135 

1136 

1137def enable_quantized_dtypes_training(fn): 

1138 """Decorator for enabling quantized_dtypes_training on a test. 

1139 

1140 This function returns a decorator intended to be applied to test methods in 

1141 a `tf.test.TestCase` class. Doing so will set quantized_dtypes_training, 

1142 reset the context, execute the test, then reset the context to the state 

1143 it was in prior to this test. 

1144 

1145 Example: 

1146 

1147 class MyTest(test.TestCase): 

1148 

1149 @enable_quantized_dtypes_training 

1150 def testFoo(self): 

1151 ... 

1152 

1153 Args: 

1154 fn: the function to be wrapped. 

1155 

1156 Returns: 

1157 The wrapped function. 

1158 """ 

1159 

1160 def wrapper(*args, **kwargs): 

1161 # If `enable_quantized_dtypes_training` is already enabled do nothing. 

1162 if flags.config().enable_quantized_dtypes_training.value(): 

1163 return fn(*args, **kwargs) 

1164 

1165 flags.config().enable_quantized_dtypes_training.reset(True) 

1166 try: 

1167 return fn(*args, **kwargs) 

1168 finally: 

1169 flags.config().enable_quantized_dtypes_training.reset(False) 

1170 

1171 return wrapper 

1172 

1173 

1174def enable_eager_op_as_function(fn): 

1175 """Returns the same fn. This will be removed once all usages are removed. 

1176 

1177 Args: 

1178 fn: the function to be wrapped. 

1179 

1180 Returns: 

1181 The wrapped function. 

1182 """ 

1183 

1184 def wrapper(*args, **kwargs): 

1185 return fn(*args, **kwargs) 

1186 

1187 return wrapper 

1188 

1189 

1190@tf_export("test.with_eager_op_as_function") 

1191def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disable=unused-argument 

1192 """Returns the same class. This will be removed once all usages are removed. 

1193 

1194 Args: 

1195 cls: class to decorate. 

1196 only_as_function: unused argument. 

1197 

1198 Returns: 

1199 cls 

1200 """ 

1201 

1202 def decorator(cls): 

1203 return cls 

1204 

1205 if cls is not None: 

1206 return decorator(cls) 

1207 

1208 return decorator 

1209 

1210 

1211def enable_graph_building_optimization(fn): 

1212 """Decorator for enabling graph_building_optimization on a test. 

1213 

1214 This function returns a decorator intended to be applied to test methods in 

1215 a `tf.test.TestCase` class. Doing so will enable graph_building_optimization, 

1216 execute the test, then reset the feature flag to its default value. 

1217 

1218 Example: 

1219 

1220 class MyTest(test.TestCase): 

1221 

1222 @enable_graph_building_optimization 

1223 def testFoo(self): 

1224 ... 

1225 

1226 Args: 

1227 fn: the function to be wrapped. 

1228 

1229 Returns: 

1230 The wrapped function. 

1231 """ 

1232 

1233 def wrapper(*args, **kwargs): 

1234 # If `graph_building_optimization` is already enabled do nothing. 

1235 if flags.config().graph_building_optimization.value(): 

1236 return fn(*args, **kwargs) 

1237 

1238 flags.config().graph_building_optimization.reset(True) 

1239 try: 

1240 return fn(*args, **kwargs) 

1241 finally: 

1242 flags.config().graph_building_optimization.reset(False) 

1243 

1244 return wrapper 

1245 

1246 

1247def add_graph_building_optimization_tests(cls=None): 

1248 """Adds methods with graph_building_optimization enabled to the test suite. 

1249 

1250 Example: 

1251 

1252 @test_util.add_graph_building_optimization_tests 

1253 class FooTest(test.TestCase): 

1254 

1255 def testBar(self): 

1256 ... 

1257 

1258 Generated class: 

1259 class FooTest(test.TestCase): 

1260 

1261 def testBar(self): 

1262 ... 

1263 

1264 def testBarWithGraphBuildingOptimization(self): 

1265 // Enable graph_building_optimization 

1266 testBar(self) 

1267 // Disable graph_building_optimization 

1268 

1269 Args: 

1270 cls: class to decorate. 

1271 

1272 Returns: 

1273 cls with new test methods added. 

1274 """ 

1275 

1276 def decorator(cls): 

1277 if flags.config().graph_building_optimization.value(): 

1278 return cls 

1279 

1280 for name, value in cls.__dict__.copy().items(): 

1281 if (callable(value) and 

1282 (name.startswith(unittest.TestLoader.testMethodPrefix) or 

1283 name.startswith("benchmark"))): 

1284 setattr(cls, name + "WithGraphBuildingOptimization", 

1285 enable_graph_building_optimization(value)) 

1286 return cls 

1287 

1288 if cls is not None: 

1289 return decorator(cls) 

1290 

1291 return decorator 

1292 

1293 

1294def disable_eager_op_as_function(unused_msg): 

1295 """Decorator for a function in a with_eager_op_as_function enabled test class. 

1296 

1297 Blocks the function from being run with eager_op_as_function enabled. 

1298 

1299 Args: 

1300 unused_msg: Reason for disabling. 

1301 

1302 Returns: 

1303 The wrapped function with _disable_eager_op_as_function attr set to True. 

1304 """ 

1305 return _disable_test(execute_func=False) 

1306 

1307 

1308def set_xla_env_flag(func=None, flag=""): 

1309 """Decorator for setting XLA_FLAGS prior to running a test. 

1310 

1311 This function returns a decorator intended to be applied to test methods in 

1312 a `tf.test.TestCase` class. Doing so will allow users to set any xla flags 

1313 exposed via the XLA_FLAGS environment variable, execute the test, then reset 

1314 the XLA_FLAGS to the state it was in prior to this test. 

1315 

1316 Example: 

1317 

1318 class MyTest(test.TestCase): 

1319 

1320 @set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false') 

1321 def testFoo(self): 

1322 ... 

1323 

1324 Args: 

1325 func: The function to be wrapped. 

1326 flag: The xla flag to be set in the XLA_FLAGS env variable. 

1327 

1328 Returns: 

1329 The wrapped function. 

1330 """ 

1331 

1332 def decorator(f): 

1333 

1334 @functools.wraps(f) 

1335 def decorated(*args, **kwargs): 

1336 original_xla_flags = os.environ.get("XLA_FLAGS") 

1337 new_xla_flags = flag 

1338 if original_xla_flags: 

1339 new_xla_flags = new_xla_flags + " " + original_xla_flags 

1340 os.environ["XLA_FLAGS"] = new_xla_flags 

1341 try: 

1342 return f(*args, **kwargs) 

1343 finally: 

1344 if original_xla_flags is None: 

1345 del os.environ["XLA_FLAGS"] 

1346 else: 

1347 os.environ["XLA_FLAGS"] = original_xla_flags 

1348 

1349 return decorated 

1350 

1351 if func is not None: 

1352 return decorator(func) 

1353 

1354 return decorator 

1355 

1356 

1357def build_as_function_and_v1_graph(func=None): 

1358 """Run a test case in v1 graph mode and inside tf.function in eager mode. 

1359 

1360 WARNING: This decorator can only be used in test cases that statically checks 

1361 generated graph. Attempting to evaluate graph or function results via. 

1362 session.run() or self.evaluate() will fail. 

1363 

1364 WARNING: This decorator can only be used for test cases that inherit from 

1365 absl.testing.parameterized.TestCase. 

1366 

1367 Args: 

1368 func: Test case function to be decorated. 

1369 

1370 Returns: 

1371 Decorated test case function. 

1372 """ 

1373 

1374 def decorator(f): 

1375 if tf_inspect.isclass(f): 

1376 raise ValueError( 

1377 "`run_in_graph_mode_and_function` only supports test methods.") 

1378 

1379 @parameterized.named_parameters(("_v1_graph", "v1_graph"), 

1380 ("_function", "function")) 

1381 @functools.wraps(f) 

1382 def decorated(self, run_mode, *args, **kwargs): 

1383 if run_mode == "v1_graph": 

1384 with ops.Graph().as_default(): 

1385 f(self, *args, **kwargs) 

1386 elif run_mode == "function": 

1387 

1388 @def_function.function 

1389 def function_in_eager(): 

1390 f(self, *args, **kwargs) 

1391 

1392 # Create a new graph for the eagerly executed version of this test for 

1393 # better isolation. 

1394 graph_for_eager_test = ops.Graph() 

1395 with graph_for_eager_test.as_default(), context.eager_mode(): 

1396 function_in_eager() 

1397 ops.dismantle_graph(graph_for_eager_test) 

1398 else: 

1399 raise ValueError("Unknown run mode %s" % run_mode) 

1400 

1401 return decorated 

1402 

1403 if func is not None: 

1404 return decorator(func) 

1405 

1406 return decorator 

1407 

1408 

1409def run_in_async_and_sync_mode(f): 

1410 """Execute the test in async mode and sync mode.""" 

1411 

1412 @parameterized.named_parameters([("Async", True), ("", False)]) 

1413 @functools.wraps(f) 

1414 def decorator(self, async_mode, *args, **kwargs): 

1415 if async_mode: 

1416 with context.execution_mode(context.ASYNC): 

1417 f(self, *args, **kwargs) 

1418 else: 

1419 with context.execution_mode(context.SYNC): 

1420 f(self, *args, **kwargs) 

1421 return decorator 

1422 

1423 

1424def run_in_graph_and_eager_modes(func=None, 

1425 config=None, 

1426 use_gpu=True, 

1427 assert_no_eager_garbage=False): 

1428 """Execute the decorated test with and without enabling eager execution. 

1429 

1430 This function returns a decorator intended to be applied to test methods in 

1431 a `tf.test.TestCase` class. Doing so will cause the contents of the test 

1432 method to be executed twice - once normally, and once with eager execution 

1433 enabled. This allows unittests to confirm the equivalence between eager 

1434 and graph execution (see `tf.compat.v1.enable_eager_execution`). 

1435 

1436 For example, consider the following unittest: 

1437 

1438 ```python 

1439 class MyTests(tf.test.TestCase): 

1440 

1441 @run_in_graph_and_eager_modes 

1442 def test_foo(self): 

1443 x = tf.constant([1, 2]) 

1444 y = tf.constant([3, 4]) 

1445 z = tf.add(x, y) 

1446 self.assertAllEqual([4, 6], self.evaluate(z)) 

1447 

1448 if __name__ == "__main__": 

1449 tf.test.main() 

1450 ``` 

1451 

1452 This test validates that `tf.add()` has the same behavior when computed with 

1453 eager execution enabled as it does when constructing a TensorFlow graph and 

1454 executing the `z` tensor in a session. 

1455 

1456 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 

1457 `run_in_graph_and_eager_modes` are available decorators for different 

1458 v1/v2/eager/graph combinations. 

1459 

1460 

1461 Args: 

1462 func: function to be annotated. If `func` is None, this method returns a 

1463 decorator the can be applied to a function. If `func` is not None this 

1464 returns the decorator applied to `func`. 

1465 config: An optional config_pb2.ConfigProto to use to configure the session 

1466 when executing graphs. 

1467 use_gpu: If True, attempt to run as many operations as possible on GPU. 

1468 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage 

1469 collector and asserts that no extra garbage has been created when running 

1470 the test with eager execution enabled. This will fail if there are 

1471 reference cycles (e.g. a = []; a.append(a)). Off by default because some 

1472 tests may create garbage for legitimate reasons (e.g. they define a class 

1473 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some 

1474 Python interpreters (meaning that tests which rely on objects being 

1475 collected elsewhere in the unit test file will not work). Additionally, 

1476 checks that nothing still has a reference to Tensors that the test 

1477 allocated. 

1478 

1479 Returns: 

1480 Returns a decorator that will run the decorated test method twice: 

1481 once by constructing and executing a graph in a session and once with 

1482 eager execution enabled. 

1483 """ 

1484 

1485 def decorator(f): 

1486 if tf_inspect.isclass(f): 

1487 raise ValueError( 

1488 "`run_in_graph_and_eager_modes` only supports test methods. " 

1489 "Did you mean to use `run_all_in_graph_and_eager_modes`?") 

1490 

1491 def decorated(self, *args, **kwargs): 

1492 logging.info("Running %s in GRAPH mode.", f.__name__) 

1493 try: 

1494 with context.graph_mode(): 

1495 with self.test_session(use_gpu=use_gpu, config=config): 

1496 f(self, *args, **kwargs) 

1497 except unittest.case.SkipTest: 

1498 pass 

1499 

1500 def run_eagerly(self, **kwargs): 

1501 logging.info("Running %s in EAGER mode.", f.__name__) 

1502 if not use_gpu: 

1503 with ops.device("/device:CPU:0"): 

1504 f(self, *args, **kwargs) 

1505 else: 

1506 f(self, *args, **kwargs) 

1507 

1508 if assert_no_eager_garbage: 

1509 ops.reset_default_graph() 

1510 run_eagerly = assert_no_new_tensors( 

1511 assert_no_garbage_created(run_eagerly)) 

1512 

1513 # This decorator runs the wrapped test twice. 

1514 # Reset the test environment between runs. 

1515 self.tearDown() 

1516 self._tempdir = None 

1517 # Create a new graph for the eagerly executed version of this test for 

1518 # better isolation. 

1519 graph_for_eager_test = ops.Graph() 

1520 with graph_for_eager_test.as_default(), context.eager_mode(): 

1521 self.setUp() 

1522 run_eagerly(self, **kwargs) 

1523 ops.dismantle_graph(graph_for_eager_test) 

1524 

1525 return tf_decorator.make_decorator(f, decorated) 

1526 

1527 if func is not None: 

1528 return decorator(func) 

1529 

1530 return decorator 

1531 

1532 

1533def py_func_if_in_function(f): 

1534 

1535 def decorated(*args, **kwds): 

1536 if not ops.inside_function(): 

1537 return f(*args, **kwds) 

1538 

1539 tensor_args = [] 

1540 tensor_indices = [] 

1541 for i, arg in enumerate(args): 

1542 if isinstance(arg, (ops.Tensor, variables.Variable)): 

1543 tensor_args.append(arg) 

1544 tensor_indices.append(i) 

1545 

1546 def inner_f(*inner_tensor_args): 

1547 my_args = list(args) 

1548 for i, n in zip(tensor_indices, inner_tensor_args): 

1549 my_args[i] = n 

1550 return f(*my_args, **kwds) 

1551 

1552 return script_ops.py_func(inner_f, tensor_args, []) 

1553 

1554 return tf_decorator.make_decorator(f, decorated) 

1555 

1556 

1557def also_run_as_tf_function(f): 

1558 """Runs the decorated test twice--once as is, once inside a tf.function. 

1559 

1560 This allows you to run a test both in eager execution and inside a 

1561 tf.function, exercising the two execution modes supported in tf 2.0. The test 

1562 assertions are automatically done inside tf.py_funcs, and tf.function ensures 

1563 that they run in the proper order and with the proper side effects. 

1564 

1565 Currently variable creation is not supported in tests annotated with this 

1566 decorator since it's tricky to ensure the variable doesn't get repeatedly 

1567 created when retracing the tf.function. 

1568 

1569 Args: 

1570 f: the test method to be decorated 

1571 

1572 Returns: 

1573 The decorated test method, which will run both in eager and inside a 

1574 tf.function. 

1575 """ 

1576 

1577 def decorated(*args, **kwds): 

1578 

1579 def bound_f(): 

1580 f(*args, **kwds) 

1581 

1582 with context.eager_mode(): 

1583 # Running in eager mode 

1584 bound_f() 

1585 # Running as TF function 

1586 # TODO(b/121143941): Remove the autograph override. 

1587 def_function.function(bound_f, autograph=False)() 

1588 

1589 return decorated 

1590 

1591 

1592def deprecated_graph_mode_only(func=None): 

1593 """Execute the decorated test in graph mode. 

1594 

1595 This function returns a decorator intended to be applied to tests that are not 

1596 compatible with eager mode. When this decorator is applied, the test body will 

1597 be run in an environment where API calls construct graphs instead of executing 

1598 eagerly. 

1599 

1600 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 

1601 `run_in_graph_and_eager_modes` are available decorators for different 

1602 v1/v2/eager/graph combinations. 

1603 

1604 Args: 

1605 func: function to be annotated. If `func` is None, this method returns a 

1606 decorator the can be applied to a function. If `func` is not None this 

1607 returns the decorator applied to `func`. 

1608 

1609 Returns: 

1610 Returns a decorator that will run the decorated test method in graph mode. 

1611 """ 

1612 

1613 def decorator(f): 

1614 if tf_inspect.isclass(f): 

1615 setup = f.__dict__.get("setUp") 

1616 if setup is not None: 

1617 setattr(f, "setUp", decorator(setup)) 

1618 

1619 for name, value in f.__dict__.copy().items(): 

1620 if (callable(value) and 

1621 name.startswith(unittest.TestLoader.testMethodPrefix)): 

1622 setattr(f, name, decorator(value)) 

1623 

1624 return f 

1625 

1626 def decorated(self, *args, **kwargs): 

1627 if context.executing_eagerly(): 

1628 with context.graph_mode(): 

1629 return f(self, *args, **kwargs) 

1630 else: 

1631 return f(self, *args, **kwargs) 

1632 

1633 return decorated 

1634 

1635 if func is not None: 

1636 return decorator(func) 

1637 

1638 return decorator 

1639 

1640 

1641run_deprecated_v1 = deprecated_graph_mode_only 

1642 

1643 

1644def run_all_in_deprecated_graph_mode_only(cls): 

1645 """Execute all tests in a class in graph mode.""" 

1646 base_decorator = deprecated_graph_mode_only 

1647 for name in dir(cls): 

1648 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or 

1649 name == "test_session"): 

1650 continue 

1651 value = getattr(cls, name, None) 

1652 if callable(value): 

1653 setattr(cls, name, base_decorator(value)) 

1654 return cls 

1655 

1656 

1657def run_v1_only(reason, func=None): 

1658 """Execute the decorated test only if running in v1 mode. 

1659 

1660 This function is intended to be applied to tests that exercise v1 only 

1661 functionality. If the test is run in v2 mode it will simply be skipped. 

1662 

1663 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 

1664 `run_in_graph_and_eager_modes` are available decorators for different 

1665 v1/v2/eager/graph combinations. 

1666 

1667 Args: 

1668 reason: string giving a reason for limiting the test to v1 only. 

1669 func: function to be annotated. If `func` is None, this method returns a 

1670 decorator the can be applied to a function. If `func` is not None this 

1671 returns the decorator applied to `func`. 

1672 

1673 Returns: 

1674 Returns a decorator that will conditionally skip the decorated test method. 

1675 """ 

1676 if not isinstance(reason, str): 

1677 raise ValueError("'reason' should be string, got {}".format(type(reason))) 

1678 

1679 def decorator(f): 

1680 if tf_inspect.isclass(f): 

1681 # To skip an entire test suite class, we only decorate the setUp method 

1682 # to skip all tests. There are cases when setUp is not defined (not 

1683 # overridden in subclasses of TestCase, so not available in f.__dict__ 

1684 # below). For those cases, we walk the method resolution order list and 

1685 # pick the first setUp method we find (usually this should be the one in 

1686 # the parent class since that's the TestCase class). 

1687 for cls in type.mro(f): 

1688 setup = cls.__dict__.get("setUp") 

1689 if setup is not None: 

1690 setattr(f, "setUp", decorator(setup)) 

1691 break 

1692 

1693 return f 

1694 else: 

1695 # If f is just a function, just create a decorator for it and return it 

1696 def decorated(self, *args, **kwargs): 

1697 if tf2.enabled(): 

1698 self.skipTest(reason) 

1699 

1700 return f(self, *args, **kwargs) 

1701 

1702 return decorated 

1703 

1704 if func is not None: 

1705 return decorator(func) 

1706 

1707 return decorator 

1708 

1709 

1710def run_v2_only(func=None): 

1711 """Execute the decorated test only if running in v2 mode. 

1712 

1713 This function is intended to be applied to tests that exercise v2 only 

1714 functionality. If the test is run in v1 mode it will simply be skipped. 

1715 

1716 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and 

1717 `run_in_graph_and_eager_modes` are available decorators for different 

1718 v1/v2/eager/graph combinations. 

1719 

1720 Args: 

1721 func: function to be annotated. If `func` is None, this method returns a 

1722 decorator the can be applied to a function. If `func` is not None this 

1723 returns the decorator applied to `func`. 

1724 

1725 Returns: 

1726 Returns a decorator that will conditionally skip the decorated test method. 

1727 """ 

1728 

1729 def decorator(f): 

1730 if tf_inspect.isclass(f): 

1731 raise ValueError("`run_v2_only` only supports test methods.") 

1732 

1733 def decorated(self, *args, **kwargs): 

1734 if not tf2.enabled(): 

1735 self.skipTest("Test is only compatible with v2") 

1736 

1737 return f(self, *args, **kwargs) 

1738 

1739 return decorated 

1740 

1741 if func is not None: 

1742 return decorator(func) 

1743 

1744 return decorator 

1745 

1746 

1747def run_gpu_only(func=None): 

1748 """Execute the decorated test only if a GPU is available. 

1749 

1750 This function is intended to be applied to tests that require the presence 

1751 of a GPU. If a GPU is absent, it will simply be skipped. 

1752 

1753 Args: 

1754 func: function to be annotated. If `func` is None, this method returns a 

1755 decorator the can be applied to a function. If `func` is not None this 

1756 returns the decorator applied to `func`. 

1757 

1758 Returns: 

1759 Returns a decorator that will conditionally skip the decorated test method. 

1760 """ 

1761 

1762 def decorator(f): 

1763 if tf_inspect.isclass(f): 

1764 raise ValueError("`run_gpu_only` only supports test methods.") 

1765 

1766 def decorated(self, *args, **kwargs): 

1767 if not is_gpu_available(): 

1768 self.skipTest("Test requires GPU") 

1769 

1770 return f(self, *args, **kwargs) 

1771 

1772 return decorated 

1773 

1774 if func is not None: 

1775 return decorator(func) 

1776 

1777 return decorator 

1778 

1779 

1780def run_cuda_only(func=None): 

1781 """Execute the decorated test only if a GPU is available. 

1782 

1783 This function is intended to be applied to tests that require the presence 

1784 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped. 

1785 

1786 Args: 

1787 func: function to be annotated. If `func` is None, this method returns a 

1788 decorator the can be applied to a function. If `func` is not None this 

1789 returns the decorator applied to `func`. 

1790 

1791 Returns: 

1792 Returns a decorator that will conditionally skip the decorated test method. 

1793 """ 

1794 

1795 def decorator(f): 

1796 if tf_inspect.isclass(f): 

1797 raise ValueError("`run_cuda_only` only supports test methods.") 

1798 

1799 def decorated(self, *args, **kwargs): 

1800 if not is_gpu_available(cuda_only=True): 

1801 self.skipTest("Test requires CUDA GPU") 

1802 

1803 return f(self, *args, **kwargs) 

1804 

1805 return decorated 

1806 

1807 if func is not None: 

1808 return decorator(func) 

1809 

1810 return decorator 

1811 

1812 

1813def run_gpu_or_tpu(func=None): 

1814 """Execute the decorated test only if a physical GPU or TPU is available. 

1815 

1816 This function is intended to be applied to tests that require the presence 

1817 of a physical GPU or TPU. It complies with the following rules: 

1818 - If a GPU is available, the test will run on the GPU. 

1819 - If a GPU is absent and a TPU is available, the test will run on the TPU. 

1820 - If both GPU and TPU are absent, the test will be skipped. 

1821 

1822 Args: 

1823 func: function to be annotated. If `func` is None, this method returns a 

1824 decorator the can be applied to a function. If `func` is not None this 

1825 returns the decorator applied to `func`. 

1826 

1827 Returns: 

1828 Returns a decorator that will conditionally skip the decorated test method. 

1829 """ 

1830 

1831 def decorator(f): 

1832 if tf_inspect.isclass(f): 

1833 raise ValueError("`run_gpu_or_tpu` only supports test methods.") 

1834 

1835 def decorated(self, *args, **kwargs): 

1836 if config.list_physical_devices("GPU"): 

1837 return f(self, "GPU", *args, **kwargs) 

1838 

1839 if config.list_physical_devices("TPU"): 

1840 return f(self, "TPU", *args, **kwargs) 

1841 

1842 self.skipTest("Test requires GPU or TPU") 

1843 

1844 return decorated 

1845 

1846 return decorator if func is None else decorator(func) 

1847 

1848 

1849def with_forward_compatibility_horizons(*horizons): 

1850 """Executes the decorated test with the specified forward-compat horizons. 

1851 

1852 Args: 

1853 *horizons: A list of (year, month, day) tuples. If the list includes 

1854 `None`, then the test will also be run with no forward-compatibility 

1855 horizon set. 

1856 

1857 Returns: 

1858 A decorator that will execute the test with the specified horizons. 

1859 """ 

1860 if not horizons: 

1861 raise ValueError("Expected at least one horizon.") 

1862 for horizon in horizons: 

1863 if not ((horizon is None) or 

1864 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))): 

1865 raise ValueError("Bad horizon value: %r" % horizon) 

1866 

1867 def decorator(f): 

1868 if tf_inspect.isclass(f): 

1869 raise ValueError("`with_forward_compatibility_horizons` only " 

1870 "supports test methods.") 

1871 def decorated(self, *args, **kwargs): 

1872 for horizon in horizons: 

1873 if horizon is None: 

1874 f(self, *args, **kwargs) 

1875 else: 

1876 (year, month, day) = horizon 

1877 with forward_compatibility_horizon(year, month, day): 

1878 f(self, *args, **kwargs) 

1879 return decorated 

1880 

1881 return decorator 

1882 

1883 

1884@deprecation.deprecated(None, 

1885 "Use `tf.config.list_physical_devices('GPU')` instead.") 

1886@tf_export("test.is_gpu_available") 

1887def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None): 

1888 """Returns whether TensorFlow can access a GPU. 

1889 

1890 Warning: if a non-GPU version of the package is installed, the function would 

1891 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow 

1892 was build with CUDA support. 

1893 

1894 For example, 

1895 >>> gpu_available = tf.test.is_gpu_available() 

1896 >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True) 

1897 >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0)) 

1898 

1899 Args: 

1900 cuda_only: limit the search to CUDA GPUs. 

1901 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum 

1902 CUDA compute capability required, or None if no requirement. 

1903 

1904 Note that the keyword arg name "cuda_only" is misleading (since routine will 

1905 return true when a GPU device is available irrespective of whether TF was 

1906 built with CUDA support or ROCm support. However no changes here because 

1907 

1908 ++ Changing the name "cuda_only" to something more generic would break 

1909 backward compatibility 

1910 

1911 ++ Adding an equivalent "rocm_only" would require the implementation check 

1912 the build type. This in turn would require doing the same for CUDA and thus 

1913 potentially break backward compatibility 

1914 

1915 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility, 

1916 but would require most (if not all) callers to update the call to use 

1917 "cuda_or_rocm_only" instead of "cuda_only" 

1918 

1919 Returns: 

1920 True if a GPU device of the requested kind is available. 

1921 """ 

1922 

1923 # This was needed earlier when we had support for SYCL in TensorFlow. 

1924 del cuda_only 

1925 

1926 try: 

1927 for local_device in device_lib.list_local_devices(): 

1928 if local_device.device_type == "GPU": 

1929 gpu_info = gpu_util.compute_capability_from_device_desc(local_device) 

1930 cc = gpu_info.compute_capability or (0, 0) 

1931 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability: 

1932 return True 

1933 return False 

1934 except errors_impl.NotFoundError as e: 

1935 if not all(x in str(e) for x in ["CUDA", "not find"]): 

1936 raise e 

1937 else: 

1938 logging.error(str(e)) 

1939 return False 

1940 

1941 

1942@contextlib.contextmanager 

1943def device(use_gpu): 

1944 """Uses gpu when requested and available.""" 

1945 if use_gpu and is_gpu_available(): 

1946 dev = "/device:GPU:0" 

1947 else: 

1948 dev = "/device:CPU:0" 

1949 with ops.device(dev): 

1950 yield 

1951 

1952 

1953@contextlib.contextmanager 

1954def use_gpu(): 

1955 """Uses gpu when requested and available.""" 

1956 with device(use_gpu=True): 

1957 yield 

1958 

1959 

1960@contextlib.contextmanager 

1961def force_gpu(): 

1962 """Force the gpu to be used.""" 

1963 with ops.device("/device:GPU:0"): 

1964 yield 

1965 

1966 

1967@contextlib.contextmanager 

1968def force_cpu(): 

1969 """Force the cpu to be used.""" 

1970 with ops.device("/device:CPU:0"): 

1971 yield 

1972 

1973 

1974@contextlib.contextmanager 

1975def deterministic_ops(): 

1976 """Enables deterministic ops.""" 

1977 try: 

1978 config.enable_op_determinism() 

1979 yield 

1980 finally: 

1981 config.disable_op_determinism() 

1982 

1983 

1984class CapturedWrites: 

1985 """A utility class to load the captured writes made to a stream.""" 

1986 

1987 def __init__(self, capture_location): 

1988 self.capture_location = capture_location 

1989 

1990 def contents(self): 

1991 """Get the captured writes as a single string.""" 

1992 with open(self.capture_location) as tmp_file: 

1993 output_data = "".join(tmp_file.readlines()) 

1994 return output_data 

1995 

1996 

1997class FakeEagerSession: 

1998 """Fake session so tests that conditionally use placeholders can use eager. 

1999 

2000 There are a number of tests that conditionally use placeholders for shape 

2001 inference. The pattern is demonstrated here: 

2002 

2003 ```python 

2004 with self.cached_session() as sess: 

2005 if static_shape: 

2006 y = math_ops.matmul(x, ...) 

2007 feed_dict = {} 

2008 else: 

2009 x_ph = array_ops.placeholder(...) 

2010 y = math_ops.matmul(x_ph, ...) 

2011 feed_dict = {x_ph: x} 

2012 val = sess.run(y, feed_dict=feed_dict) 

2013 ``` 

2014 

2015 Since the feed_dict is empty when not using placeholders we should be able to 

2016 call self.evaluate(), however this requires rewriting the test case. 

2017 This class should be considered a stop-gap solution to get tests running with 

2018 eager with minimal changes to the actual test. 

2019 """ 

2020 

2021 def __init__(self, test_case): 

2022 self._test_case = test_case 

2023 

2024 def run(self, fetches, *args, **kwargs): 

2025 """Evaluate `fetches`. 

2026 

2027 Fail if additional args are specified. 

2028 

2029 Args: 

2030 fetches: A Tensor or a nested list/tuple of Tensors. 

2031 *args: Positional arguments 

2032 **kwargs: Keyword arguments 

2033 

2034 Raises: 

2035 RuntimeError: If args or kwargs are specified. 

2036 

2037 Returns: 

2038 Tensors as numpy values. 

2039 """ 

2040 feed_dict = kwargs.pop("feed_dict", {}) 

2041 if feed_dict: 

2042 raise RuntimeError( 

2043 "feed_dict is not supported when eager execution is enabled " 

2044 "(in this case, sess.run(t) is shorthand for t.numpy()") 

2045 

2046 if args or kwargs: 

2047 raise RuntimeError( 

2048 "Optional args are not supported when eager execution is enabled " 

2049 "(in this case, sess.run(t) is shorthand for t.numpy()") 

2050 

2051 return self._test_case.evaluate(fetches) 

2052 

2053 

2054class ErrorLoggingSession(session.Session): 

2055 """Wrapper around a Session that logs errors in run().""" 

2056 

2057 def run(self, *args, **kwargs): 

2058 try: 

2059 return super().run(*args, **kwargs) 

2060 except Exception as e: # pylint: disable=broad-except 

2061 # Note: disable the logging for OutOfRangeError, which makes the output 

2062 # of tf.data tests hard to read, because OutOfRangeError is used as the 

2063 # signal completion 

2064 if not isinstance(e, errors.OutOfRangeError): 

2065 logging.error(str(e)) 

2066 raise 

2067 

2068 

2069def disable_cudnn_autotune(func): 

2070 """Disable autotuning during the call to this function. 

2071 

2072 Some tests want to base assertions on a graph being isomorphic with a copy. 

2073 To ensure this, this decorator disables autotuning. 

2074 

2075 Args: 

2076 func: Function to run with CuDNN autotuning turned off. 

2077 

2078 Returns: 

2079 Decorated function. 

2080 """ 

2081 

2082 def decorator(f): 

2083 

2084 def decorated(self, *args, **kwargs): 

2085 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE") 

2086 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false" 

2087 original_xla_flags = os.environ.get("XLA_FLAGS") 

2088 new_xla_flags = "--xla_gpu_autotune_level=0" 

2089 if original_xla_flags: 

2090 new_xla_flags = original_xla_flags + " " + new_xla_flags 

2091 os.environ["XLA_FLAGS"] = new_xla_flags 

2092 

2093 result = f(self, *args, **kwargs) 

2094 

2095 if (original_tf_cudnn_use_autotune is None): 

2096 del os.environ["TF_CUDNN_USE_AUTOTUNE"] 

2097 else: 

2098 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune 

2099 if (original_xla_flags is None): 

2100 del os.environ["XLA_FLAGS"] 

2101 else: 

2102 os.environ["XLA_FLAGS"] = original_xla_flags 

2103 

2104 return result 

2105 

2106 return tf_decorator.make_decorator(func, decorated) 

2107 

2108 if func is not None: 

2109 return decorator(func) 

2110 

2111 return decorator 

2112 

2113 

2114# The description is just for documentation purposes. 

2115def enable_tf_xla_constant_folding(description): 

2116 

2117 if not isinstance(description, str): 

2118 raise ValueError("'description' should be string, got {}".format( 

2119 type(description))) 

2120 

2121 def enable_tf_xla_constant_folding_impl(func): 

2122 """Enable constant folding during the call to this function. 

2123 

2124 Some tests fail without constant folding. 

2125 

2126 Args: 

2127 func: Function to run with constant folding turned on. 

2128 

2129 Returns: 

2130 Decorated function. 

2131 """ 

2132 

2133 def decorator(f): 

2134 

2135 def decorated(self, *args, **kwargs): 

2136 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled() 

2137 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False) 

2138 result = f(self, *args, **kwargs) 

2139 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var) 

2140 return result 

2141 

2142 return decorated 

2143 

2144 if func is not None: 

2145 return decorator(func) 

2146 

2147 return decorator 

2148 

2149 return enable_tf_xla_constant_folding_impl 

2150 

2151 

2152# Updates test function by selectively disabling it. 

2153def _disable_test(execute_func): 

2154 

2155 def disable_test_impl(func): 

2156 

2157 def decorator(func): 

2158 

2159 def decorated(self, *args, **kwargs): 

2160 if execute_func: 

2161 return func(self, *args, **kwargs) 

2162 

2163 return tf_decorator.make_decorator(func, decorated) 

2164 

2165 if func is not None: 

2166 return decorator(func) 

2167 

2168 return decorator 

2169 

2170 return disable_test_impl 

2171 

2172 

2173# The description is just for documentation purposes. 

2174def disable_xla(description): # pylint: disable=unused-argument 

2175 """Execute the test method only if xla is not enabled.""" 

2176 execute_func = not is_xla_enabled() 

2177 return _disable_test(execute_func) 

2178 

2179 

2180# The description is just for documentation purposes. 

2181def disable_mlir_bridge(description): # pylint: disable=unused-argument 

2182 """Execute the test method only if MLIR bridge is not enabled.""" 

2183 execute_func = not is_mlir_bridge_enabled() 

2184 return _disable_test(execute_func) 

2185 

2186 

2187# The description is just for documentation purposes. 

2188def disable_asan(description): # pylint: disable=unused-argument 

2189 """Execute the test method only if ASAN is not enabled.""" 

2190 execute_func = not is_asan_enabled() 

2191 return _disable_test(execute_func) 

2192 

2193 

2194# The description is just for documentation purposes. 

2195def disable_msan(description): # pylint: disable=unused-argument 

2196 """Execute the test method only if MSAN is not enabled.""" 

2197 execute_func = not is_msan_enabled() 

2198 return _disable_test(execute_func) 

2199 

2200 

2201# The description is just for documentation purposes. 

2202def disable_tsan(description): # pylint: disable=unused-argument 

2203 """Execute the test method only if TSAN is not enabled.""" 

2204 execute_func = not is_tsan_enabled() 

2205 return _disable_test(execute_func) 

2206 

2207 

2208# The description is just for documentation purposes. 

2209def disable_ubsan(description): # pylint: disable=unused-argument 

2210 """Execute the test method only if UBSAN is not enabled.""" 

2211 execute_func = not is_ubsan_enabled() 

2212 return _disable_test(execute_func) 

2213 

2214 

2215# The description is just for documentation purposes. 

2216def disable_tfrt(unused_description): 

2217 

2218 def disable_tfrt_impl(cls_or_func): 

2219 """Execute the test only if tfrt is not enabled.""" 

2220 

2221 if tf_inspect.isclass(cls_or_func): 

2222 if tfrt_utils.enabled(): 

2223 return None 

2224 else: 

2225 return cls_or_func 

2226 else: 

2227 def decorator(func): 

2228 

2229 def decorated(self, *args, **kwargs): 

2230 if tfrt_utils.enabled(): 

2231 return 

2232 else: 

2233 return func(self, *args, **kwargs) 

2234 

2235 return decorated 

2236 

2237 if cls_or_func is not None: 

2238 return decorator(cls_or_func) 

2239 

2240 return decorator 

2241 

2242 return disable_tfrt_impl 

2243 

2244 

2245def for_all_test_methods(decorator, *args, **kwargs): 

2246 """Generate class-level decorator from given method-level decorator. 

2247 

2248 It is expected for the given decorator to take some arguments and return 

2249 a method that is then called on the test method to produce a decorated 

2250 method. 

2251 

2252 Args: 

2253 decorator: The decorator to apply. 

2254 *args: Positional arguments 

2255 **kwargs: Keyword arguments 

2256 Returns: Function that will decorate a given classes test methods with the 

2257 decorator. 

2258 """ 

2259 

2260 def all_test_methods_impl(cls): 

2261 """Apply decorator to all test methods in class.""" 

2262 for name in dir(cls): 

2263 value = getattr(cls, name) 

2264 if callable(value) and name.startswith( 

2265 "test") and (name != "test_session"): 

2266 setattr(cls, name, decorator(*args, **kwargs)(value)) 

2267 return cls 

2268 

2269 return all_test_methods_impl 

2270 

2271 

2272# The description is just for documentation purposes. 

2273def no_xla_auto_jit(description): # pylint: disable=unused-argument 

2274 """This test is not intended to be run with XLA auto jit enabled.""" 

2275 execute_func = not is_xla_enabled() 

2276 return _disable_test(execute_func) 

2277 

2278 

2279# The description is just for documentation purposes. 

2280def xla_allow_fallback(description): # pylint: disable=unused-argument 

2281 

2282 def xla_allow_fallback_impl(func): 

2283 """Allow fallback to TF even though testing xla.""" 

2284 

2285 def decorator(func): 

2286 

2287 def decorated(self, *args, **kwargs): 

2288 if is_xla_enabled(): 

2289 # Update the global XLABuildOpsPassFlags to enable lazy compilation, 

2290 # which allows the compiler to fall back to TF classic. Remember the 

2291 # old value so that we can reset it. 

2292 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True) 

2293 result = func(self, *args, **kwargs) 

2294 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value) 

2295 return result 

2296 else: 

2297 return func(self, *args, **kwargs) 

2298 

2299 return decorated 

2300 

2301 if func is not None: 

2302 return decorator(func) 

2303 

2304 return decorator 

2305 

2306 return xla_allow_fallback_impl 

2307 

2308 

2309# The description is just for documentation purposes. 

2310def run_without_tensor_float_32(description): # pylint: disable=unused-argument 

2311 """Execute test with TensorFloat-32 disabled. 

2312 

2313 While almost every real-world deep learning model runs fine with 

2314 TensorFloat-32, many tests use assertAllClose or similar methods. 

2315 TensorFloat-32 matmuls typically will cause such methods to fail with the 

2316 default tolerances. 

2317 

2318 Args: 

2319 description: A description used for documentation purposes, describing why 

2320 the test requires TensorFloat-32 to be disabled. 

2321 

2322 Returns: 

2323 Decorator which runs a test with TensorFloat-32 disabled. 

2324 """ 

2325 

2326 def decorator(f): 

2327 

2328 @functools.wraps(f) 

2329 def decorated(self, *args, **kwargs): 

2330 allowed = config.tensor_float_32_execution_enabled() 

2331 try: 

2332 config.enable_tensor_float_32_execution(False) 

2333 f(self, *args, **kwargs) 

2334 finally: 

2335 config.enable_tensor_float_32_execution(allowed) 

2336 

2337 return decorated 

2338 

2339 return decorator 

2340 

2341 

2342# The description is just for documentation purposes. 

2343def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument 

2344 """Execute all tests in a class with TensorFloat-32 disabled.""" 

2345 return for_all_test_methods(run_without_tensor_float_32, description) 

2346 

2347 

2348def matmul_without_tf32(a, b, *args, **kwargs): 

2349 """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled. 

2350 

2351 This effectively runs matmul without TensorFloat-32. It should only be used in 

2352 tests when verifying some other op or functions works correctly, e.g. to test 

2353 `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In 

2354 such cases, the matmul itself is not being tested so it's OK to run it with 

2355 higher precision. 

2356 

2357 If a matmul itself is being tested, or some other op which uses matmul, use 

2358 `run_without_tensor_float_32` instead. 

2359 

2360 This also casts complex64 inputs to complex128, since TensorFloat-32 can also 

2361 be used with complex64 

2362 

2363 Args: 

2364 a: First input to tf.linalg.matmul 

2365 b: Second input to tf.linalg.matmul 

2366 args: Other positional arguments to tf.linalg.matmul 

2367 **kwargs: Other keyword arguments to tf.linalg.matmul 

2368 

2369 Returns: 

2370 A tensor with the same type as `a`. 

2371 """ 

2372 if config.tensor_float_32_execution_enabled() and a.dtype == "float32": 

2373 a = math_ops.cast(a, "float64") 

2374 b = math_ops.cast(b, "float64") 

2375 ret = math_ops.matmul(a, b, *args, **kwargs) 

2376 return math_ops.cast(ret, a.dtype) 

2377 elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64": 

2378 a = math_ops.cast(a, "complex128") 

2379 b = math_ops.cast(b, "complex128") 

2380 ret = math_ops.matmul(a, b, *args, **kwargs) 

2381 return math_ops.cast(ret, a.dtype) 

2382 else: 

2383 return math_ops.matmul(a, b, *args, **kwargs) 

2384 

2385 

2386class EagerSessionWarner: 

2387 

2388 def __getattr__(self, attr): 

2389 raise AttributeError( 

2390 "Trying to access properties or call methods on the result of " 

2391 "self.session(), self.cached_session(), etc while eager execution " 

2392 "is enabled. If you're porting this test case to TF 2.0, either " 

2393 "adapt the test to work with eager execution or insert a call to " 

2394 "tf.disable_eager_execution() in the main() function of this test " 

2395 "file.") 

2396 

2397 

2398@tf_export("test.TestCase") 

2399class TensorFlowTestCase(googletest.TestCase): 

2400 """Base class for tests that need to test TensorFlow.""" 

2401 

2402 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name 

2403 super().__init__(methodName) 

2404 # Make sure we get unfiltered stack traces during the test 

2405 traceback_utils.disable_traceback_filtering() 

2406 if is_xla_enabled(): 

2407 pywrap_tf_session.TF_SetXlaAutoJitMode("2") 

2408 pywrap_tf_session.TF_SetXlaMinClusterSize(1) 

2409 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False) 

2410 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True) 

2411 # Constant folding secretly runs code on TF:Classic CPU, so we also 

2412 # disable it here. 

2413 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True) 

2414 

2415 # Check if the mlir bridge has been explicitly enabled or disabled. If 

2416 # is_mlir_bridge_enabled() returns None, the user did not explictly enable 

2417 # or disable the bridge so do not update enable_mlir_bridge. 

2418 if is_mlir_bridge_enabled(): 

2419 context.context().enable_mlir_bridge = True 

2420 elif is_mlir_bridge_enabled() is not None: 

2421 context.context().enable_mlir_bridge = False 

2422 

2423 self._threads = [] 

2424 self._tempdir = None 

2425 self._cached_session = None 

2426 self._test_start_time = None 

2427 # This flag provides the ability to control whether the graph mode gets 

2428 # initialized for TF1 or not. Initializing for TF1, which is what was 

2429 # happening earlier, was preventing enablement of 'eager mode' in the test. 

2430 self._set_default_seed = True 

2431 

2432 def setUp(self): 

2433 super().setUp() 

2434 self._ClearCachedSession() 

2435 random.seed(random_seed.DEFAULT_GRAPH_SEED) 

2436 np.random.seed(random_seed.DEFAULT_GRAPH_SEED) 

2437 # Note: The following line is necessary because some test methods may error 

2438 # out from within nested graph contexts (e.g., via assertRaises and 

2439 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty 

2440 # under certain versions of Python. That would cause 

2441 # ops.reset_default_graph() to throw an exception if the stack were not 

2442 # cleared first. 

2443 ops._default_graph_stack.reset() # pylint: disable=protected-access 

2444 ops.reset_default_graph() 

2445 if self._set_default_seed: 

2446 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED) 

2447 # Reset summary writer in case another test used set_as_default() with their 

2448 # summary writer. 

2449 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 

2450 summary_state.writer = None 

2451 

2452 # Avoiding calling setUp() for the poorly named test_session method. 

2453 if self.id().endswith(".test_session"): 

2454 self.skipTest("Not a test.") 

2455 

2456 self._test_start_time = time.time() 

2457 

2458 def tearDown(self): 

2459 # If a subclass overrides setUp and doesn't call the parent class's setUp, 

2460 # then we may not have set the start time. 

2461 if self._test_start_time is not None: 

2462 logging.info("time(%s): %ss", self.id(), 

2463 round(time.time() - self._test_start_time, 2)) 

2464 

2465 for thread in self._threads: 

2466 thread.check_termination() 

2467 

2468 self._ClearCachedSession() 

2469 super().tearDown() 

2470 

2471 def _ClearCachedSession(self): 

2472 if self._cached_session is not None: 

2473 self._cached_session.close() 

2474 self._cached_session = None 

2475 

2476 def get_temp_dir(self): 

2477 """Returns a unique temporary directory for the test to use. 

2478 

2479 If you call this method multiple times during in a test, it will return the 

2480 same folder. However, across different runs the directories will be 

2481 different. This will ensure that across different runs tests will not be 

2482 able to pollute each others environment. 

2483 If you need multiple unique directories within a single test, you should 

2484 use tempfile.mkdtemp as follows: 

2485 tempfile.mkdtemp(dir=self.get_temp_dir()): 

2486 

2487 Returns: 

2488 string, the path to the unique temporary directory created for this test. 

2489 """ 

2490 if not self._tempdir: 

2491 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir()) 

2492 return self._tempdir 

2493 

2494 @contextlib.contextmanager 

2495 def captureWritesToStream(self, stream): 

2496 """A context manager that captures the writes to a given stream. 

2497 

2498 This context manager captures all writes to a given stream inside of a 

2499 `CapturedWrites` object. When this context manager is created, it yields 

2500 the `CapturedWrites` object. The captured contents can be accessed by 

2501 calling `.contents()` on the `CapturedWrites`. 

2502 

2503 For this function to work, the stream must have a file descriptor that 

2504 can be modified using `os.dup` and `os.dup2`, and the stream must support 

2505 a `.flush()` method. The default python sys.stdout and sys.stderr are 

2506 examples of this. Note that this does not work in Colab or Jupyter 

2507 notebooks, because those use alternate stdout streams. 

2508 

2509 Example: 

2510 ```python 

2511 class MyOperatorTest(test_util.TensorFlowTestCase): 

2512 def testMyOperator(self): 

2513 input = [1.0, 2.0, 3.0, 4.0, 5.0] 

2514 with self.captureWritesToStream(sys.stdout) as captured: 

2515 result = MyOperator(input).eval() 

2516 self.assertStartsWith(captured.contents(), "This was printed.") 

2517 ``` 

2518 

2519 Args: 

2520 stream: The stream whose writes should be captured. This stream must have 

2521 a file descriptor, support writing via using that file descriptor, and 

2522 must have a `.flush()` method. 

2523 

2524 Yields: 

2525 A `CapturedWrites` object that contains all writes to the specified stream 

2526 made during this context. 

2527 """ 

2528 stream.flush() 

2529 fd = stream.fileno() 

2530 tmp_file, tmp_file_path = tempfile.mkstemp(dir=self.get_temp_dir()) 

2531 orig_fd = os.dup(fd) 

2532 os.dup2(tmp_file, fd) 

2533 try: 

2534 yield CapturedWrites(tmp_file_path) 

2535 finally: 

2536 os.close(tmp_file) 

2537 os.dup2(orig_fd, fd) 

2538 

2539 def _AssertProtoEquals(self, a, b, msg=None, relative_tolerance=None): 

2540 """Asserts that a and b are the same proto. 

2541 

2542 Uses ProtoEq() first, as it returns correct results 

2543 for floating point attributes, and then use assertProtoEqual() 

2544 in case of failure as it provides good error messages. 

2545 

2546 Args: 

2547 a: a proto. 

2548 b: another proto. 

2549 msg: Optional message to report on failure. 

2550 relative_tolerance: float. The allowable difference between the two values 

2551 being compared is determined by multiplying the relative tolerance by 

2552 the maximum of the two values. If this is not provided, then all floats 

2553 are compared using string comparison. 

2554 """ 

2555 if not compare.ProtoEq(a, b): 

2556 compare.assertProtoEqual( 

2557 self, 

2558 a, 

2559 b, 

2560 normalize_numbers=True, 

2561 msg=msg, 

2562 relative_tolerance=relative_tolerance, 

2563 ) 

2564 

2565 def assertProtoEquals( 

2566 self, 

2567 expected_message_maybe_ascii, 

2568 message, 

2569 msg=None, 

2570 relative_tolerance=None, 

2571 ): 

2572 """Asserts that message is same as parsed expected_message_ascii. 

2573 

2574 Creates another prototype of message, reads the ascii message into it and 

2575 then compares them using self._AssertProtoEqual(). 

2576 

2577 Args: 

2578 expected_message_maybe_ascii: proto message in original or ascii form. 

2579 message: the message to validate. 

2580 msg: Optional message to report on failure. 

2581 relative_tolerance: float. The allowable difference between the two values 

2582 being compared is determined by multiplying the relative tolerance by 

2583 the maximum of the two values. If this is not provided, then all floats 

2584 are compared using string comparison. 

2585 """ 

2586 if isinstance(expected_message_maybe_ascii, type(message)): 

2587 expected_message = expected_message_maybe_ascii 

2588 self._AssertProtoEquals( 

2589 expected_message, 

2590 message, 

2591 msg=msg, 

2592 relative_tolerance=relative_tolerance, 

2593 ) 

2594 elif isinstance(expected_message_maybe_ascii, (str, bytes)): 

2595 expected_message = type(message)() 

2596 text_format.Merge( 

2597 expected_message_maybe_ascii, 

2598 expected_message, 

2599 descriptor_pool=descriptor_pool.Default()) 

2600 self._AssertProtoEquals( 

2601 expected_message, 

2602 message, 

2603 msg=msg, 

2604 relative_tolerance=relative_tolerance, 

2605 ) 

2606 else: 

2607 assert False, ("Can't compare protos of type %s and %s." % 

2608 (type(expected_message_maybe_ascii), type(message))) 

2609 

2610 def assertProtoEqualsVersion( 

2611 self, 

2612 expected, 

2613 actual, 

2614 producer=versions.GRAPH_DEF_VERSION, 

2615 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER, 

2616 msg=None): 

2617 expected = "versions { producer: %d min_consumer: %d };\n%s" % ( 

2618 producer, min_consumer, expected) 

2619 self.assertProtoEquals(expected, actual, msg=msg) 

2620 

2621 def assertStartsWith(self, actual, expected_start, msg=None): 

2622 """Assert that actual.startswith(expected_start) is True. 

2623 

2624 Args: 

2625 actual: str 

2626 expected_start: str 

2627 msg: Optional message to report on failure. 

2628 """ 

2629 if not actual.startswith(expected_start): 

2630 fail_msg = "%r does not start with %r" % (actual, expected_start) 

2631 fail_msg += " : %r" % (msg) if msg else "" 

2632 self.fail(fail_msg) 

2633 

2634 def _eval_tensor(self, tensor): 

2635 if tensor is None: 

2636 return None 

2637 elif callable(tensor): 

2638 return self._eval_helper(tensor()) 

2639 else: 

2640 try: 

2641 # for compatibility with TF1 test cases 

2642 if sparse_tensor.is_sparse(tensor): 

2643 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(), 

2644 tensor.values.numpy(), 

2645 tensor.dense_shape.numpy()) 

2646 elif ragged_tensor.is_ragged(tensor): 

2647 return ragged_tensor_value.RaggedTensorValue( 

2648 self._eval_tensor(tensor.values), 

2649 self._eval_tensor(tensor.row_splits)) 

2650 elif isinstance(tensor, indexed_slices.IndexedSlices): 

2651 return indexed_slices.IndexedSlicesValue( 

2652 values=tensor.values.numpy(), 

2653 indices=tensor.indices.numpy(), 

2654 dense_shape=None 

2655 if tensor.dense_shape is None else tensor.dense_shape.numpy()) 

2656 else: 

2657 if hasattr(tensor, "numpy") and callable(tensor.numpy): 

2658 return tensor.numpy() 

2659 else: 

2660 # Try our best to convert CompositeTensor components to NumPy 

2661 # arrays. Officially, we don't support NumPy arrays as 

2662 # CompositeTensor components. So don't be surprised if this doesn't 

2663 # work. 

2664 return nest.map_structure(lambda t: t.numpy(), tensor, 

2665 expand_composites=True) 

2666 except AttributeError as e: 

2667 raise ValueError(f"Unsupported type {type(tensor).__name__!r}.") from e 

2668 

2669 def _eval_helper(self, tensors): 

2670 if tensors is None: 

2671 return None 

2672 return nest.map_structure(self._eval_tensor, tensors) 

2673 

2674 def evaluate(self, tensors): 

2675 """Evaluates tensors and returns numpy values. 

2676 

2677 Args: 

2678 tensors: A Tensor or a nested list/tuple of Tensors. 

2679 

2680 Returns: 

2681 tensors numpy values. 

2682 """ 

2683 if context.executing_eagerly(): 

2684 return self._eval_helper(tensors) 

2685 else: 

2686 sess = ops.get_default_session() 

2687 if sess is None: 

2688 with self.test_session() as sess: 

2689 return sess.run(tensors) 

2690 else: 

2691 return sess.run(tensors) 

2692 

2693 # pylint: disable=g-doc-return-or-yield 

2694 @contextlib.contextmanager 

2695 def session(self, graph=None, config=None, use_gpu=True, force_gpu=False): 

2696 """A context manager for a TensorFlow Session for use in executing tests. 

2697 

2698 Note that this will set this session and the graph as global defaults. 

2699 

2700 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 

2701 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 

2702 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 

2703 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 

2704 the CPU. 

2705 

2706 Example: 

2707 

2708 ``` python 

2709 class MyOperatorTest(test_util.TensorFlowTestCase): 

2710 def testMyOperator(self): 

2711 with self.session(): 

2712 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 

2713 result = MyOperator(valid_input).eval() 

2714 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 

2715 invalid_input = [-1.0, 2.0, 7.0] 

2716 with self.assertRaisesOpError("negative input not supported"): 

2717 MyOperator(invalid_input).eval() 

2718 ``` 

2719 

2720 Args: 

2721 graph: Optional graph to use during the returned session. 

2722 config: An optional config_pb2.ConfigProto to use to configure the 

2723 session. 

2724 use_gpu: If True, attempt to run as many ops as possible on GPU. 

2725 force_gpu: If True, pin all ops to `/device:GPU:0`. 

2726 

2727 Yields: 

2728 A Session object that should be used as a context manager to surround 

2729 the graph building and execution code in a test case. 

2730 """ 

2731 if context.executing_eagerly(): 

2732 yield EagerSessionWarner() 

2733 else: 

2734 with self._create_session(graph, config, force_gpu) as sess: 

2735 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu): 

2736 yield sess 

2737 

2738 @contextlib.contextmanager 

2739 def cached_session(self, 

2740 graph=None, 

2741 config=None, 

2742 use_gpu=True, 

2743 force_gpu=False): 

2744 """Returns a TensorFlow Session for use in executing tests. 

2745 

2746 This method behaves differently than self.session(): for performance reasons 

2747 `cached_session` will by default reuse the same session within the same 

2748 test. The session returned by this function will only be closed at the end 

2749 of the test (in the TearDown function). 

2750 

2751 Use the `use_gpu` and `force_gpu` options to control where ops are run. If 

2752 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if 

2753 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as 

2754 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to 

2755 the CPU. 

2756 

2757 Example: 

2758 ```python 

2759 class MyOperatorTest(test_util.TensorFlowTestCase): 

2760 def testMyOperator(self): 

2761 with self.cached_session() as sess: 

2762 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] 

2763 result = MyOperator(valid_input).eval() 

2764 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] 

2765 invalid_input = [-1.0, 2.0, 7.0] 

2766 with self.assertRaisesOpError("negative input not supported"): 

2767 MyOperator(invalid_input).eval() 

2768 ``` 

2769 

2770 Args: 

2771 graph: Optional graph to use during the returned session. 

2772 config: An optional config_pb2.ConfigProto to use to configure the 

2773 session. 

2774 use_gpu: If True, attempt to run as many ops as possible on GPU. 

2775 force_gpu: If True, pin all ops to `/device:GPU:0`. 

2776 

2777 Yields: 

2778 A Session object that should be used as a context manager to surround 

2779 the graph building and execution code in a test case. 

2780 """ 

2781 if context.executing_eagerly(): 

2782 yield FakeEagerSession(self) 

2783 else: 

2784 sess = self._get_cached_session( 

2785 graph, config, force_gpu, crash_if_inconsistent_args=True) 

2786 with self._constrain_devices_and_set_default(sess, use_gpu, 

2787 force_gpu) as cached: 

2788 yield cached 

2789 

2790 @contextlib.contextmanager 

2791 @deprecation.deprecated(None, "Use `self.session()` or " 

2792 "`self.cached_session()` instead.") 

2793 def test_session(self, 

2794 graph=None, 

2795 config=None, 

2796 use_gpu=True, 

2797 force_gpu=False): 

2798 """Use cached_session instead.""" 

2799 if self.id().endswith(".test_session"): 

2800 self.skipTest( 

2801 "Tests that have the name \"test_session\" are automatically skipped " 

2802 "by TensorFlow test fixture, as the name is reserved for creating " 

2803 "sessions within tests. Please rename your test if you have a test " 

2804 "with this name.") 

2805 if context.executing_eagerly(): 

2806 yield None 

2807 else: 

2808 if graph is None: 

2809 sess = self._get_cached_session( 

2810 graph, config, force_gpu, crash_if_inconsistent_args=False) 

2811 with self._constrain_devices_and_set_default(sess, use_gpu, 

2812 force_gpu) as cached: 

2813 yield cached 

2814 else: 

2815 with self.session(graph, config, use_gpu, force_gpu) as sess: 

2816 yield sess 

2817 

2818 # pylint: enable=g-doc-return-or-yield 

2819 

2820 class _CheckedThread(object): 

2821 """A wrapper class for Thread that asserts successful completion. 

2822 

2823 This class should be created using the TensorFlowTestCase.checkedThread() 

2824 method. 

2825 """ 

2826 

2827 def __init__(self, testcase, target, args=None, kwargs=None): 

2828 """Constructs a new instance of _CheckedThread. 

2829 

2830 Args: 

2831 testcase: The TensorFlowTestCase for which this thread is being created. 

2832 target: A callable object representing the code to be executed in the 

2833 thread. 

2834 args: A tuple of positional arguments that will be passed to target. 

2835 kwargs: A dictionary of keyword arguments that will be passed to target. 

2836 """ 

2837 self._testcase = testcase 

2838 self._target = target 

2839 self._args = () if args is None else args 

2840 self._kwargs = {} if kwargs is None else kwargs 

2841 self._thread = threading.Thread(target=self._protected_run) 

2842 self._exception = None 

2843 

2844 self._is_thread_joined = False 

2845 

2846 def _protected_run(self): 

2847 """Target for the wrapper thread. Sets self._exception on failure.""" 

2848 try: 

2849 self._target(*self._args, **self._kwargs) 

2850 except Exception as e: # pylint: disable=broad-except 

2851 self._exception = e 

2852 

2853 def start(self): 

2854 """Starts the thread's activity. 

2855 

2856 This must be called at most once per _CheckedThread object. It arranges 

2857 for the object's target to be invoked in a separate thread of control. 

2858 """ 

2859 self._thread.start() 

2860 

2861 def join(self): 

2862 """Blocks until the thread terminates. 

2863 

2864 Raises: 

2865 self._testcase.failureException: If the thread terminates with due to 

2866 an exception. 

2867 """ 

2868 self._is_thread_joined = True 

2869 self._thread.join() 

2870 if self._exception is not None: 

2871 self._testcase.fail("Error in checkedThread: %s" % str(self._exception)) 

2872 

2873 def is_alive(self): 

2874 """Returns whether the thread is alive. 

2875 

2876 This method returns True just before the run() method starts 

2877 until just after the run() method terminates. 

2878 

2879 Returns: 

2880 True if the thread is alive, otherwise False. 

2881 """ 

2882 return self._thread.is_alive() 

2883 

2884 def check_termination(self): 

2885 """Returns whether the checked thread was properly used and did terminate. 

2886 

2887 Every checked thread should be "join"ed after starting, and before the 

2888 test tears down. If it is not joined, it is possible the thread will hang 

2889 and cause flaky failures in tests. 

2890 

2891 Raises: 

2892 self._testcase.failureException: If check_termination was called before 

2893 thread was joined. 

2894 

2895 RuntimeError: If the thread is not terminated. This means thread was not 

2896 joined with the main thread. 

2897 """ 

2898 if self._is_thread_joined: 

2899 if self.is_alive(): 

2900 raise RuntimeError( 

2901 "Thread was not joined with main thread, and is still running " 

2902 "when the test finished.") 

2903 else: 

2904 self._testcase.fail("A checked thread was not joined.") 

2905 

2906 def checkedThread(self, target, args=None, kwargs=None): 

2907 """Returns a Thread wrapper that asserts 'target' completes successfully. 

2908 

2909 This method should be used to create all threads in test cases, as 

2910 otherwise there is a risk that a thread will silently fail, and/or 

2911 assertions made in the thread will not be respected. 

2912 

2913 Args: 

2914 target: A callable object to be executed in the thread. 

2915 args: The argument tuple for the target invocation. Defaults to (). 

2916 kwargs: A dictionary of keyword arguments for the target invocation. 

2917 Defaults to {}. 

2918 

2919 Returns: 

2920 A wrapper for threading.Thread that supports start() and join() methods. 

2921 """ 

2922 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs) 

2923 self._threads.append(ret) 

2924 return ret 

2925 

2926 # pylint: enable=invalid-name 

2927 @py_func_if_in_function 

2928 def assertNear(self, f1, f2, err, msg=None): 

2929 """Asserts that two floats are near each other. 

2930 

2931 Checks that |f1 - f2| < err and asserts a test failure 

2932 if not. 

2933 

2934 Args: 

2935 f1: A float value. 

2936 f2: A float value. 

2937 err: A float value. 

2938 msg: An optional string message to append to the failure message. 

2939 """ 

2940 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf 

2941 self.assertTrue( 

2942 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" % 

2943 (f1, f2, err, " (%s)" % msg if msg is not None else "")) 

2944 

2945 @py_func_if_in_function 

2946 def assertArrayNear(self, farray1, farray2, err, msg=None): 

2947 """Asserts that two float arrays are near each other. 

2948 

2949 Checks that for all elements of farray1 and farray2 

2950 |f1 - f2| < err. Asserts a test failure if not. 

2951 

2952 Args: 

2953 farray1: a list of float values. 

2954 farray2: a list of float values. 

2955 err: a float value. 

2956 msg: Optional message to report on failure. 

2957 """ 

2958 self.assertEqual(len(farray1), len(farray2), msg=msg) 

2959 for f1, f2 in zip(farray1, farray2): 

2960 self.assertNear(float(f1), float(f2), err, msg=msg) 

2961 

2962 def _NDArrayNear(self, ndarray1, ndarray2, err): 

2963 return np.linalg.norm(ndarray1 - ndarray2) < err 

2964 

2965 @py_func_if_in_function 

2966 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None): 

2967 """Asserts that two numpy arrays have near values. 

2968 

2969 Args: 

2970 ndarray1: a numpy ndarray. 

2971 ndarray2: a numpy ndarray. 

2972 err: a float. The maximum absolute difference allowed. 

2973 msg: Optional message to report on failure. 

2974 """ 

2975 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg) 

2976 

2977 def _GetNdArray(self, a): 

2978 # If a is tensor-like then convert it to ndarray 

2979 if tensor_util.is_tf_type(a): 

2980 if isinstance(a, ops._EagerTensorBase): 

2981 a = a.numpy() 

2982 else: 

2983 a = self.evaluate(a) 

2984 if not isinstance(a, np.ndarray): 

2985 try: 

2986 return np.array(a) 

2987 except ValueError as e: 

2988 # TODO(b/264461299): NumPy 1.24 no longer infers dtype=object from 

2989 # ragged sequences. 

2990 # See: 

2991 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html 

2992 # Fixing this correctly requires clarifying the API contract of this 

2993 # function with respect to ragged sequences and possibly updating all 

2994 # users. As a backwards compatibility measure, if array 

2995 # creation fails with an "inhomogeneous shape" error, try again with 

2996 # an explicit dtype=object, which should restore the previous behavior. 

2997 if "inhomogeneous shape" in str(e): 

2998 return np.array(a, dtype=object) 

2999 else: 

3000 raise 

3001 return a 

3002 

3003 def evaluate_if_both_tensors(self, a, b): 

3004 if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and 

3005 not isinstance(a, ops._EagerTensorBase) and 

3006 not isinstance(b, ops._EagerTensorBase)): 

3007 return self.evaluate((a, b)) 

3008 else: 

3009 return (a, b) 

3010 

3011 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 

3012 (a, b) = self.evaluate_if_both_tensors(a, b) 

3013 a = self._GetNdArray(a) 

3014 b = self._GetNdArray(b) 

3015 # When the array rank is small, print its contents. Numpy array printing is 

3016 # implemented using inefficient recursion so prints can cause tests to 

3017 # time out. 

3018 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): 

3019 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " 

3020 "%s.") % (a.shape, b.shape, b) 

3021 else: 

3022 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, 

3023 b.shape) 

3024 self.assertEqual(a.shape, b.shape, shape_mismatch_msg) 

3025 

3026 msgs = [msg] 

3027 # np.allclose does not always work for our custom bfloat16 and float8 

3028 # extension types when type promotions are involved, so we first cast any 

3029 # arrays of such types to float32. 

3030 a_dtype = a.dtype 

3031 custom_dtypes = (dtypes.bfloat16.as_numpy_dtype, 

3032 dtypes.float8_e5m2.as_numpy_dtype, 

3033 dtypes.float8_e4m3fn.as_numpy_dtype) 

3034 a = a.astype(np.float32) if a.dtype in custom_dtypes else a 

3035 b = b.astype(np.float32) if b.dtype in custom_dtypes else b 

3036 if not np.allclose(a, b, rtol=rtol, atol=atol): 

3037 # Adds more details to np.testing.assert_allclose. 

3038 # 

3039 # NOTE: numpy.allclose (and numpy.testing.assert_allclose) 

3040 # checks whether two arrays are element-wise equal within a 

3041 # tolerance. The relative difference (rtol * abs(b)) and the 

3042 # absolute difference atol are added together to compare against 

3043 # the absolute difference between a and b. Here, we want to 

3044 # tell user which elements violate such conditions. 

3045 cond = np.logical_or( 

3046 np.abs(a - b) > atol + rtol * np.abs(b), 

3047 np.isnan(a) != np.isnan(b)) 

3048 if a.ndim: 

3049 x = a[np.where(cond)] 

3050 y = b[np.where(cond)] 

3051 msgs.append("not close where = {}".format(np.where(cond))) 

3052 else: 

3053 # np.where is broken for scalars 

3054 x, y = a, b 

3055 msgs.append("not close lhs = {}".format(x)) 

3056 msgs.append("not close rhs = {}".format(y)) 

3057 msgs.append("not close dif = {}".format(np.abs(x - y))) 

3058 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y))) 

3059 msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape)) 

3060 # TODO(xpan): There seems to be a bug: 

3061 # tensorflow/compiler/tests:binary_ops_test pass with float32 

3062 # nan even though the equal_nan is False by default internally. 

3063 np.testing.assert_allclose( 

3064 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True) 

3065 

3066 def _assertAllCloseRecursive(self, 

3067 a, 

3068 b, 

3069 rtol=1e-6, 

3070 atol=1e-6, 

3071 path=None, 

3072 msg=None): 

3073 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b): 

3074 return self._assertRaggedClose(a, b, rtol, atol, msg) 

3075 path = path or [] 

3076 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "") 

3077 msg = msg if msg else "" 

3078 

3079 # Check if a and/or b are namedtuples. 

3080 if hasattr(a, "_asdict"): 

3081 a = a._asdict() 

3082 if hasattr(b, "_asdict"): 

3083 b = b._asdict() 

3084 a_is_dict = isinstance(a, collections_abc.Mapping) 

3085 if a_is_dict != isinstance(b, collections_abc.Mapping): 

3086 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" % 

3087 (path_str, path_str, msg)) 

3088 if a_is_dict: 

3089 self.assertItemsEqual( 

3090 a.keys(), 

3091 b.keys(), 

3092 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" % 

3093 (path_str, a.keys(), path_str, b.keys(), msg)) 

3094 for k in a: 

3095 path.append(k) 

3096 self._assertAllCloseRecursive( 

3097 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg) 

3098 del path[-1] 

3099 elif isinstance(a, (list, tuple)): 

3100 # Try to directly compare a, b as ndarrays; if not work, then traverse 

3101 # through the sequence, which is more expensive. 

3102 try: 

3103 (a, b) = self.evaluate_if_both_tensors(a, b) 

3104 a_as_ndarray = self._GetNdArray(a) 

3105 b_as_ndarray = self._GetNdArray(b) 

3106 self._assertArrayLikeAllClose( 

3107 a_as_ndarray, 

3108 b_as_ndarray, 

3109 rtol=rtol, 

3110 atol=atol, 

3111 msg="Mismatched value: a%s is different from b%s. %s" % 

3112 (path_str, path_str, msg)) 

3113 except (ValueError, TypeError, NotImplementedError) as e: 

3114 if len(a) != len(b): 

3115 raise ValueError( 

3116 "Mismatched length: a%s has %d items, but b%s has %d items. %s" % 

3117 (path_str, len(a), path_str, len(b), msg)) 

3118 for idx, (a_ele, b_ele) in enumerate(zip(a, b)): 

3119 path.append(str(idx)) 

3120 self._assertAllCloseRecursive( 

3121 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg) 

3122 del path[-1] 

3123 # a and b are ndarray like objects 

3124 else: 

3125 try: 

3126 self._assertArrayLikeAllClose( 

3127 a, 

3128 b, 

3129 rtol=rtol, 

3130 atol=atol, 

3131 msg=("Mismatched value: a%s is different from b%s. %s" % 

3132 (path_str, path_str, msg))) 

3133 except TypeError as e: 

3134 msg = ("Error: a%s has %s, but b%s has %s. %s" % 

3135 (path_str, type(a), path_str, type(b), msg)) 

3136 e.args = ((e.args[0] + " : " + msg,) + e.args[1:]) 

3137 raise 

3138 

3139 @py_func_if_in_function 

3140 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 

3141 """Asserts that two structures of numpy arrays or Tensors, have near values. 

3142 

3143 `a` and `b` can be arbitrarily nested structures. A layer of a nested 

3144 structure can be a `dict`, `namedtuple`, `tuple` or `list`. 

3145 

3146 Note: the implementation follows 

3147 [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html) 

3148 (and numpy.testing.assert_allclose). It checks whether two arrays are 

3149 element-wise equal within a tolerance. The relative difference 

3150 (`rtol * abs(b)`) and the absolute difference `atol` are added together 

3151 to compare against the absolute difference between `a` and `b`. 

3152 

3153 Args: 

3154 a: The expected numpy `ndarray`, or anything that can be converted into a 

3155 numpy `ndarray` (including Tensor), or any arbitrarily nested of 

3156 structure of these. 

3157 b: The actual numpy `ndarray`, or anything that can be converted into a 

3158 numpy `ndarray` (including Tensor), or any arbitrarily nested of 

3159 structure of these. 

3160 rtol: relative tolerance. 

3161 atol: absolute tolerance. 

3162 msg: Optional message to report on failure. 

3163 

3164 Raises: 

3165 ValueError: if only one of `a[p]` and `b[p]` is a dict or 

3166 `a[p]` and `b[p]` have different length, where `[p]` denotes a path 

3167 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and 

3168 `[p] = [1]['d']`, then `a[p] = (6, 7)`. 

3169 """ 

3170 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg) 

3171 

3172 @py_func_if_in_function 

3173 def assertAllCloseAccordingToType(self, 

3174 a, 

3175 b, 

3176 rtol=1e-6, 

3177 atol=1e-6, 

3178 float_rtol=1e-6, 

3179 float_atol=1e-6, 

3180 half_rtol=1e-3, 

3181 half_atol=1e-3, 

3182 bfloat16_rtol=1e-2, 

3183 bfloat16_atol=1e-2, 

3184 msg=None): 

3185 """Like assertAllClose, but also suitable for comparing fp16 arrays. 

3186 

3187 In particular, the tolerance is reduced to 1e-3 if at least 

3188 one of the arguments is of type float16. 

3189 

3190 Args: 

3191 a: the expected numpy ndarray or anything can be converted to one. 

3192 b: the actual numpy ndarray or anything can be converted to one. 

3193 rtol: relative tolerance. 

3194 atol: absolute tolerance. 

3195 float_rtol: relative tolerance for float32. 

3196 float_atol: absolute tolerance for float32. 

3197 half_rtol: relative tolerance for float16. 

3198 half_atol: absolute tolerance for float16. 

3199 bfloat16_rtol: relative tolerance for bfloat16. 

3200 bfloat16_atol: absolute tolerance for bfloat16. 

3201 msg: Optional message to report on failure. 

3202 """ 

3203 (a, b) = self.evaluate_if_both_tensors(a, b) 

3204 a = self._GetNdArray(a) 

3205 b = self._GetNdArray(b) 

3206 # types with lower tol are put later to overwrite previous ones. 

3207 if (a.dtype == np.float32 or b.dtype == np.float32 or 

3208 a.dtype == np.complex64 or b.dtype == np.complex64): 

3209 rtol = max(rtol, float_rtol) 

3210 atol = max(atol, float_atol) 

3211 if a.dtype == np.float16 or b.dtype == np.float16: 

3212 rtol = max(rtol, half_rtol) 

3213 atol = max(atol, half_atol) 

3214 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or 

3215 b.dtype == dtypes.bfloat16.as_numpy_dtype): 

3216 rtol = max(rtol, bfloat16_rtol) 

3217 atol = max(atol, bfloat16_atol) 

3218 

3219 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 

3220 

3221 @py_func_if_in_function 

3222 def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): 

3223 """Assert that two numpy arrays, or Tensors, do not have near values. 

3224 

3225 Args: 

3226 a: The expected numpy `ndarray`, or anything that can be converted into a 

3227 numpy `ndarray` (including Tensor), or any arbitrarily nested of 

3228 structure of these. 

3229 b: The actual numpy `ndarray`, or anything that can be converted into a 

3230 numpy `ndarray` (including Tensor), or any arbitrarily nested of 

3231 structure of these. 

3232 rtol: relative tolerance. 

3233 atol: absolute tolerance. 

3234 msg: Optional message to report on failure. 

3235 

3236 Raises: 

3237 AssertionError: If `a` and `b` are unexpectedly close at all elements. 

3238 """ 

3239 try: 

3240 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg) 

3241 except AssertionError: 

3242 return 

3243 msg = msg or "" 

3244 raise AssertionError("The two values are close at all elements. %s" % msg) 

3245 

3246 @py_func_if_in_function 

3247 def assertAllEqual(self, a, b, msg=None): 

3248 """Asserts that two numpy arrays or Tensors have the same values. 

3249 

3250 Args: 

3251 a: the expected numpy ndarray or anything can be converted to one. 

3252 b: the actual numpy ndarray or anything can be converted to one. 

3253 msg: Optional message to report on failure. 

3254 """ 

3255 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)): 

3256 return self._assertRaggedEqual(a, b, msg) 

3257 msg = msg if msg else "" 

3258 (a, b) = self.evaluate_if_both_tensors(a, b) 

3259 a = self._GetNdArray(a) 

3260 b = self._GetNdArray(b) 

3261 # Arbitrary bounds so that we don't print giant tensors. 

3262 if (b.ndim <= 3 or b.size < 500): 

3263 self.assertEqual( 

3264 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 

3265 " Contents: %r. \n%s." % (a.shape, b.shape, b, msg)) 

3266 else: 

3267 self.assertEqual( 

3268 a.shape, b.shape, "Shape mismatch: expected %s, got %s." 

3269 " %s" % (a.shape, b.shape, msg)) 

3270 

3271 same = (a == b) 

3272 

3273 if dtypes.as_dtype(a.dtype).is_floating: 

3274 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b))) 

3275 msgs = [msg] 

3276 if not np.all(same): 

3277 # Adds more details to np.testing.assert_array_equal. 

3278 diff = np.logical_not(same) 

3279 if a.ndim: 

3280 x = a[np.where(diff)] 

3281 y = b[np.where(diff)] 

3282 msgs.append("not equal where = {}".format(np.where(diff))) 

3283 else: 

3284 # np.where is broken for scalars 

3285 x, y = a, b 

3286 msgs.append("not equal lhs = %r" % x) 

3287 msgs.append("not equal rhs = %r" % y) 

3288 

3289 if (a.dtype.kind != b.dtype.kind and 

3290 {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})): 

3291 a_list = [] 

3292 b_list = [] 

3293 # OK to flatten `a` and `b` because they are guaranteed to have the 

3294 # same shape. 

3295 for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]: 

3296 for item in flat_arr: 

3297 if isinstance(item, str): 

3298 out_list.append(item.encode("utf-8")) 

3299 else: 

3300 out_list.append(item) 

3301 a = np.array(a_list) 

3302 b = np.array(b_list) 

3303 

3304 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs)) 

3305 

3306 @py_func_if_in_function 

3307 def assertNotAllEqual(self, a, b, msg=None): 

3308 """Asserts that two numpy arrays or Tensors do not have the same values. 

3309 

3310 Args: 

3311 a: the expected numpy ndarray or anything can be converted to one. 

3312 b: the actual numpy ndarray or anything can be converted to one. 

3313 msg: Optional message to report on failure. 

3314 """ 

3315 try: 

3316 self.assertAllEqual(a, b) 

3317 except AssertionError: 

3318 return 

3319 msg = msg or "" 

3320 raise AssertionError("The two values are equal at all elements. %s" % msg) 

3321 

3322 @py_func_if_in_function 

3323 def assertAllGreater(self, a, comparison_target): 

3324 """Assert element values are all greater than a target value. 

3325 

3326 Args: 

3327 a: The numpy `ndarray`, or anything that can be converted into a numpy 

3328 `ndarray` (including Tensor). 

3329 comparison_target: The target value of comparison. 

3330 """ 

3331 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 

3332 a = self._GetNdArray(a) 

3333 self.assertGreater(np.min(a), comparison_target) 

3334 

3335 @py_func_if_in_function 

3336 def assertAllLess(self, a, comparison_target): 

3337 """Assert element values are all less than a target value. 

3338 

3339 Args: 

3340 a: The numpy `ndarray`, or anything that can be converted into a numpy 

3341 `ndarray` (including Tensor). 

3342 comparison_target: The target value of comparison. 

3343 """ 

3344 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 

3345 a = self._GetNdArray(a) 

3346 self.assertLess(np.max(a), comparison_target) 

3347 

3348 @py_func_if_in_function 

3349 def assertAllGreaterEqual(self, a, comparison_target): 

3350 """Assert element values are all greater than or equal to a target value. 

3351 

3352 Args: 

3353 a: The numpy `ndarray`, or anything that can be converted into a numpy 

3354 `ndarray` (including Tensor). 

3355 comparison_target: The target value of comparison. 

3356 """ 

3357 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 

3358 a = self._GetNdArray(a) 

3359 self.assertGreaterEqual(np.min(a), comparison_target) 

3360 

3361 @py_func_if_in_function 

3362 def assertAllLessEqual(self, a, comparison_target): 

3363 """Assert element values are all less than or equal to a target value. 

3364 

3365 Args: 

3366 a: The numpy `ndarray`, or anything that can be converted into a numpy 

3367 `ndarray` (including Tensor). 

3368 comparison_target: The target value of comparison. 

3369 """ 

3370 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target) 

3371 a = self._GetNdArray(a) 

3372 self.assertLessEqual(np.max(a), comparison_target) 

3373 

3374 def _format_subscripts(self, subscripts, value, limit=10, indent=2): 

3375 """Generate a summary of ndarray subscripts as a list of str. 

3376 

3377 If limit == N, this method will print up to the first N subscripts on 

3378 separate 

3379 lines. A line of ellipses (...) will be appended at the end if the number of 

3380 subscripts exceeds N. 

3381 

3382 Args: 

3383 subscripts: The tensor (np.ndarray) subscripts, of the same format as 

3384 np.where()'s return value, i.e., a tuple of arrays with each array 

3385 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])). 

3386 value: (np.ndarray) value of the tensor. 

3387 limit: (int) The maximum number of indices to print. 

3388 indent: (int) Number of characters to indent at the beginning of each 

3389 line. 

3390 

3391 Returns: 

3392 (list of str) the multi-line representation of the subscripts and values, 

3393 potentially with omission at the end. 

3394 """ 

3395 lines = [] 

3396 subscripts = np.transpose(subscripts) 

3397 prefix = " " * indent 

3398 if np.ndim(value) == 0: 

3399 return [prefix + "[0] : " + str(value)] 

3400 for subscript in itertools.islice(subscripts, limit): 

3401 lines.append(prefix + str(subscript) + " : " + 

3402 str(value[tuple(subscript)])) 

3403 if len(subscripts) > limit: 

3404 lines.append(prefix + "...") 

3405 return lines 

3406 

3407 @py_func_if_in_function 

3408 def assertAllInRange(self, 

3409 target, 

3410 lower_bound, 

3411 upper_bound, 

3412 open_lower_bound=False, 

3413 open_upper_bound=False): 

3414 """Assert that elements in a Tensor are all in a given range. 

3415 

3416 Args: 

3417 target: The numpy `ndarray`, or anything that can be converted into a 

3418 numpy `ndarray` (including Tensor). 

3419 lower_bound: lower bound of the range 

3420 upper_bound: upper bound of the range 

3421 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather 

3422 than the default >=) 

3423 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather 

3424 than the default <=) 

3425 

3426 Raises: 

3427 AssertionError: 

3428 if the value tensor does not have an ordered numeric type (float* or 

3429 int*), or 

3430 if there are nan values, or 

3431 if any of the elements do not fall in the specified range. 

3432 """ 

3433 target = self._GetNdArray(target) 

3434 if not (np.issubdtype(target.dtype, np.floating) or 

3435 np.issubdtype(target.dtype, np.integer)): 

3436 raise AssertionError( 

3437 "The value of %s does not have an ordered numeric type, instead it " 

3438 "has type: %s" % (target, target.dtype)) 

3439 

3440 nan_subscripts = np.where(np.isnan(target)) 

3441 if np.size(nan_subscripts): 

3442 raise AssertionError( 

3443 "%d of the %d element(s) are NaN. " 

3444 "Subscripts(s) and value(s) of the NaN element(s):\n" % 

3445 (len(nan_subscripts[0]), np.size(target)) + 

3446 "\n".join(self._format_subscripts(nan_subscripts, target))) 

3447 

3448 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " + 

3449 str(upper_bound) + (")" if open_upper_bound else "]")) 

3450 

3451 violations = ( 

3452 np.less_equal(target, lower_bound) if open_lower_bound else np.less( 

3453 target, lower_bound)) 

3454 violations = np.logical_or( 

3455 violations, 

3456 np.greater_equal(target, upper_bound) 

3457 if open_upper_bound else np.greater(target, upper_bound)) 

3458 violation_subscripts = np.where(violations) 

3459 if np.size(violation_subscripts): 

3460 raise AssertionError( 

3461 "%d of the %d element(s) are outside the range %s. " % 

3462 (len(violation_subscripts[0]), np.size(target), range_str) + 

3463 "Subscript(s) and value(s) of the offending elements:\n" + 

3464 "\n".join(self._format_subscripts(violation_subscripts, target))) 

3465 

3466 @py_func_if_in_function 

3467 def assertAllInSet(self, target, expected_set): 

3468 """Assert that elements of a Tensor are all in a given closed set. 

3469 

3470 Args: 

3471 target: The numpy `ndarray`, or anything that can be converted into a 

3472 numpy `ndarray` (including Tensor). 

3473 expected_set: (`list`, `tuple` or `set`) The closed set that the elements 

3474 of the value of `target` are expected to fall into. 

3475 

3476 Raises: 

3477 AssertionError: 

3478 if any of the elements do not fall into `expected_set`. 

3479 """ 

3480 target = self._GetNdArray(target) 

3481 

3482 # Elements in target that are not in expected_set. 

3483 diff = np.setdiff1d(target.flatten(), list(expected_set)) 

3484 if np.size(diff): 

3485 raise AssertionError("%d unique element(s) are not in the set %s: %s" % 

3486 (np.size(diff), expected_set, diff)) 

3487 

3488 @py_func_if_in_function 

3489 def assertDTypeEqual(self, target, expected_dtype): 

3490 """Assert ndarray data type is equal to expected. 

3491 

3492 Args: 

3493 target: The numpy `ndarray`, or anything that can be converted into a 

3494 numpy `ndarray` (including Tensor). 

3495 expected_dtype: Expected data type. 

3496 """ 

3497 target = self._GetNdArray(target) 

3498 if not isinstance(target, list): 

3499 arrays = [target] 

3500 for arr in arrays: 

3501 self.assertEqual(arr.dtype, expected_dtype) 

3502 

3503 # pylint: disable=g-doc-return-or-yield 

3504 @contextlib.contextmanager 

3505 def assertRaisesWithPredicateMatch(self, exception_type, 

3506 expected_err_re_or_predicate): 

3507 """Returns a context manager to enclose code expected to raise an exception. 

3508 

3509 If the exception is an OpError, the op stack is also included in the message 

3510 predicate search. 

3511 

3512 Args: 

3513 exception_type: The expected type of exception that should be raised. 

3514 expected_err_re_or_predicate: If this is callable, it should be a function 

3515 of one argument that inspects the passed-in exception and returns True 

3516 (success) or False (please fail the test). Otherwise, the error message 

3517 is expected to match this regular expression partially. 

3518 

3519 Returns: 

3520 A context manager to surround code that is expected to raise an 

3521 exception. 

3522 """ 

3523 if callable(expected_err_re_or_predicate): 

3524 predicate = expected_err_re_or_predicate 

3525 else: 

3526 

3527 def predicate(e): 

3528 err_str = e.message if isinstance(e, errors.OpError) else str(e) 

3529 op = e.op if isinstance(e, errors.OpError) else None 

3530 while op is not None: 

3531 err_str += "\nCaused by: " + op.name 

3532 op = op._original_op # pylint: disable=protected-access 

3533 logging.info("Searching within error strings: '%s' within '%s'", 

3534 expected_err_re_or_predicate, err_str) 

3535 return re.search(expected_err_re_or_predicate, err_str) 

3536 

3537 try: 

3538 yield 

3539 self.fail(exception_type.__name__ + " not raised") 

3540 except Exception as e: # pylint: disable=broad-except 

3541 if not isinstance(e, exception_type) or not predicate(e): 

3542 raise AssertionError("Exception of type %s: %s" % 

3543 (str(type(e)), str(e))) 

3544 

3545 # pylint: enable=g-doc-return-or-yield 

3546 

3547 def assertRaisesOpError(self, expected_err_re_or_predicate): 

3548 return self.assertRaisesWithPredicateMatch(errors.OpError, 

3549 expected_err_re_or_predicate) 

3550 

3551 def assertRaisesIncompatibleShapesError( 

3552 self, exception_type=errors.InvalidArgumentError): 

3553 return self.assertRaisesWithPredicateMatch( 

3554 exception_type, r"Incompatible shapes|Dimensions must be equal|" 

3555 r"required broadcastable shapes") 

3556 

3557 def assertShapeEqual(self, input_a, input_b, msg=None): 

3558 """Asserts that two Numpy or TensorFlow objects have the same shape. 

3559 

3560 For Tensors, this compares statically known shapes at compile time, not 

3561 dynamic shapes at runtime. 

3562 

3563 Args: 

3564 input_a: A Numpy ndarray, Numpy scalar, or a Tensor. 

3565 input_b: A Numpy ndarray, Numpy scalar, or a Tensor. 

3566 msg: Optional message to report on failure. 

3567 

3568 Raises: 

3569 TypeError: If the arguments have the wrong type. 

3570 """ 

3571 if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)): 

3572 raise TypeError( 

3573 "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor." 

3574 f"Instead received {type(input_a)}") 

3575 if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)): 

3576 raise TypeError( 

3577 "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor." 

3578 f"Instead received {type(input_b)}") 

3579 shape_a = input_a.get_shape().as_list() if isinstance( 

3580 input_a, ops.Tensor) else input_a.shape 

3581 shape_b = input_b.get_shape().as_list() if isinstance( 

3582 input_b, ops.Tensor) else input_b.shape 

3583 self.assertAllEqual(shape_a, shape_b, msg=msg) 

3584 

3585 def assertDeviceEqual(self, device1, device2, msg=None): 

3586 """Asserts that the two given devices are the same. 

3587 

3588 Args: 

3589 device1: A string device name or TensorFlow `DeviceSpec` object. 

3590 device2: A string device name or TensorFlow `DeviceSpec` object. 

3591 msg: Optional message to report on failure. 

3592 """ 

3593 device1 = pydev.canonical_name(device1) 

3594 device2 = pydev.canonical_name(device2) 

3595 self.assertEqual( 

3596 device1, device2, 

3597 "Devices %s and %s are not equal. %s" % (device1, device2, msg)) 

3598 

3599 @py_func_if_in_function 

3600 def assertDictEqual(self, a, b, msg=None): 

3601 """Assert that two given dictionary of tensors are the same. 

3602 

3603 Args: 

3604 a: Expected dictionary with numpy ndarray or anything else that can be 

3605 converted to one as values. 

3606 b: Actual dictionary with numpy ndarray or anything else that can be 

3607 converted to one as values. 

3608 msg: Optional message to report on failure. 

3609 """ 

3610 # To keep backwards compatibility, we first try the base class 

3611 # assertDictEqual. If that fails we try the tensorflow one. 

3612 try: 

3613 super().assertDictEqual(a, b, msg) 

3614 except Exception: # pylint: disable=broad-except 

3615 self.assertSameElements(a.keys(), b.keys()) # pylint: disable=g-assert-in-except 

3616 for k, v in a.items(): 

3617 (a_k, b_k) = self.evaluate_if_both_tensors(v, b[k]) 

3618 a_k = self._GetNdArray(a_k) 

3619 b_k = self._GetNdArray(b_k) 

3620 if np.issubdtype(a_k.dtype, np.floating): 

3621 self.assertAllClose(v, b[k], msg=k) 

3622 else: 

3623 self.assertAllEqual(v, b[k], msg=k) 

3624 

3625 def _GetPyList(self, a): 

3626 """Converts `a` to a nested python list.""" 

3627 if isinstance(a, ragged_tensor.RaggedTensor): 

3628 return self.evaluate(a).to_list() 

3629 elif isinstance(a, ops.Tensor): 

3630 a = self.evaluate(a) 

3631 return a.tolist() if isinstance(a, np.ndarray) else a 

3632 elif isinstance(a, np.ndarray): 

3633 return a.tolist() 

3634 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 

3635 return a.to_list() 

3636 else: 

3637 return np.array(a, dtype=object).tolist() 

3638 

3639 def _assertRaggedEqual(self, a, b, msg): 

3640 """Asserts that two ragged tensors are equal.""" 

3641 a_list = self._GetPyList(a) 

3642 b_list = self._GetPyList(b) 

3643 self.assertEqual(a_list, b_list, msg) 

3644 

3645 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 

3646 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 

3647 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 

3648 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 

3649 

3650 def _assertRaggedClose(self, a, b, rtol, atol, msg=None): 

3651 a_list = self._GetPyList(a) 

3652 b_list = self._GetPyList(b) 

3653 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg) 

3654 

3655 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 

3656 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 

3657 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 

3658 self.assertEqual(a_ragged_rank, b_ragged_rank, msg) 

3659 

3660 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"): 

3661 self.assertEqual(type(a), type(b)) 

3662 if isinstance(a, (list, tuple)): 

3663 self.assertLen(a, len(b), "Length differs for %s" % path) 

3664 for i in range(len(a)): 

3665 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg, 

3666 "%s[%s]" % (path, i)) 

3667 else: 

3668 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg) 

3669 

3670 # Fix Python 3+ compatibility issues 

3671 # pylint: disable=invalid-name 

3672 

3673 # Silence a deprecation warning 

3674 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex 

3675 

3676 # assertItemsEqual is assertCountEqual as of 3.2. 

3677 assertItemsEqual = googletest.TestCase.assertCountEqual 

3678 

3679 # pylint: enable=invalid-name 

3680 

3681 @contextlib.contextmanager 

3682 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): 

3683 """Set the session and its graph to global default and constrain devices.""" 

3684 if context.executing_eagerly(): 

3685 yield None 

3686 else: 

3687 with sess.graph.as_default(), sess.as_default(): 

3688 if force_gpu: 

3689 # Use the name of an actual device if one is detected, or 

3690 # '/device:GPU:0' otherwise 

3691 gpu_name = gpu_device_name() 

3692 if not gpu_name: 

3693 gpu_name = "/device:GPU:0" 

3694 with sess.graph.device(gpu_name): 

3695 yield sess 

3696 elif use_gpu: 

3697 yield sess 

3698 else: 

3699 with sess.graph.device("/device:CPU:0"): 

3700 yield sess 

3701 

3702 def _create_session(self, graph, config, force_gpu): 

3703 """See session() for details.""" 

3704 

3705 def prepare_config(config): 

3706 """Returns a config for sessions. 

3707 

3708 Args: 

3709 config: An optional config_pb2.ConfigProto to use to configure the 

3710 session. 

3711 

3712 Returns: 

3713 A config_pb2.ConfigProto object. 

3714 """ 

3715 # TODO(b/114333779): Enforce allow_soft_placement=False when 

3716 # use_gpu=False. Currently many tests rely on the fact that any device 

3717 # will be used even when a specific device is supposed to be used. 

3718 allow_soft_placement = not force_gpu 

3719 if config is None: 

3720 config = context.context().config 

3721 config.allow_soft_placement = allow_soft_placement 

3722 elif not allow_soft_placement and config.allow_soft_placement: 

3723 config_copy = context.context().config 

3724 config = config_copy 

3725 config.allow_soft_placement = False 

3726 # Don't perform optimizations for tests so we don't inadvertently run 

3727 # gpu ops on cpu 

3728 config.graph_options.optimizer_options.opt_level = -1 

3729 # Disable Grappler constant folding since some tests & benchmarks 

3730 # use constant input and become meaningless after constant folding. 

3731 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE 

3732 # GRAPPLER TEAM. 

3733 config.graph_options.rewrite_options.constant_folding = ( 

3734 rewriter_config_pb2.RewriterConfig.OFF) 

3735 config.graph_options.rewrite_options.pin_to_host_optimization = ( 

3736 rewriter_config_pb2.RewriterConfig.OFF) 

3737 return config 

3738 

3739 return ErrorLoggingSession(graph=graph, config=prepare_config(config)) 

3740 

3741 def _get_cached_session(self, 

3742 graph=None, 

3743 config=None, 

3744 force_gpu=False, 

3745 crash_if_inconsistent_args=True): 

3746 """See cached_session() for documentation.""" 

3747 if self._cached_session is None: 

3748 sess = self._create_session( 

3749 graph=graph, config=config, force_gpu=force_gpu) 

3750 self._cached_session = sess 

3751 self._cached_graph = graph 

3752 self._cached_config = config 

3753 self._cached_force_gpu = force_gpu 

3754 return sess 

3755 else: 

3756 if crash_if_inconsistent_args and self._cached_graph is not graph: 

3757 raise ValueError("The graph used to get the cached session is " 

3758 "different than the one that was used to create the " 

3759 "session. Maybe create a new session with " 

3760 "self.session()") 

3761 if crash_if_inconsistent_args and self._cached_config is not config: 

3762 raise ValueError("The config used to get the cached session is " 

3763 "different than the one that was used to create the " 

3764 "session. Maybe create a new session with " 

3765 "self.session()") 

3766 if crash_if_inconsistent_args and (self._cached_force_gpu is 

3767 not force_gpu): 

3768 raise ValueError( 

3769 "The force_gpu value used to get the cached session is " 

3770 "different than the one that was used to create the " 

3771 "session. Maybe create a new session with " 

3772 "self.session()") 

3773 return self._cached_session 

3774 

3775 

3776ASSIGNED_PORTS = set() 

3777lock = threading.Lock() 

3778 

3779 

3780def pick_unused_port(): 

3781 """Returns an unused and unassigned local port.""" 

3782 import portpicker # pylint: disable=g-import-not-at-top 

3783 

3784 global ASSIGNED_PORTS 

3785 with lock: 

3786 while True: 

3787 try: 

3788 port = portpicker.pick_unused_port() 

3789 except portpicker.NoFreePortFoundError as porterror: 

3790 raise unittest.SkipTest("Flakes in portpicker library do not represent" 

3791 " TensorFlow errors.") from porterror 

3792 if port > 10000 and port not in ASSIGNED_PORTS: 

3793 ASSIGNED_PORTS.add(port) 

3794 logging.info("Using local port %r", port) 

3795 return port 

3796 

3797 

3798@tf_export("test.create_local_cluster") 

3799def create_local_cluster(num_workers, 

3800 num_ps, 

3801 protocol="grpc", 

3802 worker_config=None, 

3803 ps_config=None): 

3804 """Create and start local servers and return the associated `Server` objects. 

3805 

3806 "PS" stands for "parameter server": a task responsible for storing and 

3807 updating the model's parameters. Other tasks send updates to these parameters 

3808 as they work on optimizing the parameters. This particular division of labor 

3809 between tasks is not required, but is common for distributed training. 

3810 

3811 Read more at https://www.tensorflow.org/guide/extend/architecture 

3812 

3813 ![components](https://www.tensorflow.org/images/diag1.svg "components") 

3814 

3815 

3816 Figure illustrates the interaction of these components. 

3817 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services. 

3818 

3819 

3820 Example: 

3821 ```python 

3822 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2) 

3823 

3824 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers] 

3825 

3826 with tf.device("/job:ps/task:0"): 

3827 ... 

3828 with tf.device("/job:ps/task:1"): 

3829 ... 

3830 with tf.device("/job:worker/task:0"): 

3831 ... 

3832 with tf.device("/job:worker/task:1"): 

3833 ... 

3834 

3835 worker_sessions[0].run(...) 

3836 ``` 

3837 

3838 Args: 

3839 num_workers: Number of worker servers to start. 

3840 num_ps: Number of PS servers to start. 

3841 protocol: Communication protocol. Allowed values are documented in the 

3842 documentation of `tf.distribute.Server`. 

3843 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be 

3844 used to instantiate multiple devices etc. 

3845 ps_config: (optional) `tf.ConfigProto` to initialize PS servers. 

3846 

3847 Returns: 

3848 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list 

3849 of `num_workers` objects of type `tf.distribute.Server` (all running 

3850 locally); 

3851 and `ps_servers` is a list of `num_ps` objects of similar type. 

3852 

3853 Raises: 

3854 ImportError: if portpicker module was not found at load time 

3855 """ 

3856 worker_ports = [pick_unused_port() for _ in range(num_workers)] 

3857 ps_ports = [pick_unused_port() for _ in range(num_ps)] 

3858 cluster_dict = { 

3859 "worker": ["localhost:%s" % port for port in worker_ports], 

3860 "ps": ["localhost:%s" % port for port in ps_ports] 

3861 } 

3862 cs = server_lib.ClusterSpec(cluster_dict) 

3863 

3864 workers = [ 

3865 server_lib.Server( 

3866 cs, 

3867 job_name="worker", 

3868 protocol=protocol, 

3869 task_index=ix, 

3870 config=worker_config, 

3871 start=True) for ix in range(num_workers) 

3872 ] 

3873 ps_servers = [ 

3874 server_lib.Server( 

3875 cs, 

3876 job_name="ps", 

3877 protocol=protocol, 

3878 task_index=ix, 

3879 config=ps_config, 

3880 start=True) for ix in range(num_ps) 

3881 ] 

3882 

3883 return workers, ps_servers 

3884 

3885 

3886def get_node_def_from_graph(node_name, graph_def): 

3887 """Returns the `NodeDef` instance for given node name in the graph def. 

3888 

3889 This method explores only the NodeDefs in `graph_def.node`. 

3890 

3891 Args: 

3892 node_name: Name of the NodeDef to search for. 

3893 graph_def: An instance of `GraphDef` proto. 

3894 

3895 Returns: 

3896 the `NodeDef` instance whose name field matches the given node_name or None. 

3897 """ 

3898 for node_def in graph_def.node: 

3899 if node_def.name == node_name: 

3900 return node_def 

3901 return None 

3902 

3903 

3904def set_producer_version(graph, producer_version): 

3905 """Sets graph.graph_def_versions.producer to `producer_version`.""" 

3906 # The C API doesn't expose altering GraphDefVersions. We can indirectly set 

3907 # it via import_graph_def though. 

3908 graph_def = graph_pb2.GraphDef() 

3909 graph_def.versions.producer = producer_version 

3910 with graph.as_default(): 

3911 importer.import_graph_def(graph_def) 

3912 assert graph.graph_def_versions.producer, producer_version 

3913 

3914 

3915@contextlib.contextmanager 

3916def _fake_gradient_tape_context_manager(): 

3917 """tf.gradients(...) implemented as tf.GradientTape context manager interface. 

3918 

3919 This is useful to test tf.gradients() in tests that uses tf.GradientTape(). 

3920 

3921 Yields: 

3922 gradient tape instance that's implemented by tf.gradients() underneath. 

3923 """ 

3924 try: 

3925 class FakeGradientTape: 

3926 

3927 def watch(self, x): 

3928 pass 

3929 

3930 def gradient(self, y, x, grad_ys=None): 

3931 result = gradients_impl.gradients(y, x, grad_ys) 

3932 

3933 # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single 

3934 # element. So unpack if needed to match `tape.gradient()` behavior. 

3935 if not isinstance(x, (list, tuple)): 

3936 assert len(result) == 1 

3937 return result[0] 

3938 

3939 return result 

3940 

3941 yield FakeGradientTape() 

3942 finally: 

3943 pass 

3944 

3945 

3946class AbstractGradientTape: 

3947 """Abstract GradientTape context manager that has multiple implementations. 

3948 

3949 This is useful to test both tf.GradientTape() and tf.gradients() without 

3950 duplicating tests. 

3951 """ 

3952 

3953 def __init__(self, use_tape, persistent=False): 

3954 self._use_tape = use_tape 

3955 self._persistent = persistent 

3956 

3957 def __enter__(self): 

3958 if self._use_tape: 

3959 self._tape_impl = backprop.GradientTape(persistent=self._persistent) 

3960 else: 

3961 self._tape_impl = _fake_gradient_tape_context_manager() 

3962 return self._tape_impl.__enter__() 

3963 

3964 def __exit__(self, exc_type, exc_val, exc_tb): 

3965 self._tape_impl.__exit__(exc_type, exc_val, exc_tb) 

3966 

3967 

3968@contextlib.contextmanager 

3969def run_functions_eagerly(run_eagerly): 

3970 """Runs functions eagerly if `run_eagerly` is true. 

3971 

3972 WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode 

3973 *WILL NOT* make the tf.function to run eagerly because eager is disabled by 

3974 default in V1. Instead, tf.function will run as a traced graph function. 

3975 

3976 Ensures that the state (for running functions eagerly) is back to the initial 

3977 `def_function.RUN_FUNCTIONS_EAGERLY` state. 

3978 

3979 Args: 

3980 run_eagerly: Boolean determining whether to run the function eagerly or not. 

3981 

3982 Raises: 

3983 ValueError if `run_eagerly` is not a boolean. 

3984 

3985 Yields: 

3986 Nothing. 

3987 """ 

3988 if not isinstance(run_eagerly, bool): 

3989 raise ValueError( 

3990 "Expected bool for `run_eagerly` but got {}".format(run_eagerly)) 

3991 

3992 is_eager = context.executing_eagerly() 

3993 if not is_eager and run_eagerly: 

3994 logging.warning( 

3995 "Running tf.function eagerly in V1 graph mode is not supported. " 

3996 "tf.function will be run as a traced graph function.") 

3997 

3998 initial_state = def_function.functions_run_eagerly() 

3999 def_function.run_functions_eagerly(run_eagerly) 

4000 try: 

4001 yield 

4002 finally: 

4003 def_function.run_functions_eagerly(initial_state) 

4004 

4005 

4006class TestDelta: 

4007 """A utility class to track increments to test counters.""" 

4008 

4009 def __init__(self, name, label): 

4010 self.name = name 

4011 self.label = label 

4012 self.Reset() 

4013 

4014 def Reset(self): 

4015 self.last_value = _test_metrics_util.test_counter_value( 

4016 self.name, self.label) 

4017 

4018 def Get(self): 

4019 value = _test_metrics_util.test_counter_value(self.name, self.label) 

4020 return value - self.last_value 

4021 

4022 

4023@tf_export("test.experimental.sync_devices") 

4024def sync_devices(): 

4025 """Synchronizes all devices. 

4026 

4027 By default, GPUs run asynchronously. This means that when you run an op on the 

4028 GPU, like `tf.linalg.matmul`, the op may still be running on the GPU when the 

4029 function returns. Non-GPU devices can also be made to run asynchronously by 

4030 calling `tf.config.experimental.set_synchronous_execution(False)`. Calling 

4031 `sync_devices()` blocks until pending ops have finished executing. This is 

4032 primarily useful for measuring performance during a benchmark. 

4033 

4034 For example, here is how you can measure how long `tf.linalg.matmul` runs: 

4035 

4036 >>> import time 

4037 >>> x = tf.random.normal((4096, 4096)) 

4038 >>> tf.linalg.matmul(x, x) # Warmup. 

4039 >>> tf.test.experimental.sync_devices() # Block until warmup has completed. 

4040 >>> 

4041 >>> start = time.time() 

4042 >>> y = tf.linalg.matmul(x, x) 

4043 >>> tf.test.experimental.sync_devices() # Block until matmul has completed. 

4044 >>> end = time.time() 

4045 >>> print(f'Time taken: {end - start}') 

4046 

4047 If the call to `sync_devices()` was omitted, the time printed could be too 

4048 small. This is because the op could still be running asynchronously when 

4049 the line `end = time.time()` is executed. 

4050 

4051 Raises: 

4052 RuntimeError: If run outside Eager mode. This must be called in Eager mode, 

4053 outside any `tf.function`s. 

4054 """ 

4055 if not context.executing_eagerly(): 

4056 raise RuntimeError( 

4057 "sync_devices() must only be called in Eager mode, outside tf.functions" 

4058 ) 

4059 

4060 # There are two sources of asynchrony in TensorFlow: 

4061 # 

4062 # 1. On GPUs, kernels are run on a CUDA stream, which is inherently 

4063 # asynchronous. 

4064 # 2. Calling `tf.config.experimental.set_synchronous_execution(False)` makes 

4065 # all ops asynchronous, in which case TensorFlow maintains internal queues 

4066 # of pending ops. 

4067 # 

4068 # Calling SyncDevice addresses source (1). Calling async_await addresses 

4069 # source (2). It is important that SyncDevice() is called before async_wait(), 

4070 # otherwise the SyncDevice op itself may still be pending on an internal 

4071 # TensorFlow queue when the sync_devices() Python function returns. 

4072 devices = config.list_logical_devices() 

4073 for dev in devices: 

4074 with ops.device(dev.name): 

4075 gen_sync_ops.SyncDevice() 

4076 context.async_wait()