Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/test_util.py: 19%
1564 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18import collections
19from collections import OrderedDict
20import contextlib
21import functools
22import gc
23import itertools
24import math
25import os
26import random
27import re
28import tempfile
29import threading
30import time
31import unittest
33from absl.testing import parameterized
34import numpy as np
36from google.protobuf import descriptor_pool
37from google.protobuf import text_format
39from tensorflow.core.config import flags
40from tensorflow.core.framework import graph_pb2
41from tensorflow.core.protobuf import rewriter_config_pb2
42from tensorflow.python import pywrap_sanitizers
43from tensorflow.python import tf2
44from tensorflow.python.client import device_lib
45from tensorflow.python.client import pywrap_tf_session
46from tensorflow.python.client import session
47from tensorflow.python.compat.compat import forward_compatibility_horizon
48from tensorflow.python.eager import backprop
49from tensorflow.python.eager import context
50from tensorflow.python.eager import def_function
51from tensorflow.python.framework import _test_metrics_util
52from tensorflow.python.framework import config
53from tensorflow.python.framework import device as pydev
54from tensorflow.python.framework import dtypes
55from tensorflow.python.framework import errors
56from tensorflow.python.framework import errors_impl
57from tensorflow.python.framework import gpu_util
58from tensorflow.python.framework import importer
59from tensorflow.python.framework import indexed_slices
60from tensorflow.python.framework import ops
61from tensorflow.python.framework import random_seed
62from tensorflow.python.framework import sparse_tensor
63from tensorflow.python.framework import tensor_shape
64from tensorflow.python.framework import tensor_util
65from tensorflow.python.framework import tfrt_utils
66from tensorflow.python.framework import versions
67from tensorflow.python.ops import array_ops
68from tensorflow.python.ops import control_flow_util
69from tensorflow.python.ops import control_flow_util_v2
70from tensorflow.python.ops import gen_sync_ops
71from tensorflow.python.ops import gradients_impl
72from tensorflow.python.ops import math_ops
73from tensorflow.python.ops import script_ops
74from tensorflow.python.ops import summary_ops_v2
75from tensorflow.python.ops import variables
76from tensorflow.python.ops.ragged import ragged_ops # pylint: disable=unused-import
77from tensorflow.python.ops.ragged import ragged_tensor
78from tensorflow.python.ops.ragged import ragged_tensor_value
79from tensorflow.python.platform import _pywrap_stacktrace_handler
80from tensorflow.python.platform import googletest
81from tensorflow.python.platform import tf_logging as logging
82from tensorflow.python.training import server_lib
83from tensorflow.python.util import _pywrap_util_port
84from tensorflow.python.util import compat
85from tensorflow.python.util import deprecation
86from tensorflow.python.util import nest
87from tensorflow.python.util import tf_decorator
88from tensorflow.python.util import tf_inspect
89from tensorflow.python.util import traceback_utils
90from tensorflow.python.util.compat import collections_abc
91from tensorflow.python.util.protobuf import compare
92from tensorflow.python.util.tf_export import tf_export
95# If the below import is made available through the BUILD rule, then this
96# function is overridden and will instead return True and cause Tensorflow
97# graphs to be compiled with XLA.
98def is_xla_enabled():
99 return False
102try:
103 from tensorflow.python.framework.is_xla_test_true import is_xla_enabled # pylint: disable=g-import-not-at-top, unused-import
104except Exception: # pylint: disable=broad-except
105 pass
108# Uses the same mechanism as above to selectively enable/disable MLIR
109# compilation.
110def is_mlir_bridge_enabled():
111 return None
114try:
115 from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import
116except ImportError:
117 try:
118 from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled # pylint: disable=g-import-not-at-top, unused-import
119 except ImportError:
120 pass
123def is_asan_enabled():
124 """Check if ASAN is enabled."""
125 return pywrap_sanitizers.is_asan_enabled()
128def is_msan_enabled():
129 """Check if MSAN is enabled."""
130 return pywrap_sanitizers.is_msan_enabled()
133def is_tsan_enabled():
134 """Check if TSAN is enabled."""
135 return pywrap_sanitizers.is_tsan_enabled()
138def is_ubsan_enabled():
139 """Check if UBSAN is enabled."""
140 return pywrap_sanitizers.is_ubsan_enabled()
143def _get_object_count_by_type(exclude=()):
144 return (
145 collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
146 collections.Counter([type(obj).__name__ for obj in exclude]))
149@tf_export("test.gpu_device_name")
150def gpu_device_name():
151 """Returns the name of a GPU device if available or a empty string.
153 This method should only be used in tests written with `tf.test.TestCase`.
155 >>> class MyTest(tf.test.TestCase):
156 ...
157 ... def test_add_on_gpu(self):
158 ... if not tf.test.is_built_with_gpu_support():
159 ... self.skipTest("test is only applicable on GPU")
160 ...
161 ... with tf.device(tf.test.gpu_device_name()):
162 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
164 """
165 for x in device_lib.list_local_devices():
166 if x.device_type == "GPU":
167 return compat.as_str(x.name)
168 return ""
171def assert_ops_in_graph(expected_ops, graph):
172 """Assert all expected operations are found.
174 Args:
175 expected_ops: `dict<string, string>` of op name to op type.
176 graph: Graph to check.
178 Returns:
179 `dict<string, node>` of node name to node.
181 Raises:
182 ValueError: If the expected ops are not present in the graph.
183 """
184 actual_ops = {}
185 gd = graph.as_graph_def()
186 for node in gd.node:
187 if node.name in expected_ops:
188 if expected_ops[node.name] != node.op:
189 raise ValueError("Expected op for node %s is different. %s vs %s" %
190 (node.name, expected_ops[node.name], node.op))
191 actual_ops[node.name] = node
192 if set(expected_ops.keys()) != set(actual_ops.keys()):
193 raise ValueError("Not all expected ops are present. Expected %s, found %s" %
194 (expected_ops.keys(), actual_ops.keys()))
195 return actual_ops
198@tf_export("test.assert_equal_graph_def", v1=[])
199def assert_equal_graph_def_v2(expected, actual):
200 """Asserts that two `GraphDef`s are (mostly) the same.
202 Compares two `GraphDef` protos for equality, ignoring versions and ordering of
203 nodes, attrs, and control inputs. Node names are used to match up nodes
204 between the graphs, so the naming of nodes must be consistent. This function
205 ignores randomized attribute values that may appear in V2 checkpoints.
207 Args:
208 expected: The `GraphDef` we expected.
209 actual: The `GraphDef` we have.
211 Raises:
212 AssertionError: If the `GraphDef`s do not match.
213 TypeError: If either argument is not a `GraphDef`.
214 """
215 assert_equal_graph_def(actual, expected, checkpoint_v2=True,
216 hash_table_shared_name=True)
219@tf_export(v1=["test.assert_equal_graph_def"])
220def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False,
221 hash_table_shared_name=False):
222 """Asserts that two `GraphDef`s are (mostly) the same.
224 Compares two `GraphDef` protos for equality, ignoring versions and ordering of
225 nodes, attrs, and control inputs. Node names are used to match up nodes
226 between the graphs, so the naming of nodes must be consistent.
228 Args:
229 actual: The `GraphDef` we have.
230 expected: The `GraphDef` we expected.
231 checkpoint_v2: boolean determining whether to ignore randomized attribute
232 values that appear in V2 checkpoints.
233 hash_table_shared_name: boolean determining whether to ignore randomized
234 shared_names that appear in HashTableV2 op defs.
236 Raises:
237 AssertionError: If the `GraphDef`s do not match.
238 TypeError: If either argument is not a `GraphDef`.
239 """
240 assert_equal_graph_def(actual, expected, checkpoint_v2,
241 hash_table_shared_name)
244def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
245 hash_table_shared_name=False):
246 if not isinstance(actual, graph_pb2.GraphDef):
247 raise TypeError("Expected tf.GraphDef for actual, got %s" %
248 type(actual).__name__)
249 if not isinstance(expected, graph_pb2.GraphDef):
250 raise TypeError("Expected tf.GraphDef for expected, got %s" %
251 type(expected).__name__)
253 if checkpoint_v2:
254 _strip_checkpoint_v2_randomized(actual)
255 _strip_checkpoint_v2_randomized(expected)
257 if hash_table_shared_name:
258 _strip_hash_table_shared_name(actual)
259 _strip_hash_table_shared_name(expected)
261 diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
262 expected.SerializeToString())
263 if diff:
264 raise AssertionError(compat.as_str(diff))
267def assert_meta_graph_protos_equal(tester, a, b):
268 """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
269 # Carefully check the collection_defs
270 tester.assertEqual(set(a.collection_def), set(b.collection_def))
271 collection_keys = a.collection_def.keys()
272 for k in collection_keys:
273 a_value = a.collection_def[k]
274 b_value = b.collection_def[k]
275 proto_type = ops.get_collection_proto_type(k)
276 if proto_type:
277 a_proto = proto_type()
278 b_proto = proto_type()
279 # Number of entries in the collections is the same
280 tester.assertEqual(
281 len(a_value.bytes_list.value), len(b_value.bytes_list.value))
282 for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
283 b_value.bytes_list.value):
284 a_proto.ParseFromString(a_value_item)
285 b_proto.ParseFromString(b_value_item)
286 tester.assertProtoEquals(a_proto, b_proto)
287 else:
288 tester.assertEquals(a_value, b_value)
289 # Compared the fields directly, remove their raw values from the
290 # proto comparison below.
291 a.ClearField("collection_def")
292 b.ClearField("collection_def")
294 # Check the graph_defs.
295 assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
296 # Check graph_def versions (ignored by assert_equal_graph_def).
297 tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
298 # Compared the fields directly, remove their raw values from the
299 # proto comparison below.
300 a.ClearField("graph_def")
301 b.ClearField("graph_def")
303 tester.assertProtoEquals(a, b)
306# Matches attributes named via _SHARDED_SUFFIX in
307# tensorflow/python/training/saver.py
308_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
311def _strip_checkpoint_v2_randomized(graph_def):
312 for node in graph_def.node:
313 delete_keys = []
314 for attr_key in node.attr:
315 attr_tensor_value = node.attr[attr_key].tensor
316 if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
317 attr_tensor_string_value = attr_tensor_value.string_val[0]
318 if (attr_tensor_string_value and
319 re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN),
320 attr_tensor_string_value)):
321 delete_keys.append(attr_key)
322 for attr_key in delete_keys:
323 del node.attr[attr_key]
326_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
329def _strip_hash_table_shared_name(graph_def):
330 for node in graph_def.node:
331 delete_keys = []
332 if node.op == "HashTableV2" and "shared_name" in node.attr:
333 if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN),
334 node.attr["shared_name"].s):
335 delete_keys.append("shared_name")
336 for attr_key in delete_keys:
337 del node.attr[attr_key]
340def IsGoogleCudaEnabled():
341 return _pywrap_util_port.IsGoogleCudaEnabled()
344def IsBuiltWithROCm():
345 return _pywrap_util_port.IsBuiltWithROCm()
348def IsBuiltWithXLA():
349 return _pywrap_util_port.IsBuiltWithXLA()
352def IsBuiltWithNvcc():
353 return _pywrap_util_port.IsBuiltWithNvcc()
356def GpuSupportsHalfMatMulAndConv():
357 return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
360def IsMklEnabled():
361 return _pywrap_util_port.IsMklEnabled()
364def InstallStackTraceHandler():
365 _pywrap_stacktrace_handler.InstallStacktraceHandler()
368def NHWCToNCHW(input_tensor):
369 """Converts the input from the NHWC format to NCHW.
371 Args:
372 input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
374 Returns:
375 converted tensor or shape array
376 """
377 # tensor dim -> new axis order
378 new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
379 if isinstance(input_tensor, ops.Tensor):
380 ndims = input_tensor.shape.ndims
381 return array_ops.transpose(input_tensor, new_axes[ndims])
382 else:
383 ndims = len(input_tensor)
384 return [input_tensor[a] for a in new_axes[ndims]]
387def NHWCToNCHW_VECT_C(input_shape_or_tensor):
388 """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
390 Note: Does not include quantization or type conversion steps, which should
391 be applied afterwards.
393 Args:
394 input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
396 Returns:
397 tensor or shape array transformed into NCHW_VECT_C
399 Raises:
400 ValueError: if last dimension of `input_shape_or_tensor` is not evenly
401 divisible by 4.
402 """
403 permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
404 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
405 temp_shape = (
406 input_shape_or_tensor.shape.as_list()
407 if is_tensor else input_shape_or_tensor)
408 if temp_shape[-1] % 4 != 0:
409 raise ValueError(
410 "Last dimension of input must be evenly divisible by 4 to convert to "
411 "NCHW_VECT_C.")
412 temp_shape[-1] //= 4
413 temp_shape.append(4)
414 permutation = permutations[len(temp_shape)]
415 if is_tensor:
416 t = array_ops.reshape(input_shape_or_tensor, temp_shape)
417 return array_ops.transpose(t, permutation)
418 else:
419 return [temp_shape[a] for a in permutation]
422def NCHW_VECT_CToNHWC(input_shape_or_tensor):
423 """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
425 Note: Does not include de-quantization or type conversion steps, which should
426 be applied beforehand.
428 Args:
429 input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
431 Returns:
432 tensor or shape array transformed into NHWC
434 Raises:
435 ValueError: if last dimension of `input_shape_or_tensor` is not 4.
436 """
437 permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
438 is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
439 input_shape = (
440 input_shape_or_tensor.shape.as_list()
441 if is_tensor else input_shape_or_tensor)
442 if input_shape[-1] != 4:
443 raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
444 permutation = permutations[len(input_shape)]
445 nhwc_shape = [input_shape[a] for a in permutation[:-1]]
446 nhwc_shape[-1] *= input_shape[-1]
447 if is_tensor:
448 t = array_ops.transpose(input_shape_or_tensor, permutation)
449 return array_ops.reshape(t, nhwc_shape)
450 else:
451 return nhwc_shape
454def NCHWToNHWC(input_tensor):
455 """Converts the input from the NCHW format to NHWC.
457 Args:
458 input_tensor: a 4- or 5-D tensor, or an array representing shape
460 Returns:
461 converted tensor or shape array
462 """
463 # tensor dim -> new axis order
464 new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
465 if isinstance(input_tensor, ops.Tensor):
466 ndims = input_tensor.shape.ndims
467 return array_ops.transpose(input_tensor, new_axes[ndims])
468 else:
469 ndims = len(input_tensor)
470 return [input_tensor[a] for a in new_axes[ndims]]
473def skip_if(condition):
474 """Skips the decorated function if condition is or evaluates to True.
476 Args:
477 condition: Either an expression that can be used in "if not condition"
478 statement, or a callable whose result should be a boolean.
480 Returns:
481 The wrapped function
482 """
484 def real_skip_if(fn):
486 def wrapper(*args, **kwargs):
487 if callable(condition):
488 skip = condition()
489 else:
490 skip = condition
491 if not skip:
492 return fn(*args, **kwargs)
494 return wrapper
496 return real_skip_if
499@contextlib.contextmanager
500def skip_if_error(test_obj, error_type, messages=None):
501 """Context manager to skip cases not considered failures by the tests.
503 Note that this does not work if used in setUpClass/tearDownClass.
504 Usage in setUp/tearDown works fine just like regular test methods.
506 Args:
507 test_obj: A test object provided as `self` in the test methods; this object
508 is usually an instance of `unittest.TestCase`'s subclass and should have
509 `skipTest` method.
510 error_type: The error type to skip. Note that if `messages` are given, both
511 `error_type` and `messages` need to match for the test to be skipped.
512 messages: Optional, a string or list of strings. If `None`, the test will be
513 skipped if `error_type` matches what is raised; otherwise, the test is
514 skipped if any of the `messages` is contained in the message of the error
515 raised, and `error_type` matches the error raised.
517 Yields:
518 Nothing.
519 """
520 if messages:
521 messages = nest.flatten(messages)
522 try:
523 yield
524 except error_type as e:
525 if not messages or any(message in str(e) for message in messages):
526 test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e)))
527 else:
528 raise
531def enable_c_shapes(fn):
532 """No-op. TODO(b/74620627): Remove this."""
533 return fn
536def with_c_shapes(cls):
537 """No-op. TODO(b/74620627): Remove this."""
538 return cls
541def enable_control_flow_v2(fn):
542 """Decorator for enabling CondV2 and WhileV2 on a test.
544 Note this enables using CondV2 and WhileV2 after running the test class's
545 setup/teardown methods.
547 In addition to this, callers must import the while_v2 module in order to set
548 the _while_v2 module in control_flow_ops.
550 Args:
551 fn: the function to be wrapped
553 Returns:
554 The wrapped function
555 """
557 def wrapper(*args, **kwargs):
558 enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
559 control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
560 try:
561 return fn(*args, **kwargs)
562 finally:
563 control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
565 return wrapper
568def with_control_flow_v2(cls):
569 """Adds methods that call original methods with WhileV2 and CondV2 enabled.
571 Note this enables CondV2 and WhileV2 in new methods after running the test
572 class's setup method.
574 In addition to this, callers must import the while_v2 module in order to set
575 the _while_v2 module in control_flow_ops.
577 If a test function has _disable_control_flow_v2 attr set to True (using the
578 @disable_control_flow_v2 decorator), the v2 function is not generated for it.
580 Example:
582 @test_util.with_control_flow_v2
583 class ControlFlowTest(test.TestCase):
585 def testEnabledForV2(self):
586 ...
588 @test_util.disable_control_flow_v2("b/xyzabc")
589 def testDisabledForV2(self):
590 ...
592 Generated class:
593 class ControlFlowTest(test.TestCase):
595 def testEnabledForV2(self):
596 ...
598 def testEnabledForV2WithControlFlowV2(self):
599 // Enable V2 flags.
600 testEnabledForV2(self)
601 // Restore V2 flags.
603 def testDisabledForV2(self):
604 ...
606 Args:
607 cls: class to decorate
609 Returns:
610 cls with new test methods added
611 """
612 if control_flow_util.ENABLE_CONTROL_FLOW_V2:
613 return cls
615 for name, value in cls.__dict__.copy().items():
616 if (callable(value) and
617 name.startswith(unittest.TestLoader.testMethodPrefix) and
618 not getattr(value, "_disable_control_flow_v2", False)):
619 setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
620 return cls
623def disable_control_flow_v2(unused_msg):
624 """Decorator for a function in a with_control_flow_v2 enabled test class.
626 Blocks the function from being run with v2 control flow ops.
628 Args:
629 unused_msg: Reason for disabling.
631 Returns:
632 The wrapped function with _disable_control_flow_v2 attr set to True.
633 """
635 def wrapper(func):
636 func._disable_control_flow_v2 = True
637 return func
639 return wrapper
642def enable_output_all_intermediates(fn):
643 """Force-enable outputing all intermediates from functional control flow ops.
645 Args:
646 fn: the function to be wrapped
648 Returns:
649 The wrapped function
650 """
652 def wrapper(*args, **kwargs):
653 output_all_intermediates_old = \
654 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
655 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
656 try:
657 return fn(*args, **kwargs)
658 finally:
659 control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
660 output_all_intermediates_old
662 return wrapper
665def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
666 """Decorator for asserting that no new Python objects persist after a test.
668 Runs the test multiple times executing eagerly, first as a warmup and then to
669 let objects accumulate. The warmup helps ignore caches which do not grow as
670 the test is run repeatedly.
672 Useful for checking that there are no missing Py_DECREFs in the C exercised by
673 a bit of Python.
675 Args:
676 func: The function to test.
677 warmup_iters: The numer of warmup iterations, excluded from measuring.
679 Returns:
680 The wrapped function performing the test.
681 """
683 def wrap_f(f):
684 def decorator(self, *args, **kwargs):
685 """Warms up, gets object counts, runs the test, checks for new objects."""
686 with context.eager_mode():
687 gc.disable()
688 # Python 3.11 removed "errors" and "skipped" as members of
689 # unittest.case._Outcome so get them from the test result object
690 # instead.
691 test_errors = None
692 test_skipped = None
693 if hasattr(self._outcome, "errors"):
694 test_errors = self._outcome.errors
695 test_skipped = self._outcome.skipped
696 else:
697 test_errors = self._outcome.result.errors
698 test_skipped = self._outcome.result.skipped
699 # Run the test 2 times as warmup, in an attempt to fill up caches, which
700 # should not grow as the test is run repeatedly below.
701 #
702 # TODO(b/117156879): Running warmup twice is black magic; we have seen
703 # tests that fail with 1 warmup run, and pass with 2, on various
704 # versions of python2.7.x.
705 for _ in range(warmup_iters):
706 f(self, *args, **kwargs)
707 # Since we aren't in the normal test lifecycle, we need to manually run
708 # cleanups to clear out their object references.
709 self.doCleanups()
711 # Some objects are newly created by _get_object_count_by_type(). So
712 # create and save as a dummy variable to include it as a baseline.
713 obj_count_by_type = _get_object_count_by_type()
714 gc.collect()
716 # Make sure any registered functions are cleaned up in the C++ runtime.
717 registered_function_names = context.context().list_function_names()
719 # unittest.doCleanups adds to self._outcome with each unwound call.
720 # These objects are retained across gc collections so we exclude them
721 # from the object count calculation.
722 obj_count_by_type = _get_object_count_by_type(
723 exclude=gc.get_referents(test_errors, test_skipped))
725 if ops.has_default_graph():
726 collection_sizes_before = {
727 collection: len(ops.get_collection(collection))
728 for collection in ops.get_default_graph().collections
729 }
730 for _ in range(3):
731 f(self, *args, **kwargs)
732 # Since we aren't in the normal test lifecycle, we need to manually run
733 # cleanups to clear out their object references.
734 self.doCleanups()
735 # Note that gc.get_objects misses anything that isn't subject to garbage
736 # collection (C types). Collections are a common source of leaks, so we
737 # test for collection sizes explicitly.
738 if ops.has_default_graph():
739 for collection_key in ops.get_default_graph().collections:
740 collection = ops.get_collection(collection_key)
741 size_before = collection_sizes_before.get(collection_key, 0)
742 if len(collection) > size_before:
743 raise AssertionError(
744 ("Collection %s increased in size from "
745 "%d to %d (current items %s).") %
746 (collection_key, size_before, len(collection), collection))
747 # Make sure our collection checks don't show up as leaked memory by
748 # removing references to temporary variables.
749 del collection
750 del collection_key
751 del size_before
752 del collection_sizes_before
753 gc.collect()
755 # There should be no new Python objects hanging around.
756 obj_count_by_type = (
757 _get_object_count_by_type(
758 exclude=gc.get_referents(test_errors, test_skipped)) -
759 obj_count_by_type)
761 # There should be no newly registered functions hanging around.
762 leftover_functions = (
763 context.context().list_function_names() - registered_function_names)
764 assert not leftover_functions, (
765 "The following functions were newly created: %s" %
766 leftover_functions)
768 # In some cases (specifically on MacOS), new_count is somehow
769 # smaller than previous_count.
770 # Using plain assert because not all classes using this decorator
771 # have assertLessEqual
772 assert not obj_count_by_type, (
773 "The following objects were newly created: %s" %
774 str(obj_count_by_type))
775 gc.enable()
776 return decorator
778 if func is None:
779 return wrap_f
780 else:
781 return wrap_f(func)
784def assert_no_new_tensors(f):
785 """Decorator for asserting that no new Tensors persist after a test.
787 Mainly useful for checking that code using the Python C API has correctly
788 manipulated reference counts.
790 Clears the caches that it knows about, runs the garbage collector, then checks
791 that there are no Tensor or Tensor-like objects still around. This includes
792 Tensors to which something still has a reference (e.g. from missing
793 Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
794 of the objects has __del__ defined).
796 Args:
797 f: The test case to run.
799 Returns:
800 The decorated test case.
801 """
803 def decorator(self, **kwargs):
804 """Finds existing Tensors, runs the test, checks for new Tensors."""
806 def _is_tensorflow_object(obj):
807 try:
808 return isinstance(obj,
809 (ops.Tensor, variables.Variable,
810 tensor_shape.Dimension, tensor_shape.TensorShape))
811 except (ReferenceError, AttributeError):
812 # If the object no longer exists, we don't care about it.
813 return False
815 tensors_before = set(
816 id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
817 outside_executed_eagerly = context.executing_eagerly()
818 # Run the test in a new graph so that collections get cleared when it's
819 # done, but inherit the graph key so optimizers behave.
820 outside_graph_key = ops.get_default_graph()._graph_key
821 with ops.Graph().as_default():
822 ops.get_default_graph()._graph_key = outside_graph_key
823 if outside_executed_eagerly:
824 with context.eager_mode():
825 result = f(self, **kwargs)
826 else:
827 result = f(self, **kwargs)
828 # Make an effort to clear caches, which would otherwise look like leaked
829 # Tensors.
830 context.context()._clear_caches() # pylint: disable=protected-access
831 gc.collect()
832 tensors_after = [
833 obj for obj in gc.get_objects()
834 if _is_tensorflow_object(obj) and id(obj) not in tensors_before
835 ]
836 if tensors_after:
837 raise AssertionError(("%d Tensors not deallocated after test: %s" % (
838 len(tensors_after),
839 str(tensors_after),
840 )))
841 return result
843 return decorator
846def _find_reference_cycle(objects, idx):
848 def get_ignore_reason(obj, denylist):
849 """Tests whether an object should be omitted from the dependency graph."""
850 if len(denylist) > 100:
851 return "<depth limit>"
852 if tf_inspect.isframe(obj):
853 if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
854 return "<test code>"
855 for b in denylist:
856 if b is obj:
857 return "<test code>"
858 if obj is denylist:
859 return "<test code>"
860 return None
862 # Note: this function is meant to help with diagnostics. Its output is purely
863 # a human-readable representation, so you may freely modify it to suit your
864 # needs.
865 def describe(obj, denylist, leaves_only=False):
866 """Returns a custom human-readable summary of obj.
868 Args:
869 obj: the value to describe.
870 denylist: same as denylist in get_ignore_reason.
871 leaves_only: boolean flag used when calling describe recursively. Useful
872 for summarizing collections.
873 """
874 if get_ignore_reason(obj, denylist):
875 return "{}{}".format(get_ignore_reason(obj, denylist), type(obj))
876 if tf_inspect.isframe(obj):
877 return "frame: {}".format(tf_inspect.getframeinfo(obj))
878 elif tf_inspect.ismodule(obj):
879 return "module: {}".format(obj.__name__)
880 else:
881 if leaves_only:
882 return "{}, {}".format(type(obj), id(obj))
883 elif isinstance(obj, list):
884 return "list({}): {}".format(
885 id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
886 elif isinstance(obj, tuple):
887 return "tuple({}): {}".format(
888 id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
889 elif isinstance(obj, dict):
890 return "dict({}): {} keys".format(id(obj), len(obj.keys()))
891 elif tf_inspect.isfunction(obj):
892 return "function({}) {}; globals ID: {}".format(
893 id(obj), obj.__name__, id(obj.__globals__))
894 else:
895 return "{}, {}".format(type(obj), id(obj))
897 def build_ref_graph(obj, graph, reprs, denylist):
898 """Builds a reference graph as <referrer> -> <list of referents>.
900 Args:
901 obj: The object to start from. The graph will be built by recursively
902 adding its referrers.
903 graph: Dict holding the graph to be built. To avoid creating extra
904 references, the graph holds object IDs rather than actual objects.
905 reprs: Auxiliary structure that maps object IDs to their human-readable
906 description.
907 denylist: List of objects to ignore.
908 """
909 referrers = gc.get_referrers(obj)
910 denylist = denylist + (referrers,)
912 obj_id = id(obj)
913 for r in referrers:
914 if get_ignore_reason(r, denylist) is None:
915 r_id = id(r)
916 if r_id not in graph:
917 graph[r_id] = []
918 if obj_id not in graph[r_id]:
919 graph[r_id].append(obj_id)
920 build_ref_graph(r, graph, reprs, denylist)
921 reprs[r_id] = describe(r, denylist)
923 def find_cycle(el, graph, reprs, path):
924 """Finds and prints a single cycle in the dependency graph."""
925 if el not in graph:
926 return
927 for r in graph[el]:
928 if r in path:
929 logging.error("Reference cycle sample:")
930 for p in path + (r,):
931 logging.error(reprs.get(p, "unknown object " + str(p)))
932 return True
933 else:
934 if find_cycle(r, graph, reprs, path + (r,)):
935 return True
936 return False
938 obj = objects[idx]
939 graph = {} # referrer ID -> object ID
940 reprs = {} # object ID -> description
941 build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
942 describe, build_ref_graph, find_cycle))
943 for k in graph:
944 if find_cycle(k, graph, reprs, ()):
945 return True
946 return False
949def assert_no_garbage_created(f):
950 """Test method decorator to assert that no garbage has been created.
952 Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
953 cannot be un-set (i.e. will disable garbage collection for any other unit
954 tests in the same file/shard).
956 Args:
957 f: The function to decorate.
959 Returns:
960 The decorated function.
961 """
963 def decorator(self, **kwargs):
964 """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
965 gc.disable()
966 previous_debug_flags = gc.get_debug()
967 gc.set_debug(gc.DEBUG_SAVEALL)
968 gc.collect()
969 previous_garbage = len(gc.garbage)
970 result = f(self, **kwargs)
971 gc.collect()
972 new_garbage = len(gc.garbage)
973 if new_garbage > previous_garbage:
975 for i, obj in enumerate(gc.garbage[previous_garbage:]):
976 # Known false positive for ast.fix_missing_locations.
977 if getattr(obj, "__module__", "") == "ast":
978 new_garbage -= 3
980 if new_garbage > previous_garbage:
981 logging.error(
982 "The decorated test created work for Python's garbage collector, "
983 "likely due to a reference cycle. New objects in cycle(s):")
984 for i, obj in enumerate(gc.garbage[previous_garbage:]):
985 try:
986 logging.error("Object %d of %d", i,
987 len(gc.garbage) - previous_garbage)
989 def _safe_object_str(obj):
990 return "<%s %d>" % (obj.__class__.__name__, id(obj))
992 logging.error(" Object type: %s", _safe_object_str(obj))
993 logging.error(
994 " Referrer types: %s", ", ".join(
995 [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
996 logging.error(
997 " Referent types: %s", ", ".join(
998 [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
999 logging.error(" Object attribute names: %s", dir(obj))
1000 logging.error(" Object __str__:")
1001 logging.error(obj)
1002 logging.error(" Object __repr__:")
1003 logging.error(repr(obj))
1004 except Exception: # pylint: disable=broad-except
1005 logging.error("(Exception while printing object)")
1007 # When garbage is created, this call can help identify reference cycles,
1008 # which are typically the cause of such garbage.
1009 if new_garbage > previous_garbage:
1010 for i in range(previous_garbage, new_garbage):
1011 if _find_reference_cycle(gc.garbage, i):
1012 break
1014 # This will fail if any garbage has been created, typically because of a
1015 # reference cycle.
1016 self.assertEqual(previous_garbage, new_garbage)
1017 # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
1018 # be nice to be able to decorate arbitrary tests in a large test suite and
1019 # not hold on to every object in other tests.
1020 gc.set_debug(previous_debug_flags)
1021 gc.enable()
1022 return result
1024 return decorator
1027def _combine_named_parameters(**kwargs):
1028 """Generate combinations based on its keyword arguments.
1030 Two sets of returned combinations can be concatenated using +. Their product
1031 can be computed using `times()`.
1033 Args:
1034 **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1035 `option=the_only_possibility`.
1037 Returns:
1038 a list of dictionaries for each combination. Keys in the dictionaries are
1039 the keyword argument names. Each key has one value - one of the
1040 corresponding keyword argument values.
1041 """
1042 sort_by_key = lambda k: k[0]
1043 combinations = []
1044 for key, values in sorted(kwargs.items(), key=sort_by_key):
1045 if not isinstance(values, list):
1046 values = [values]
1047 combinations.append([(key, value) for value in values])
1049 return [OrderedDict(result) for result in itertools.product(*combinations)]
1052def generate_combinations_with_testcase_name(**kwargs):
1053 """Generate combinations based on its keyword arguments using combine().
1055 This function calls combine() and appends a testcase name to the list of
1056 dictionaries returned. The 'testcase_name' key is a required for named
1057 parameterized tests.
1059 Args:
1060 **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1061 `option=the_only_possibility`.
1063 Returns:
1064 a list of dictionaries for each combination. Keys in the dictionaries are
1065 the keyword argument names. Each key has one value - one of the
1066 corresponding keyword argument values.
1067 """
1068 combinations = _combine_named_parameters(**kwargs)
1069 named_combinations = []
1070 for combination in combinations:
1071 assert isinstance(combination, OrderedDict)
1072 name = "".join([
1073 "_{}_{}".format("".join(filter(str.isalnum, key)),
1074 "".join(filter(str.isalnum, str(value))))
1075 for key, value in combination.items()
1076 ])
1077 named_combinations.append(
1078 OrderedDict(
1079 list(combination.items()) +
1080 [("testcase_name", "_test{}".format(name))]))
1082 return named_combinations
1085def run_all_in_graph_and_eager_modes(cls):
1086 """Execute all test methods in the given class with and without eager."""
1087 base_decorator = run_in_graph_and_eager_modes
1088 for name in dir(cls):
1089 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1090 name.startswith("testSkipEager") or
1091 name.startswith("test_skip_eager") or
1092 name == "test_session"):
1093 continue
1094 value = getattr(cls, name, None)
1095 if callable(value):
1096 setattr(cls, name, base_decorator(value))
1097 return cls
1100def enable_nested_function_shape_inference(fn):
1101 """Decorator for enabling nested_function_shape_inference on a test.
1103 This function returns a decorator intended to be applied to test methods in
1104 a `tf.test.TestCase` class. Doing so will set nested_function_shape_inference,
1105 reset the context, execute the test, then reset the context to the state
1106 it was in prior to this test.
1108 Example:
1110 class MyTest(test.TestCase):
1112 @enable_nested_function_shape_inference
1113 def testFoo(self):
1114 ...
1116 Args:
1117 fn: the function to be wrapped.
1119 Returns:
1120 The wrapped function.
1121 """
1123 def wrapper(*args, **kwargs):
1124 # If `nested_function_shape_inference` is already enabled do nothing.
1125 if flags.config().enable_nested_function_shape_inference.value():
1126 return fn(*args, **kwargs)
1128 flags.config().enable_nested_function_shape_inference.reset(True)
1129 try:
1130 return fn(*args, **kwargs)
1131 finally:
1132 flags.config().enable_nested_function_shape_inference.reset(False)
1134 return wrapper
1137def enable_quantized_dtypes_training(fn):
1138 """Decorator for enabling quantized_dtypes_training on a test.
1140 This function returns a decorator intended to be applied to test methods in
1141 a `tf.test.TestCase` class. Doing so will set quantized_dtypes_training,
1142 reset the context, execute the test, then reset the context to the state
1143 it was in prior to this test.
1145 Example:
1147 class MyTest(test.TestCase):
1149 @enable_quantized_dtypes_training
1150 def testFoo(self):
1151 ...
1153 Args:
1154 fn: the function to be wrapped.
1156 Returns:
1157 The wrapped function.
1158 """
1160 def wrapper(*args, **kwargs):
1161 # If `enable_quantized_dtypes_training` is already enabled do nothing.
1162 if flags.config().enable_quantized_dtypes_training.value():
1163 return fn(*args, **kwargs)
1165 flags.config().enable_quantized_dtypes_training.reset(True)
1166 try:
1167 return fn(*args, **kwargs)
1168 finally:
1169 flags.config().enable_quantized_dtypes_training.reset(False)
1171 return wrapper
1174def enable_eager_op_as_function(fn):
1175 """Returns the same fn. This will be removed once all usages are removed.
1177 Args:
1178 fn: the function to be wrapped.
1180 Returns:
1181 The wrapped function.
1182 """
1184 def wrapper(*args, **kwargs):
1185 return fn(*args, **kwargs)
1187 return wrapper
1190@tf_export("test.with_eager_op_as_function")
1191def with_eager_op_as_function(cls=None, only_as_function=False): # pylint: disable=unused-argument
1192 """Returns the same class. This will be removed once all usages are removed.
1194 Args:
1195 cls: class to decorate.
1196 only_as_function: unused argument.
1198 Returns:
1199 cls
1200 """
1202 def decorator(cls):
1203 return cls
1205 if cls is not None:
1206 return decorator(cls)
1208 return decorator
1211def enable_graph_building_optimization(fn):
1212 """Decorator for enabling graph_building_optimization on a test.
1214 This function returns a decorator intended to be applied to test methods in
1215 a `tf.test.TestCase` class. Doing so will enable graph_building_optimization,
1216 execute the test, then reset the feature flag to its default value.
1218 Example:
1220 class MyTest(test.TestCase):
1222 @enable_graph_building_optimization
1223 def testFoo(self):
1224 ...
1226 Args:
1227 fn: the function to be wrapped.
1229 Returns:
1230 The wrapped function.
1231 """
1233 def wrapper(*args, **kwargs):
1234 # If `graph_building_optimization` is already enabled do nothing.
1235 if flags.config().graph_building_optimization.value():
1236 return fn(*args, **kwargs)
1238 flags.config().graph_building_optimization.reset(True)
1239 try:
1240 return fn(*args, **kwargs)
1241 finally:
1242 flags.config().graph_building_optimization.reset(False)
1244 return wrapper
1247def add_graph_building_optimization_tests(cls=None):
1248 """Adds methods with graph_building_optimization enabled to the test suite.
1250 Example:
1252 @test_util.add_graph_building_optimization_tests
1253 class FooTest(test.TestCase):
1255 def testBar(self):
1256 ...
1258 Generated class:
1259 class FooTest(test.TestCase):
1261 def testBar(self):
1262 ...
1264 def testBarWithGraphBuildingOptimization(self):
1265 // Enable graph_building_optimization
1266 testBar(self)
1267 // Disable graph_building_optimization
1269 Args:
1270 cls: class to decorate.
1272 Returns:
1273 cls with new test methods added.
1274 """
1276 def decorator(cls):
1277 if flags.config().graph_building_optimization.value():
1278 return cls
1280 for name, value in cls.__dict__.copy().items():
1281 if (callable(value) and
1282 (name.startswith(unittest.TestLoader.testMethodPrefix) or
1283 name.startswith("benchmark"))):
1284 setattr(cls, name + "WithGraphBuildingOptimization",
1285 enable_graph_building_optimization(value))
1286 return cls
1288 if cls is not None:
1289 return decorator(cls)
1291 return decorator
1294def disable_eager_op_as_function(unused_msg):
1295 """Decorator for a function in a with_eager_op_as_function enabled test class.
1297 Blocks the function from being run with eager_op_as_function enabled.
1299 Args:
1300 unused_msg: Reason for disabling.
1302 Returns:
1303 The wrapped function with _disable_eager_op_as_function attr set to True.
1304 """
1305 return _disable_test(execute_func=False)
1308def set_xla_env_flag(func=None, flag=""):
1309 """Decorator for setting XLA_FLAGS prior to running a test.
1311 This function returns a decorator intended to be applied to test methods in
1312 a `tf.test.TestCase` class. Doing so will allow users to set any xla flags
1313 exposed via the XLA_FLAGS environment variable, execute the test, then reset
1314 the XLA_FLAGS to the state it was in prior to this test.
1316 Example:
1318 class MyTest(test.TestCase):
1320 @set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false')
1321 def testFoo(self):
1322 ...
1324 Args:
1325 func: The function to be wrapped.
1326 flag: The xla flag to be set in the XLA_FLAGS env variable.
1328 Returns:
1329 The wrapped function.
1330 """
1332 def decorator(f):
1334 @functools.wraps(f)
1335 def decorated(*args, **kwargs):
1336 original_xla_flags = os.environ.get("XLA_FLAGS")
1337 new_xla_flags = flag
1338 if original_xla_flags:
1339 new_xla_flags = new_xla_flags + " " + original_xla_flags
1340 os.environ["XLA_FLAGS"] = new_xla_flags
1341 try:
1342 return f(*args, **kwargs)
1343 finally:
1344 if original_xla_flags is None:
1345 del os.environ["XLA_FLAGS"]
1346 else:
1347 os.environ["XLA_FLAGS"] = original_xla_flags
1349 return decorated
1351 if func is not None:
1352 return decorator(func)
1354 return decorator
1357def build_as_function_and_v1_graph(func=None):
1358 """Run a test case in v1 graph mode and inside tf.function in eager mode.
1360 WARNING: This decorator can only be used in test cases that statically checks
1361 generated graph. Attempting to evaluate graph or function results via.
1362 session.run() or self.evaluate() will fail.
1364 WARNING: This decorator can only be used for test cases that inherit from
1365 absl.testing.parameterized.TestCase.
1367 Args:
1368 func: Test case function to be decorated.
1370 Returns:
1371 Decorated test case function.
1372 """
1374 def decorator(f):
1375 if tf_inspect.isclass(f):
1376 raise ValueError(
1377 "`run_in_graph_mode_and_function` only supports test methods.")
1379 @parameterized.named_parameters(("_v1_graph", "v1_graph"),
1380 ("_function", "function"))
1381 @functools.wraps(f)
1382 def decorated(self, run_mode, *args, **kwargs):
1383 if run_mode == "v1_graph":
1384 with ops.Graph().as_default():
1385 f(self, *args, **kwargs)
1386 elif run_mode == "function":
1388 @def_function.function
1389 def function_in_eager():
1390 f(self, *args, **kwargs)
1392 # Create a new graph for the eagerly executed version of this test for
1393 # better isolation.
1394 graph_for_eager_test = ops.Graph()
1395 with graph_for_eager_test.as_default(), context.eager_mode():
1396 function_in_eager()
1397 ops.dismantle_graph(graph_for_eager_test)
1398 else:
1399 raise ValueError("Unknown run mode %s" % run_mode)
1401 return decorated
1403 if func is not None:
1404 return decorator(func)
1406 return decorator
1409def run_in_async_and_sync_mode(f):
1410 """Execute the test in async mode and sync mode."""
1412 @parameterized.named_parameters([("Async", True), ("", False)])
1413 @functools.wraps(f)
1414 def decorator(self, async_mode, *args, **kwargs):
1415 if async_mode:
1416 with context.execution_mode(context.ASYNC):
1417 f(self, *args, **kwargs)
1418 else:
1419 with context.execution_mode(context.SYNC):
1420 f(self, *args, **kwargs)
1421 return decorator
1424def run_in_graph_and_eager_modes(func=None,
1425 config=None,
1426 use_gpu=True,
1427 assert_no_eager_garbage=False):
1428 """Execute the decorated test with and without enabling eager execution.
1430 This function returns a decorator intended to be applied to test methods in
1431 a `tf.test.TestCase` class. Doing so will cause the contents of the test
1432 method to be executed twice - once normally, and once with eager execution
1433 enabled. This allows unittests to confirm the equivalence between eager
1434 and graph execution (see `tf.compat.v1.enable_eager_execution`).
1436 For example, consider the following unittest:
1438 ```python
1439 class MyTests(tf.test.TestCase):
1441 @run_in_graph_and_eager_modes
1442 def test_foo(self):
1443 x = tf.constant([1, 2])
1444 y = tf.constant([3, 4])
1445 z = tf.add(x, y)
1446 self.assertAllEqual([4, 6], self.evaluate(z))
1448 if __name__ == "__main__":
1449 tf.test.main()
1450 ```
1452 This test validates that `tf.add()` has the same behavior when computed with
1453 eager execution enabled as it does when constructing a TensorFlow graph and
1454 executing the `z` tensor in a session.
1456 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1457 `run_in_graph_and_eager_modes` are available decorators for different
1458 v1/v2/eager/graph combinations.
1461 Args:
1462 func: function to be annotated. If `func` is None, this method returns a
1463 decorator the can be applied to a function. If `func` is not None this
1464 returns the decorator applied to `func`.
1465 config: An optional config_pb2.ConfigProto to use to configure the session
1466 when executing graphs.
1467 use_gpu: If True, attempt to run as many operations as possible on GPU.
1468 assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
1469 collector and asserts that no extra garbage has been created when running
1470 the test with eager execution enabled. This will fail if there are
1471 reference cycles (e.g. a = []; a.append(a)). Off by default because some
1472 tests may create garbage for legitimate reasons (e.g. they define a class
1473 which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
1474 Python interpreters (meaning that tests which rely on objects being
1475 collected elsewhere in the unit test file will not work). Additionally,
1476 checks that nothing still has a reference to Tensors that the test
1477 allocated.
1479 Returns:
1480 Returns a decorator that will run the decorated test method twice:
1481 once by constructing and executing a graph in a session and once with
1482 eager execution enabled.
1483 """
1485 def decorator(f):
1486 if tf_inspect.isclass(f):
1487 raise ValueError(
1488 "`run_in_graph_and_eager_modes` only supports test methods. "
1489 "Did you mean to use `run_all_in_graph_and_eager_modes`?")
1491 def decorated(self, *args, **kwargs):
1492 logging.info("Running %s in GRAPH mode.", f.__name__)
1493 try:
1494 with context.graph_mode():
1495 with self.test_session(use_gpu=use_gpu, config=config):
1496 f(self, *args, **kwargs)
1497 except unittest.case.SkipTest:
1498 pass
1500 def run_eagerly(self, **kwargs):
1501 logging.info("Running %s in EAGER mode.", f.__name__)
1502 if not use_gpu:
1503 with ops.device("/device:CPU:0"):
1504 f(self, *args, **kwargs)
1505 else:
1506 f(self, *args, **kwargs)
1508 if assert_no_eager_garbage:
1509 ops.reset_default_graph()
1510 run_eagerly = assert_no_new_tensors(
1511 assert_no_garbage_created(run_eagerly))
1513 # This decorator runs the wrapped test twice.
1514 # Reset the test environment between runs.
1515 self.tearDown()
1516 self._tempdir = None
1517 # Create a new graph for the eagerly executed version of this test for
1518 # better isolation.
1519 graph_for_eager_test = ops.Graph()
1520 with graph_for_eager_test.as_default(), context.eager_mode():
1521 self.setUp()
1522 run_eagerly(self, **kwargs)
1523 ops.dismantle_graph(graph_for_eager_test)
1525 return tf_decorator.make_decorator(f, decorated)
1527 if func is not None:
1528 return decorator(func)
1530 return decorator
1533def py_func_if_in_function(f):
1535 def decorated(*args, **kwds):
1536 if not ops.inside_function():
1537 return f(*args, **kwds)
1539 tensor_args = []
1540 tensor_indices = []
1541 for i, arg in enumerate(args):
1542 if isinstance(arg, (ops.Tensor, variables.Variable)):
1543 tensor_args.append(arg)
1544 tensor_indices.append(i)
1546 def inner_f(*inner_tensor_args):
1547 my_args = list(args)
1548 for i, n in zip(tensor_indices, inner_tensor_args):
1549 my_args[i] = n
1550 return f(*my_args, **kwds)
1552 return script_ops.py_func(inner_f, tensor_args, [])
1554 return tf_decorator.make_decorator(f, decorated)
1557def also_run_as_tf_function(f):
1558 """Runs the decorated test twice--once as is, once inside a tf.function.
1560 This allows you to run a test both in eager execution and inside a
1561 tf.function, exercising the two execution modes supported in tf 2.0. The test
1562 assertions are automatically done inside tf.py_funcs, and tf.function ensures
1563 that they run in the proper order and with the proper side effects.
1565 Currently variable creation is not supported in tests annotated with this
1566 decorator since it's tricky to ensure the variable doesn't get repeatedly
1567 created when retracing the tf.function.
1569 Args:
1570 f: the test method to be decorated
1572 Returns:
1573 The decorated test method, which will run both in eager and inside a
1574 tf.function.
1575 """
1577 def decorated(*args, **kwds):
1579 def bound_f():
1580 f(*args, **kwds)
1582 with context.eager_mode():
1583 # Running in eager mode
1584 bound_f()
1585 # Running as TF function
1586 # TODO(b/121143941): Remove the autograph override.
1587 def_function.function(bound_f, autograph=False)()
1589 return decorated
1592def deprecated_graph_mode_only(func=None):
1593 """Execute the decorated test in graph mode.
1595 This function returns a decorator intended to be applied to tests that are not
1596 compatible with eager mode. When this decorator is applied, the test body will
1597 be run in an environment where API calls construct graphs instead of executing
1598 eagerly.
1600 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1601 `run_in_graph_and_eager_modes` are available decorators for different
1602 v1/v2/eager/graph combinations.
1604 Args:
1605 func: function to be annotated. If `func` is None, this method returns a
1606 decorator the can be applied to a function. If `func` is not None this
1607 returns the decorator applied to `func`.
1609 Returns:
1610 Returns a decorator that will run the decorated test method in graph mode.
1611 """
1613 def decorator(f):
1614 if tf_inspect.isclass(f):
1615 setup = f.__dict__.get("setUp")
1616 if setup is not None:
1617 setattr(f, "setUp", decorator(setup))
1619 for name, value in f.__dict__.copy().items():
1620 if (callable(value) and
1621 name.startswith(unittest.TestLoader.testMethodPrefix)):
1622 setattr(f, name, decorator(value))
1624 return f
1626 def decorated(self, *args, **kwargs):
1627 if context.executing_eagerly():
1628 with context.graph_mode():
1629 return f(self, *args, **kwargs)
1630 else:
1631 return f(self, *args, **kwargs)
1633 return decorated
1635 if func is not None:
1636 return decorator(func)
1638 return decorator
1641run_deprecated_v1 = deprecated_graph_mode_only
1644def run_all_in_deprecated_graph_mode_only(cls):
1645 """Execute all tests in a class in graph mode."""
1646 base_decorator = deprecated_graph_mode_only
1647 for name in dir(cls):
1648 if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1649 name == "test_session"):
1650 continue
1651 value = getattr(cls, name, None)
1652 if callable(value):
1653 setattr(cls, name, base_decorator(value))
1654 return cls
1657def run_v1_only(reason, func=None):
1658 """Execute the decorated test only if running in v1 mode.
1660 This function is intended to be applied to tests that exercise v1 only
1661 functionality. If the test is run in v2 mode it will simply be skipped.
1663 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1664 `run_in_graph_and_eager_modes` are available decorators for different
1665 v1/v2/eager/graph combinations.
1667 Args:
1668 reason: string giving a reason for limiting the test to v1 only.
1669 func: function to be annotated. If `func` is None, this method returns a
1670 decorator the can be applied to a function. If `func` is not None this
1671 returns the decorator applied to `func`.
1673 Returns:
1674 Returns a decorator that will conditionally skip the decorated test method.
1675 """
1676 if not isinstance(reason, str):
1677 raise ValueError("'reason' should be string, got {}".format(type(reason)))
1679 def decorator(f):
1680 if tf_inspect.isclass(f):
1681 # To skip an entire test suite class, we only decorate the setUp method
1682 # to skip all tests. There are cases when setUp is not defined (not
1683 # overridden in subclasses of TestCase, so not available in f.__dict__
1684 # below). For those cases, we walk the method resolution order list and
1685 # pick the first setUp method we find (usually this should be the one in
1686 # the parent class since that's the TestCase class).
1687 for cls in type.mro(f):
1688 setup = cls.__dict__.get("setUp")
1689 if setup is not None:
1690 setattr(f, "setUp", decorator(setup))
1691 break
1693 return f
1694 else:
1695 # If f is just a function, just create a decorator for it and return it
1696 def decorated(self, *args, **kwargs):
1697 if tf2.enabled():
1698 self.skipTest(reason)
1700 return f(self, *args, **kwargs)
1702 return decorated
1704 if func is not None:
1705 return decorator(func)
1707 return decorator
1710def run_v2_only(func=None):
1711 """Execute the decorated test only if running in v2 mode.
1713 This function is intended to be applied to tests that exercise v2 only
1714 functionality. If the test is run in v1 mode it will simply be skipped.
1716 `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1717 `run_in_graph_and_eager_modes` are available decorators for different
1718 v1/v2/eager/graph combinations.
1720 Args:
1721 func: function to be annotated. If `func` is None, this method returns a
1722 decorator the can be applied to a function. If `func` is not None this
1723 returns the decorator applied to `func`.
1725 Returns:
1726 Returns a decorator that will conditionally skip the decorated test method.
1727 """
1729 def decorator(f):
1730 if tf_inspect.isclass(f):
1731 raise ValueError("`run_v2_only` only supports test methods.")
1733 def decorated(self, *args, **kwargs):
1734 if not tf2.enabled():
1735 self.skipTest("Test is only compatible with v2")
1737 return f(self, *args, **kwargs)
1739 return decorated
1741 if func is not None:
1742 return decorator(func)
1744 return decorator
1747def run_gpu_only(func=None):
1748 """Execute the decorated test only if a GPU is available.
1750 This function is intended to be applied to tests that require the presence
1751 of a GPU. If a GPU is absent, it will simply be skipped.
1753 Args:
1754 func: function to be annotated. If `func` is None, this method returns a
1755 decorator the can be applied to a function. If `func` is not None this
1756 returns the decorator applied to `func`.
1758 Returns:
1759 Returns a decorator that will conditionally skip the decorated test method.
1760 """
1762 def decorator(f):
1763 if tf_inspect.isclass(f):
1764 raise ValueError("`run_gpu_only` only supports test methods.")
1766 def decorated(self, *args, **kwargs):
1767 if not is_gpu_available():
1768 self.skipTest("Test requires GPU")
1770 return f(self, *args, **kwargs)
1772 return decorated
1774 if func is not None:
1775 return decorator(func)
1777 return decorator
1780def run_cuda_only(func=None):
1781 """Execute the decorated test only if a GPU is available.
1783 This function is intended to be applied to tests that require the presence
1784 of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1786 Args:
1787 func: function to be annotated. If `func` is None, this method returns a
1788 decorator the can be applied to a function. If `func` is not None this
1789 returns the decorator applied to `func`.
1791 Returns:
1792 Returns a decorator that will conditionally skip the decorated test method.
1793 """
1795 def decorator(f):
1796 if tf_inspect.isclass(f):
1797 raise ValueError("`run_cuda_only` only supports test methods.")
1799 def decorated(self, *args, **kwargs):
1800 if not is_gpu_available(cuda_only=True):
1801 self.skipTest("Test requires CUDA GPU")
1803 return f(self, *args, **kwargs)
1805 return decorated
1807 if func is not None:
1808 return decorator(func)
1810 return decorator
1813def run_gpu_or_tpu(func=None):
1814 """Execute the decorated test only if a physical GPU or TPU is available.
1816 This function is intended to be applied to tests that require the presence
1817 of a physical GPU or TPU. It complies with the following rules:
1818 - If a GPU is available, the test will run on the GPU.
1819 - If a GPU is absent and a TPU is available, the test will run on the TPU.
1820 - If both GPU and TPU are absent, the test will be skipped.
1822 Args:
1823 func: function to be annotated. If `func` is None, this method returns a
1824 decorator the can be applied to a function. If `func` is not None this
1825 returns the decorator applied to `func`.
1827 Returns:
1828 Returns a decorator that will conditionally skip the decorated test method.
1829 """
1831 def decorator(f):
1832 if tf_inspect.isclass(f):
1833 raise ValueError("`run_gpu_or_tpu` only supports test methods.")
1835 def decorated(self, *args, **kwargs):
1836 if config.list_physical_devices("GPU"):
1837 return f(self, "GPU", *args, **kwargs)
1839 if config.list_physical_devices("TPU"):
1840 return f(self, "TPU", *args, **kwargs)
1842 self.skipTest("Test requires GPU or TPU")
1844 return decorated
1846 return decorator if func is None else decorator(func)
1849def with_forward_compatibility_horizons(*horizons):
1850 """Executes the decorated test with the specified forward-compat horizons.
1852 Args:
1853 *horizons: A list of (year, month, day) tuples. If the list includes
1854 `None`, then the test will also be run with no forward-compatibility
1855 horizon set.
1857 Returns:
1858 A decorator that will execute the test with the specified horizons.
1859 """
1860 if not horizons:
1861 raise ValueError("Expected at least one horizon.")
1862 for horizon in horizons:
1863 if not ((horizon is None) or
1864 (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
1865 raise ValueError("Bad horizon value: %r" % horizon)
1867 def decorator(f):
1868 if tf_inspect.isclass(f):
1869 raise ValueError("`with_forward_compatibility_horizons` only "
1870 "supports test methods.")
1871 def decorated(self, *args, **kwargs):
1872 for horizon in horizons:
1873 if horizon is None:
1874 f(self, *args, **kwargs)
1875 else:
1876 (year, month, day) = horizon
1877 with forward_compatibility_horizon(year, month, day):
1878 f(self, *args, **kwargs)
1879 return decorated
1881 return decorator
1884@deprecation.deprecated(None,
1885 "Use `tf.config.list_physical_devices('GPU')` instead.")
1886@tf_export("test.is_gpu_available")
1887def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1888 """Returns whether TensorFlow can access a GPU.
1890 Warning: if a non-GPU version of the package is installed, the function would
1891 also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow
1892 was build with CUDA support.
1894 For example,
1895 >>> gpu_available = tf.test.is_gpu_available()
1896 >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True)
1897 >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0))
1899 Args:
1900 cuda_only: limit the search to CUDA GPUs.
1901 min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1902 CUDA compute capability required, or None if no requirement.
1904 Note that the keyword arg name "cuda_only" is misleading (since routine will
1905 return true when a GPU device is available irrespective of whether TF was
1906 built with CUDA support or ROCm support. However no changes here because
1908 ++ Changing the name "cuda_only" to something more generic would break
1909 backward compatibility
1911 ++ Adding an equivalent "rocm_only" would require the implementation check
1912 the build type. This in turn would require doing the same for CUDA and thus
1913 potentially break backward compatibility
1915 ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility,
1916 but would require most (if not all) callers to update the call to use
1917 "cuda_or_rocm_only" instead of "cuda_only"
1919 Returns:
1920 True if a GPU device of the requested kind is available.
1921 """
1923 # This was needed earlier when we had support for SYCL in TensorFlow.
1924 del cuda_only
1926 try:
1927 for local_device in device_lib.list_local_devices():
1928 if local_device.device_type == "GPU":
1929 gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
1930 cc = gpu_info.compute_capability or (0, 0)
1931 if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
1932 return True
1933 return False
1934 except errors_impl.NotFoundError as e:
1935 if not all(x in str(e) for x in ["CUDA", "not find"]):
1936 raise e
1937 else:
1938 logging.error(str(e))
1939 return False
1942@contextlib.contextmanager
1943def device(use_gpu):
1944 """Uses gpu when requested and available."""
1945 if use_gpu and is_gpu_available():
1946 dev = "/device:GPU:0"
1947 else:
1948 dev = "/device:CPU:0"
1949 with ops.device(dev):
1950 yield
1953@contextlib.contextmanager
1954def use_gpu():
1955 """Uses gpu when requested and available."""
1956 with device(use_gpu=True):
1957 yield
1960@contextlib.contextmanager
1961def force_gpu():
1962 """Force the gpu to be used."""
1963 with ops.device("/device:GPU:0"):
1964 yield
1967@contextlib.contextmanager
1968def force_cpu():
1969 """Force the cpu to be used."""
1970 with ops.device("/device:CPU:0"):
1971 yield
1974@contextlib.contextmanager
1975def deterministic_ops():
1976 """Enables deterministic ops."""
1977 try:
1978 config.enable_op_determinism()
1979 yield
1980 finally:
1981 config.disable_op_determinism()
1984class CapturedWrites:
1985 """A utility class to load the captured writes made to a stream."""
1987 def __init__(self, capture_location):
1988 self.capture_location = capture_location
1990 def contents(self):
1991 """Get the captured writes as a single string."""
1992 with open(self.capture_location) as tmp_file:
1993 output_data = "".join(tmp_file.readlines())
1994 return output_data
1997class FakeEagerSession:
1998 """Fake session so tests that conditionally use placeholders can use eager.
2000 There are a number of tests that conditionally use placeholders for shape
2001 inference. The pattern is demonstrated here:
2003 ```python
2004 with self.cached_session() as sess:
2005 if static_shape:
2006 y = math_ops.matmul(x, ...)
2007 feed_dict = {}
2008 else:
2009 x_ph = array_ops.placeholder(...)
2010 y = math_ops.matmul(x_ph, ...)
2011 feed_dict = {x_ph: x}
2012 val = sess.run(y, feed_dict=feed_dict)
2013 ```
2015 Since the feed_dict is empty when not using placeholders we should be able to
2016 call self.evaluate(), however this requires rewriting the test case.
2017 This class should be considered a stop-gap solution to get tests running with
2018 eager with minimal changes to the actual test.
2019 """
2021 def __init__(self, test_case):
2022 self._test_case = test_case
2024 def run(self, fetches, *args, **kwargs):
2025 """Evaluate `fetches`.
2027 Fail if additional args are specified.
2029 Args:
2030 fetches: A Tensor or a nested list/tuple of Tensors.
2031 *args: Positional arguments
2032 **kwargs: Keyword arguments
2034 Raises:
2035 RuntimeError: If args or kwargs are specified.
2037 Returns:
2038 Tensors as numpy values.
2039 """
2040 feed_dict = kwargs.pop("feed_dict", {})
2041 if feed_dict:
2042 raise RuntimeError(
2043 "feed_dict is not supported when eager execution is enabled "
2044 "(in this case, sess.run(t) is shorthand for t.numpy()")
2046 if args or kwargs:
2047 raise RuntimeError(
2048 "Optional args are not supported when eager execution is enabled "
2049 "(in this case, sess.run(t) is shorthand for t.numpy()")
2051 return self._test_case.evaluate(fetches)
2054class ErrorLoggingSession(session.Session):
2055 """Wrapper around a Session that logs errors in run()."""
2057 def run(self, *args, **kwargs):
2058 try:
2059 return super().run(*args, **kwargs)
2060 except Exception as e: # pylint: disable=broad-except
2061 # Note: disable the logging for OutOfRangeError, which makes the output
2062 # of tf.data tests hard to read, because OutOfRangeError is used as the
2063 # signal completion
2064 if not isinstance(e, errors.OutOfRangeError):
2065 logging.error(str(e))
2066 raise
2069def disable_cudnn_autotune(func):
2070 """Disable autotuning during the call to this function.
2072 Some tests want to base assertions on a graph being isomorphic with a copy.
2073 To ensure this, this decorator disables autotuning.
2075 Args:
2076 func: Function to run with CuDNN autotuning turned off.
2078 Returns:
2079 Decorated function.
2080 """
2082 def decorator(f):
2084 def decorated(self, *args, **kwargs):
2085 original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
2086 os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
2087 original_xla_flags = os.environ.get("XLA_FLAGS")
2088 new_xla_flags = "--xla_gpu_autotune_level=0"
2089 if original_xla_flags:
2090 new_xla_flags = original_xla_flags + " " + new_xla_flags
2091 os.environ["XLA_FLAGS"] = new_xla_flags
2093 result = f(self, *args, **kwargs)
2095 if (original_tf_cudnn_use_autotune is None):
2096 del os.environ["TF_CUDNN_USE_AUTOTUNE"]
2097 else:
2098 os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune
2099 if (original_xla_flags is None):
2100 del os.environ["XLA_FLAGS"]
2101 else:
2102 os.environ["XLA_FLAGS"] = original_xla_flags
2104 return result
2106 return tf_decorator.make_decorator(func, decorated)
2108 if func is not None:
2109 return decorator(func)
2111 return decorator
2114# The description is just for documentation purposes.
2115def enable_tf_xla_constant_folding(description):
2117 if not isinstance(description, str):
2118 raise ValueError("'description' should be string, got {}".format(
2119 type(description)))
2121 def enable_tf_xla_constant_folding_impl(func):
2122 """Enable constant folding during the call to this function.
2124 Some tests fail without constant folding.
2126 Args:
2127 func: Function to run with constant folding turned on.
2129 Returns:
2130 Decorated function.
2131 """
2133 def decorator(f):
2135 def decorated(self, *args, **kwargs):
2136 original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
2137 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
2138 result = f(self, *args, **kwargs)
2139 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
2140 return result
2142 return decorated
2144 if func is not None:
2145 return decorator(func)
2147 return decorator
2149 return enable_tf_xla_constant_folding_impl
2152# Updates test function by selectively disabling it.
2153def _disable_test(execute_func):
2155 def disable_test_impl(func):
2157 def decorator(func):
2159 def decorated(self, *args, **kwargs):
2160 if execute_func:
2161 return func(self, *args, **kwargs)
2163 return tf_decorator.make_decorator(func, decorated)
2165 if func is not None:
2166 return decorator(func)
2168 return decorator
2170 return disable_test_impl
2173# The description is just for documentation purposes.
2174def disable_xla(description): # pylint: disable=unused-argument
2175 """Execute the test method only if xla is not enabled."""
2176 execute_func = not is_xla_enabled()
2177 return _disable_test(execute_func)
2180# The description is just for documentation purposes.
2181def disable_mlir_bridge(description): # pylint: disable=unused-argument
2182 """Execute the test method only if MLIR bridge is not enabled."""
2183 execute_func = not is_mlir_bridge_enabled()
2184 return _disable_test(execute_func)
2187# The description is just for documentation purposes.
2188def disable_asan(description): # pylint: disable=unused-argument
2189 """Execute the test method only if ASAN is not enabled."""
2190 execute_func = not is_asan_enabled()
2191 return _disable_test(execute_func)
2194# The description is just for documentation purposes.
2195def disable_msan(description): # pylint: disable=unused-argument
2196 """Execute the test method only if MSAN is not enabled."""
2197 execute_func = not is_msan_enabled()
2198 return _disable_test(execute_func)
2201# The description is just for documentation purposes.
2202def disable_tsan(description): # pylint: disable=unused-argument
2203 """Execute the test method only if TSAN is not enabled."""
2204 execute_func = not is_tsan_enabled()
2205 return _disable_test(execute_func)
2208# The description is just for documentation purposes.
2209def disable_ubsan(description): # pylint: disable=unused-argument
2210 """Execute the test method only if UBSAN is not enabled."""
2211 execute_func = not is_ubsan_enabled()
2212 return _disable_test(execute_func)
2215# The description is just for documentation purposes.
2216def disable_tfrt(unused_description):
2218 def disable_tfrt_impl(cls_or_func):
2219 """Execute the test only if tfrt is not enabled."""
2221 if tf_inspect.isclass(cls_or_func):
2222 if tfrt_utils.enabled():
2223 return None
2224 else:
2225 return cls_or_func
2226 else:
2227 def decorator(func):
2229 def decorated(self, *args, **kwargs):
2230 if tfrt_utils.enabled():
2231 return
2232 else:
2233 return func(self, *args, **kwargs)
2235 return decorated
2237 if cls_or_func is not None:
2238 return decorator(cls_or_func)
2240 return decorator
2242 return disable_tfrt_impl
2245def for_all_test_methods(decorator, *args, **kwargs):
2246 """Generate class-level decorator from given method-level decorator.
2248 It is expected for the given decorator to take some arguments and return
2249 a method that is then called on the test method to produce a decorated
2250 method.
2252 Args:
2253 decorator: The decorator to apply.
2254 *args: Positional arguments
2255 **kwargs: Keyword arguments
2256 Returns: Function that will decorate a given classes test methods with the
2257 decorator.
2258 """
2260 def all_test_methods_impl(cls):
2261 """Apply decorator to all test methods in class."""
2262 for name in dir(cls):
2263 value = getattr(cls, name)
2264 if callable(value) and name.startswith(
2265 "test") and (name != "test_session"):
2266 setattr(cls, name, decorator(*args, **kwargs)(value))
2267 return cls
2269 return all_test_methods_impl
2272# The description is just for documentation purposes.
2273def no_xla_auto_jit(description): # pylint: disable=unused-argument
2274 """This test is not intended to be run with XLA auto jit enabled."""
2275 execute_func = not is_xla_enabled()
2276 return _disable_test(execute_func)
2279# The description is just for documentation purposes.
2280def xla_allow_fallback(description): # pylint: disable=unused-argument
2282 def xla_allow_fallback_impl(func):
2283 """Allow fallback to TF even though testing xla."""
2285 def decorator(func):
2287 def decorated(self, *args, **kwargs):
2288 if is_xla_enabled():
2289 # Update the global XLABuildOpsPassFlags to enable lazy compilation,
2290 # which allows the compiler to fall back to TF classic. Remember the
2291 # old value so that we can reset it.
2292 old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
2293 result = func(self, *args, **kwargs)
2294 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
2295 return result
2296 else:
2297 return func(self, *args, **kwargs)
2299 return decorated
2301 if func is not None:
2302 return decorator(func)
2304 return decorator
2306 return xla_allow_fallback_impl
2309# The description is just for documentation purposes.
2310def run_without_tensor_float_32(description): # pylint: disable=unused-argument
2311 """Execute test with TensorFloat-32 disabled.
2313 While almost every real-world deep learning model runs fine with
2314 TensorFloat-32, many tests use assertAllClose or similar methods.
2315 TensorFloat-32 matmuls typically will cause such methods to fail with the
2316 default tolerances.
2318 Args:
2319 description: A description used for documentation purposes, describing why
2320 the test requires TensorFloat-32 to be disabled.
2322 Returns:
2323 Decorator which runs a test with TensorFloat-32 disabled.
2324 """
2326 def decorator(f):
2328 @functools.wraps(f)
2329 def decorated(self, *args, **kwargs):
2330 allowed = config.tensor_float_32_execution_enabled()
2331 try:
2332 config.enable_tensor_float_32_execution(False)
2333 f(self, *args, **kwargs)
2334 finally:
2335 config.enable_tensor_float_32_execution(allowed)
2337 return decorated
2339 return decorator
2342# The description is just for documentation purposes.
2343def run_all_without_tensor_float_32(description): # pylint: disable=unused-argument
2344 """Execute all tests in a class with TensorFloat-32 disabled."""
2345 return for_all_test_methods(run_without_tensor_float_32, description)
2348def matmul_without_tf32(a, b, *args, **kwargs):
2349 """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled.
2351 This effectively runs matmul without TensorFloat-32. It should only be used in
2352 tests when verifying some other op or functions works correctly, e.g. to test
2353 `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In
2354 such cases, the matmul itself is not being tested so it's OK to run it with
2355 higher precision.
2357 If a matmul itself is being tested, or some other op which uses matmul, use
2358 `run_without_tensor_float_32` instead.
2360 This also casts complex64 inputs to complex128, since TensorFloat-32 can also
2361 be used with complex64
2363 Args:
2364 a: First input to tf.linalg.matmul
2365 b: Second input to tf.linalg.matmul
2366 args: Other positional arguments to tf.linalg.matmul
2367 **kwargs: Other keyword arguments to tf.linalg.matmul
2369 Returns:
2370 A tensor with the same type as `a`.
2371 """
2372 if config.tensor_float_32_execution_enabled() and a.dtype == "float32":
2373 a = math_ops.cast(a, "float64")
2374 b = math_ops.cast(b, "float64")
2375 ret = math_ops.matmul(a, b, *args, **kwargs)
2376 return math_ops.cast(ret, a.dtype)
2377 elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
2378 a = math_ops.cast(a, "complex128")
2379 b = math_ops.cast(b, "complex128")
2380 ret = math_ops.matmul(a, b, *args, **kwargs)
2381 return math_ops.cast(ret, a.dtype)
2382 else:
2383 return math_ops.matmul(a, b, *args, **kwargs)
2386class EagerSessionWarner:
2388 def __getattr__(self, attr):
2389 raise AttributeError(
2390 "Trying to access properties or call methods on the result of "
2391 "self.session(), self.cached_session(), etc while eager execution "
2392 "is enabled. If you're porting this test case to TF 2.0, either "
2393 "adapt the test to work with eager execution or insert a call to "
2394 "tf.disable_eager_execution() in the main() function of this test "
2395 "file.")
2398@tf_export("test.TestCase")
2399class TensorFlowTestCase(googletest.TestCase):
2400 """Base class for tests that need to test TensorFlow."""
2402 def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
2403 super().__init__(methodName)
2404 # Make sure we get unfiltered stack traces during the test
2405 traceback_utils.disable_traceback_filtering()
2406 if is_xla_enabled():
2407 pywrap_tf_session.TF_SetXlaAutoJitMode("2")
2408 pywrap_tf_session.TF_SetXlaMinClusterSize(1)
2409 pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
2410 pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
2411 # Constant folding secretly runs code on TF:Classic CPU, so we also
2412 # disable it here.
2413 pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
2415 # Check if the mlir bridge has been explicitly enabled or disabled. If
2416 # is_mlir_bridge_enabled() returns None, the user did not explictly enable
2417 # or disable the bridge so do not update enable_mlir_bridge.
2418 if is_mlir_bridge_enabled():
2419 context.context().enable_mlir_bridge = True
2420 elif is_mlir_bridge_enabled() is not None:
2421 context.context().enable_mlir_bridge = False
2423 self._threads = []
2424 self._tempdir = None
2425 self._cached_session = None
2426 self._test_start_time = None
2427 # This flag provides the ability to control whether the graph mode gets
2428 # initialized for TF1 or not. Initializing for TF1, which is what was
2429 # happening earlier, was preventing enablement of 'eager mode' in the test.
2430 self._set_default_seed = True
2432 def setUp(self):
2433 super().setUp()
2434 self._ClearCachedSession()
2435 random.seed(random_seed.DEFAULT_GRAPH_SEED)
2436 np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
2437 # Note: The following line is necessary because some test methods may error
2438 # out from within nested graph contexts (e.g., via assertRaises and
2439 # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
2440 # under certain versions of Python. That would cause
2441 # ops.reset_default_graph() to throw an exception if the stack were not
2442 # cleared first.
2443 ops._default_graph_stack.reset() # pylint: disable=protected-access
2444 ops.reset_default_graph()
2445 if self._set_default_seed:
2446 random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
2447 # Reset summary writer in case another test used set_as_default() with their
2448 # summary writer.
2449 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access
2450 summary_state.writer = None
2452 # Avoiding calling setUp() for the poorly named test_session method.
2453 if self.id().endswith(".test_session"):
2454 self.skipTest("Not a test.")
2456 self._test_start_time = time.time()
2458 def tearDown(self):
2459 # If a subclass overrides setUp and doesn't call the parent class's setUp,
2460 # then we may not have set the start time.
2461 if self._test_start_time is not None:
2462 logging.info("time(%s): %ss", self.id(),
2463 round(time.time() - self._test_start_time, 2))
2465 for thread in self._threads:
2466 thread.check_termination()
2468 self._ClearCachedSession()
2469 super().tearDown()
2471 def _ClearCachedSession(self):
2472 if self._cached_session is not None:
2473 self._cached_session.close()
2474 self._cached_session = None
2476 def get_temp_dir(self):
2477 """Returns a unique temporary directory for the test to use.
2479 If you call this method multiple times during in a test, it will return the
2480 same folder. However, across different runs the directories will be
2481 different. This will ensure that across different runs tests will not be
2482 able to pollute each others environment.
2483 If you need multiple unique directories within a single test, you should
2484 use tempfile.mkdtemp as follows:
2485 tempfile.mkdtemp(dir=self.get_temp_dir()):
2487 Returns:
2488 string, the path to the unique temporary directory created for this test.
2489 """
2490 if not self._tempdir:
2491 self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
2492 return self._tempdir
2494 @contextlib.contextmanager
2495 def captureWritesToStream(self, stream):
2496 """A context manager that captures the writes to a given stream.
2498 This context manager captures all writes to a given stream inside of a
2499 `CapturedWrites` object. When this context manager is created, it yields
2500 the `CapturedWrites` object. The captured contents can be accessed by
2501 calling `.contents()` on the `CapturedWrites`.
2503 For this function to work, the stream must have a file descriptor that
2504 can be modified using `os.dup` and `os.dup2`, and the stream must support
2505 a `.flush()` method. The default python sys.stdout and sys.stderr are
2506 examples of this. Note that this does not work in Colab or Jupyter
2507 notebooks, because those use alternate stdout streams.
2509 Example:
2510 ```python
2511 class MyOperatorTest(test_util.TensorFlowTestCase):
2512 def testMyOperator(self):
2513 input = [1.0, 2.0, 3.0, 4.0, 5.0]
2514 with self.captureWritesToStream(sys.stdout) as captured:
2515 result = MyOperator(input).eval()
2516 self.assertStartsWith(captured.contents(), "This was printed.")
2517 ```
2519 Args:
2520 stream: The stream whose writes should be captured. This stream must have
2521 a file descriptor, support writing via using that file descriptor, and
2522 must have a `.flush()` method.
2524 Yields:
2525 A `CapturedWrites` object that contains all writes to the specified stream
2526 made during this context.
2527 """
2528 stream.flush()
2529 fd = stream.fileno()
2530 tmp_file, tmp_file_path = tempfile.mkstemp(dir=self.get_temp_dir())
2531 orig_fd = os.dup(fd)
2532 os.dup2(tmp_file, fd)
2533 try:
2534 yield CapturedWrites(tmp_file_path)
2535 finally:
2536 os.close(tmp_file)
2537 os.dup2(orig_fd, fd)
2539 def _AssertProtoEquals(self, a, b, msg=None, relative_tolerance=None):
2540 """Asserts that a and b are the same proto.
2542 Uses ProtoEq() first, as it returns correct results
2543 for floating point attributes, and then use assertProtoEqual()
2544 in case of failure as it provides good error messages.
2546 Args:
2547 a: a proto.
2548 b: another proto.
2549 msg: Optional message to report on failure.
2550 relative_tolerance: float. The allowable difference between the two values
2551 being compared is determined by multiplying the relative tolerance by
2552 the maximum of the two values. If this is not provided, then all floats
2553 are compared using string comparison.
2554 """
2555 if not compare.ProtoEq(a, b):
2556 compare.assertProtoEqual(
2557 self,
2558 a,
2559 b,
2560 normalize_numbers=True,
2561 msg=msg,
2562 relative_tolerance=relative_tolerance,
2563 )
2565 def assertProtoEquals(
2566 self,
2567 expected_message_maybe_ascii,
2568 message,
2569 msg=None,
2570 relative_tolerance=None,
2571 ):
2572 """Asserts that message is same as parsed expected_message_ascii.
2574 Creates another prototype of message, reads the ascii message into it and
2575 then compares them using self._AssertProtoEqual().
2577 Args:
2578 expected_message_maybe_ascii: proto message in original or ascii form.
2579 message: the message to validate.
2580 msg: Optional message to report on failure.
2581 relative_tolerance: float. The allowable difference between the two values
2582 being compared is determined by multiplying the relative tolerance by
2583 the maximum of the two values. If this is not provided, then all floats
2584 are compared using string comparison.
2585 """
2586 if isinstance(expected_message_maybe_ascii, type(message)):
2587 expected_message = expected_message_maybe_ascii
2588 self._AssertProtoEquals(
2589 expected_message,
2590 message,
2591 msg=msg,
2592 relative_tolerance=relative_tolerance,
2593 )
2594 elif isinstance(expected_message_maybe_ascii, (str, bytes)):
2595 expected_message = type(message)()
2596 text_format.Merge(
2597 expected_message_maybe_ascii,
2598 expected_message,
2599 descriptor_pool=descriptor_pool.Default())
2600 self._AssertProtoEquals(
2601 expected_message,
2602 message,
2603 msg=msg,
2604 relative_tolerance=relative_tolerance,
2605 )
2606 else:
2607 assert False, ("Can't compare protos of type %s and %s." %
2608 (type(expected_message_maybe_ascii), type(message)))
2610 def assertProtoEqualsVersion(
2611 self,
2612 expected,
2613 actual,
2614 producer=versions.GRAPH_DEF_VERSION,
2615 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
2616 msg=None):
2617 expected = "versions { producer: %d min_consumer: %d };\n%s" % (
2618 producer, min_consumer, expected)
2619 self.assertProtoEquals(expected, actual, msg=msg)
2621 def assertStartsWith(self, actual, expected_start, msg=None):
2622 """Assert that actual.startswith(expected_start) is True.
2624 Args:
2625 actual: str
2626 expected_start: str
2627 msg: Optional message to report on failure.
2628 """
2629 if not actual.startswith(expected_start):
2630 fail_msg = "%r does not start with %r" % (actual, expected_start)
2631 fail_msg += " : %r" % (msg) if msg else ""
2632 self.fail(fail_msg)
2634 def _eval_tensor(self, tensor):
2635 if tensor is None:
2636 return None
2637 elif callable(tensor):
2638 return self._eval_helper(tensor())
2639 else:
2640 try:
2641 # for compatibility with TF1 test cases
2642 if sparse_tensor.is_sparse(tensor):
2643 return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
2644 tensor.values.numpy(),
2645 tensor.dense_shape.numpy())
2646 elif ragged_tensor.is_ragged(tensor):
2647 return ragged_tensor_value.RaggedTensorValue(
2648 self._eval_tensor(tensor.values),
2649 self._eval_tensor(tensor.row_splits))
2650 elif isinstance(tensor, indexed_slices.IndexedSlices):
2651 return indexed_slices.IndexedSlicesValue(
2652 values=tensor.values.numpy(),
2653 indices=tensor.indices.numpy(),
2654 dense_shape=None
2655 if tensor.dense_shape is None else tensor.dense_shape.numpy())
2656 else:
2657 if hasattr(tensor, "numpy") and callable(tensor.numpy):
2658 return tensor.numpy()
2659 else:
2660 # Try our best to convert CompositeTensor components to NumPy
2661 # arrays. Officially, we don't support NumPy arrays as
2662 # CompositeTensor components. So don't be surprised if this doesn't
2663 # work.
2664 return nest.map_structure(lambda t: t.numpy(), tensor,
2665 expand_composites=True)
2666 except AttributeError as e:
2667 raise ValueError(f"Unsupported type {type(tensor).__name__!r}.") from e
2669 def _eval_helper(self, tensors):
2670 if tensors is None:
2671 return None
2672 return nest.map_structure(self._eval_tensor, tensors)
2674 def evaluate(self, tensors):
2675 """Evaluates tensors and returns numpy values.
2677 Args:
2678 tensors: A Tensor or a nested list/tuple of Tensors.
2680 Returns:
2681 tensors numpy values.
2682 """
2683 if context.executing_eagerly():
2684 return self._eval_helper(tensors)
2685 else:
2686 sess = ops.get_default_session()
2687 if sess is None:
2688 with self.test_session() as sess:
2689 return sess.run(tensors)
2690 else:
2691 return sess.run(tensors)
2693 # pylint: disable=g-doc-return-or-yield
2694 @contextlib.contextmanager
2695 def session(self, graph=None, config=None, use_gpu=True, force_gpu=False):
2696 """A context manager for a TensorFlow Session for use in executing tests.
2698 Note that this will set this session and the graph as global defaults.
2700 Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2701 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2702 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2703 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2704 the CPU.
2706 Example:
2708 ``` python
2709 class MyOperatorTest(test_util.TensorFlowTestCase):
2710 def testMyOperator(self):
2711 with self.session():
2712 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2713 result = MyOperator(valid_input).eval()
2714 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2715 invalid_input = [-1.0, 2.0, 7.0]
2716 with self.assertRaisesOpError("negative input not supported"):
2717 MyOperator(invalid_input).eval()
2718 ```
2720 Args:
2721 graph: Optional graph to use during the returned session.
2722 config: An optional config_pb2.ConfigProto to use to configure the
2723 session.
2724 use_gpu: If True, attempt to run as many ops as possible on GPU.
2725 force_gpu: If True, pin all ops to `/device:GPU:0`.
2727 Yields:
2728 A Session object that should be used as a context manager to surround
2729 the graph building and execution code in a test case.
2730 """
2731 if context.executing_eagerly():
2732 yield EagerSessionWarner()
2733 else:
2734 with self._create_session(graph, config, force_gpu) as sess:
2735 with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
2736 yield sess
2738 @contextlib.contextmanager
2739 def cached_session(self,
2740 graph=None,
2741 config=None,
2742 use_gpu=True,
2743 force_gpu=False):
2744 """Returns a TensorFlow Session for use in executing tests.
2746 This method behaves differently than self.session(): for performance reasons
2747 `cached_session` will by default reuse the same session within the same
2748 test. The session returned by this function will only be closed at the end
2749 of the test (in the TearDown function).
2751 Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2752 `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2753 `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2754 possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2755 the CPU.
2757 Example:
2758 ```python
2759 class MyOperatorTest(test_util.TensorFlowTestCase):
2760 def testMyOperator(self):
2761 with self.cached_session() as sess:
2762 valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2763 result = MyOperator(valid_input).eval()
2764 self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2765 invalid_input = [-1.0, 2.0, 7.0]
2766 with self.assertRaisesOpError("negative input not supported"):
2767 MyOperator(invalid_input).eval()
2768 ```
2770 Args:
2771 graph: Optional graph to use during the returned session.
2772 config: An optional config_pb2.ConfigProto to use to configure the
2773 session.
2774 use_gpu: If True, attempt to run as many ops as possible on GPU.
2775 force_gpu: If True, pin all ops to `/device:GPU:0`.
2777 Yields:
2778 A Session object that should be used as a context manager to surround
2779 the graph building and execution code in a test case.
2780 """
2781 if context.executing_eagerly():
2782 yield FakeEagerSession(self)
2783 else:
2784 sess = self._get_cached_session(
2785 graph, config, force_gpu, crash_if_inconsistent_args=True)
2786 with self._constrain_devices_and_set_default(sess, use_gpu,
2787 force_gpu) as cached:
2788 yield cached
2790 @contextlib.contextmanager
2791 @deprecation.deprecated(None, "Use `self.session()` or "
2792 "`self.cached_session()` instead.")
2793 def test_session(self,
2794 graph=None,
2795 config=None,
2796 use_gpu=True,
2797 force_gpu=False):
2798 """Use cached_session instead."""
2799 if self.id().endswith(".test_session"):
2800 self.skipTest(
2801 "Tests that have the name \"test_session\" are automatically skipped "
2802 "by TensorFlow test fixture, as the name is reserved for creating "
2803 "sessions within tests. Please rename your test if you have a test "
2804 "with this name.")
2805 if context.executing_eagerly():
2806 yield None
2807 else:
2808 if graph is None:
2809 sess = self._get_cached_session(
2810 graph, config, force_gpu, crash_if_inconsistent_args=False)
2811 with self._constrain_devices_and_set_default(sess, use_gpu,
2812 force_gpu) as cached:
2813 yield cached
2814 else:
2815 with self.session(graph, config, use_gpu, force_gpu) as sess:
2816 yield sess
2818 # pylint: enable=g-doc-return-or-yield
2820 class _CheckedThread(object):
2821 """A wrapper class for Thread that asserts successful completion.
2823 This class should be created using the TensorFlowTestCase.checkedThread()
2824 method.
2825 """
2827 def __init__(self, testcase, target, args=None, kwargs=None):
2828 """Constructs a new instance of _CheckedThread.
2830 Args:
2831 testcase: The TensorFlowTestCase for which this thread is being created.
2832 target: A callable object representing the code to be executed in the
2833 thread.
2834 args: A tuple of positional arguments that will be passed to target.
2835 kwargs: A dictionary of keyword arguments that will be passed to target.
2836 """
2837 self._testcase = testcase
2838 self._target = target
2839 self._args = () if args is None else args
2840 self._kwargs = {} if kwargs is None else kwargs
2841 self._thread = threading.Thread(target=self._protected_run)
2842 self._exception = None
2844 self._is_thread_joined = False
2846 def _protected_run(self):
2847 """Target for the wrapper thread. Sets self._exception on failure."""
2848 try:
2849 self._target(*self._args, **self._kwargs)
2850 except Exception as e: # pylint: disable=broad-except
2851 self._exception = e
2853 def start(self):
2854 """Starts the thread's activity.
2856 This must be called at most once per _CheckedThread object. It arranges
2857 for the object's target to be invoked in a separate thread of control.
2858 """
2859 self._thread.start()
2861 def join(self):
2862 """Blocks until the thread terminates.
2864 Raises:
2865 self._testcase.failureException: If the thread terminates with due to
2866 an exception.
2867 """
2868 self._is_thread_joined = True
2869 self._thread.join()
2870 if self._exception is not None:
2871 self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
2873 def is_alive(self):
2874 """Returns whether the thread is alive.
2876 This method returns True just before the run() method starts
2877 until just after the run() method terminates.
2879 Returns:
2880 True if the thread is alive, otherwise False.
2881 """
2882 return self._thread.is_alive()
2884 def check_termination(self):
2885 """Returns whether the checked thread was properly used and did terminate.
2887 Every checked thread should be "join"ed after starting, and before the
2888 test tears down. If it is not joined, it is possible the thread will hang
2889 and cause flaky failures in tests.
2891 Raises:
2892 self._testcase.failureException: If check_termination was called before
2893 thread was joined.
2895 RuntimeError: If the thread is not terminated. This means thread was not
2896 joined with the main thread.
2897 """
2898 if self._is_thread_joined:
2899 if self.is_alive():
2900 raise RuntimeError(
2901 "Thread was not joined with main thread, and is still running "
2902 "when the test finished.")
2903 else:
2904 self._testcase.fail("A checked thread was not joined.")
2906 def checkedThread(self, target, args=None, kwargs=None):
2907 """Returns a Thread wrapper that asserts 'target' completes successfully.
2909 This method should be used to create all threads in test cases, as
2910 otherwise there is a risk that a thread will silently fail, and/or
2911 assertions made in the thread will not be respected.
2913 Args:
2914 target: A callable object to be executed in the thread.
2915 args: The argument tuple for the target invocation. Defaults to ().
2916 kwargs: A dictionary of keyword arguments for the target invocation.
2917 Defaults to {}.
2919 Returns:
2920 A wrapper for threading.Thread that supports start() and join() methods.
2921 """
2922 ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
2923 self._threads.append(ret)
2924 return ret
2926 # pylint: enable=invalid-name
2927 @py_func_if_in_function
2928 def assertNear(self, f1, f2, err, msg=None):
2929 """Asserts that two floats are near each other.
2931 Checks that |f1 - f2| < err and asserts a test failure
2932 if not.
2934 Args:
2935 f1: A float value.
2936 f2: A float value.
2937 err: A float value.
2938 msg: An optional string message to append to the failure message.
2939 """
2940 # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
2941 self.assertTrue(
2942 f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
2943 (f1, f2, err, " (%s)" % msg if msg is not None else ""))
2945 @py_func_if_in_function
2946 def assertArrayNear(self, farray1, farray2, err, msg=None):
2947 """Asserts that two float arrays are near each other.
2949 Checks that for all elements of farray1 and farray2
2950 |f1 - f2| < err. Asserts a test failure if not.
2952 Args:
2953 farray1: a list of float values.
2954 farray2: a list of float values.
2955 err: a float value.
2956 msg: Optional message to report on failure.
2957 """
2958 self.assertEqual(len(farray1), len(farray2), msg=msg)
2959 for f1, f2 in zip(farray1, farray2):
2960 self.assertNear(float(f1), float(f2), err, msg=msg)
2962 def _NDArrayNear(self, ndarray1, ndarray2, err):
2963 return np.linalg.norm(ndarray1 - ndarray2) < err
2965 @py_func_if_in_function
2966 def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
2967 """Asserts that two numpy arrays have near values.
2969 Args:
2970 ndarray1: a numpy ndarray.
2971 ndarray2: a numpy ndarray.
2972 err: a float. The maximum absolute difference allowed.
2973 msg: Optional message to report on failure.
2974 """
2975 self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2977 def _GetNdArray(self, a):
2978 # If a is tensor-like then convert it to ndarray
2979 if tensor_util.is_tf_type(a):
2980 if isinstance(a, ops._EagerTensorBase):
2981 a = a.numpy()
2982 else:
2983 a = self.evaluate(a)
2984 if not isinstance(a, np.ndarray):
2985 try:
2986 return np.array(a)
2987 except ValueError as e:
2988 # TODO(b/264461299): NumPy 1.24 no longer infers dtype=object from
2989 # ragged sequences.
2990 # See:
2991 # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
2992 # Fixing this correctly requires clarifying the API contract of this
2993 # function with respect to ragged sequences and possibly updating all
2994 # users. As a backwards compatibility measure, if array
2995 # creation fails with an "inhomogeneous shape" error, try again with
2996 # an explicit dtype=object, which should restore the previous behavior.
2997 if "inhomogeneous shape" in str(e):
2998 return np.array(a, dtype=object)
2999 else:
3000 raise
3001 return a
3003 def evaluate_if_both_tensors(self, a, b):
3004 if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and
3005 not isinstance(a, ops._EagerTensorBase) and
3006 not isinstance(b, ops._EagerTensorBase)):
3007 return self.evaluate((a, b))
3008 else:
3009 return (a, b)
3011 def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
3012 (a, b) = self.evaluate_if_both_tensors(a, b)
3013 a = self._GetNdArray(a)
3014 b = self._GetNdArray(b)
3015 # When the array rank is small, print its contents. Numpy array printing is
3016 # implemented using inefficient recursion so prints can cause tests to
3017 # time out.
3018 if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
3019 shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
3020 "%s.") % (a.shape, b.shape, b)
3021 else:
3022 shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
3023 b.shape)
3024 self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
3026 msgs = [msg]
3027 # np.allclose does not always work for our custom bfloat16 and float8
3028 # extension types when type promotions are involved, so we first cast any
3029 # arrays of such types to float32.
3030 a_dtype = a.dtype
3031 custom_dtypes = (dtypes.bfloat16.as_numpy_dtype,
3032 dtypes.float8_e5m2.as_numpy_dtype,
3033 dtypes.float8_e4m3fn.as_numpy_dtype)
3034 a = a.astype(np.float32) if a.dtype in custom_dtypes else a
3035 b = b.astype(np.float32) if b.dtype in custom_dtypes else b
3036 if not np.allclose(a, b, rtol=rtol, atol=atol):
3037 # Adds more details to np.testing.assert_allclose.
3038 #
3039 # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
3040 # checks whether two arrays are element-wise equal within a
3041 # tolerance. The relative difference (rtol * abs(b)) and the
3042 # absolute difference atol are added together to compare against
3043 # the absolute difference between a and b. Here, we want to
3044 # tell user which elements violate such conditions.
3045 cond = np.logical_or(
3046 np.abs(a - b) > atol + rtol * np.abs(b),
3047 np.isnan(a) != np.isnan(b))
3048 if a.ndim:
3049 x = a[np.where(cond)]
3050 y = b[np.where(cond)]
3051 msgs.append("not close where = {}".format(np.where(cond)))
3052 else:
3053 # np.where is broken for scalars
3054 x, y = a, b
3055 msgs.append("not close lhs = {}".format(x))
3056 msgs.append("not close rhs = {}".format(y))
3057 msgs.append("not close dif = {}".format(np.abs(x - y)))
3058 msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
3059 msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
3060 # TODO(xpan): There seems to be a bug:
3061 # tensorflow/compiler/tests:binary_ops_test pass with float32
3062 # nan even though the equal_nan is False by default internally.
3063 np.testing.assert_allclose(
3064 a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
3066 def _assertAllCloseRecursive(self,
3067 a,
3068 b,
3069 rtol=1e-6,
3070 atol=1e-6,
3071 path=None,
3072 msg=None):
3073 if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
3074 return self._assertRaggedClose(a, b, rtol, atol, msg)
3075 path = path or []
3076 path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
3077 msg = msg if msg else ""
3079 # Check if a and/or b are namedtuples.
3080 if hasattr(a, "_asdict"):
3081 a = a._asdict()
3082 if hasattr(b, "_asdict"):
3083 b = b._asdict()
3084 a_is_dict = isinstance(a, collections_abc.Mapping)
3085 if a_is_dict != isinstance(b, collections_abc.Mapping):
3086 raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
3087 (path_str, path_str, msg))
3088 if a_is_dict:
3089 self.assertItemsEqual(
3090 a.keys(),
3091 b.keys(),
3092 msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
3093 (path_str, a.keys(), path_str, b.keys(), msg))
3094 for k in a:
3095 path.append(k)
3096 self._assertAllCloseRecursive(
3097 a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
3098 del path[-1]
3099 elif isinstance(a, (list, tuple)):
3100 # Try to directly compare a, b as ndarrays; if not work, then traverse
3101 # through the sequence, which is more expensive.
3102 try:
3103 (a, b) = self.evaluate_if_both_tensors(a, b)
3104 a_as_ndarray = self._GetNdArray(a)
3105 b_as_ndarray = self._GetNdArray(b)
3106 self._assertArrayLikeAllClose(
3107 a_as_ndarray,
3108 b_as_ndarray,
3109 rtol=rtol,
3110 atol=atol,
3111 msg="Mismatched value: a%s is different from b%s. %s" %
3112 (path_str, path_str, msg))
3113 except (ValueError, TypeError, NotImplementedError) as e:
3114 if len(a) != len(b):
3115 raise ValueError(
3116 "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
3117 (path_str, len(a), path_str, len(b), msg))
3118 for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
3119 path.append(str(idx))
3120 self._assertAllCloseRecursive(
3121 a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
3122 del path[-1]
3123 # a and b are ndarray like objects
3124 else:
3125 try:
3126 self._assertArrayLikeAllClose(
3127 a,
3128 b,
3129 rtol=rtol,
3130 atol=atol,
3131 msg=("Mismatched value: a%s is different from b%s. %s" %
3132 (path_str, path_str, msg)))
3133 except TypeError as e:
3134 msg = ("Error: a%s has %s, but b%s has %s. %s" %
3135 (path_str, type(a), path_str, type(b), msg))
3136 e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
3137 raise
3139 @py_func_if_in_function
3140 def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
3141 """Asserts that two structures of numpy arrays or Tensors, have near values.
3143 `a` and `b` can be arbitrarily nested structures. A layer of a nested
3144 structure can be a `dict`, `namedtuple`, `tuple` or `list`.
3146 Note: the implementation follows
3147 [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html)
3148 (and numpy.testing.assert_allclose). It checks whether two arrays are
3149 element-wise equal within a tolerance. The relative difference
3150 (`rtol * abs(b)`) and the absolute difference `atol` are added together
3151 to compare against the absolute difference between `a` and `b`.
3153 Args:
3154 a: The expected numpy `ndarray`, or anything that can be converted into a
3155 numpy `ndarray` (including Tensor), or any arbitrarily nested of
3156 structure of these.
3157 b: The actual numpy `ndarray`, or anything that can be converted into a
3158 numpy `ndarray` (including Tensor), or any arbitrarily nested of
3159 structure of these.
3160 rtol: relative tolerance.
3161 atol: absolute tolerance.
3162 msg: Optional message to report on failure.
3164 Raises:
3165 ValueError: if only one of `a[p]` and `b[p]` is a dict or
3166 `a[p]` and `b[p]` have different length, where `[p]` denotes a path
3167 to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
3168 `[p] = [1]['d']`, then `a[p] = (6, 7)`.
3169 """
3170 self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
3172 @py_func_if_in_function
3173 def assertAllCloseAccordingToType(self,
3174 a,
3175 b,
3176 rtol=1e-6,
3177 atol=1e-6,
3178 float_rtol=1e-6,
3179 float_atol=1e-6,
3180 half_rtol=1e-3,
3181 half_atol=1e-3,
3182 bfloat16_rtol=1e-2,
3183 bfloat16_atol=1e-2,
3184 msg=None):
3185 """Like assertAllClose, but also suitable for comparing fp16 arrays.
3187 In particular, the tolerance is reduced to 1e-3 if at least
3188 one of the arguments is of type float16.
3190 Args:
3191 a: the expected numpy ndarray or anything can be converted to one.
3192 b: the actual numpy ndarray or anything can be converted to one.
3193 rtol: relative tolerance.
3194 atol: absolute tolerance.
3195 float_rtol: relative tolerance for float32.
3196 float_atol: absolute tolerance for float32.
3197 half_rtol: relative tolerance for float16.
3198 half_atol: absolute tolerance for float16.
3199 bfloat16_rtol: relative tolerance for bfloat16.
3200 bfloat16_atol: absolute tolerance for bfloat16.
3201 msg: Optional message to report on failure.
3202 """
3203 (a, b) = self.evaluate_if_both_tensors(a, b)
3204 a = self._GetNdArray(a)
3205 b = self._GetNdArray(b)
3206 # types with lower tol are put later to overwrite previous ones.
3207 if (a.dtype == np.float32 or b.dtype == np.float32 or
3208 a.dtype == np.complex64 or b.dtype == np.complex64):
3209 rtol = max(rtol, float_rtol)
3210 atol = max(atol, float_atol)
3211 if a.dtype == np.float16 or b.dtype == np.float16:
3212 rtol = max(rtol, half_rtol)
3213 atol = max(atol, half_atol)
3214 if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
3215 b.dtype == dtypes.bfloat16.as_numpy_dtype):
3216 rtol = max(rtol, bfloat16_rtol)
3217 atol = max(atol, bfloat16_atol)
3219 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
3221 @py_func_if_in_function
3222 def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
3223 """Assert that two numpy arrays, or Tensors, do not have near values.
3225 Args:
3226 a: The expected numpy `ndarray`, or anything that can be converted into a
3227 numpy `ndarray` (including Tensor), or any arbitrarily nested of
3228 structure of these.
3229 b: The actual numpy `ndarray`, or anything that can be converted into a
3230 numpy `ndarray` (including Tensor), or any arbitrarily nested of
3231 structure of these.
3232 rtol: relative tolerance.
3233 atol: absolute tolerance.
3234 msg: Optional message to report on failure.
3236 Raises:
3237 AssertionError: If `a` and `b` are unexpectedly close at all elements.
3238 """
3239 try:
3240 self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
3241 except AssertionError:
3242 return
3243 msg = msg or ""
3244 raise AssertionError("The two values are close at all elements. %s" % msg)
3246 @py_func_if_in_function
3247 def assertAllEqual(self, a, b, msg=None):
3248 """Asserts that two numpy arrays or Tensors have the same values.
3250 Args:
3251 a: the expected numpy ndarray or anything can be converted to one.
3252 b: the actual numpy ndarray or anything can be converted to one.
3253 msg: Optional message to report on failure.
3254 """
3255 if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
3256 return self._assertRaggedEqual(a, b, msg)
3257 msg = msg if msg else ""
3258 (a, b) = self.evaluate_if_both_tensors(a, b)
3259 a = self._GetNdArray(a)
3260 b = self._GetNdArray(b)
3261 # Arbitrary bounds so that we don't print giant tensors.
3262 if (b.ndim <= 3 or b.size < 500):
3263 self.assertEqual(
3264 a.shape, b.shape, "Shape mismatch: expected %s, got %s."
3265 " Contents: %r. \n%s." % (a.shape, b.shape, b, msg))
3266 else:
3267 self.assertEqual(
3268 a.shape, b.shape, "Shape mismatch: expected %s, got %s."
3269 " %s" % (a.shape, b.shape, msg))
3271 same = (a == b)
3273 if dtypes.as_dtype(a.dtype).is_floating:
3274 same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
3275 msgs = [msg]
3276 if not np.all(same):
3277 # Adds more details to np.testing.assert_array_equal.
3278 diff = np.logical_not(same)
3279 if a.ndim:
3280 x = a[np.where(diff)]
3281 y = b[np.where(diff)]
3282 msgs.append("not equal where = {}".format(np.where(diff)))
3283 else:
3284 # np.where is broken for scalars
3285 x, y = a, b
3286 msgs.append("not equal lhs = %r" % x)
3287 msgs.append("not equal rhs = %r" % y)
3289 if (a.dtype.kind != b.dtype.kind and
3290 {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
3291 a_list = []
3292 b_list = []
3293 # OK to flatten `a` and `b` because they are guaranteed to have the
3294 # same shape.
3295 for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
3296 for item in flat_arr:
3297 if isinstance(item, str):
3298 out_list.append(item.encode("utf-8"))
3299 else:
3300 out_list.append(item)
3301 a = np.array(a_list)
3302 b = np.array(b_list)
3304 np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
3306 @py_func_if_in_function
3307 def assertNotAllEqual(self, a, b, msg=None):
3308 """Asserts that two numpy arrays or Tensors do not have the same values.
3310 Args:
3311 a: the expected numpy ndarray or anything can be converted to one.
3312 b: the actual numpy ndarray or anything can be converted to one.
3313 msg: Optional message to report on failure.
3314 """
3315 try:
3316 self.assertAllEqual(a, b)
3317 except AssertionError:
3318 return
3319 msg = msg or ""
3320 raise AssertionError("The two values are equal at all elements. %s" % msg)
3322 @py_func_if_in_function
3323 def assertAllGreater(self, a, comparison_target):
3324 """Assert element values are all greater than a target value.
3326 Args:
3327 a: The numpy `ndarray`, or anything that can be converted into a numpy
3328 `ndarray` (including Tensor).
3329 comparison_target: The target value of comparison.
3330 """
3331 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3332 a = self._GetNdArray(a)
3333 self.assertGreater(np.min(a), comparison_target)
3335 @py_func_if_in_function
3336 def assertAllLess(self, a, comparison_target):
3337 """Assert element values are all less than a target value.
3339 Args:
3340 a: The numpy `ndarray`, or anything that can be converted into a numpy
3341 `ndarray` (including Tensor).
3342 comparison_target: The target value of comparison.
3343 """
3344 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3345 a = self._GetNdArray(a)
3346 self.assertLess(np.max(a), comparison_target)
3348 @py_func_if_in_function
3349 def assertAllGreaterEqual(self, a, comparison_target):
3350 """Assert element values are all greater than or equal to a target value.
3352 Args:
3353 a: The numpy `ndarray`, or anything that can be converted into a numpy
3354 `ndarray` (including Tensor).
3355 comparison_target: The target value of comparison.
3356 """
3357 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3358 a = self._GetNdArray(a)
3359 self.assertGreaterEqual(np.min(a), comparison_target)
3361 @py_func_if_in_function
3362 def assertAllLessEqual(self, a, comparison_target):
3363 """Assert element values are all less than or equal to a target value.
3365 Args:
3366 a: The numpy `ndarray`, or anything that can be converted into a numpy
3367 `ndarray` (including Tensor).
3368 comparison_target: The target value of comparison.
3369 """
3370 (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3371 a = self._GetNdArray(a)
3372 self.assertLessEqual(np.max(a), comparison_target)
3374 def _format_subscripts(self, subscripts, value, limit=10, indent=2):
3375 """Generate a summary of ndarray subscripts as a list of str.
3377 If limit == N, this method will print up to the first N subscripts on
3378 separate
3379 lines. A line of ellipses (...) will be appended at the end if the number of
3380 subscripts exceeds N.
3382 Args:
3383 subscripts: The tensor (np.ndarray) subscripts, of the same format as
3384 np.where()'s return value, i.e., a tuple of arrays with each array
3385 corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
3386 value: (np.ndarray) value of the tensor.
3387 limit: (int) The maximum number of indices to print.
3388 indent: (int) Number of characters to indent at the beginning of each
3389 line.
3391 Returns:
3392 (list of str) the multi-line representation of the subscripts and values,
3393 potentially with omission at the end.
3394 """
3395 lines = []
3396 subscripts = np.transpose(subscripts)
3397 prefix = " " * indent
3398 if np.ndim(value) == 0:
3399 return [prefix + "[0] : " + str(value)]
3400 for subscript in itertools.islice(subscripts, limit):
3401 lines.append(prefix + str(subscript) + " : " +
3402 str(value[tuple(subscript)]))
3403 if len(subscripts) > limit:
3404 lines.append(prefix + "...")
3405 return lines
3407 @py_func_if_in_function
3408 def assertAllInRange(self,
3409 target,
3410 lower_bound,
3411 upper_bound,
3412 open_lower_bound=False,
3413 open_upper_bound=False):
3414 """Assert that elements in a Tensor are all in a given range.
3416 Args:
3417 target: The numpy `ndarray`, or anything that can be converted into a
3418 numpy `ndarray` (including Tensor).
3419 lower_bound: lower bound of the range
3420 upper_bound: upper bound of the range
3421 open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
3422 than the default >=)
3423 open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
3424 than the default <=)
3426 Raises:
3427 AssertionError:
3428 if the value tensor does not have an ordered numeric type (float* or
3429 int*), or
3430 if there are nan values, or
3431 if any of the elements do not fall in the specified range.
3432 """
3433 target = self._GetNdArray(target)
3434 if not (np.issubdtype(target.dtype, np.floating) or
3435 np.issubdtype(target.dtype, np.integer)):
3436 raise AssertionError(
3437 "The value of %s does not have an ordered numeric type, instead it "
3438 "has type: %s" % (target, target.dtype))
3440 nan_subscripts = np.where(np.isnan(target))
3441 if np.size(nan_subscripts):
3442 raise AssertionError(
3443 "%d of the %d element(s) are NaN. "
3444 "Subscripts(s) and value(s) of the NaN element(s):\n" %
3445 (len(nan_subscripts[0]), np.size(target)) +
3446 "\n".join(self._format_subscripts(nan_subscripts, target)))
3448 range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
3449 str(upper_bound) + (")" if open_upper_bound else "]"))
3451 violations = (
3452 np.less_equal(target, lower_bound) if open_lower_bound else np.less(
3453 target, lower_bound))
3454 violations = np.logical_or(
3455 violations,
3456 np.greater_equal(target, upper_bound)
3457 if open_upper_bound else np.greater(target, upper_bound))
3458 violation_subscripts = np.where(violations)
3459 if np.size(violation_subscripts):
3460 raise AssertionError(
3461 "%d of the %d element(s) are outside the range %s. " %
3462 (len(violation_subscripts[0]), np.size(target), range_str) +
3463 "Subscript(s) and value(s) of the offending elements:\n" +
3464 "\n".join(self._format_subscripts(violation_subscripts, target)))
3466 @py_func_if_in_function
3467 def assertAllInSet(self, target, expected_set):
3468 """Assert that elements of a Tensor are all in a given closed set.
3470 Args:
3471 target: The numpy `ndarray`, or anything that can be converted into a
3472 numpy `ndarray` (including Tensor).
3473 expected_set: (`list`, `tuple` or `set`) The closed set that the elements
3474 of the value of `target` are expected to fall into.
3476 Raises:
3477 AssertionError:
3478 if any of the elements do not fall into `expected_set`.
3479 """
3480 target = self._GetNdArray(target)
3482 # Elements in target that are not in expected_set.
3483 diff = np.setdiff1d(target.flatten(), list(expected_set))
3484 if np.size(diff):
3485 raise AssertionError("%d unique element(s) are not in the set %s: %s" %
3486 (np.size(diff), expected_set, diff))
3488 @py_func_if_in_function
3489 def assertDTypeEqual(self, target, expected_dtype):
3490 """Assert ndarray data type is equal to expected.
3492 Args:
3493 target: The numpy `ndarray`, or anything that can be converted into a
3494 numpy `ndarray` (including Tensor).
3495 expected_dtype: Expected data type.
3496 """
3497 target = self._GetNdArray(target)
3498 if not isinstance(target, list):
3499 arrays = [target]
3500 for arr in arrays:
3501 self.assertEqual(arr.dtype, expected_dtype)
3503 # pylint: disable=g-doc-return-or-yield
3504 @contextlib.contextmanager
3505 def assertRaisesWithPredicateMatch(self, exception_type,
3506 expected_err_re_or_predicate):
3507 """Returns a context manager to enclose code expected to raise an exception.
3509 If the exception is an OpError, the op stack is also included in the message
3510 predicate search.
3512 Args:
3513 exception_type: The expected type of exception that should be raised.
3514 expected_err_re_or_predicate: If this is callable, it should be a function
3515 of one argument that inspects the passed-in exception and returns True
3516 (success) or False (please fail the test). Otherwise, the error message
3517 is expected to match this regular expression partially.
3519 Returns:
3520 A context manager to surround code that is expected to raise an
3521 exception.
3522 """
3523 if callable(expected_err_re_or_predicate):
3524 predicate = expected_err_re_or_predicate
3525 else:
3527 def predicate(e):
3528 err_str = e.message if isinstance(e, errors.OpError) else str(e)
3529 op = e.op if isinstance(e, errors.OpError) else None
3530 while op is not None:
3531 err_str += "\nCaused by: " + op.name
3532 op = op._original_op # pylint: disable=protected-access
3533 logging.info("Searching within error strings: '%s' within '%s'",
3534 expected_err_re_or_predicate, err_str)
3535 return re.search(expected_err_re_or_predicate, err_str)
3537 try:
3538 yield
3539 self.fail(exception_type.__name__ + " not raised")
3540 except Exception as e: # pylint: disable=broad-except
3541 if not isinstance(e, exception_type) or not predicate(e):
3542 raise AssertionError("Exception of type %s: %s" %
3543 (str(type(e)), str(e)))
3545 # pylint: enable=g-doc-return-or-yield
3547 def assertRaisesOpError(self, expected_err_re_or_predicate):
3548 return self.assertRaisesWithPredicateMatch(errors.OpError,
3549 expected_err_re_or_predicate)
3551 def assertRaisesIncompatibleShapesError(
3552 self, exception_type=errors.InvalidArgumentError):
3553 return self.assertRaisesWithPredicateMatch(
3554 exception_type, r"Incompatible shapes|Dimensions must be equal|"
3555 r"required broadcastable shapes")
3557 def assertShapeEqual(self, input_a, input_b, msg=None):
3558 """Asserts that two Numpy or TensorFlow objects have the same shape.
3560 For Tensors, this compares statically known shapes at compile time, not
3561 dynamic shapes at runtime.
3563 Args:
3564 input_a: A Numpy ndarray, Numpy scalar, or a Tensor.
3565 input_b: A Numpy ndarray, Numpy scalar, or a Tensor.
3566 msg: Optional message to report on failure.
3568 Raises:
3569 TypeError: If the arguments have the wrong type.
3570 """
3571 if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)):
3572 raise TypeError(
3573 "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor."
3574 f"Instead received {type(input_a)}")
3575 if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)):
3576 raise TypeError(
3577 "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor."
3578 f"Instead received {type(input_b)}")
3579 shape_a = input_a.get_shape().as_list() if isinstance(
3580 input_a, ops.Tensor) else input_a.shape
3581 shape_b = input_b.get_shape().as_list() if isinstance(
3582 input_b, ops.Tensor) else input_b.shape
3583 self.assertAllEqual(shape_a, shape_b, msg=msg)
3585 def assertDeviceEqual(self, device1, device2, msg=None):
3586 """Asserts that the two given devices are the same.
3588 Args:
3589 device1: A string device name or TensorFlow `DeviceSpec` object.
3590 device2: A string device name or TensorFlow `DeviceSpec` object.
3591 msg: Optional message to report on failure.
3592 """
3593 device1 = pydev.canonical_name(device1)
3594 device2 = pydev.canonical_name(device2)
3595 self.assertEqual(
3596 device1, device2,
3597 "Devices %s and %s are not equal. %s" % (device1, device2, msg))
3599 @py_func_if_in_function
3600 def assertDictEqual(self, a, b, msg=None):
3601 """Assert that two given dictionary of tensors are the same.
3603 Args:
3604 a: Expected dictionary with numpy ndarray or anything else that can be
3605 converted to one as values.
3606 b: Actual dictionary with numpy ndarray or anything else that can be
3607 converted to one as values.
3608 msg: Optional message to report on failure.
3609 """
3610 # To keep backwards compatibility, we first try the base class
3611 # assertDictEqual. If that fails we try the tensorflow one.
3612 try:
3613 super().assertDictEqual(a, b, msg)
3614 except Exception: # pylint: disable=broad-except
3615 self.assertSameElements(a.keys(), b.keys()) # pylint: disable=g-assert-in-except
3616 for k, v in a.items():
3617 (a_k, b_k) = self.evaluate_if_both_tensors(v, b[k])
3618 a_k = self._GetNdArray(a_k)
3619 b_k = self._GetNdArray(b_k)
3620 if np.issubdtype(a_k.dtype, np.floating):
3621 self.assertAllClose(v, b[k], msg=k)
3622 else:
3623 self.assertAllEqual(v, b[k], msg=k)
3625 def _GetPyList(self, a):
3626 """Converts `a` to a nested python list."""
3627 if isinstance(a, ragged_tensor.RaggedTensor):
3628 return self.evaluate(a).to_list()
3629 elif isinstance(a, ops.Tensor):
3630 a = self.evaluate(a)
3631 return a.tolist() if isinstance(a, np.ndarray) else a
3632 elif isinstance(a, np.ndarray):
3633 return a.tolist()
3634 elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
3635 return a.to_list()
3636 else:
3637 return np.array(a, dtype=object).tolist()
3639 def _assertRaggedEqual(self, a, b, msg):
3640 """Asserts that two ragged tensors are equal."""
3641 a_list = self._GetPyList(a)
3642 b_list = self._GetPyList(b)
3643 self.assertEqual(a_list, b_list, msg)
3645 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3646 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3647 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3648 self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3650 def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
3651 a_list = self._GetPyList(a)
3652 b_list = self._GetPyList(b)
3653 self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
3655 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3656 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3657 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3658 self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3660 def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
3661 self.assertEqual(type(a), type(b))
3662 if isinstance(a, (list, tuple)):
3663 self.assertLen(a, len(b), "Length differs for %s" % path)
3664 for i in range(len(a)):
3665 self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
3666 "%s[%s]" % (path, i))
3667 else:
3668 self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
3670 # Fix Python 3+ compatibility issues
3671 # pylint: disable=invalid-name
3673 # Silence a deprecation warning
3674 assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
3676 # assertItemsEqual is assertCountEqual as of 3.2.
3677 assertItemsEqual = googletest.TestCase.assertCountEqual
3679 # pylint: enable=invalid-name
3681 @contextlib.contextmanager
3682 def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
3683 """Set the session and its graph to global default and constrain devices."""
3684 if context.executing_eagerly():
3685 yield None
3686 else:
3687 with sess.graph.as_default(), sess.as_default():
3688 if force_gpu:
3689 # Use the name of an actual device if one is detected, or
3690 # '/device:GPU:0' otherwise
3691 gpu_name = gpu_device_name()
3692 if not gpu_name:
3693 gpu_name = "/device:GPU:0"
3694 with sess.graph.device(gpu_name):
3695 yield sess
3696 elif use_gpu:
3697 yield sess
3698 else:
3699 with sess.graph.device("/device:CPU:0"):
3700 yield sess
3702 def _create_session(self, graph, config, force_gpu):
3703 """See session() for details."""
3705 def prepare_config(config):
3706 """Returns a config for sessions.
3708 Args:
3709 config: An optional config_pb2.ConfigProto to use to configure the
3710 session.
3712 Returns:
3713 A config_pb2.ConfigProto object.
3714 """
3715 # TODO(b/114333779): Enforce allow_soft_placement=False when
3716 # use_gpu=False. Currently many tests rely on the fact that any device
3717 # will be used even when a specific device is supposed to be used.
3718 allow_soft_placement = not force_gpu
3719 if config is None:
3720 config = context.context().config
3721 config.allow_soft_placement = allow_soft_placement
3722 elif not allow_soft_placement and config.allow_soft_placement:
3723 config_copy = context.context().config
3724 config = config_copy
3725 config.allow_soft_placement = False
3726 # Don't perform optimizations for tests so we don't inadvertently run
3727 # gpu ops on cpu
3728 config.graph_options.optimizer_options.opt_level = -1
3729 # Disable Grappler constant folding since some tests & benchmarks
3730 # use constant input and become meaningless after constant folding.
3731 # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
3732 # GRAPPLER TEAM.
3733 config.graph_options.rewrite_options.constant_folding = (
3734 rewriter_config_pb2.RewriterConfig.OFF)
3735 config.graph_options.rewrite_options.pin_to_host_optimization = (
3736 rewriter_config_pb2.RewriterConfig.OFF)
3737 return config
3739 return ErrorLoggingSession(graph=graph, config=prepare_config(config))
3741 def _get_cached_session(self,
3742 graph=None,
3743 config=None,
3744 force_gpu=False,
3745 crash_if_inconsistent_args=True):
3746 """See cached_session() for documentation."""
3747 if self._cached_session is None:
3748 sess = self._create_session(
3749 graph=graph, config=config, force_gpu=force_gpu)
3750 self._cached_session = sess
3751 self._cached_graph = graph
3752 self._cached_config = config
3753 self._cached_force_gpu = force_gpu
3754 return sess
3755 else:
3756 if crash_if_inconsistent_args and self._cached_graph is not graph:
3757 raise ValueError("The graph used to get the cached session is "
3758 "different than the one that was used to create the "
3759 "session. Maybe create a new session with "
3760 "self.session()")
3761 if crash_if_inconsistent_args and self._cached_config is not config:
3762 raise ValueError("The config used to get the cached session is "
3763 "different than the one that was used to create the "
3764 "session. Maybe create a new session with "
3765 "self.session()")
3766 if crash_if_inconsistent_args and (self._cached_force_gpu is
3767 not force_gpu):
3768 raise ValueError(
3769 "The force_gpu value used to get the cached session is "
3770 "different than the one that was used to create the "
3771 "session. Maybe create a new session with "
3772 "self.session()")
3773 return self._cached_session
3776ASSIGNED_PORTS = set()
3777lock = threading.Lock()
3780def pick_unused_port():
3781 """Returns an unused and unassigned local port."""
3782 import portpicker # pylint: disable=g-import-not-at-top
3784 global ASSIGNED_PORTS
3785 with lock:
3786 while True:
3787 try:
3788 port = portpicker.pick_unused_port()
3789 except portpicker.NoFreePortFoundError as porterror:
3790 raise unittest.SkipTest("Flakes in portpicker library do not represent"
3791 " TensorFlow errors.") from porterror
3792 if port > 10000 and port not in ASSIGNED_PORTS:
3793 ASSIGNED_PORTS.add(port)
3794 logging.info("Using local port %r", port)
3795 return port
3798@tf_export("test.create_local_cluster")
3799def create_local_cluster(num_workers,
3800 num_ps,
3801 protocol="grpc",
3802 worker_config=None,
3803 ps_config=None):
3804 """Create and start local servers and return the associated `Server` objects.
3806 "PS" stands for "parameter server": a task responsible for storing and
3807 updating the model's parameters. Other tasks send updates to these parameters
3808 as they work on optimizing the parameters. This particular division of labor
3809 between tasks is not required, but is common for distributed training.
3811 Read more at https://www.tensorflow.org/guide/extend/architecture
3813 
3816 Figure illustrates the interaction of these components.
3817 "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services.
3820 Example:
3821 ```python
3822 workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
3824 worker_sessions = [tf.compat.v1.Session(w.target) for w in workers]
3826 with tf.device("/job:ps/task:0"):
3827 ...
3828 with tf.device("/job:ps/task:1"):
3829 ...
3830 with tf.device("/job:worker/task:0"):
3831 ...
3832 with tf.device("/job:worker/task:1"):
3833 ...
3835 worker_sessions[0].run(...)
3836 ```
3838 Args:
3839 num_workers: Number of worker servers to start.
3840 num_ps: Number of PS servers to start.
3841 protocol: Communication protocol. Allowed values are documented in the
3842 documentation of `tf.distribute.Server`.
3843 worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be
3844 used to instantiate multiple devices etc.
3845 ps_config: (optional) `tf.ConfigProto` to initialize PS servers.
3847 Returns:
3848 A tuple `(worker_servers, ps_servers)`. `worker_servers` is a list
3849 of `num_workers` objects of type `tf.distribute.Server` (all running
3850 locally);
3851 and `ps_servers` is a list of `num_ps` objects of similar type.
3853 Raises:
3854 ImportError: if portpicker module was not found at load time
3855 """
3856 worker_ports = [pick_unused_port() for _ in range(num_workers)]
3857 ps_ports = [pick_unused_port() for _ in range(num_ps)]
3858 cluster_dict = {
3859 "worker": ["localhost:%s" % port for port in worker_ports],
3860 "ps": ["localhost:%s" % port for port in ps_ports]
3861 }
3862 cs = server_lib.ClusterSpec(cluster_dict)
3864 workers = [
3865 server_lib.Server(
3866 cs,
3867 job_name="worker",
3868 protocol=protocol,
3869 task_index=ix,
3870 config=worker_config,
3871 start=True) for ix in range(num_workers)
3872 ]
3873 ps_servers = [
3874 server_lib.Server(
3875 cs,
3876 job_name="ps",
3877 protocol=protocol,
3878 task_index=ix,
3879 config=ps_config,
3880 start=True) for ix in range(num_ps)
3881 ]
3883 return workers, ps_servers
3886def get_node_def_from_graph(node_name, graph_def):
3887 """Returns the `NodeDef` instance for given node name in the graph def.
3889 This method explores only the NodeDefs in `graph_def.node`.
3891 Args:
3892 node_name: Name of the NodeDef to search for.
3893 graph_def: An instance of `GraphDef` proto.
3895 Returns:
3896 the `NodeDef` instance whose name field matches the given node_name or None.
3897 """
3898 for node_def in graph_def.node:
3899 if node_def.name == node_name:
3900 return node_def
3901 return None
3904def set_producer_version(graph, producer_version):
3905 """Sets graph.graph_def_versions.producer to `producer_version`."""
3906 # The C API doesn't expose altering GraphDefVersions. We can indirectly set
3907 # it via import_graph_def though.
3908 graph_def = graph_pb2.GraphDef()
3909 graph_def.versions.producer = producer_version
3910 with graph.as_default():
3911 importer.import_graph_def(graph_def)
3912 assert graph.graph_def_versions.producer, producer_version
3915@contextlib.contextmanager
3916def _fake_gradient_tape_context_manager():
3917 """tf.gradients(...) implemented as tf.GradientTape context manager interface.
3919 This is useful to test tf.gradients() in tests that uses tf.GradientTape().
3921 Yields:
3922 gradient tape instance that's implemented by tf.gradients() underneath.
3923 """
3924 try:
3925 class FakeGradientTape:
3927 def watch(self, x):
3928 pass
3930 def gradient(self, y, x, grad_ys=None):
3931 result = gradients_impl.gradients(y, x, grad_ys)
3933 # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single
3934 # element. So unpack if needed to match `tape.gradient()` behavior.
3935 if not isinstance(x, (list, tuple)):
3936 assert len(result) == 1
3937 return result[0]
3939 return result
3941 yield FakeGradientTape()
3942 finally:
3943 pass
3946class AbstractGradientTape:
3947 """Abstract GradientTape context manager that has multiple implementations.
3949 This is useful to test both tf.GradientTape() and tf.gradients() without
3950 duplicating tests.
3951 """
3953 def __init__(self, use_tape, persistent=False):
3954 self._use_tape = use_tape
3955 self._persistent = persistent
3957 def __enter__(self):
3958 if self._use_tape:
3959 self._tape_impl = backprop.GradientTape(persistent=self._persistent)
3960 else:
3961 self._tape_impl = _fake_gradient_tape_context_manager()
3962 return self._tape_impl.__enter__()
3964 def __exit__(self, exc_type, exc_val, exc_tb):
3965 self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
3968@contextlib.contextmanager
3969def run_functions_eagerly(run_eagerly):
3970 """Runs functions eagerly if `run_eagerly` is true.
3972 WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode
3973 *WILL NOT* make the tf.function to run eagerly because eager is disabled by
3974 default in V1. Instead, tf.function will run as a traced graph function.
3976 Ensures that the state (for running functions eagerly) is back to the initial
3977 `def_function.RUN_FUNCTIONS_EAGERLY` state.
3979 Args:
3980 run_eagerly: Boolean determining whether to run the function eagerly or not.
3982 Raises:
3983 ValueError if `run_eagerly` is not a boolean.
3985 Yields:
3986 Nothing.
3987 """
3988 if not isinstance(run_eagerly, bool):
3989 raise ValueError(
3990 "Expected bool for `run_eagerly` but got {}".format(run_eagerly))
3992 is_eager = context.executing_eagerly()
3993 if not is_eager and run_eagerly:
3994 logging.warning(
3995 "Running tf.function eagerly in V1 graph mode is not supported. "
3996 "tf.function will be run as a traced graph function.")
3998 initial_state = def_function.functions_run_eagerly()
3999 def_function.run_functions_eagerly(run_eagerly)
4000 try:
4001 yield
4002 finally:
4003 def_function.run_functions_eagerly(initial_state)
4006class TestDelta:
4007 """A utility class to track increments to test counters."""
4009 def __init__(self, name, label):
4010 self.name = name
4011 self.label = label
4012 self.Reset()
4014 def Reset(self):
4015 self.last_value = _test_metrics_util.test_counter_value(
4016 self.name, self.label)
4018 def Get(self):
4019 value = _test_metrics_util.test_counter_value(self.name, self.label)
4020 return value - self.last_value
4023@tf_export("test.experimental.sync_devices")
4024def sync_devices():
4025 """Synchronizes all devices.
4027 By default, GPUs run asynchronously. This means that when you run an op on the
4028 GPU, like `tf.linalg.matmul`, the op may still be running on the GPU when the
4029 function returns. Non-GPU devices can also be made to run asynchronously by
4030 calling `tf.config.experimental.set_synchronous_execution(False)`. Calling
4031 `sync_devices()` blocks until pending ops have finished executing. This is
4032 primarily useful for measuring performance during a benchmark.
4034 For example, here is how you can measure how long `tf.linalg.matmul` runs:
4036 >>> import time
4037 >>> x = tf.random.normal((4096, 4096))
4038 >>> tf.linalg.matmul(x, x) # Warmup.
4039 >>> tf.test.experimental.sync_devices() # Block until warmup has completed.
4040 >>>
4041 >>> start = time.time()
4042 >>> y = tf.linalg.matmul(x, x)
4043 >>> tf.test.experimental.sync_devices() # Block until matmul has completed.
4044 >>> end = time.time()
4045 >>> print(f'Time taken: {end - start}')
4047 If the call to `sync_devices()` was omitted, the time printed could be too
4048 small. This is because the op could still be running asynchronously when
4049 the line `end = time.time()` is executed.
4051 Raises:
4052 RuntimeError: If run outside Eager mode. This must be called in Eager mode,
4053 outside any `tf.function`s.
4054 """
4055 if not context.executing_eagerly():
4056 raise RuntimeError(
4057 "sync_devices() must only be called in Eager mode, outside tf.functions"
4058 )
4060 # There are two sources of asynchrony in TensorFlow:
4061 #
4062 # 1. On GPUs, kernels are run on a CUDA stream, which is inherently
4063 # asynchronous.
4064 # 2. Calling `tf.config.experimental.set_synchronous_execution(False)` makes
4065 # all ops asynchronous, in which case TensorFlow maintains internal queues
4066 # of pending ops.
4067 #
4068 # Calling SyncDevice addresses source (1). Calling async_await addresses
4069 # source (2). It is important that SyncDevice() is called before async_wait(),
4070 # otherwise the SyncDevice op itself may still be pending on an internal
4071 # TensorFlow queue when the sync_devices() Python function returns.
4072 devices = config.list_logical_devices()
4073 for dev in devices:
4074 with ops.device(dev.name):
4075 gen_sync_ops.SyncDevice()
4076 context.async_wait()