# Copyright 2017 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Test the retry_util module."""

import functools
import itertools
import os
import sys
import time
from unittest import mock

from chromite.lib import cros_build_lib
from chromite.lib import cros_test_lib
from chromite.lib import osutils
from chromite.lib import retry_util


class TestRetries(cros_test_lib.MockTempDirTestCase):
    """Tests of GenericRetry and relatives."""

    def testWithRetrySuccess(self) -> None:
        """Test basic retry success case."""

        @retry_util.WithRetry(max_retry=3)
        def _run():
            return 10

        self.assertEqual(10, _run())

    def testWithRetrySuccessAfterRetry(self) -> None:
        """Test basic retry success case, but failed at least once."""
        counter = itertools.count()

        @retry_util.WithRetry(max_retry=3)
        def _run():
            current = next(counter)
            # Failed twice, then success.
            if current < 2:
                raise Exception()
            return 10

        self.assertEqual(10, _run())

    def testWithRetryFail(self) -> None:
        """Test basic retry fail case."""

        @retry_util.WithRetry(max_retry=3)
        def _run() -> None:
            raise Exception("Retry fail")

        with self.assertRaisesRegex(Exception, "Retry fail"):
            _run()

    def testGenericRetry(self) -> None:
        """Test basic semantics of retry and success recording."""
        source = functools.partial(next, iter(range(5)))

        def _TestMain():
            # TODO(b/236161656): Fix.
            # pylint: disable-next=assignment-from-no-return
            val = source()
            if val < 4:
                raise ValueError()
            return val

        handler = lambda ex: isinstance(ex, ValueError)

        callback_args = []
        with self.assertRaises(ValueError):
            retry_util.GenericRetry(
                handler,
                3,
                _TestMain,
                status_callback=lambda *args: callback_args.append(args),
            )
        self.assertEqual(
            callback_args, [(0, False), (1, False), (2, False), (3, False)]
        )

        callback_args = []
        self.assertEqual(
            4,
            retry_util.GenericRetry(
                handler,
                1,
                _TestMain,
                status_callback=lambda *args: callback_args.append(args),
            ),
        )
        self.assertEqual(callback_args, [(0, True)])

        callback_args = []
        with self.assertRaises(StopIteration):
            retry_util.GenericRetry(
                handler,
                3,
                _TestMain,
                status_callback=lambda *args: callback_args.append(args),
            )
        self.assertEqual(callback_args, [(0, False)])

    def testGenericRetryBadArgs(self) -> None:
        """Test bad retry related arguments to GenericRetry raise ValueError."""

        def _AlwaysRaise() -> None:
            raise Exception("Not a ValueError")

        # |max_retry| must be non-negative number.
        with self.assertRaises(ValueError):
            retry_util.GenericRetry(lambda _: True, -1, _AlwaysRaise)

        # |backoff_factor| must be 1 or greator.
        with self.assertRaises(ValueError):
            retry_util.GenericRetry(
                lambda _: True, 3, _AlwaysRaise, backoff_factor=0.9
            )

        # Sleep must be non-negative number.
        with self.assertRaises(ValueError):
            retry_util.GenericRetry(lambda _: True, 3, _AlwaysRaise, sleep=-1)

    def testRaisedException(self) -> None:
        """Test which exception gets raised by repeated failure."""

        def _GetTestMain():
            """Get func that fails once with ValueError, then AssertionError."""
            source = itertools.count()

            def _TestMain() -> None:
                if next(source) == 0:
                    raise ValueError()
                else:
                    raise AssertionError()

            return _TestMain

        with self.assertRaises(ValueError):
            retry_util.GenericRetry(lambda _: True, 3, _GetTestMain())

        with self.assertRaises(AssertionError):
            retry_util.GenericRetry(
                lambda _: True,
                3,
                _GetTestMain(),
                raise_first_exception_on_failure=False,
            )

    class CheckException(Exception):
        """Exception thrown from the below function."""

    def _RaiseCheckException(self, *_) -> None:
        raise TestRetries.CheckException()

    def testStatustCallbackExceptionForSuccess(self) -> None:
        """Exception from |status_callback| should be raised even on success."""
        with self.assertRaises(TestRetries.CheckException):
            retry_util.GenericRetry(
                lambda _: True,
                1,
                lambda: None,
                status_callback=self._RaiseCheckException,
            )

    def testStatusCallbackExceptionForRetry(self) -> None:
        """Exception from |status_callback| should stop retry."""
        counter = [0]  # Counter to track how many times _functor is called.

        def _TestMain() -> None:
            counter[0] += 1
            raise Exception()  # Let it fail.

        with self.assertRaises(TestRetries.CheckException):
            retry_util.GenericRetry(
                lambda _: True,
                10,
                _TestMain,
                status_callback=self._RaiseCheckException,
            )
        # Do not expect retry in case |status_callback| raises an exception.
        self.assertEqual(counter[0], 1)

    def testRetryExceptionBadArgs(self) -> None:
        """Verify we reject non-classes or tuples of classes"""
        with self.assertRaises(TypeError):
            retry_util.RetryException("", 3, map)
        with self.assertRaises(TypeError):
            retry_util.RetryException(123, 3, map)
        with self.assertRaises(TypeError):
            retry_util.RetryException(None, 3, map)
        with self.assertRaises(TypeError):
            retry_util.RetryException([None], 3, map)

    def testRetryException(self) -> None:
        """Verify we retry only when certain exceptions get thrown"""
        source = functools.partial(next, iter(range(6)))

        def _TestMain():
            # TODO(b/236161656): Fix.
            # pylint: disable-next=assignment-from-no-return
            val = source()
            if val < 2:
                raise OSError()
            if val < 5:
                raise ValueError()
            return val

        with self.assertRaises(OSError):
            retry_util.RetryException((OSError, ValueError), 2, _TestMain)
        with self.assertRaises(ValueError):
            retry_util.RetryException((OSError, ValueError), 1, _TestMain)
        self.assertEqual(5, retry_util.RetryException(ValueError, 1, _TestMain))
        with self.assertRaises(StopIteration):
            retry_util.RetryException(ValueError, 3, _TestMain)

    def testRetryWithBackoff(self) -> None:
        sleep_history = []
        self.PatchObject(time, "sleep", new=sleep_history.append)

        def _AlwaysFail() -> None:
            raise ValueError()

        with self.assertRaises(ValueError):
            retry_util.GenericRetry(
                lambda _: True, 5, _AlwaysFail, sleep=1, backoff_factor=2
            )

        self.assertEqual(sleep_history, [1, 2, 4, 8, 16])

    def testBasicRetry(self) -> None:
        path = os.path.join(self.tempdir, "script")
        paths = {
            "stop": os.path.join(self.tempdir, "stop"),
            "store": os.path.join(self.tempdir, "store"),
        }
        osutils.WriteFile(
            path,
            "import sys\n"
            "val = int(open(%(store)r).read())\n"
            "stop_val = int(open(%(stop)r).read())\n"
            "open(%(store)r, 'w').write(str(val + 1))\n"
            "print(val)\n"
            "sys.exit(0 if val == stop_val else 1)\n" % paths,
        )

        os.chmod(path, 0o755)

        def _SetupCounters(start, stop) -> None:
            sleep_mock.reset_mock()
            osutils.WriteFile(paths["store"], str(start))
            osutils.WriteFile(paths["stop"], str(stop))

        def _AssertCounters(sleep, sleep_cnt) -> None:
            calls = [
                mock.call(float(sleep * (x + 1))) for x in range(sleep_cnt)
            ]
            sleep_mock.assert_has_calls(calls)

        sleep_mock = self.PatchObject(time, "sleep")

        _SetupCounters(0, 0)
        command = [sys.executable, path]
        kwargs = {"stdout": True, "print_cmd": False}
        self.assertEqual(cros_build_lib.run(command, **kwargs).stdout, b"0\n")
        _AssertCounters(0, 0)

        func = retry_util.RunCommandWithRetries

        _SetupCounters(2, 2)
        self.assertEqual(func(0, command, sleep=0, **kwargs).stdout, b"2\n")
        _AssertCounters(0, 0)

        _SetupCounters(0, 2)
        self.assertEqual(func(2, command, sleep=1, **kwargs).stdout, b"2\n")
        _AssertCounters(1, 2)

        _SetupCounters(0, 1)
        self.assertEqual(func(1, command, sleep=2, **kwargs).stdout, b"1\n")
        _AssertCounters(2, 1)

        _SetupCounters(0, 3)
        with self.assertRaises(cros_build_lib.RunCommandError):
            func(2, command, sleep=3, **kwargs)
        _AssertCounters(3, 2)
