# -*- coding: utf-8 -*-
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Test the retry_util module."""
from __future__ import print_function
import functools
import itertools
import os
import time
import sys
import mock
from chromite.lib import cros_build_lib
from chromite.lib import cros_test_lib
from chromite.lib import retry_util
from chromite.lib import osutils
class TestRetries(cros_test_lib.MockTempDirTestCase):
"""Tests of GenericRetry and relatives."""
def testWithRetrySuccess(self):
"""Test basic retry success case."""
def _run():
return 10
self.assertEqual(10, _run())
def testWithRetrySuccessAfterRetry(self):
"""Test basic retry success case, but failed at least once."""
counter = itertools.count()
def _run():
current = next(counter)
# Failed twice, then success.
if current < 2:
raise Exception()
return 10
self.assertEqual(10, _run())
def testWithRetryFail(self):
"""Test basic retry fail case."""
def _run():
raise Exception('Retry fail')
with self.assertRaisesRegex(Exception, 'Retry fail'):
def testGenericRetry(self):
"""Test basic semantics of retry and success recording."""
source = functools.partial(next, iter(range(5)))
def _TestMain():
val = source()
if val < 4:
raise ValueError()
return val
handler = lambda ex: isinstance(ex, ValueError)
callback_args = []
with self.assertRaises(ValueError):
handler, 3, _TestMain,
status_callback=lambda *args: callback_args.append(args))
[(0, False), (1, False), (2, False), (3, False)])
callback_args = []
handler, 1, _TestMain,
status_callback=lambda *args: callback_args.append(args)))
self.assertEqual(callback_args, [(0, True)])
callback_args = []
with self.assertRaises(StopIteration):
handler, 3, _TestMain,
status_callback=lambda *args: callback_args.append(args))
self.assertEqual(callback_args, [(0, False)])
def testGenericRetryBadArgs(self):
"""Test bad retry related arguments to GenericRetry raise ValueError."""
def _AlwaysRaise():
raise Exception('Not a ValueError')
# |max_retry| must be non-negative number.
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, -1, _AlwaysRaise)
# |backoff_factor| must be 1 or greator.
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 3, _AlwaysRaise,
# Sleep must be non-negative number.
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 3, _AlwaysRaise, sleep=-1)
def testRaisedException(self):
"""Test which exception gets raised by repeated failure."""
def _GetTestMain():
"""Get function that fails once with ValueError, Then AssertionError."""
source = itertools.count()
def _TestMain():
if next(source) == 0:
raise ValueError()
raise AssertionError()
return _TestMain
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 3, _GetTestMain())
with self.assertRaises(AssertionError):
retry_util.GenericRetry(lambda _: True, 3, _GetTestMain(),
class CheckException(Exception):
"""Exception thrown from the below function."""
def _RaiseCheckException(self, *_):
raise TestRetries.CheckException()
def testStatustCallbackExceptionForSuccess(self):
"""Exception from |status_callback| should be raised even on success."""
with self.assertRaises(TestRetries.CheckException):
retry_util.GenericRetry(lambda _: True, 1, lambda: None,
def testStatusCallbackExceptionForRetry(self):
"""Exception from |status_callback| should stop retry."""
counter = [0] # Counter to track how many times _functor is called.
def _TestMain():
counter[0] += 1
raise Exception() # Let it fail.
with self.assertRaises(TestRetries.CheckException):
retry_util.GenericRetry(lambda _: True, 10, _TestMain,
# Do not expect retry in case |status_callback| raises an exception.
self.assertEqual(counter[0], 1)
def testRetryExceptionBadArgs(self):
"""Verify we reject non-classes or tuples of classes"""
with self.assertRaises(TypeError):
retry_util.RetryException('', 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException(123, 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException(None, 3, map)
with self.assertRaises(TypeError):
retry_util.RetryException([None], 3, map)
def testRetryException(self):
"""Verify we retry only when certain exceptions get thrown"""
source = functools.partial(next, iter(range(6)))
def _TestMain():
val = source()
if val < 2:
raise OSError()
if val < 5:
raise ValueError()
return val
with self.assertRaises(OSError):
retry_util.RetryException((OSError, ValueError), 2, _TestMain)
with self.assertRaises(ValueError):
retry_util.RetryException((OSError, ValueError), 1, _TestMain)
self.assertEqual(5, retry_util.RetryException(ValueError, 1, _TestMain))
with self.assertRaises(StopIteration):
retry_util.RetryException(ValueError, 3, _TestMain)
def testRetryWithBackoff(self):
sleep_history = []
self.PatchObject(time, 'sleep', new=sleep_history.append)
def _AlwaysFail():
raise ValueError()
with self.assertRaises(ValueError):
retry_util.GenericRetry(lambda _: True, 5, _AlwaysFail, sleep=1,
self.assertEqual(sleep_history, [1, 2, 4, 8, 16])
def testBasicRetry(self):
path = os.path.join(self.tempdir, 'script')
paths = {
'stop': os.path.join(self.tempdir, 'stop'),
'store': os.path.join(self.tempdir, 'store'),
'from __future__ import print_function\n'
'import sys\n'
'val = int(open(%(store)r).read())\n'
'stop_val = int(open(%(stop)r).read())\n'
"open(%(store)r, 'w').write(str(val + 1))\n"
'sys.exit(0 if val == stop_val else 1)\n' % paths)
os.chmod(path, 0o755)
def _SetupCounters(start, stop):
osutils.WriteFile(paths['store'], str(start))
osutils.WriteFile(paths['stop'], str(stop))
def _AssertCounters(sleep, sleep_cnt):
calls = [ * (x + 1))) for x in range(sleep_cnt)]
sleep_mock = self.PatchObject(time, 'sleep')
_SetupCounters(0, 0)
command = [sys.executable, path]
kwargs = {'stdout': True, 'print_cmd': False}
self.assertEqual(, **kwargs).output, b'0\n')
_AssertCounters(0, 0)
func = retry_util.RunCommandWithRetries
_SetupCounters(2, 2)
self.assertEqual(func(0, command, sleep=0, **kwargs).output, b'2\n')
_AssertCounters(0, 0)
_SetupCounters(0, 2)
self.assertEqual(func(2, command, sleep=1, **kwargs).output, b'2\n')
_AssertCounters(1, 2)
_SetupCounters(0, 1)
self.assertEqual(func(1, command, sleep=2, **kwargs).output, b'1\n')
_AssertCounters(2, 1)
_SetupCounters(0, 3)
with self.assertRaises(cros_build_lib.RunCommandError):
func(2, command, sleep=3, **kwargs)
_AssertCounters(3, 2)