Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/absl/testing/absltest.py: 21%
943 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Base functionality for Abseil Python tests.
17This module contains base classes and high-level functions for Abseil-style
18tests.
19"""
21from collections import abc
22import contextlib
23import difflib
24import enum
25import errno
26import faulthandler
27import getpass
28import inspect
29import io
30import itertools
31import json
32import os
33import random
34import re
35import shlex
36import shutil
37import signal
38import stat
39import subprocess
40import sys
41import tempfile
42import textwrap
43import typing
44from typing import Any, AnyStr, BinaryIO, Callable, ContextManager, IO, Iterator, List, Mapping, MutableMapping, MutableSequence, NoReturn, Optional, Sequence, Text, TextIO, Tuple, Type, Union
45import unittest
46from unittest import mock # pylint: disable=unused-import Allow absltest.mock.
47from urllib import parse
49from absl import app # pylint: disable=g-import-not-at-top
50from absl import flags
51from absl import logging
52from absl.testing import _pretty_print_reporter
53from absl.testing import xml_reporter
55# Use an if-type-checking block to prevent leakage of type-checking only
56# symbols. We don't want people relying on these at runtime.
57if typing.TYPE_CHECKING:
58 # Unbounded TypeVar for general usage
59 _T = typing.TypeVar('_T')
61 import unittest.case # pylint: disable=g-import-not-at-top,g-bad-import-order
63 _OutcomeType = unittest.case._Outcome # pytype: disable=module-attr
66# pylint: enable=g-import-not-at-top
68# Re-export a bunch of unittest functions we support so that people don't
69# have to import unittest to get them
70# pylint: disable=invalid-name
71skip = unittest.skip
72skipIf = unittest.skipIf
73skipUnless = unittest.skipUnless
74SkipTest = unittest.SkipTest
75expectedFailure = unittest.expectedFailure
76# pylint: enable=invalid-name
78# End unittest re-exports
80FLAGS = flags.FLAGS
82_TEXT_OR_BINARY_TYPES = (str, bytes)
84# Suppress surplus entries in AssertionError stack traces.
85__unittest = True # pylint: disable=invalid-name
88def expectedFailureIf(condition, reason): # pylint: disable=invalid-name
89 """Expects the test to fail if the run condition is True.
91 Example usage::
93 @expectedFailureIf(sys.version.major == 2, "Not yet working in py2")
94 def test_foo(self):
95 ...
97 Args:
98 condition: bool, whether to expect failure or not.
99 reason: Text, the reason to expect failure.
100 Returns:
101 Decorator function
102 """
103 del reason # Unused
104 if condition:
105 return unittest.expectedFailure
106 else:
107 return lambda f: f
110class TempFileCleanup(enum.Enum):
111 # Always cleanup temp files when the test completes.
112 ALWAYS = 'always'
113 # Only cleanup temp file if the test passes. This allows easier inspection
114 # of tempfile contents on test failure. absltest.TEST_TMPDIR.value determines
115 # where tempfiles are created.
116 SUCCESS = 'success'
117 # Never cleanup temp files.
118 OFF = 'never'
121# Many of the methods in this module have names like assertSameElements.
122# This kind of name does not comply with PEP8 style,
123# but it is consistent with the naming of methods in unittest.py.
124# pylint: disable=invalid-name
127def _get_default_test_random_seed():
128 # type: () -> int
129 random_seed = 301
130 value = os.environ.get('TEST_RANDOM_SEED', '')
131 try:
132 random_seed = int(value)
133 except ValueError:
134 pass
135 return random_seed
138def get_default_test_srcdir():
139 # type: () -> Text
140 """Returns default test source dir."""
141 return os.environ.get('TEST_SRCDIR', '')
144def get_default_test_tmpdir():
145 # type: () -> Text
146 """Returns default test temp dir."""
147 tmpdir = os.environ.get('TEST_TMPDIR', '')
148 if not tmpdir:
149 tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing')
151 return tmpdir
154def _get_default_randomize_ordering_seed():
155 # type: () -> int
156 """Returns default seed to use for randomizing test order.
158 This function first checks the --test_randomize_ordering_seed flag, and then
159 the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value
160 we find is:
161 * (not set): disable test randomization
162 * 0: disable test randomization
163 * 'random': choose a random seed in [1, 4294967295] for test order
164 randomization
165 * positive integer: use this seed for test order randomization
167 (The values used are patterned after
168 https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED).
170 In principle, it would be simpler to return None if no override is provided;
171 however, the python random module has no `get_seed()`, only `getstate()`,
172 which returns far more data than we want to pass via an environment variable
173 or flag.
175 Returns:
176 A default value for test case randomization (int). 0 means do not randomize.
178 Raises:
179 ValueError: Raised when the flag or env value is not one of the options
180 above.
181 """
182 if FLAGS['test_randomize_ordering_seed'].present:
183 randomize = FLAGS.test_randomize_ordering_seed
184 elif 'TEST_RANDOMIZE_ORDERING_SEED' in os.environ:
185 randomize = os.environ['TEST_RANDOMIZE_ORDERING_SEED']
186 else:
187 randomize = ''
188 if not randomize:
189 return 0
190 if randomize == 'random':
191 return random.Random().randint(1, 4294967295)
192 if randomize == '0':
193 return 0
194 try:
195 seed = int(randomize)
196 if seed > 0:
197 return seed
198 except ValueError:
199 pass
200 raise ValueError(
201 'Unknown test randomization seed value: {}'.format(randomize))
204TEST_SRCDIR = flags.DEFINE_string(
205 'test_srcdir',
206 get_default_test_srcdir(),
207 'Root of directory tree where source files live',
208 allow_override_cpp=True)
209TEST_TMPDIR = flags.DEFINE_string(
210 'test_tmpdir',
211 get_default_test_tmpdir(),
212 'Directory for temporary testing files',
213 allow_override_cpp=True)
215flags.DEFINE_integer(
216 'test_random_seed',
217 _get_default_test_random_seed(),
218 'Random seed for testing. Some test frameworks may '
219 'change the default value of this flag between runs, so '
220 'it is not appropriate for seeding probabilistic tests.',
221 allow_override_cpp=True)
222flags.DEFINE_string(
223 'test_randomize_ordering_seed',
224 '',
225 'If positive, use this as a seed to randomize the '
226 'execution order for test cases. If "random", pick a '
227 'random seed to use. If 0 or not set, do not randomize '
228 'test case execution order. This flag also overrides '
229 'the TEST_RANDOMIZE_ORDERING_SEED environment variable.',
230 allow_override_cpp=True)
231flags.DEFINE_string('xml_output_file', '', 'File to store XML test results')
234# We might need to monkey-patch TestResult so that it stops considering an
235# unexpected pass as a as a "successful result". For details, see
236# http://bugs.python.org/issue20165
237def _monkey_patch_test_result_for_unexpected_passes():
238 # type: () -> None
239 """Workaround for <http://bugs.python.org/issue20165>."""
241 def wasSuccessful(self):
242 # type: () -> bool
243 """Tells whether or not this result was a success.
245 Any unexpected pass is to be counted as a non-success.
247 Args:
248 self: The TestResult instance.
250 Returns:
251 Whether or not this result was a success.
252 """
253 return (len(self.failures) == len(self.errors) ==
254 len(self.unexpectedSuccesses) == 0)
256 test_result = unittest.TestResult()
257 test_result.addUnexpectedSuccess(unittest.FunctionTestCase(lambda: None))
258 if test_result.wasSuccessful(): # The bug is present.
259 unittest.TestResult.wasSuccessful = wasSuccessful
260 if test_result.wasSuccessful(): # Warn the user if our hot-fix failed.
261 sys.stderr.write('unittest.result.TestResult monkey patch to report'
262 ' unexpected passes as failures did not work.\n')
265_monkey_patch_test_result_for_unexpected_passes()
268def _open(filepath, mode, _open_func=open):
269 # type: (Text, Text, Callable[..., IO]) -> IO
270 """Opens a file.
272 Like open(), but ensure that we can open real files even if tests stub out
273 open().
275 Args:
276 filepath: A filepath.
277 mode: A mode.
278 _open_func: A built-in open() function.
280 Returns:
281 The opened file object.
282 """
283 return _open_func(filepath, mode, encoding='utf-8')
286class _TempDir(object):
287 """Represents a temporary directory for tests.
289 Creation of this class is internal. Using its public methods is OK.
291 This class implements the `os.PathLike` interface (specifically,
292 `os.PathLike[str]`). This means, in Python 3, it can be directly passed
293 to e.g. `os.path.join()`.
294 """
296 def __init__(self, path):
297 # type: (Text) -> None
298 """Module-private: do not instantiate outside module."""
299 self._path = path
301 @property
302 def full_path(self):
303 # type: () -> Text
304 """Returns the path, as a string, for the directory.
306 TIP: Instead of e.g. `os.path.join(temp_dir.full_path)`, you can simply
307 do `os.path.join(temp_dir)` because `__fspath__()` is implemented.
308 """
309 return self._path
311 def __fspath__(self):
312 # type: () -> Text
313 """See os.PathLike."""
314 return self.full_path
316 def create_file(self, file_path=None, content=None, mode='w', encoding='utf8',
317 errors='strict'):
318 # type: (Optional[Text], Optional[AnyStr], Text, Text, Text) -> _TempFile
319 """Create a file in the directory.
321 NOTE: If the file already exists, it will be made writable and overwritten.
323 Args:
324 file_path: Optional file path for the temp file. If not given, a unique
325 file name will be generated and used. Slashes are allowed in the name;
326 any missing intermediate directories will be created. NOTE: This path
327 is the path that will be cleaned up, including any directories in the
328 path, e.g., 'foo/bar/baz.txt' will `rm -r foo`
329 content: Optional string or bytes to initially write to the file. If not
330 specified, then an empty file is created.
331 mode: Mode string to use when writing content. Only used if `content` is
332 non-empty.
333 encoding: Encoding to use when writing string content. Only used if
334 `content` is text.
335 errors: How to handle text to bytes encoding errors. Only used if
336 `content` is text.
338 Returns:
339 A _TempFile representing the created file.
340 """
341 tf, _ = _TempFile._create(self._path, file_path, content, mode, encoding,
342 errors)
343 return tf
345 def mkdir(self, dir_path=None):
346 # type: (Optional[Text]) -> _TempDir
347 """Create a directory in the directory.
349 Args:
350 dir_path: Optional path to the directory to create. If not given,
351 a unique name will be generated and used.
353 Returns:
354 A _TempDir representing the created directory.
355 """
356 if dir_path:
357 path = os.path.join(self._path, dir_path)
358 else:
359 path = tempfile.mkdtemp(dir=self._path)
361 # Note: there's no need to clear the directory since the containing
362 # dir was cleared by the tempdir() function.
363 os.makedirs(path, exist_ok=True)
364 return _TempDir(path)
367class _TempFile(object):
368 """Represents a tempfile for tests.
370 Creation of this class is internal. Using its public methods is OK.
372 This class implements the `os.PathLike` interface (specifically,
373 `os.PathLike[str]`). This means, in Python 3, it can be directly passed
374 to e.g. `os.path.join()`.
375 """
377 def __init__(self, path):
378 # type: (Text) -> None
379 """Private: use _create instead."""
380 self._path = path
382 # pylint: disable=line-too-long
383 @classmethod
384 def _create(cls, base_path, file_path, content, mode, encoding, errors):
385 # type: (Text, Optional[Text], AnyStr, Text, Text, Text) -> Tuple[_TempFile, Text]
386 # pylint: enable=line-too-long
387 """Module-private: create a tempfile instance."""
388 if file_path:
389 cleanup_path = os.path.join(base_path, _get_first_part(file_path))
390 path = os.path.join(base_path, file_path)
391 os.makedirs(os.path.dirname(path), exist_ok=True)
392 # The file may already exist, in which case, ensure it's writable so that
393 # it can be truncated.
394 if os.path.exists(path) and not os.access(path, os.W_OK):
395 stat_info = os.stat(path)
396 os.chmod(path, stat_info.st_mode | stat.S_IWUSR)
397 else:
398 os.makedirs(base_path, exist_ok=True)
399 fd, path = tempfile.mkstemp(dir=str(base_path))
400 os.close(fd)
401 cleanup_path = path
403 tf = cls(path)
405 if content:
406 if isinstance(content, str):
407 tf.write_text(content, mode=mode, encoding=encoding, errors=errors)
408 else:
409 tf.write_bytes(content, mode)
411 else:
412 tf.write_bytes(b'')
414 return tf, cleanup_path
416 @property
417 def full_path(self):
418 # type: () -> Text
419 """Returns the path, as a string, for the file.
421 TIP: Instead of e.g. `os.path.join(temp_file.full_path)`, you can simply
422 do `os.path.join(temp_file)` because `__fspath__()` is implemented.
423 """
424 return self._path
426 def __fspath__(self):
427 # type: () -> Text
428 """See os.PathLike."""
429 return self.full_path
431 def read_text(self, encoding='utf8', errors='strict'):
432 # type: (Text, Text) -> Text
433 """Return the contents of the file as text."""
434 with self.open_text(encoding=encoding, errors=errors) as fp:
435 return fp.read()
437 def read_bytes(self):
438 # type: () -> bytes
439 """Return the content of the file as bytes."""
440 with self.open_bytes() as fp:
441 return fp.read()
443 def write_text(self, text, mode='w', encoding='utf8', errors='strict'):
444 # type: (Text, Text, Text, Text) -> None
445 """Write text to the file.
447 Args:
448 text: Text to write. In Python 2, it can be bytes, which will be
449 decoded using the `encoding` arg (this is as an aid for code that
450 is 2 and 3 compatible).
451 mode: The mode to open the file for writing.
452 encoding: The encoding to use when writing the text to the file.
453 errors: The error handling strategy to use when converting text to bytes.
454 """
455 with self.open_text(mode, encoding=encoding, errors=errors) as fp:
456 fp.write(text)
458 def write_bytes(self, data, mode='wb'):
459 # type: (bytes, Text) -> None
460 """Write bytes to the file.
462 Args:
463 data: bytes to write.
464 mode: Mode to open the file for writing. The "b" flag is implicit if
465 not already present. It must not have the "t" flag.
466 """
467 with self.open_bytes(mode) as fp:
468 fp.write(data)
470 def open_text(self, mode='rt', encoding='utf8', errors='strict'):
471 # type: (Text, Text, Text) -> ContextManager[TextIO]
472 """Return a context manager for opening the file in text mode.
474 Args:
475 mode: The mode to open the file in. The "t" flag is implicit if not
476 already present. It must not have the "b" flag.
477 encoding: The encoding to use when opening the file.
478 errors: How to handle decoding errors.
480 Returns:
481 Context manager that yields an open file.
483 Raises:
484 ValueError: if invalid inputs are provided.
485 """
486 if 'b' in mode:
487 raise ValueError('Invalid mode {!r}: "b" flag not allowed when opening '
488 'file in text mode'.format(mode))
489 if 't' not in mode:
490 mode += 't'
491 cm = self._open(mode, encoding, errors)
492 return cm
494 def open_bytes(self, mode='rb'):
495 # type: (Text) -> ContextManager[BinaryIO]
496 """Return a context manager for opening the file in binary mode.
498 Args:
499 mode: The mode to open the file in. The "b" mode is implicit if not
500 already present. It must not have the "t" flag.
502 Returns:
503 Context manager that yields an open file.
505 Raises:
506 ValueError: if invalid inputs are provided.
507 """
508 if 't' in mode:
509 raise ValueError('Invalid mode {!r}: "t" flag not allowed when opening '
510 'file in binary mode'.format(mode))
511 if 'b' not in mode:
512 mode += 'b'
513 cm = self._open(mode, encoding=None, errors=None)
514 return cm
516 # TODO(b/123775699): Once pytype supports typing.Literal, use overload and
517 # Literal to express more precise return types. The contained type is
518 # currently `Any` to avoid [bad-return-type] errors in the open_* methods.
519 @contextlib.contextmanager
520 def _open(
521 self,
522 mode: str,
523 encoding: Optional[str] = 'utf8',
524 errors: Optional[str] = 'strict',
525 ) -> Iterator[Any]:
526 with io.open(
527 self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
528 yield fp
531class _method(object):
532 """A decorator that supports both instance and classmethod invocations.
534 Using similar semantics to the @property builtin, this decorator can augment
535 an instance method to support conditional logic when invoked on a class
536 object. This breaks support for invoking an instance method via the class
537 (e.g. Cls.method(self, ...)) but is still situationally useful.
538 """
540 def __init__(self, finstancemethod):
541 # type: (Callable[..., Any]) -> None
542 self._finstancemethod = finstancemethod
543 self._fclassmethod = None
545 def classmethod(self, fclassmethod):
546 # type: (Callable[..., Any]) -> _method
547 self._fclassmethod = classmethod(fclassmethod)
548 return self
550 def __doc__(self):
551 # type: () -> str
552 if getattr(self._finstancemethod, '__doc__'):
553 return self._finstancemethod.__doc__
554 elif getattr(self._fclassmethod, '__doc__'):
555 return self._fclassmethod.__doc__
556 return ''
558 def __get__(self, obj, type_):
559 # type: (Optional[Any], Optional[Type[Any]]) -> Callable[..., Any]
560 func = self._fclassmethod if obj is None else self._finstancemethod
561 return func.__get__(obj, type_) # pytype: disable=attribute-error
564class TestCase(unittest.TestCase):
565 """Extension of unittest.TestCase providing more power."""
567 # When to cleanup files/directories created by our `create_tempfile()` and
568 # `create_tempdir()` methods after each test case completes. This does *not*
569 # affect e.g., files created outside of those methods, e.g., using the stdlib
570 # tempfile module. This can be overridden at the class level, instance level,
571 # or with the `cleanup` arg of `create_tempfile()` and `create_tempdir()`. See
572 # `TempFileCleanup` for details on the different values.
573 # TODO(b/70517332): Remove the type comment and the disable once pytype has
574 # better support for enums.
575 tempfile_cleanup = TempFileCleanup.ALWAYS # type: TempFileCleanup # pytype: disable=annotation-type-mismatch
577 maxDiff = 80 * 20
578 longMessage = True
580 # Exit stacks for per-test and per-class scopes.
581 if sys.version_info < (3, 11):
582 _exit_stack = None
583 _cls_exit_stack = None
585 def __init__(self, *args, **kwargs):
586 super(TestCase, self).__init__(*args, **kwargs)
587 # This is to work around missing type stubs in unittest.pyi
588 self._outcome = getattr(self, '_outcome') # type: Optional[_OutcomeType]
590 def setUp(self):
591 super(TestCase, self).setUp()
592 # NOTE: Only Python 3 contextlib has ExitStack and
593 # Python 3.11+ already has enterContext.
594 if hasattr(contextlib, 'ExitStack') and sys.version_info < (3, 11):
595 self._exit_stack = contextlib.ExitStack()
596 self.addCleanup(self._exit_stack.close)
598 @classmethod
599 def setUpClass(cls):
600 super(TestCase, cls).setUpClass()
601 # NOTE: Only Python 3 contextlib has ExitStack, only Python 3.8+ has
602 # addClassCleanup and Python 3.11+ already has enterClassContext.
603 if (
604 hasattr(contextlib, 'ExitStack')
605 and hasattr(cls, 'addClassCleanup')
606 and sys.version_info < (3, 11)
607 ):
608 cls._cls_exit_stack = contextlib.ExitStack()
609 cls.addClassCleanup(cls._cls_exit_stack.close)
611 def create_tempdir(self, name=None, cleanup=None):
612 # type: (Optional[Text], Optional[TempFileCleanup]) -> _TempDir
613 """Create a temporary directory specific to the test.
615 NOTE: The directory and its contents will be recursively cleared before
616 creation. This ensures that there is no pre-existing state.
618 This creates a named directory on disk that is isolated to this test, and
619 will be properly cleaned up by the test. This avoids several pitfalls of
620 creating temporary directories for test purposes, as well as makes it easier
621 to setup directories and verify their contents. For example::
623 def test_foo(self):
624 out_dir = self.create_tempdir()
625 out_log = out_dir.create_file('output.log')
626 expected_outputs = [
627 os.path.join(out_dir, 'data-0.txt'),
628 os.path.join(out_dir, 'data-1.txt'),
629 ]
630 code_under_test(out_dir)
631 self.assertTrue(os.path.exists(expected_paths[0]))
632 self.assertTrue(os.path.exists(expected_paths[1]))
633 self.assertEqual('foo', out_log.read_text())
635 See also: :meth:`create_tempfile` for creating temporary files.
637 Args:
638 name: Optional name of the directory. If not given, a unique
639 name will be generated and used.
640 cleanup: Optional cleanup policy on when/if to remove the directory (and
641 all its contents) at the end of the test. If None, then uses
642 :attr:`tempfile_cleanup`.
644 Returns:
645 A _TempDir representing the created directory; see _TempDir class docs
646 for usage.
647 """
648 test_path = self._get_tempdir_path_test()
650 if name:
651 path = os.path.join(test_path, name)
652 cleanup_path = os.path.join(test_path, _get_first_part(name))
653 else:
654 os.makedirs(test_path, exist_ok=True)
655 path = tempfile.mkdtemp(dir=test_path)
656 cleanup_path = path
658 _rmtree_ignore_errors(cleanup_path)
659 os.makedirs(path, exist_ok=True)
661 self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
663 return _TempDir(path)
665 # pylint: disable=line-too-long
666 def create_tempfile(self, file_path=None, content=None, mode='w',
667 encoding='utf8', errors='strict', cleanup=None):
668 # type: (Optional[Text], Optional[AnyStr], Text, Text, Text, Optional[TempFileCleanup]) -> _TempFile
669 # pylint: enable=line-too-long
670 """Create a temporary file specific to the test.
672 This creates a named file on disk that is isolated to this test, and will
673 be properly cleaned up by the test. This avoids several pitfalls of
674 creating temporary files for test purposes, as well as makes it easier
675 to setup files, their data, read them back, and inspect them when
676 a test fails. For example::
678 def test_foo(self):
679 output = self.create_tempfile()
680 code_under_test(output)
681 self.assertGreater(os.path.getsize(output), 0)
682 self.assertEqual('foo', output.read_text())
684 NOTE: This will zero-out the file. This ensures there is no pre-existing
685 state.
686 NOTE: If the file already exists, it will be made writable and overwritten.
688 See also: :meth:`create_tempdir` for creating temporary directories, and
689 ``_TempDir.create_file`` for creating files within a temporary directory.
691 Args:
692 file_path: Optional file path for the temp file. If not given, a unique
693 file name will be generated and used. Slashes are allowed in the name;
694 any missing intermediate directories will be created. NOTE: This path is
695 the path that will be cleaned up, including any directories in the path,
696 e.g., ``'foo/bar/baz.txt'`` will ``rm -r foo``.
697 content: Optional string or
698 bytes to initially write to the file. If not
699 specified, then an empty file is created.
700 mode: Mode string to use when writing content. Only used if `content` is
701 non-empty.
702 encoding: Encoding to use when writing string content. Only used if
703 `content` is text.
704 errors: How to handle text to bytes encoding errors. Only used if
705 `content` is text.
706 cleanup: Optional cleanup policy on when/if to remove the directory (and
707 all its contents) at the end of the test. If None, then uses
708 :attr:`tempfile_cleanup`.
710 Returns:
711 A _TempFile representing the created file; see _TempFile class docs for
712 usage.
713 """
714 test_path = self._get_tempdir_path_test()
715 tf, cleanup_path = _TempFile._create(test_path, file_path, content=content,
716 mode=mode, encoding=encoding,
717 errors=errors)
718 self._maybe_add_temp_path_cleanup(cleanup_path, cleanup)
719 return tf
721 @_method
722 def enter_context(self, manager):
723 # type: (ContextManager[_T]) -> _T
724 """Returns the CM's value after registering it with the exit stack.
726 Entering a context pushes it onto a stack of contexts. When `enter_context`
727 is called on the test instance (e.g. `self.enter_context`), the context is
728 exited after the test case's tearDown call. When called on the test class
729 (e.g. `TestCase.enter_context`), the context is exited after the test
730 class's tearDownClass call.
732 Contexts are exited in the reverse order of entering. They will always
733 be exited, regardless of test failure/success.
735 This is useful to eliminate per-test boilerplate when context managers
736 are used. For example, instead of decorating every test with `@mock.patch`,
737 simply do `self.foo = self.enter_context(mock.patch(...))' in `setUp()`.
739 NOTE: The context managers will always be exited without any error
740 information. This is an unfortunate implementation detail due to some
741 internals of how unittest runs tests.
743 Args:
744 manager: The context manager to enter.
745 """
746 if sys.version_info >= (3, 11):
747 return self.enterContext(manager)
749 if not self._exit_stack:
750 raise AssertionError(
751 'self._exit_stack is not set: enter_context is Py3-only; also make '
752 'sure that AbslTest.setUp() is called.')
753 return self._exit_stack.enter_context(manager)
755 @enter_context.classmethod
756 def enter_context(cls, manager): # pylint: disable=no-self-argument
757 # type: (ContextManager[_T]) -> _T
758 if sys.version_info >= (3, 11):
759 return cls.enterClassContext(manager)
761 if not cls._cls_exit_stack:
762 raise AssertionError(
763 'cls._cls_exit_stack is not set: cls.enter_context requires '
764 'Python 3.8+; also make sure that AbslTest.setUpClass() is called.')
765 return cls._cls_exit_stack.enter_context(manager)
767 @classmethod
768 def _get_tempdir_path_cls(cls):
769 # type: () -> Text
770 return os.path.join(TEST_TMPDIR.value,
771 cls.__qualname__.replace('__main__.', ''))
773 def _get_tempdir_path_test(self):
774 # type: () -> Text
775 return os.path.join(self._get_tempdir_path_cls(), self._testMethodName)
777 def _get_tempfile_cleanup(self, override):
778 # type: (Optional[TempFileCleanup]) -> TempFileCleanup
779 if override is not None:
780 return override
781 return self.tempfile_cleanup
783 def _maybe_add_temp_path_cleanup(self, path, cleanup):
784 # type: (Text, Optional[TempFileCleanup]) -> None
785 cleanup = self._get_tempfile_cleanup(cleanup)
786 if cleanup == TempFileCleanup.OFF:
787 return
788 elif cleanup == TempFileCleanup.ALWAYS:
789 self.addCleanup(_rmtree_ignore_errors, path)
790 elif cleanup == TempFileCleanup.SUCCESS:
791 self._internal_add_cleanup_on_success(_rmtree_ignore_errors, path)
792 else:
793 raise AssertionError('Unexpected cleanup value: {}'.format(cleanup))
795 def _internal_add_cleanup_on_success(
796 self,
797 function: Callable[..., Any],
798 *args: Any,
799 **kwargs: Any,
800 ) -> None:
801 """Adds `function` as cleanup when the test case succeeds."""
802 outcome = self._outcome
803 assert outcome is not None
804 previous_failure_count = (
805 len(outcome.result.failures)
806 + len(outcome.result.errors)
807 + len(outcome.result.unexpectedSuccesses)
808 )
809 def _call_cleaner_on_success(*args, **kwargs):
810 if not self._internal_ran_and_passed_when_called_during_cleanup(
811 previous_failure_count):
812 return
813 function(*args, **kwargs)
814 self.addCleanup(_call_cleaner_on_success, *args, **kwargs)
816 def _internal_ran_and_passed_when_called_during_cleanup(
817 self,
818 previous_failure_count: int,
819 ) -> bool:
820 """Returns whether test is passed. Expected to be called during cleanup."""
821 outcome = self._outcome
822 if sys.version_info[:2] >= (3, 11):
823 assert outcome is not None
824 current_failure_count = (
825 len(outcome.result.failures)
826 + len(outcome.result.errors)
827 + len(outcome.result.unexpectedSuccesses)
828 )
829 return current_failure_count == previous_failure_count
830 else:
831 # Before Python 3.11 https://github.com/python/cpython/pull/28180, errors
832 # were bufferred in _Outcome before calling cleanup.
833 result = self.defaultTestResult()
834 self._feedErrorsToResult(result, outcome.errors) # pytype: disable=attribute-error
835 return result.wasSuccessful()
837 def shortDescription(self):
838 # type: () -> Text
839 """Formats both the test method name and the first line of its docstring.
841 If no docstring is given, only returns the method name.
843 This method overrides unittest.TestCase.shortDescription(), which
844 only returns the first line of the docstring, obscuring the name
845 of the test upon failure.
847 Returns:
848 desc: A short description of a test method.
849 """
850 desc = self.id()
852 # Omit the main name so that test name can be directly copy/pasted to
853 # the command line.
854 if desc.startswith('__main__.'):
855 desc = desc[len('__main__.'):]
857 # NOTE: super() is used here instead of directly invoking
858 # unittest.TestCase.shortDescription(self), because of the
859 # following line that occurs later on:
860 # unittest.TestCase = TestCase
861 # Because of this, direct invocation of what we think is the
862 # superclass will actually cause infinite recursion.
863 doc_first_line = super(TestCase, self).shortDescription()
864 if doc_first_line is not None:
865 desc = '\n'.join((desc, doc_first_line))
866 return desc
868 def assertStartsWith(self, actual, expected_start, msg=None):
869 """Asserts that actual.startswith(expected_start) is True.
871 Args:
872 actual: str
873 expected_start: str
874 msg: Optional message to report on failure.
875 """
876 if not actual.startswith(expected_start):
877 self.fail('%r does not start with %r' % (actual, expected_start), msg)
879 def assertNotStartsWith(self, actual, unexpected_start, msg=None):
880 """Asserts that actual.startswith(unexpected_start) is False.
882 Args:
883 actual: str
884 unexpected_start: str
885 msg: Optional message to report on failure.
886 """
887 if actual.startswith(unexpected_start):
888 self.fail('%r does start with %r' % (actual, unexpected_start), msg)
890 def assertEndsWith(self, actual, expected_end, msg=None):
891 """Asserts that actual.endswith(expected_end) is True.
893 Args:
894 actual: str
895 expected_end: str
896 msg: Optional message to report on failure.
897 """
898 if not actual.endswith(expected_end):
899 self.fail('%r does not end with %r' % (actual, expected_end), msg)
901 def assertNotEndsWith(self, actual, unexpected_end, msg=None):
902 """Asserts that actual.endswith(unexpected_end) is False.
904 Args:
905 actual: str
906 unexpected_end: str
907 msg: Optional message to report on failure.
908 """
909 if actual.endswith(unexpected_end):
910 self.fail('%r does end with %r' % (actual, unexpected_end), msg)
912 def assertSequenceStartsWith(self, prefix, whole, msg=None):
913 """An equality assertion for the beginning of ordered sequences.
915 If prefix is an empty sequence, it will raise an error unless whole is also
916 an empty sequence.
918 If prefix is not a sequence, it will raise an error if the first element of
919 whole does not match.
921 Args:
922 prefix: A sequence expected at the beginning of the whole parameter.
923 whole: The sequence in which to look for prefix.
924 msg: Optional message to report on failure.
925 """
926 try:
927 prefix_len = len(prefix)
928 except (TypeError, NotImplementedError):
929 prefix = [prefix]
930 prefix_len = 1
932 if isinstance(whole, abc.Mapping) or isinstance(whole, abc.Set):
933 self.fail(
934 'For whole: Mapping or Set objects are not supported, found type: %s'
935 % type(whole),
936 msg,
937 )
938 try:
939 whole_len = len(whole)
940 except (TypeError, NotImplementedError):
941 self.fail('For whole: len(%s) is not supported, it appears to be type: '
942 '%s' % (whole, type(whole)), msg)
944 assert prefix_len <= whole_len, self._formatMessage(
945 msg,
946 'Prefix length (%d) is longer than whole length (%d).' %
947 (prefix_len, whole_len)
948 )
950 if not prefix_len and whole_len:
951 self.fail('Prefix length is 0 but whole length is %d: %s' %
952 (len(whole), whole), msg)
954 try:
955 self.assertSequenceEqual(prefix, whole[:prefix_len], msg)
956 except AssertionError:
957 self.fail('prefix: %s not found at start of whole: %s.' %
958 (prefix, whole), msg)
960 def assertEmpty(self, container, msg=None):
961 """Asserts that an object has zero length.
963 Args:
964 container: Anything that implements the collections.abc.Sized interface.
965 msg: Optional message to report on failure.
966 """
967 if not isinstance(container, abc.Sized):
968 self.fail('Expected a Sized object, got: '
969 '{!r}'.format(type(container).__name__), msg)
971 # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
972 # have strange __nonzero__/__bool__ behavior.
973 if len(container): # pylint: disable=g-explicit-length-test
974 self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
976 def assertNotEmpty(self, container, msg=None):
977 """Asserts that an object has non-zero length.
979 Args:
980 container: Anything that implements the collections.abc.Sized interface.
981 msg: Optional message to report on failure.
982 """
983 if not isinstance(container, abc.Sized):
984 self.fail('Expected a Sized object, got: '
985 '{!r}'.format(type(container).__name__), msg)
987 # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
988 # have strange __nonzero__/__bool__ behavior.
989 if not len(container): # pylint: disable=g-explicit-length-test
990 self.fail('{!r} has length of 0.'.format(container), msg)
992 def assertLen(self, container, expected_len, msg=None):
993 """Asserts that an object has the expected length.
995 Args:
996 container: Anything that implements the collections.abc.Sized interface.
997 expected_len: The expected length of the container.
998 msg: Optional message to report on failure.
999 """
1000 if not isinstance(container, abc.Sized):
1001 self.fail('Expected a Sized object, got: '
1002 '{!r}'.format(type(container).__name__), msg)
1003 if len(container) != expected_len:
1004 container_repr = unittest.util.safe_repr(container) # pytype: disable=module-attr
1005 self.fail('{} has length of {}, expected {}.'.format(
1006 container_repr, len(container), expected_len), msg)
1008 def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None,
1009 msg=None, delta=None):
1010 """An approximate equality assertion for ordered sequences.
1012 Fail if the two sequences are unequal as determined by their value
1013 differences rounded to the given number of decimal places (default 7) and
1014 comparing to zero, or by comparing that the difference between each value
1015 in the two sequences is more than the given delta.
1017 Note that decimal places (from zero) are usually not the same as significant
1018 digits (measured from the most significant digit).
1020 If the two sequences compare equal then they will automatically compare
1021 almost equal.
1023 Args:
1024 expected_seq: A sequence containing elements we are expecting.
1025 actual_seq: The sequence that we are testing.
1026 places: The number of decimal places to compare.
1027 msg: The message to be printed if the test fails.
1028 delta: The OK difference between compared values.
1029 """
1030 if len(expected_seq) != len(actual_seq):
1031 self.fail('Sequence size mismatch: {} vs {}'.format(
1032 len(expected_seq), len(actual_seq)), msg)
1034 err_list = []
1035 for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)):
1036 try:
1037 # assertAlmostEqual should be called with at most one of `places` and
1038 # `delta`. However, it's okay for assertSequenceAlmostEqual to pass
1039 # both because we want the latter to fail if the former does.
1040 # pytype: disable=wrong-keyword-args
1041 self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg,
1042 delta=delta)
1043 # pytype: enable=wrong-keyword-args
1044 except self.failureException as err:
1045 err_list.append('At index {}: {}'.format(idx, err))
1047 if err_list:
1048 if len(err_list) > 30:
1049 err_list = err_list[:30] + ['...']
1050 msg = self._formatMessage(msg, '\n'.join(err_list))
1051 self.fail(msg)
1053 def assertContainsSubset(self, expected_subset, actual_set, msg=None):
1054 """Checks whether actual iterable is a superset of expected iterable."""
1055 missing = set(expected_subset) - set(actual_set)
1056 if not missing:
1057 return
1059 self.fail('Missing elements %s\nExpected: %s\nActual: %s' % (
1060 missing, expected_subset, actual_set), msg)
1062 def assertNoCommonElements(self, expected_seq, actual_seq, msg=None):
1063 """Checks whether actual iterable and expected iterable are disjoint."""
1064 common = set(expected_seq) & set(actual_seq)
1065 if not common:
1066 return
1068 self.fail('Common elements %s\nExpected: %s\nActual: %s' % (
1069 common, expected_seq, actual_seq), msg)
1071 def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
1072 """Deprecated, please use assertCountEqual instead.
1074 This is equivalent to assertCountEqual.
1076 Args:
1077 expected_seq: A sequence containing elements we are expecting.
1078 actual_seq: The sequence that we are testing.
1079 msg: The message to be printed if the test fails.
1080 """
1081 super().assertCountEqual(expected_seq, actual_seq, msg)
1083 def assertSameElements(self, expected_seq, actual_seq, msg=None):
1084 """Asserts that two sequences have the same elements (in any order).
1086 This method, unlike assertCountEqual, doesn't care about any
1087 duplicates in the expected and actual sequences::
1089 # Doesn't raise an AssertionError
1090 assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
1092 If possible, you should use assertCountEqual instead of
1093 assertSameElements.
1095 Args:
1096 expected_seq: A sequence containing elements we are expecting.
1097 actual_seq: The sequence that we are testing.
1098 msg: The message to be printed if the test fails.
1099 """
1100 # `unittest2.TestCase` used to have assertSameElements, but it was
1101 # removed in favor of assertItemsEqual. As there's a unit test
1102 # that explicitly checks this behavior, I am leaving this method
1103 # alone.
1104 # Fail on strings: empirically, passing strings to this test method
1105 # is almost always a bug. If comparing the character sets of two strings
1106 # is desired, cast the inputs to sets or lists explicitly.
1107 if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or
1108 isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)):
1109 self.fail('Passing string/bytes to assertSameElements is usually a bug. '
1110 'Did you mean to use assertEqual?\n'
1111 'Expected: %s\nActual: %s' % (expected_seq, actual_seq))
1112 try:
1113 expected = dict([(element, None) for element in expected_seq])
1114 actual = dict([(element, None) for element in actual_seq])
1115 missing = [element for element in expected if element not in actual]
1116 unexpected = [element for element in actual if element not in expected]
1117 missing.sort()
1118 unexpected.sort()
1119 except TypeError:
1120 # Fall back to slower list-compare if any of the objects are
1121 # not hashable.
1122 expected = list(expected_seq)
1123 actual = list(actual_seq)
1124 expected.sort()
1125 actual.sort()
1126 missing, unexpected = _sorted_list_difference(expected, actual)
1127 errors = []
1128 if msg:
1129 errors.extend((msg, ':\n'))
1130 if missing:
1131 errors.append('Expected, but missing:\n %r\n' % missing)
1132 if unexpected:
1133 errors.append('Unexpected, but present:\n %r\n' % unexpected)
1134 if missing or unexpected:
1135 self.fail(''.join(errors))
1137 # unittest.TestCase.assertMultiLineEqual works very similarly, but it
1138 # has a different error format. However, I find this slightly more readable.
1139 def assertMultiLineEqual(self, first, second, msg=None, **kwargs):
1140 """Asserts that two multi-line strings are equal."""
1141 assert isinstance(first,
1142 str), ('First argument is not a string: %r' % (first,))
1143 assert isinstance(second,
1144 str), ('Second argument is not a string: %r' % (second,))
1145 line_limit = kwargs.pop('line_limit', 0)
1146 if kwargs:
1147 raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs)))
1149 if first == second:
1150 return
1151 if msg:
1152 failure_message = [msg + ':\n']
1153 else:
1154 failure_message = ['\n']
1155 if line_limit:
1156 line_limit += len(failure_message)
1157 for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)):
1158 failure_message.append(line)
1159 if not line.endswith('\n'):
1160 failure_message.append('\n')
1161 if line_limit and len(failure_message) > line_limit:
1162 n_omitted = len(failure_message) - line_limit
1163 failure_message = failure_message[:line_limit]
1164 failure_message.append(
1165 '(... and {} more delta lines omitted for brevity.)\n'.format(
1166 n_omitted))
1168 raise self.failureException(''.join(failure_message))
1170 def assertBetween(self, value, minv, maxv, msg=None):
1171 """Asserts that value is between minv and maxv (inclusive)."""
1172 msg = self._formatMessage(msg,
1173 '"%r" unexpectedly not between "%r" and "%r"' %
1174 (value, minv, maxv))
1175 self.assertTrue(minv <= value, msg)
1176 self.assertTrue(maxv >= value, msg)
1178 def assertRegexMatch(self, actual_str, regexes, message=None):
1179 r"""Asserts that at least one regex in regexes matches str.
1181 If possible you should use `assertRegex`, which is a simpler
1182 version of this method. `assertRegex` takes a single regular
1183 expression (a string or re compiled object) instead of a list.
1185 Notes:
1187 1. This function uses substring matching, i.e. the matching
1188 succeeds if *any* substring of the error message matches *any*
1189 regex in the list. This is more convenient for the user than
1190 full-string matching.
1192 2. If regexes is the empty list, the matching will always fail.
1194 3. Use regexes=[''] for a regex that will always pass.
1196 4. '.' matches any single character *except* the newline. To
1197 match any character, use '(.|\n)'.
1199 5. '^' matches the beginning of each line, not just the beginning
1200 of the string. Similarly, '$' matches the end of each line.
1202 6. An exception will be thrown if regexes contains an invalid
1203 regex.
1205 Args:
1206 actual_str: The string we try to match with the items in regexes.
1207 regexes: The regular expressions we want to match against str.
1208 See "Notes" above for detailed notes on how this is interpreted.
1209 message: The message to be printed if the test fails.
1210 """
1211 if isinstance(regexes, _TEXT_OR_BINARY_TYPES):
1212 self.fail('regexes is string or bytes; use assertRegex instead.',
1213 message)
1214 if not regexes:
1215 self.fail('No regexes specified.', message)
1217 regex_type = type(regexes[0])
1218 for regex in regexes[1:]:
1219 if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck
1220 self.fail('regexes list must all be the same type.', message)
1222 if regex_type is bytes and isinstance(actual_str, str):
1223 regexes = [regex.decode('utf-8') for regex in regexes]
1224 regex_type = str
1225 elif regex_type is str and isinstance(actual_str, bytes):
1226 regexes = [regex.encode('utf-8') for regex in regexes]
1227 regex_type = bytes
1229 if regex_type is str:
1230 regex = u'(?:%s)' % u')|(?:'.join(regexes)
1231 elif regex_type is bytes:
1232 regex = b'(?:' + (b')|(?:'.join(regexes)) + b')'
1233 else:
1234 self.fail('Only know how to deal with unicode str or bytes regexes.',
1235 message)
1237 if not re.search(regex, actual_str, re.MULTILINE):
1238 self.fail('"%s" does not contain any of these regexes: %s.' %
1239 (actual_str, regexes), message)
1241 def assertCommandSucceeds(self, command, regexes=(b'',), env=None,
1242 close_fds=True, msg=None):
1243 """Asserts that a shell command succeeds (i.e. exits with code 0).
1245 Args:
1246 command: List or string representing the command to run.
1247 regexes: List of regular expression byte strings that match success.
1248 env: Dictionary of environment variable settings. If None, no environment
1249 variables will be set for the child process. This is to make tests
1250 more hermetic. NOTE: this behavior is different than the standard
1251 subprocess module.
1252 close_fds: Whether or not to close all open fd's in the child after
1253 forking.
1254 msg: Optional message to report on failure.
1255 """
1256 (ret_code, err) = get_command_stderr(command, env, close_fds)
1258 # We need bytes regexes here because `err` is bytes.
1259 # Accommodate code which listed their output regexes w/o the b'' prefix by
1260 # converting them to bytes for the user.
1261 if isinstance(regexes[0], str):
1262 regexes = [regex.encode('utf-8') for regex in regexes]
1264 command_string = get_command_string(command)
1265 self.assertEqual(
1266 ret_code, 0,
1267 self._formatMessage(msg,
1268 'Running command\n'
1269 '%s failed with error code %s and message\n'
1270 '%s' % (_quote_long_string(command_string),
1271 ret_code,
1272 _quote_long_string(err)))
1273 )
1274 self.assertRegexMatch(
1275 err,
1276 regexes,
1277 message=self._formatMessage(
1278 msg,
1279 'Running command\n'
1280 '%s failed with error code %s and message\n'
1281 '%s which matches no regex in %s' % (
1282 _quote_long_string(command_string),
1283 ret_code,
1284 _quote_long_string(err),
1285 regexes)))
1287 def assertCommandFails(self, command, regexes, env=None, close_fds=True,
1288 msg=None):
1289 """Asserts a shell command fails and the error matches a regex in a list.
1291 Args:
1292 command: List or string representing the command to run.
1293 regexes: the list of regular expression strings.
1294 env: Dictionary of environment variable settings. If None, no environment
1295 variables will be set for the child process. This is to make tests
1296 more hermetic. NOTE: this behavior is different than the standard
1297 subprocess module.
1298 close_fds: Whether or not to close all open fd's in the child after
1299 forking.
1300 msg: Optional message to report on failure.
1301 """
1302 (ret_code, err) = get_command_stderr(command, env, close_fds)
1304 # We need bytes regexes here because `err` is bytes.
1305 # Accommodate code which listed their output regexes w/o the b'' prefix by
1306 # converting them to bytes for the user.
1307 if isinstance(regexes[0], str):
1308 regexes = [regex.encode('utf-8') for regex in regexes]
1310 command_string = get_command_string(command)
1311 self.assertNotEqual(
1312 ret_code, 0,
1313 self._formatMessage(msg, 'The following command succeeded '
1314 'while expected to fail:\n%s' %
1315 _quote_long_string(command_string)))
1316 self.assertRegexMatch(
1317 err,
1318 regexes,
1319 message=self._formatMessage(
1320 msg,
1321 'Running command\n'
1322 '%s failed with error code %s and message\n'
1323 '%s which matches no regex in %s' % (
1324 _quote_long_string(command_string),
1325 ret_code,
1326 _quote_long_string(err),
1327 regexes)))
1329 class _AssertRaisesContext(object):
1331 def __init__(self, expected_exception, test_case, test_func, msg=None):
1332 self.expected_exception = expected_exception
1333 self.test_case = test_case
1334 self.test_func = test_func
1335 self.msg = msg
1337 def __enter__(self):
1338 return self
1340 def __exit__(self, exc_type, exc_value, tb):
1341 if exc_type is None:
1342 self.test_case.fail(self.expected_exception.__name__ + ' not raised',
1343 self.msg)
1344 if not issubclass(exc_type, self.expected_exception):
1345 return False
1346 self.test_func(exc_value)
1347 if exc_value:
1348 self.exception = exc_value.with_traceback(None)
1349 return True
1351 @typing.overload
1352 def assertRaisesWithPredicateMatch(
1353 self, expected_exception, predicate) -> _AssertRaisesContext:
1354 # The purpose of this return statement is to work around
1355 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1356 return self._AssertRaisesContext(None, None, None)
1358 @typing.overload
1359 def assertRaisesWithPredicateMatch(
1360 self, expected_exception, predicate, callable_obj: Callable[..., Any],
1361 *args, **kwargs) -> None:
1362 # The purpose of this return statement is to work around
1363 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1364 return self._AssertRaisesContext(None, None, None)
1366 def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
1367 callable_obj=None, *args, **kwargs):
1368 """Asserts that exception is thrown and predicate(exception) is true.
1370 Args:
1371 expected_exception: Exception class expected to be raised.
1372 predicate: Function of one argument that inspects the passed-in exception
1373 and returns True (success) or False (please fail the test).
1374 callable_obj: Function to be called.
1375 *args: Extra args.
1376 **kwargs: Extra keyword args.
1378 Returns:
1379 A context manager if callable_obj is None. Otherwise, None.
1381 Raises:
1382 self.failureException if callable_obj does not raise a matching exception.
1383 """
1384 def Check(err):
1385 self.assertTrue(predicate(err),
1386 '%r does not match predicate %r' % (err, predicate))
1388 context = self._AssertRaisesContext(expected_exception, self, Check)
1389 if callable_obj is None:
1390 return context
1391 with context:
1392 callable_obj(*args, **kwargs)
1394 @typing.overload
1395 def assertRaisesWithLiteralMatch(
1396 self, expected_exception, expected_exception_message
1397 ) -> _AssertRaisesContext:
1398 # The purpose of this return statement is to work around
1399 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1400 return self._AssertRaisesContext(None, None, None)
1402 @typing.overload
1403 def assertRaisesWithLiteralMatch(
1404 self, expected_exception, expected_exception_message,
1405 callable_obj: Callable[..., Any], *args, **kwargs) -> None:
1406 # The purpose of this return statement is to work around
1407 # https://github.com/PyCQA/pylint/issues/5273; it is otherwise ignored.
1408 return self._AssertRaisesContext(None, None, None)
1410 def assertRaisesWithLiteralMatch(self, expected_exception,
1411 expected_exception_message,
1412 callable_obj=None, *args, **kwargs):
1413 """Asserts that the message in a raised exception equals the given string.
1415 Unlike assertRaisesRegex, this method takes a literal string, not
1416 a regular expression.
1418 with self.assertRaisesWithLiteralMatch(ExType, 'message'):
1419 DoSomething()
1421 Args:
1422 expected_exception: Exception class expected to be raised.
1423 expected_exception_message: String message expected in the raised
1424 exception. For a raise exception e, expected_exception_message must
1425 equal str(e).
1426 callable_obj: Function to be called, or None to return a context.
1427 *args: Extra args.
1428 **kwargs: Extra kwargs.
1430 Returns:
1431 A context manager if callable_obj is None. Otherwise, None.
1433 Raises:
1434 self.failureException if callable_obj does not raise a matching exception.
1435 """
1436 def Check(err):
1437 actual_exception_message = str(err)
1438 self.assertTrue(expected_exception_message == actual_exception_message,
1439 'Exception message does not match.\n'
1440 'Expected: %r\n'
1441 'Actual: %r' % (expected_exception_message,
1442 actual_exception_message))
1444 context = self._AssertRaisesContext(expected_exception, self, Check)
1445 if callable_obj is None:
1446 return context
1447 with context:
1448 callable_obj(*args, **kwargs)
1450 def assertContainsInOrder(self, strings, target, msg=None):
1451 """Asserts that the strings provided are found in the target in order.
1453 This may be useful for checking HTML output.
1455 Args:
1456 strings: A list of strings, such as [ 'fox', 'dog' ]
1457 target: A target string in which to look for the strings, such as
1458 'The quick brown fox jumped over the lazy dog'.
1459 msg: Optional message to report on failure.
1460 """
1461 if isinstance(strings, (bytes, unicode if str is bytes else str)):
1462 strings = (strings,)
1464 current_index = 0
1465 last_string = None
1466 for string in strings:
1467 index = target.find(str(string), current_index)
1468 if index == -1 and current_index == 0:
1469 self.fail("Did not find '%s' in '%s'" %
1470 (string, target), msg)
1471 elif index == -1:
1472 self.fail("Did not find '%s' after '%s' in '%s'" %
1473 (string, last_string, target), msg)
1474 last_string = string
1475 current_index = index
1477 def assertContainsSubsequence(self, container, subsequence, msg=None):
1478 """Asserts that "container" contains "subsequence" as a subsequence.
1480 Asserts that "container" contains all the elements of "subsequence", in
1481 order, but possibly with other elements interspersed. For example, [1, 2, 3]
1482 is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
1484 Args:
1485 container: the list we're testing for subsequence inclusion.
1486 subsequence: the list we hope will be a subsequence of container.
1487 msg: Optional message to report on failure.
1488 """
1489 first_nonmatching = None
1490 reversed_container = list(reversed(container))
1491 subsequence = list(subsequence)
1493 for e in subsequence:
1494 if e not in reversed_container:
1495 first_nonmatching = e
1496 break
1497 while e != reversed_container.pop():
1498 pass
1500 if first_nonmatching is not None:
1501 self.fail('%s not a subsequence of %s. First non-matching element: %s' %
1502 (subsequence, container, first_nonmatching), msg)
1504 def assertContainsExactSubsequence(self, container, subsequence, msg=None):
1505 """Asserts that "container" contains "subsequence" as an exact subsequence.
1507 Asserts that "container" contains all the elements of "subsequence", in
1508 order, and without other elements interspersed. For example, [1, 2, 3] is an
1509 exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
1511 Args:
1512 container: the list we're testing for subsequence inclusion.
1513 subsequence: the list we hope will be an exact subsequence of container.
1514 msg: Optional message to report on failure.
1515 """
1516 container = list(container)
1517 subsequence = list(subsequence)
1518 longest_match = 0
1520 for start in range(1 + len(container) - len(subsequence)):
1521 if longest_match == len(subsequence):
1522 break
1523 index = 0
1524 while (index < len(subsequence) and
1525 subsequence[index] == container[start + index]):
1526 index += 1
1527 longest_match = max(longest_match, index)
1529 if longest_match < len(subsequence):
1530 self.fail('%s not an exact subsequence of %s. '
1531 'Longest matching prefix: %s' %
1532 (subsequence, container, subsequence[:longest_match]), msg)
1534 def assertTotallyOrdered(self, *groups, **kwargs):
1535 """Asserts that total ordering has been implemented correctly.
1537 For example, say you have a class A that compares only on its attribute x.
1538 Comparators other than ``__lt__`` are omitted for brevity::
1540 class A(object):
1541 def __init__(self, x, y):
1542 self.x = x
1543 self.y = y
1545 def __hash__(self):
1546 return hash(self.x)
1548 def __lt__(self, other):
1549 try:
1550 return self.x < other.x
1551 except AttributeError:
1552 return NotImplemented
1554 assertTotallyOrdered will check that instances can be ordered correctly.
1555 For example::
1557 self.assertTotallyOrdered(
1558 [None], # None should come before everything else.
1559 [1], # Integers sort earlier.
1560 [A(1, 'a')],
1561 [A(2, 'b')], # 2 is after 1.
1562 [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
1563 [A(4, 'z')],
1564 ['foo']) # Strings sort last.
1566 Args:
1567 *groups: A list of groups of elements. Each group of elements is a list
1568 of objects that are equal. The elements in each group must be less
1569 than the elements in the group after it. For example, these groups are
1570 totally ordered: ``[None]``, ``[1]``, ``[2, 2]``, ``[3]``.
1571 **kwargs: optional msg keyword argument can be passed.
1572 """
1574 def CheckOrder(small, big):
1575 """Ensures small is ordered before big."""
1576 self.assertFalse(small == big,
1577 self._formatMessage(msg, '%r unexpectedly equals %r' %
1578 (small, big)))
1579 self.assertTrue(small != big,
1580 self._formatMessage(msg, '%r unexpectedly equals %r' %
1581 (small, big)))
1582 self.assertLess(small, big, msg)
1583 self.assertFalse(big < small,
1584 self._formatMessage(msg,
1585 '%r unexpectedly less than %r' %
1586 (big, small)))
1587 self.assertLessEqual(small, big, msg)
1588 self.assertFalse(big <= small, self._formatMessage(
1589 '%r unexpectedly less than or equal to %r' % (big, small), msg
1590 ))
1591 self.assertGreater(big, small, msg)
1592 self.assertFalse(small > big,
1593 self._formatMessage(msg,
1594 '%r unexpectedly greater than %r' %
1595 (small, big)))
1596 self.assertGreaterEqual(big, small)
1597 self.assertFalse(small >= big, self._formatMessage(
1598 msg,
1599 '%r unexpectedly greater than or equal to %r' % (small, big)))
1601 def CheckEqual(a, b):
1602 """Ensures that a and b are equal."""
1603 self.assertEqual(a, b, msg)
1604 self.assertFalse(a != b,
1605 self._formatMessage(msg, '%r unexpectedly unequals %r' %
1606 (a, b)))
1608 # Objects that compare equal must hash to the same value, but this only
1609 # applies if both objects are hashable.
1610 if (isinstance(a, abc.Hashable) and
1611 isinstance(b, abc.Hashable)):
1612 self.assertEqual(
1613 hash(a), hash(b),
1614 self._formatMessage(
1615 msg, 'hash %d of %r unexpectedly not equal to hash %d of %r' %
1616 (hash(a), a, hash(b), b)))
1618 self.assertFalse(a < b,
1619 self._formatMessage(msg,
1620 '%r unexpectedly less than %r' %
1621 (a, b)))
1622 self.assertFalse(b < a,
1623 self._formatMessage(msg,
1624 '%r unexpectedly less than %r' %
1625 (b, a)))
1626 self.assertLessEqual(a, b, msg)
1627 self.assertLessEqual(b, a, msg) # pylint: disable=arguments-out-of-order
1628 self.assertFalse(a > b,
1629 self._formatMessage(msg,
1630 '%r unexpectedly greater than %r' %
1631 (a, b)))
1632 self.assertFalse(b > a,
1633 self._formatMessage(msg,
1634 '%r unexpectedly greater than %r' %
1635 (b, a)))
1636 self.assertGreaterEqual(a, b, msg)
1637 self.assertGreaterEqual(b, a, msg) # pylint: disable=arguments-out-of-order
1639 msg = kwargs.get('msg')
1641 # For every combination of elements, check the order of every pair of
1642 # elements.
1643 for elements in itertools.product(*groups):
1644 elements = list(elements)
1645 for index, small in enumerate(elements[:-1]):
1646 for big in elements[index + 1:]:
1647 CheckOrder(small, big)
1649 # Check that every element in each group is equal.
1650 for group in groups:
1651 for a in group:
1652 CheckEqual(a, a)
1653 for a, b in itertools.product(group, group):
1654 CheckEqual(a, b)
1656 def assertDictEqual(self, a, b, msg=None):
1657 """Raises AssertionError if a and b are not equal dictionaries.
1659 Args:
1660 a: A dict, the expected value.
1661 b: A dict, the actual value.
1662 msg: An optional str, the associated message.
1664 Raises:
1665 AssertionError: if the dictionaries are not equal.
1666 """
1667 self.assertIsInstance(a, dict, self._formatMessage(
1668 msg,
1669 'First argument is not a dictionary'
1670 ))
1671 self.assertIsInstance(b, dict, self._formatMessage(
1672 msg,
1673 'Second argument is not a dictionary'
1674 ))
1676 def Sorted(list_of_items):
1677 try:
1678 return sorted(list_of_items) # In 3.3, unordered are possible.
1679 except TypeError:
1680 return list_of_items
1682 if a == b:
1683 return
1684 a_items = Sorted(list(a.items()))
1685 b_items = Sorted(list(b.items()))
1687 unexpected = []
1688 missing = []
1689 different = []
1691 safe_repr = unittest.util.safe_repr # pytype: disable=module-attr
1693 def Repr(dikt):
1694 """Deterministic repr for dict."""
1695 # Sort the entries based on their repr, not based on their sort order,
1696 # which will be non-deterministic across executions, for many types.
1697 entries = sorted((safe_repr(k), safe_repr(v)) for k, v in dikt.items())
1698 return '{%s}' % (', '.join('%s: %s' % pair for pair in entries))
1700 message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')]
1702 # The standard library default output confounds lexical difference with
1703 # value difference; treat them separately.
1704 for a_key, a_value in a_items:
1705 if a_key not in b:
1706 missing.append((a_key, a_value))
1707 elif a_value != b[a_key]:
1708 different.append((a_key, a_value, b[a_key]))
1710 for b_key, b_value in b_items:
1711 if b_key not in a:
1712 unexpected.append((b_key, b_value))
1714 if unexpected:
1715 message.append(
1716 'Unexpected, but present entries:\n%s' % ''.join(
1717 '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected))
1719 if different:
1720 message.append(
1721 'repr() of differing entries:\n%s' % ''.join(
1722 '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value),
1723 safe_repr(b_value))
1724 for k, a_value, b_value in different))
1726 if missing:
1727 message.append(
1728 'Missing entries:\n%s' % ''.join(
1729 ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing)))
1731 raise self.failureException('\n'.join(message))
1733 def assertUrlEqual(self, a, b, msg=None):
1734 """Asserts that urls are equal, ignoring ordering of query params."""
1735 parsed_a = parse.urlparse(a)
1736 parsed_b = parse.urlparse(b)
1737 self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg)
1738 self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg)
1739 self.assertEqual(parsed_a.path, parsed_b.path, msg)
1740 self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg)
1741 self.assertEqual(sorted(parsed_a.params.split(';')),
1742 sorted(parsed_b.params.split(';')), msg)
1743 self.assertDictEqual(
1744 parse.parse_qs(parsed_a.query, keep_blank_values=True),
1745 parse.parse_qs(parsed_b.query, keep_blank_values=True), msg)
1747 def assertSameStructure(self, a, b, aname='a', bname='b', msg=None):
1748 """Asserts that two values contain the same structural content.
1750 The two arguments should be data trees consisting of trees of dicts and
1751 lists. They will be deeply compared by walking into the contents of dicts
1752 and lists; other items will be compared using the == operator.
1753 If the two structures differ in content, the failure message will indicate
1754 the location within the structures where the first difference is found.
1755 This may be helpful when comparing large structures.
1757 Mixed Sequence and Set types are supported. Mixed Mapping types are
1758 supported, but the order of the keys will not be considered in the
1759 comparison.
1761 Args:
1762 a: The first structure to compare.
1763 b: The second structure to compare.
1764 aname: Variable name to use for the first structure in assertion messages.
1765 bname: Variable name to use for the second structure.
1766 msg: Additional text to include in the failure message.
1767 """
1769 # Accumulate all the problems found so we can report all of them at once
1770 # rather than just stopping at the first
1771 problems = []
1773 _walk_structure_for_problems(a, b, aname, bname, problems,
1774 self.assertEqual, self.failureException)
1776 # Avoid spamming the user toooo much
1777 if self.maxDiff is not None:
1778 max_problems_to_show = self.maxDiff // 80
1779 if len(problems) > max_problems_to_show:
1780 problems = problems[0:max_problems_to_show-1] + ['...']
1782 if problems:
1783 self.fail('; '.join(problems), msg)
1785 def assertJsonEqual(self, first, second, msg=None):
1786 """Asserts that the JSON objects defined in two strings are equal.
1788 A summary of the differences will be included in the failure message
1789 using assertSameStructure.
1791 Args:
1792 first: A string containing JSON to decode and compare to second.
1793 second: A string containing JSON to decode and compare to first.
1794 msg: Additional text to include in the failure message.
1795 """
1796 try:
1797 first_structured = json.loads(first)
1798 except ValueError as e:
1799 raise ValueError(self._formatMessage(
1800 msg,
1801 'could not decode first JSON value %s: %s' % (first, e)))
1803 try:
1804 second_structured = json.loads(second)
1805 except ValueError as e:
1806 raise ValueError(self._formatMessage(
1807 msg,
1808 'could not decode second JSON value %s: %s' % (second, e)))
1810 self.assertSameStructure(first_structured, second_structured,
1811 aname='first', bname='second', msg=msg)
1813 def _getAssertEqualityFunc(self, first, second):
1814 # type: (Any, Any) -> Callable[..., None]
1815 try:
1816 return super(TestCase, self)._getAssertEqualityFunc(first, second)
1817 except AttributeError:
1818 # This is a workaround if unittest.TestCase.__init__ was never run.
1819 # It usually means that somebody created a subclass just for the
1820 # assertions and has overridden __init__. "assertTrue" is a safe
1821 # value that will not make __init__ raise a ValueError.
1822 test_method = getattr(self, '_testMethodName', 'assertTrue')
1823 super(TestCase, self).__init__(test_method)
1825 return super(TestCase, self)._getAssertEqualityFunc(first, second)
1827 def fail(self, msg=None, user_msg=None) -> NoReturn:
1828 """Fail immediately with the given standard message and user message."""
1829 return super(TestCase, self).fail(self._formatMessage(user_msg, msg))
1832def _sorted_list_difference(expected, actual):
1833 # type: (List[_T], List[_T]) -> Tuple[List[_T], List[_T]]
1834 """Finds elements in only one or the other of two, sorted input lists.
1836 Returns a two-element tuple of lists. The first list contains those
1837 elements in the "expected" list but not in the "actual" list, and the
1838 second contains those elements in the "actual" list but not in the
1839 "expected" list. Duplicate elements in either input list are ignored.
1841 Args:
1842 expected: The list we expected.
1843 actual: The list we actually got.
1844 Returns:
1845 (missing, unexpected)
1846 missing: items in expected that are not in actual.
1847 unexpected: items in actual that are not in expected.
1848 """
1849 i = j = 0
1850 missing = []
1851 unexpected = []
1852 while True:
1853 try:
1854 e = expected[i]
1855 a = actual[j]
1856 if e < a:
1857 missing.append(e)
1858 i += 1
1859 while expected[i] == e:
1860 i += 1
1861 elif e > a:
1862 unexpected.append(a)
1863 j += 1
1864 while actual[j] == a:
1865 j += 1
1866 else:
1867 i += 1
1868 try:
1869 while expected[i] == e:
1870 i += 1
1871 finally:
1872 j += 1
1873 while actual[j] == a:
1874 j += 1
1875 except IndexError:
1876 missing.extend(expected[i:])
1877 unexpected.extend(actual[j:])
1878 break
1879 return missing, unexpected
1882def _are_both_of_integer_type(a, b):
1883 # type: (object, object) -> bool
1884 return isinstance(a, int) and isinstance(b, int)
1887def _are_both_of_sequence_type(a, b):
1888 # type: (object, object) -> bool
1889 return isinstance(a, abc.Sequence) and isinstance(
1890 b, abc.Sequence) and not isinstance(
1891 a, _TEXT_OR_BINARY_TYPES) and not isinstance(b, _TEXT_OR_BINARY_TYPES)
1894def _are_both_of_set_type(a, b):
1895 # type: (object, object) -> bool
1896 return isinstance(a, abc.Set) and isinstance(b, abc.Set)
1899def _are_both_of_mapping_type(a, b):
1900 # type: (object, object) -> bool
1901 return isinstance(a, abc.Mapping) and isinstance(
1902 b, abc.Mapping)
1905def _walk_structure_for_problems(
1906 a, b, aname, bname, problem_list, leaf_assert_equal_func, failure_exception
1907):
1908 """The recursive comparison behind assertSameStructure."""
1909 if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck
1910 _are_both_of_integer_type(a, b) or _are_both_of_sequence_type(a, b) or
1911 _are_both_of_set_type(a, b) or _are_both_of_mapping_type(a, b)):
1912 # We do not distinguish between int and long types as 99.99% of Python 2
1913 # code should never care. They collapse into a single type in Python 3.
1914 problem_list.append('%s is a %r but %s is a %r' %
1915 (aname, type(a), bname, type(b)))
1916 # If they have different types there's no point continuing
1917 return
1919 if isinstance(a, abc.Set):
1920 for k in a:
1921 if k not in b:
1922 problem_list.append(
1923 '%s has %r but %s does not' % (aname, k, bname))
1924 for k in b:
1925 if k not in a:
1926 problem_list.append('%s lacks %r but %s has it' % (aname, k, bname))
1928 # NOTE: a or b could be a defaultdict, so we must take care that the traversal
1929 # doesn't modify the data.
1930 elif isinstance(a, abc.Mapping):
1931 for k in a:
1932 if k in b:
1933 _walk_structure_for_problems(
1934 a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k),
1935 problem_list, leaf_assert_equal_func, failure_exception)
1936 else:
1937 problem_list.append(
1938 "%s has [%r] with value %r but it's missing in %s" %
1939 (aname, k, a[k], bname))
1940 for k in b:
1941 if k not in a:
1942 problem_list.append(
1943 '%s lacks [%r] but %s has it with value %r' %
1944 (aname, k, bname, b[k]))
1946 # Strings/bytes are Sequences but we'll just do those with regular !=
1947 elif (isinstance(a, abc.Sequence) and
1948 not isinstance(a, _TEXT_OR_BINARY_TYPES)):
1949 minlen = min(len(a), len(b))
1950 for i in range(minlen):
1951 _walk_structure_for_problems(
1952 a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i),
1953 problem_list, leaf_assert_equal_func, failure_exception)
1954 for i in range(minlen, len(a)):
1955 problem_list.append('%s has [%i] with value %r but %s does not' %
1956 (aname, i, a[i], bname))
1957 for i in range(minlen, len(b)):
1958 problem_list.append('%s lacks [%i] but %s has it with value %r' %
1959 (aname, i, bname, b[i]))
1961 else:
1962 try:
1963 leaf_assert_equal_func(a, b)
1964 except failure_exception:
1965 problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b))
1968def get_command_string(command):
1969 """Returns an escaped string that can be used as a shell command.
1971 Args:
1972 command: List or string representing the command to run.
1973 Returns:
1974 A string suitable for use as a shell command.
1975 """
1976 if isinstance(command, str):
1977 return command
1978 else:
1979 if os.name == 'nt':
1980 return ' '.join(command)
1981 else:
1982 # The following is identical to Python 3's shlex.quote function.
1983 command_string = ''
1984 for word in command:
1985 # Single quote word, and replace each ' in word with '"'"'
1986 command_string += "'" + word.replace("'", "'\"'\"'") + "' "
1987 return command_string[:-1]
1990def get_command_stderr(command, env=None, close_fds=True):
1991 """Runs the given shell command and returns a tuple.
1993 Args:
1994 command: List or string representing the command to run.
1995 env: Dictionary of environment variable settings. If None, no environment
1996 variables will be set for the child process. This is to make tests
1997 more hermetic. NOTE: this behavior is different than the standard
1998 subprocess module.
1999 close_fds: Whether or not to close all open fd's in the child after forking.
2000 On Windows, this is ignored and close_fds is always False.
2002 Returns:
2003 Tuple of (exit status, text printed to stdout and stderr by the command).
2004 """
2005 if env is None: env = {}
2006 if os.name == 'nt':
2007 # Windows does not support setting close_fds to True while also redirecting
2008 # standard handles.
2009 close_fds = False
2011 use_shell = isinstance(command, str)
2012 process = subprocess.Popen(
2013 command,
2014 close_fds=close_fds,
2015 env=env,
2016 shell=use_shell,
2017 stderr=subprocess.STDOUT,
2018 stdout=subprocess.PIPE)
2019 output = process.communicate()[0]
2020 exit_status = process.wait()
2021 return (exit_status, output)
2024def _quote_long_string(s):
2025 # type: (Union[Text, bytes, bytearray]) -> Text
2026 """Quotes a potentially multi-line string to make the start and end obvious.
2028 Args:
2029 s: A string.
2031 Returns:
2032 The quoted string.
2033 """
2034 if isinstance(s, (bytes, bytearray)):
2035 try:
2036 s = s.decode('utf-8')
2037 except UnicodeDecodeError:
2038 s = str(s)
2039 return ('8<-----------\n' +
2040 s + '\n' +
2041 '----------->8\n')
2044def print_python_version():
2045 # type: () -> None
2046 # Having this in the test output logs by default helps debugging when all
2047 # you've got is the log and no other idea of which Python was used.
2048 sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: '
2049 '{1}\n'.format(
2050 sys.version_info,
2051 sys.executable if sys.executable else 'embedded.'))
2054def main(*args, **kwargs):
2055 # type: (Text, Any) -> None
2056 """Executes a set of Python unit tests.
2058 Usually this function is called without arguments, so the
2059 unittest.TestProgram instance will get created with the default settings,
2060 so it will run all test methods of all TestCase classes in the ``__main__``
2061 module.
2063 Args:
2064 *args: Positional arguments passed through to
2065 ``unittest.TestProgram.__init__``.
2066 **kwargs: Keyword arguments passed through to
2067 ``unittest.TestProgram.__init__``.
2068 """
2069 print_python_version()
2070 _run_in_app(run_tests, args, kwargs)
2073def _is_in_app_main():
2074 # type: () -> bool
2075 """Returns True iff app.run is active."""
2076 f = sys._getframe().f_back # pylint: disable=protected-access
2077 while f:
2078 if f.f_code == app.run.__code__:
2079 return True
2080 f = f.f_back
2081 return False
2084def _register_sigterm_with_faulthandler():
2085 # type: () -> None
2086 """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts."""
2087 if getattr(faulthandler, 'register', None):
2088 # faulthandler.register is not available on Windows.
2089 # faulthandler.enable() is already called by app.run.
2090 try:
2091 faulthandler.register(signal.SIGTERM, chain=True) # pytype: disable=module-attr
2092 except Exception as e: # pylint: disable=broad-except
2093 sys.stderr.write('faulthandler.register(SIGTERM) failed '
2094 '%r; ignoring.\n' % e)
2097def _run_in_app(function, args, kwargs):
2098 # type: (Callable[..., None], Sequence[Text], Mapping[Text, Any]) -> None
2099 """Executes a set of Python unit tests, ensuring app.run.
2101 This is a private function, users should call absltest.main().
2103 _run_in_app calculates argv to be the command-line arguments of this program
2104 (without the flags), sets the default of FLAGS.alsologtostderr to True,
2105 then it calls function(argv, args, kwargs), making sure that `function'
2106 will get called within app.run(). _run_in_app does this by checking whether
2107 it is called by app.run(), or by calling app.run() explicitly.
2109 The reason why app.run has to be ensured is to make sure that
2110 flags are parsed and stripped properly, and other initializations done by
2111 the app module are also carried out, no matter if absltest.run() is called
2112 from within or outside app.run().
2114 If _run_in_app is called from within app.run(), then it will reparse
2115 sys.argv and pass the result without command-line flags into the argv
2116 argument of `function'. The reason why this parsing is needed is that
2117 __main__.main() calls absltest.main() without passing its argv. So the
2118 only way _run_in_app could get to know the argv without the flags is that
2119 it reparses sys.argv.
2121 _run_in_app changes the default of FLAGS.alsologtostderr to True so that the
2122 test program's stderr will contain all the log messages unless otherwise
2123 specified on the command-line. This overrides any explicit assignment to
2124 FLAGS.alsologtostderr by the test program prior to the call to _run_in_app()
2125 (e.g. in __main__.main).
2127 Please note that _run_in_app (and the function it calls) is allowed to make
2128 changes to kwargs.
2130 Args:
2131 function: absltest.run_tests or a similar function. It will be called as
2132 function(argv, args, kwargs) where argv is a list containing the
2133 elements of sys.argv without the command-line flags.
2134 args: Positional arguments passed through to unittest.TestProgram.__init__.
2135 kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
2136 """
2137 if _is_in_app_main():
2138 _register_sigterm_with_faulthandler()
2140 # Change the default of alsologtostderr from False to True, so the test
2141 # programs's stderr will contain all the log messages.
2142 # If --alsologtostderr=false is specified in the command-line, or user
2143 # has called FLAGS.alsologtostderr = False before, then the value is kept
2144 # False.
2145 FLAGS.set_default('alsologtostderr', True)
2147 # Here we only want to get the `argv` without the flags. To avoid any
2148 # side effects of parsing flags, we temporarily stub out the `parse` method
2149 stored_parse_methods = {}
2150 noop_parse = lambda _: None
2151 for name in FLAGS:
2152 # Avoid any side effects of parsing flags.
2153 stored_parse_methods[name] = FLAGS[name].parse
2154 # This must be a separate loop since multiple flag names (short_name=) can
2155 # point to the same flag object.
2156 for name in FLAGS:
2157 FLAGS[name].parse = noop_parse
2158 try:
2159 argv = FLAGS(sys.argv)
2160 finally:
2161 for name in FLAGS:
2162 FLAGS[name].parse = stored_parse_methods[name]
2163 sys.stdout.flush()
2165 function(argv, args, kwargs)
2166 else:
2167 # Send logging to stderr. Use --alsologtostderr instead of --logtostderr
2168 # in case tests are reading their own logs.
2169 FLAGS.set_default('alsologtostderr', True)
2171 def main_function(argv):
2172 _register_sigterm_with_faulthandler()
2173 function(argv, args, kwargs)
2175 app.run(main=main_function)
2178def _is_suspicious_attribute(testCaseClass, name):
2179 # type: (Type, Text) -> bool
2180 """Returns True if an attribute is a method named like a test method."""
2181 if name.startswith('Test') and len(name) > 4 and name[4].isupper():
2182 attr = getattr(testCaseClass, name)
2183 if inspect.isfunction(attr) or inspect.ismethod(attr):
2184 args = inspect.getfullargspec(attr)
2185 return (len(args.args) == 1 and args.args[0] == 'self' and
2186 args.varargs is None and args.varkw is None and
2187 not args.kwonlyargs)
2188 return False
2191def skipThisClass(reason):
2192 # type: (Text) -> Callable[[_T], _T]
2193 """Skip tests in the decorated TestCase, but not any of its subclasses.
2195 This decorator indicates that this class should skip all its tests, but not
2196 any of its subclasses. Useful for if you want to share testMethod or setUp
2197 implementations between a number of concrete testcase classes.
2199 Example usage, showing how you can share some common test methods between
2200 subclasses. In this example, only ``BaseTest`` will be marked as skipped, and
2201 not RealTest or SecondRealTest::
2203 @absltest.skipThisClass("Shared functionality")
2204 class BaseTest(absltest.TestCase):
2205 def test_simple_functionality(self):
2206 self.assertEqual(self.system_under_test.method(), 1)
2208 class RealTest(BaseTest):
2209 def setUp(self):
2210 super().setUp()
2211 self.system_under_test = MakeSystem(argument)
2213 def test_specific_behavior(self):
2214 ...
2216 class SecondRealTest(BaseTest):
2217 def setUp(self):
2218 super().setUp()
2219 self.system_under_test = MakeSystem(other_arguments)
2221 def test_other_behavior(self):
2222 ...
2224 Args:
2225 reason: The reason we have a skip in place. For instance: 'shared test
2226 methods' or 'shared assertion methods'.
2228 Returns:
2229 Decorator function that will cause a class to be skipped.
2230 """
2231 if isinstance(reason, type):
2232 raise TypeError('Got {!r}, expected reason as string'.format(reason))
2234 def _skip_class(test_case_class):
2235 if not issubclass(test_case_class, unittest.TestCase):
2236 raise TypeError(
2237 'Decorating {!r}, expected TestCase subclass'.format(test_case_class))
2239 # Only shadow the setUpClass method if it is directly defined. If it is
2240 # in the parent class we invoke it via a super() call instead of holding
2241 # a reference to it.
2242 shadowed_setupclass = test_case_class.__dict__.get('setUpClass', None)
2244 @classmethod
2245 def replacement_setupclass(cls, *args, **kwargs):
2246 # Skip this class if it is the one that was decorated with @skipThisClass
2247 if cls is test_case_class:
2248 raise SkipTest(reason)
2249 if shadowed_setupclass:
2250 # Pass along `cls` so the MRO chain doesn't break.
2251 # The original method is a `classmethod` descriptor, which can't
2252 # be directly called, but `__func__` has the underlying function.
2253 return shadowed_setupclass.__func__(cls, *args, **kwargs)
2254 else:
2255 # Because there's no setUpClass() defined directly on test_case_class,
2256 # we call super() ourselves to continue execution of the inheritance
2257 # chain.
2258 return super(test_case_class, cls).setUpClass(*args, **kwargs)
2260 test_case_class.setUpClass = replacement_setupclass
2261 return test_case_class
2263 return _skip_class
2266class TestLoader(unittest.TestLoader):
2267 """A test loader which supports common test features.
2269 Supported features include:
2270 * Banning untested methods with test-like names: methods attached to this
2271 testCase with names starting with `Test` are ignored by the test runner,
2272 and often represent mistakenly-omitted test cases. This loader will raise
2273 a TypeError when attempting to load a TestCase with such methods.
2274 * Randomization of test case execution order (optional).
2275 """
2277 _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but
2278 is not one. This is often a bug. If you want it to be a test method,
2279 name it with 'test' in lowercase. If not, rename the method to not begin
2280 with 'Test'.""")
2282 def __init__(self, *args, **kwds):
2283 super(TestLoader, self).__init__(*args, **kwds)
2284 seed = _get_default_randomize_ordering_seed()
2285 if seed:
2286 self._randomize_ordering_seed = seed
2287 self._random = random.Random(self._randomize_ordering_seed)
2288 else:
2289 self._randomize_ordering_seed = None
2290 self._random = None
2292 def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
2293 """Validates and returns a (possibly randomized) list of test case names."""
2294 for name in dir(testCaseClass):
2295 if _is_suspicious_attribute(testCaseClass, name):
2296 raise TypeError(TestLoader._ERROR_MSG % name)
2297 names = list(super(TestLoader, self).getTestCaseNames(testCaseClass))
2298 if self._randomize_ordering_seed is not None:
2299 logging.info(
2300 'Randomizing test order with seed: %d', self._randomize_ordering_seed)
2301 logging.info(
2302 'To reproduce this order, re-run with '
2303 '--test_randomize_ordering_seed=%d', self._randomize_ordering_seed)
2304 self._random.shuffle(names)
2305 return names
2308def get_default_xml_output_filename():
2309 # type: () -> Optional[Text]
2310 if os.environ.get('XML_OUTPUT_FILE'):
2311 return os.environ['XML_OUTPUT_FILE']
2312 elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'):
2313 return os.path.join(os.path.dirname(TEST_TMPDIR.value), 'test_detail.xml')
2314 elif os.environ.get('TEST_XMLOUTPUTDIR'):
2315 return os.path.join(
2316 os.environ['TEST_XMLOUTPUTDIR'],
2317 os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
2320def _setup_filtering(argv: MutableSequence[str]) -> bool:
2321 """Implements the bazel test filtering protocol.
2323 The following environment variable is used in this method:
2325 TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest
2326 framework to use as a test filter. Its value is split with shlex, then:
2327 1. On Python 3.6 and before, split values are passed as positional
2328 arguments on argv.
2329 2. On Python 3.7+, split values are passed to unittest's `-k` flag. Tests
2330 are matched by glob patterns or substring. See
2331 https://docs.python.org/3/library/unittest.html#cmdoption-unittest-k
2333 Args:
2334 argv: the argv to mutate in-place.
2336 Returns:
2337 Whether test filtering is requested.
2338 """
2339 test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY')
2340 if argv is None or not test_filter:
2341 return False
2343 filters = shlex.split(test_filter)
2344 if sys.version_info[:2] >= (3, 7):
2345 filters = ['-k=' + test_filter for test_filter in filters]
2347 argv[1:1] = filters
2348 return True
2351def _setup_test_runner_fail_fast(argv):
2352 # type: (MutableSequence[Text]) -> None
2353 """Implements the bazel test fail fast protocol.
2355 The following environment variable is used in this method:
2357 TESTBRIDGE_TEST_RUNNER_FAIL_FAST=<1|0>
2359 If set to 1, --failfast is passed to the unittest framework to return upon
2360 first failure.
2362 Args:
2363 argv: the argv to mutate in-place.
2364 """
2366 if argv is None:
2367 return
2369 if os.environ.get('TESTBRIDGE_TEST_RUNNER_FAIL_FAST') != '1':
2370 return
2372 argv[1:1] = ['--failfast']
2375def _setup_sharding(
2376 custom_loader: Optional[unittest.TestLoader] = None,
2377) -> Tuple[unittest.TestLoader, Optional[int]]:
2378 """Implements the bazel sharding protocol.
2380 The following environment variables are used in this method:
2382 TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank
2383 file to tell the test runner that this test implements the test sharding
2384 protocol.
2386 TEST_TOTAL_SHARDS: int, if set, sharding is requested.
2388 TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies
2389 the shard index for this instance of the test process. Must satisfy:
2390 0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS.
2392 Args:
2393 custom_loader: A TestLoader to be made sharded.
2395 Returns:
2396 A tuple of ``(test_loader, shard_index)``. ``test_loader`` is for
2397 shard-filtering or the standard test loader depending on the sharding
2398 environment variables. ``shard_index`` is the shard index, or ``None`` when
2399 sharding is not used.
2400 """
2402 # It may be useful to write the shard file even if the other sharding
2403 # environment variables are not set. Test runners may use this functionality
2404 # to query whether a test binary implements the test sharding protocol.
2405 if 'TEST_SHARD_STATUS_FILE' in os.environ:
2406 try:
2407 with open(os.environ['TEST_SHARD_STATUS_FILE'], 'w') as f:
2408 f.write('')
2409 except IOError:
2410 sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
2411 % os.environ['TEST_SHARD_STATUS_FILE'])
2412 sys.exit(1)
2414 base_loader = custom_loader or TestLoader()
2415 if 'TEST_TOTAL_SHARDS' not in os.environ:
2416 # Not using sharding, use the expected test loader.
2417 return base_loader, None
2419 total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
2420 shard_index = int(os.environ['TEST_SHARD_INDEX'])
2422 if shard_index < 0 or shard_index >= total_shards:
2423 sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' %
2424 (shard_index, total_shards))
2425 sys.exit(1)
2427 # Replace the original getTestCaseNames with one that returns
2428 # the test case names for this shard.
2429 delegate_get_names = base_loader.getTestCaseNames
2431 bucket_iterator = itertools.cycle(range(total_shards))
2433 def getShardedTestCaseNames(testCaseClass):
2434 filtered_names = []
2435 # We need to sort the list of tests in order to determine which tests this
2436 # shard is responsible for; however, it's important to preserve the order
2437 # returned by the base loader, e.g. in the case of randomized test ordering.
2438 ordered_names = delegate_get_names(testCaseClass)
2439 for testcase in sorted(ordered_names):
2440 bucket = next(bucket_iterator)
2441 if bucket == shard_index:
2442 filtered_names.append(testcase)
2443 return [x for x in ordered_names if x in filtered_names]
2445 base_loader.getTestCaseNames = getShardedTestCaseNames
2446 return base_loader, shard_index
2449def _run_and_get_tests_result(
2450 argv: MutableSequence[str],
2451 args: Sequence[Any],
2452 kwargs: MutableMapping[str, Any],
2453 xml_test_runner_class: Type[unittest.TextTestRunner],
2454) -> Tuple[unittest.TestResult, bool]:
2455 """Same as run_tests, but it doesn't exit.
2457 Args:
2458 argv: sys.argv with the command-line flags removed from the front, i.e. the
2459 argv with which :func:`app.run()<absl.app.run>` has called
2460 ``__main__.main``. It is passed to
2461 ``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing.
2462 It is ignored if kwargs contains an argv entry.
2463 args: Positional arguments passed through to
2464 ``unittest.TestProgram.__init__``.
2465 kwargs: Keyword arguments passed through to
2466 ``unittest.TestProgram.__init__``.
2467 xml_test_runner_class: The type of the test runner class.
2469 Returns:
2470 A tuple of ``(test_result, fail_when_no_tests_ran)``.
2471 ``fail_when_no_tests_ran`` indicates whether the test should fail when
2472 no tests ran.
2473 """
2475 # The entry from kwargs overrides argv.
2476 argv = kwargs.pop('argv', argv)
2478 if sys.version_info[:2] >= (3, 12):
2479 # Python 3.12 unittest changed the behavior from PASS to FAIL in
2480 # https://github.com/python/cpython/pull/102051. absltest follows this.
2481 fail_when_no_tests_ran = True
2482 else:
2483 # Historically, absltest and unittest before Python 3.12 passes if no tests
2484 # ran.
2485 fail_when_no_tests_ran = False
2487 # Set up test filtering if requested in environment.
2488 if _setup_filtering(argv):
2489 # When test filtering is requested, ideally we also want to fail when no
2490 # tests ran. However, the test filters are usually done when running bazel.
2491 # When you run multiple targets, e.g. `bazel test //my_dir/...
2492 # --test_filter=MyTest`, you don't necessarily want individual tests to fail
2493 # because no tests match in that particular target.
2494 # Due to this use case, we don't fail when test filtering is requested via
2495 # the environment variable from bazel.
2496 fail_when_no_tests_ran = False
2498 # Set up --failfast as requested in environment
2499 _setup_test_runner_fail_fast(argv)
2501 # Shard the (default or custom) loader if sharding is turned on.
2502 kwargs['testLoader'], shard_index = _setup_sharding(
2503 kwargs.get('testLoader', None)
2504 )
2505 if shard_index is not None and shard_index > 0:
2506 # When sharding is requested, all the shards except the first one shall not
2507 # fail when no tests ran. This happens when the shard count is greater than
2508 # the test case count.
2509 fail_when_no_tests_ran = False
2511 # XML file name is based upon (sorted by priority):
2512 # --xml_output_file flag, XML_OUTPUT_FILE variable,
2513 # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
2514 if not FLAGS.xml_output_file:
2515 FLAGS.xml_output_file = get_default_xml_output_filename()
2516 xml_output_file = FLAGS.xml_output_file
2518 xml_buffer = None
2519 if xml_output_file:
2520 xml_output_dir = os.path.dirname(xml_output_file)
2521 if xml_output_dir and not os.path.isdir(xml_output_dir):
2522 try:
2523 os.makedirs(xml_output_dir)
2524 except OSError as e:
2525 # File exists error can occur with concurrent tests
2526 if e.errno != errno.EEXIST:
2527 raise
2528 # Fail early if we can't write to the XML output file. This is so that we
2529 # don't waste people's time running tests that will just fail anyways.
2530 with _open(xml_output_file, 'w'):
2531 pass
2533 # We can reuse testRunner if it supports XML output (e. g. by inheriting
2534 # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use
2535 # xml_reporter.TextAndXMLTestRunner.
2536 if (kwargs.get('testRunner') is not None
2537 and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')):
2538 sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting '
2539 'overrides testRunner=%r setting (possibly from --pdb)'
2540 % (kwargs['testRunner']))
2541 # Passing a class object here allows TestProgram to initialize
2542 # instances based on its kwargs and/or parsed command-line args.
2543 kwargs['testRunner'] = xml_test_runner_class
2544 if kwargs.get('testRunner') is None:
2545 kwargs['testRunner'] = xml_test_runner_class
2546 # Use an in-memory buffer (not backed by the actual file) to store the XML
2547 # report, because some tools modify the file (e.g., create a placeholder
2548 # with partial information, in case the test process crashes).
2549 xml_buffer = io.StringIO()
2550 kwargs['testRunner'].set_default_xml_stream(xml_buffer) # pytype: disable=attribute-error
2552 # If we've used a seed to randomize test case ordering, we want to record it
2553 # as a top-level attribute in the `testsuites` section of the XML output.
2554 randomize_ordering_seed = getattr(
2555 kwargs['testLoader'], '_randomize_ordering_seed', None)
2556 setter = getattr(kwargs['testRunner'], 'set_testsuites_property', None)
2557 if randomize_ordering_seed and setter:
2558 setter('test_randomize_ordering_seed', randomize_ordering_seed)
2559 elif kwargs.get('testRunner') is None:
2560 kwargs['testRunner'] = _pretty_print_reporter.TextTestRunner
2562 if FLAGS.pdb_post_mortem:
2563 runner = kwargs['testRunner']
2564 # testRunner can be a class or an instance, which must be tested for
2565 # differently.
2566 # Overriding testRunner isn't uncommon, so only enable the debugging
2567 # integration if the runner claims it does; we don't want to accidentally
2568 # clobber something on the runner.
2569 if ((isinstance(runner, type) and
2570 issubclass(runner, _pretty_print_reporter.TextTestRunner)) or
2571 isinstance(runner, _pretty_print_reporter.TextTestRunner)):
2572 runner.run_for_debugging = True
2574 # Make sure tmpdir exists.
2575 if not os.path.isdir(TEST_TMPDIR.value):
2576 try:
2577 os.makedirs(TEST_TMPDIR.value)
2578 except OSError as e:
2579 # Concurrent test might have created the directory.
2580 if e.errno != errno.EEXIST:
2581 raise
2583 # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v',
2584 # on argv, which is sys.argv without the command-line flags.
2585 kwargs['argv'] = argv
2587 # Request unittest.TestProgram to not exit. The exit will be handled by
2588 # `absltest.run_tests`.
2589 kwargs['exit'] = False
2591 try:
2592 test_program = unittest.TestProgram(*args, **kwargs)
2593 return test_program.result, fail_when_no_tests_ran
2594 finally:
2595 if xml_buffer:
2596 try:
2597 with _open(xml_output_file, 'w') as f:
2598 f.write(xml_buffer.getvalue())
2599 finally:
2600 xml_buffer.close()
2603def run_tests(
2604 argv: MutableSequence[Text],
2605 args: Sequence[Any],
2606 kwargs: MutableMapping[Text, Any],
2607) -> None:
2608 """Executes a set of Python unit tests.
2610 Most users should call absltest.main() instead of run_tests.
2612 Please note that run_tests should be called from app.run.
2613 Calling absltest.main() would ensure that.
2615 Please note that run_tests is allowed to make changes to kwargs.
2617 Args:
2618 argv: sys.argv with the command-line flags removed from the front, i.e. the
2619 argv with which :func:`app.run()<absl.app.run>` has called
2620 ``__main__.main``. It is passed to
2621 ``unittest.TestProgram.__init__(argv=)``, which does its own flag parsing.
2622 It is ignored if kwargs contains an argv entry.
2623 args: Positional arguments passed through to
2624 ``unittest.TestProgram.__init__``.
2625 kwargs: Keyword arguments passed through to
2626 ``unittest.TestProgram.__init__``.
2627 """
2628 result, fail_when_no_tests_ran = _run_and_get_tests_result(
2629 argv, args, kwargs, xml_reporter.TextAndXMLTestRunner
2630 )
2631 if fail_when_no_tests_ran and result.testsRun == 0:
2632 # Python 3.12 unittest exits with 5 when no tests ran. The code comes from
2633 # pytest which does the same thing.
2634 sys.exit(5)
2635 sys.exit(not result.wasSuccessful())
2638def _rmtree_ignore_errors(path):
2639 # type: (Text) -> None
2640 if os.path.isfile(path):
2641 try:
2642 os.unlink(path)
2643 except OSError:
2644 pass
2645 else:
2646 shutil.rmtree(path, ignore_errors=True)
2649def _get_first_part(path):
2650 # type: (Text) -> Text
2651 parts = path.split(os.sep, 1)
2652 return parts[0]