| # Copyright 2017 The ChromiumOS Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| """Test the retry_util module.""" |
| |
| import functools |
| import itertools |
| import os |
| import sys |
| import time |
| from unittest import mock |
| |
| from chromite.lib import cros_build_lib |
| from chromite.lib import cros_test_lib |
| from chromite.lib import osutils |
| from chromite.lib import retry_util |
| |
| |
| class TestRetries(cros_test_lib.MockTempDirTestCase): |
| """Tests of GenericRetry and relatives.""" |
| |
| def testWithRetrySuccess(self) -> None: |
| """Test basic retry success case.""" |
| |
| @retry_util.WithRetry(max_retry=3) |
| def _run(): |
| return 10 |
| |
| self.assertEqual(10, _run()) |
| |
| def testWithRetrySuccessAfterRetry(self) -> None: |
| """Test basic retry success case, but failed at least once.""" |
| counter = itertools.count() |
| |
| @retry_util.WithRetry(max_retry=3) |
| def _run(): |
| current = next(counter) |
| # Failed twice, then success. |
| if current < 2: |
| raise Exception() |
| return 10 |
| |
| self.assertEqual(10, _run()) |
| |
| def testWithRetryFail(self) -> None: |
| """Test basic retry fail case.""" |
| |
| @retry_util.WithRetry(max_retry=3) |
| def _run() -> None: |
| raise Exception("Retry fail") |
| |
| with self.assertRaisesRegex(Exception, "Retry fail"): |
| _run() |
| |
| def testGenericRetry(self) -> None: |
| """Test basic semantics of retry and success recording.""" |
| source = functools.partial(next, iter(range(5))) |
| |
| def _TestMain(): |
| # TODO(b/236161656): Fix. |
| # pylint: disable-next=assignment-from-no-return |
| val = source() |
| if val < 4: |
| raise ValueError() |
| return val |
| |
| handler = lambda ex: isinstance(ex, ValueError) |
| |
| callback_args = [] |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry( |
| handler, |
| 3, |
| _TestMain, |
| status_callback=lambda *args: callback_args.append(args), |
| ) |
| self.assertEqual( |
| callback_args, [(0, False), (1, False), (2, False), (3, False)] |
| ) |
| |
| callback_args = [] |
| self.assertEqual( |
| 4, |
| retry_util.GenericRetry( |
| handler, |
| 1, |
| _TestMain, |
| status_callback=lambda *args: callback_args.append(args), |
| ), |
| ) |
| self.assertEqual(callback_args, [(0, True)]) |
| |
| callback_args = [] |
| with self.assertRaises(StopIteration): |
| retry_util.GenericRetry( |
| handler, |
| 3, |
| _TestMain, |
| status_callback=lambda *args: callback_args.append(args), |
| ) |
| self.assertEqual(callback_args, [(0, False)]) |
| |
| def testGenericRetryBadArgs(self) -> None: |
| """Test bad retry related arguments to GenericRetry raise ValueError.""" |
| |
| def _AlwaysRaise() -> None: |
| raise Exception("Not a ValueError") |
| |
| # |max_retry| must be non-negative number. |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry(lambda _: True, -1, _AlwaysRaise) |
| |
| # |backoff_factor| must be 1 or greator. |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry( |
| lambda _: True, 3, _AlwaysRaise, backoff_factor=0.9 |
| ) |
| |
| # Sleep must be non-negative number. |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry(lambda _: True, 3, _AlwaysRaise, sleep=-1) |
| |
| def testRaisedException(self) -> None: |
| """Test which exception gets raised by repeated failure.""" |
| |
| def _GetTestMain(): |
| """Get func that fails once with ValueError, then AssertionError.""" |
| source = itertools.count() |
| |
| def _TestMain() -> None: |
| if next(source) == 0: |
| raise ValueError() |
| else: |
| raise AssertionError() |
| |
| return _TestMain |
| |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry(lambda _: True, 3, _GetTestMain()) |
| |
| with self.assertRaises(AssertionError): |
| retry_util.GenericRetry( |
| lambda _: True, |
| 3, |
| _GetTestMain(), |
| raise_first_exception_on_failure=False, |
| ) |
| |
| class CheckException(Exception): |
| """Exception thrown from the below function.""" |
| |
| def _RaiseCheckException(self, *_) -> None: |
| raise TestRetries.CheckException() |
| |
| def testStatustCallbackExceptionForSuccess(self) -> None: |
| """Exception from |status_callback| should be raised even on success.""" |
| with self.assertRaises(TestRetries.CheckException): |
| retry_util.GenericRetry( |
| lambda _: True, |
| 1, |
| lambda: None, |
| status_callback=self._RaiseCheckException, |
| ) |
| |
| def testStatusCallbackExceptionForRetry(self) -> None: |
| """Exception from |status_callback| should stop retry.""" |
| counter = [0] # Counter to track how many times _functor is called. |
| |
| def _TestMain() -> None: |
| counter[0] += 1 |
| raise Exception() # Let it fail. |
| |
| with self.assertRaises(TestRetries.CheckException): |
| retry_util.GenericRetry( |
| lambda _: True, |
| 10, |
| _TestMain, |
| status_callback=self._RaiseCheckException, |
| ) |
| # Do not expect retry in case |status_callback| raises an exception. |
| self.assertEqual(counter[0], 1) |
| |
| def testRetryExceptionBadArgs(self) -> None: |
| """Verify we reject non-classes or tuples of classes""" |
| with self.assertRaises(TypeError): |
| retry_util.RetryException("", 3, map) |
| with self.assertRaises(TypeError): |
| retry_util.RetryException(123, 3, map) |
| with self.assertRaises(TypeError): |
| retry_util.RetryException(None, 3, map) |
| with self.assertRaises(TypeError): |
| retry_util.RetryException([None], 3, map) |
| |
| def testRetryException(self) -> None: |
| """Verify we retry only when certain exceptions get thrown""" |
| source = functools.partial(next, iter(range(6))) |
| |
| def _TestMain(): |
| # TODO(b/236161656): Fix. |
| # pylint: disable-next=assignment-from-no-return |
| val = source() |
| if val < 2: |
| raise OSError() |
| if val < 5: |
| raise ValueError() |
| return val |
| |
| with self.assertRaises(OSError): |
| retry_util.RetryException((OSError, ValueError), 2, _TestMain) |
| with self.assertRaises(ValueError): |
| retry_util.RetryException((OSError, ValueError), 1, _TestMain) |
| self.assertEqual(5, retry_util.RetryException(ValueError, 1, _TestMain)) |
| with self.assertRaises(StopIteration): |
| retry_util.RetryException(ValueError, 3, _TestMain) |
| |
| def testRetryWithBackoff(self) -> None: |
| sleep_history = [] |
| self.PatchObject(time, "sleep", new=sleep_history.append) |
| |
| def _AlwaysFail() -> None: |
| raise ValueError() |
| |
| with self.assertRaises(ValueError): |
| retry_util.GenericRetry( |
| lambda _: True, 5, _AlwaysFail, sleep=1, backoff_factor=2 |
| ) |
| |
| self.assertEqual(sleep_history, [1, 2, 4, 8, 16]) |
| |
| def testBasicRetry(self) -> None: |
| path = os.path.join(self.tempdir, "script") |
| paths = { |
| "stop": os.path.join(self.tempdir, "stop"), |
| "store": os.path.join(self.tempdir, "store"), |
| } |
| osutils.WriteFile( |
| path, |
| "import sys\n" |
| "val = int(open(%(store)r).read())\n" |
| "stop_val = int(open(%(stop)r).read())\n" |
| "open(%(store)r, 'w').write(str(val + 1))\n" |
| "print(val)\n" |
| "sys.exit(0 if val == stop_val else 1)\n" % paths, |
| ) |
| |
| os.chmod(path, 0o755) |
| |
| def _SetupCounters(start, stop) -> None: |
| sleep_mock.reset_mock() |
| osutils.WriteFile(paths["store"], str(start)) |
| osutils.WriteFile(paths["stop"], str(stop)) |
| |
| def _AssertCounters(sleep, sleep_cnt) -> None: |
| calls = [ |
| mock.call(float(sleep * (x + 1))) for x in range(sleep_cnt) |
| ] |
| sleep_mock.assert_has_calls(calls) |
| |
| sleep_mock = self.PatchObject(time, "sleep") |
| |
| _SetupCounters(0, 0) |
| command = [sys.executable, path] |
| kwargs = {"stdout": True, "print_cmd": False} |
| self.assertEqual(cros_build_lib.run(command, **kwargs).stdout, b"0\n") |
| _AssertCounters(0, 0) |
| |
| func = retry_util.RunCommandWithRetries |
| |
| _SetupCounters(2, 2) |
| self.assertEqual(func(0, command, sleep=0, **kwargs).stdout, b"2\n") |
| _AssertCounters(0, 0) |
| |
| _SetupCounters(0, 2) |
| self.assertEqual(func(2, command, sleep=1, **kwargs).stdout, b"2\n") |
| _AssertCounters(1, 2) |
| |
| _SetupCounters(0, 1) |
| self.assertEqual(func(1, command, sleep=2, **kwargs).stdout, b"1\n") |
| _AssertCounters(2, 1) |
| |
| _SetupCounters(0, 3) |
| with self.assertRaises(cros_build_lib.RunCommandError): |
| func(2, command, sleep=3, **kwargs) |
| _AssertCounters(3, 2) |