blob: 1be2801a632616970eb0f67130d416c8e39c99cc [file] [log] [blame]
# Copyright 2017 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Test the retry_util module."""
import functools
import itertools
import os
import sys
import time
from unittest import mock
from chromite.lib import cros_build_lib
from chromite.lib import cros_test_lib
from chromite.lib import osutils
from chromite.lib import retry_util
class TestRetries(cros_test_lib.MockTempDirTestCase):
"""Tests of GenericRetry and relatives."""
def testWithRetrySuccess(self):
"""Test basic retry success case."""
@retry_util.WithRetry(max_retry=3)
def _run():
return 10
self.assertEqual(10, _run())
def testWithRetrySuccessAfterRetry(self):
"""Test basic retry success case, but failed at least once."""
counter = itertools.count()
@retry_util.WithRetry(max_retry=3)
def _run():
current = next(counter)
# Failed twice, then success.
if current < 2:
raise Exception()
return 10
self.assertEqual(10, _run())
def testWithRetryFail(self):
"""Test basic retry fail case."""
@retry_util.WithRetry(max_retry=3)
def _run():
raise Exception("Retry fail")
with self.assertRaisesRegex(Exception, "Retry fail"):
_run()
def testGenericRetry(self):
"""Test basic semantics of retry and success recording."""
source = functools.partial(next, iter(range(5)))
def _TestMain():
# TODO(b/236161656): Fix.
# pylint: disable-next=assignment-from-no-return
val = source()
if val < 4:
raise ValueError()
return val
handler = lambda ex: isinstance(ex, ValueError)
callback_args = []
with self.assertRaises(ValueError):
retry_util.GenericRetry(
handler,
3,
_TestMain,
status_callback=lambda *args: callback_args.append(args),
)
self.assertEqual(
callback_args, [(0, False), (1, False), (2, False), (3, False)]
)
callback_args = []
self.assertEqual(
4,
retry_util.GenericRetry(
handler,
1,
_TestMain,
status_callback=lambda *args: callback_args.append(args),
),
)
self.assertEqual(callback_args, [(0, True)])
callback_args = []
with self.assertRaises(StopIteration):
retry_util.GenericRetry(
handler,
3,
_TestMain,
status_callback=lambda *args: callback_args.append(args),
)
self.assertEqual(callback_args, [(0, False)])
def testGenericRetryBadArgs(self):
"""Test bad retry related arguments to GenericRetry raise ValueError."""
def _AlwaysRaise():
raise Exception("Not a ValueError")
# |max_retry| must be non-negative number.
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, -1, _AlwaysRaise)
# |backoff_factor| must be 1 or greator.
with self.assertRaises(ValueError):
retry_util.GenericRetry(
lambda _: True, 3, _AlwaysRaise, backoff_factor=0.9
)
# Sleep must be non-negative number.
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 3, _AlwaysRaise, sleep=-1)
def testRaisedException(self):
"""Test which exception gets raised by repeated failure."""
def _GetTestMain():
"""Get func that fails once with ValueError, then AssertionError."""
source = itertools.count()
def _TestMain():
if next(source) == 0:
raise ValueError()
else:
raise AssertionError()
return _TestMain
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 3, _GetTestMain())
with self.assertRaises(AssertionError):
retry_util.GenericRetry(
lambda _: True,
3,
_GetTestMain(),
raise_first_exception_on_failure=False,
)
class CheckException(Exception):
"""Exception thrown from the below function."""
def _RaiseCheckException(self, *_):
raise TestRetries.CheckException()
def testStatustCallbackExceptionForSuccess(self):
"""Exception from |status_callback| should be raised even on success."""
with self.assertRaises(TestRetries.CheckException):
retry_util.GenericRetry(
lambda _: True,
1,
lambda: None,
status_callback=self._RaiseCheckException,
)
def testStatusCallbackExceptionForRetry(self):
"""Exception from |status_callback| should stop retry."""
counter = [0] # Counter to track how many times _functor is called.
def _TestMain():
counter[0] += 1
raise Exception() # Let it fail.
with self.assertRaises(TestRetries.CheckException):
retry_util.GenericRetry(
lambda _: True,
10,
_TestMain,
status_callback=self._RaiseCheckException,
)
# Do not expect retry in case |status_callback| raises an exception.
self.assertEqual(counter[0], 1)
def testRetryExceptionBadArgs(self):
"""Verify we reject non-classes or tuples of classes"""
with self.assertRaises(TypeError):
retry_util.RetryException("", 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException(123, 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException(None, 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException([None], 3, map)
def testRetryException(self):
"""Verify we retry only when certain exceptions get thrown"""
source = functools.partial(next, iter(range(6)))
def _TestMain():
# TODO(b/236161656): Fix.
# pylint: disable-next=assignment-from-no-return
val = source()
if val < 2:
raise OSError()
if val < 5:
raise ValueError()
return val
with self.assertRaises(OSError):
retry_util.RetryException((OSError, ValueError), 2, _TestMain)
with self.assertRaises(ValueError):
retry_util.RetryException((OSError, ValueError), 1, _TestMain)
self.assertEqual(5, retry_util.RetryException(ValueError, 1, _TestMain))
with self.assertRaises(StopIteration):
retry_util.RetryException(ValueError, 3, _TestMain)
def testRetryWithBackoff(self):
sleep_history = []
self.PatchObject(time, "sleep", new=sleep_history.append)
def _AlwaysFail():
raise ValueError()
with self.assertRaises(ValueError):
retry_util.GenericRetry(
lambda _: True, 5, _AlwaysFail, sleep=1, backoff_factor=2
)
self.assertEqual(sleep_history, [1, 2, 4, 8, 16])
def testBasicRetry(self):
path = os.path.join(self.tempdir, "script")
paths = {
"stop": os.path.join(self.tempdir, "stop"),
"store": os.path.join(self.tempdir, "store"),
}
osutils.WriteFile(
path,
"import sys\n"
"val = int(open(%(store)r).read())\n"
"stop_val = int(open(%(stop)r).read())\n"
"open(%(store)r, 'w').write(str(val + 1))\n"
"print(val)\n"
"sys.exit(0 if val == stop_val else 1)\n" % paths,
)
os.chmod(path, 0o755)
def _SetupCounters(start, stop):
sleep_mock.reset_mock()
osutils.WriteFile(paths["store"], str(start))
osutils.WriteFile(paths["stop"], str(stop))
def _AssertCounters(sleep, sleep_cnt):
calls = [
mock.call(float(sleep * (x + 1))) for x in range(sleep_cnt)
]
sleep_mock.assert_has_calls(calls)
sleep_mock = self.PatchObject(time, "sleep")
_SetupCounters(0, 0)
command = [sys.executable, path]
kwargs = {"stdout": True, "print_cmd": False}
self.assertEqual(cros_build_lib.run(command, **kwargs).stdout, b"0\n")
_AssertCounters(0, 0)
func = retry_util.RunCommandWithRetries
_SetupCounters(2, 2)
self.assertEqual(func(0, command, sleep=0, **kwargs).stdout, b"2\n")
_AssertCounters(0, 0)
_SetupCounters(0, 2)
self.assertEqual(func(2, command, sleep=1, **kwargs).stdout, b"2\n")
_AssertCounters(1, 2)
_SetupCounters(0, 1)
self.assertEqual(func(1, command, sleep=2, **kwargs).stdout, b"1\n")
_AssertCounters(2, 1)
_SetupCounters(0, 3)
with self.assertRaises(cros_build_lib.RunCommandError):
func(2, command, sleep=3, **kwargs)
_AssertCounters(3, 2)