# -*- coding: utf-8 -*-
# Copyright 2016 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Module that contains unittests for auth module."""

from __future__ import print_function

import sys
import time

import mock

from chromite.lib import auth
from chromite.lib import cros_test_lib


assert sys.version_info >= (3, 6), 'This module requires Python 3.6+'


class AuthTest(cros_test_lib.RunCommandTestCase):
  """Test cases for methods in auth."""

  def setUp(self):
    self.PatchObject(time, 'sleep')
    self.PatchObject(auth, 'GetLuciAuth', return_value='luci-auth')
    self.PatchObject(
        auth, 'GetLuciGitCreds', return_value='git-credential-luci')

  def testLoginFailed(self):
    """Test Login failing."""
    self.rc.AddCmdResult(['luci-auth', 'login'], stderr='', returncode=1)
    self.assertRaises(auth.AccessTokenError, auth.Login)

  def testLoginPassed(self):
    """Test Login working."""
    self.rc.AddCmdResult(['luci-auth', 'login'], stdout=None)
    self.assertIsNone(auth.Login())

  def testTokenFailed(self):
    """Test Token failing."""
    self.rc.AddCmdResult(['luci-auth', 'token'], stderr='', returncode=1)
    self.assertRaises(auth.AccessTokenError, auth.Token)

  def testTokenPassed(self):
    """Test Token working."""
    self.rc.AddCmdResult(['luci-auth', 'token'], stdout='token')
    self.assertEqual(auth.Token(), 'token')

  def testTokenAndLoginIfNeed(self):
    """Test TokenAndLoginIfNeed."""
    # pylint: disable=protected-access
    sc_json = '/tmp/service_account.json'

    mock_login = self.PatchObject(auth, 'Login', return_value=None)
    mock_token = self.PatchObject(auth, 'Token', return_value='token')
    auth._TokenAndLoginIfNeed(service_account_json=sc_json)
    mock_token.assert_called_once_with(service_account_json=sc_json)
    mock_login.assert_not_called()

    mock_login.reset_mock()
    mock_token.reset_mock()
    auth._TokenAndLoginIfNeed(
        service_account_json=sc_json, force_token_renew=True)
    mock_token.assert_called_once_with(service_account_json=sc_json)
    mock_login.assert_called_once_with(service_account_json=sc_json)

    mock_login.reset_mock()
    mock_token = self.PatchObject(
        auth, 'Token', side_effect=auth.AccessTokenError())
    self.assertRaises(
        auth.AccessTokenError,
        auth._TokenAndLoginIfNeed,
        service_account_json=sc_json)
    mock_token.assert_called_once_with(service_account_json=sc_json)
    mock_login.assert_called_once_with(service_account_json=sc_json)

    mock_login.reset_mock()
    mock_token.reset_mock()
    self.assertRaises(
        auth.AccessTokenError,
        auth._TokenAndLoginIfNeed,
        service_account_json=sc_json, force_token_renew=True)
    mock_token.assert_called_once_with(service_account_json=sc_json)
    mock_login.assert_called_once_with(service_account_json=sc_json)

  def testGetAccessToken(self):
    """Test GetAccessToken."""
    mock_login = self.PatchObject(auth, 'Login', return_value=None)
    mock_token = self.PatchObject(auth, 'Token',
                                  side_effect=auth.AccessTokenError())
    auth.GetAccessToken()
    self.assertEqual(mock_login.call_count, auth.RETRY_GET_ACCESS_TOKEN+1)
    self.assertEqual(mock_token.call_count, auth.RETRY_GET_ACCESS_TOKEN+1)

    mock_login.reset_mock()
    mock_token.reset_mock()
    auth.GetAccessToken(force_token_renew=True)
    self.assertEqual(mock_login.call_count, auth.RETRY_GET_ACCESS_TOKEN+1)
    self.assertEqual(mock_token.call_count, auth.RETRY_GET_ACCESS_TOKEN+1)

  def testGitCredsFailed(self):
    """Test git-credential-luci failing."""
    self.rc.AddCmdResult(
        ['git-credential-luci', 'get'], stderr='', returncode=1)
    self.assertRaises(auth.AccessTokenError, auth.GitCreds)

  def testGitCredsPassed(self):
    """Test git-credential-luci working."""
    stdout = '\n'.join([
        'user=some-luci-user',
        'password=some-git-password',
    ])
    self.rc.AddCmdResult(['git-credential-luci', 'get'], stdout=stdout)
    self.assertEqual(auth.GitCreds(), 'some-git-password')

  def testGitCredsNoPassword(self):
    """Test git-credential-luci returning unknown output."""
    stdout = '\n'.join([
        'unknown stdout format, line #1',
        'unknown stdout format, line #2',
    ])
    self.rc.AddCmdResult(['git-credential-luci', 'get'], stdout=stdout)
    self.assertRaises(auth.AccessTokenError, auth.GitCreds)


class AuthorizedHttp(cros_test_lib.MockTestCase):
  """Test cases for AuthorizedHttp."""

  def setUp(self):
    self.mock_http = mock.Mock()
    self.mock_resp = mock.Mock()
    self.mock_get_token = self.PatchObject(
        auth, 'GetAccessToken', return_value='token')
    self.account_json = 'service_account_json'

  def testAuthorize(self):
    self.mock_resp.status = 200
    self.mock_http.request.return_value = self.mock_resp, 'content'

    auth_http = auth.AuthorizedHttp(
        auth.GetAccessToken, self.mock_http,
        service_account_json=self.account_json)
    auth_http.request('url', 'GET', body={}, headers={})
    auth_http.request('url', 'PUT', body={}, headers={})
    auth_http.request('url', 'POST', body={}, headers={})

    self.assertEqual(1, self.mock_get_token.call_count)
    self.assertEqual(3, self.mock_http.request.call_count)
    self.mock_get_token.assert_called_with(
        service_account_json=self.account_json)

  def testAuthorize2(self):
    self.mock_resp.status = 401
    self.mock_http.request.return_value = self.mock_resp, 'content'

    auth_http = auth.AuthorizedHttp(
        auth.GetAccessToken, self.mock_http,
        service_account_json=self.account_json)
    auth_http.request('url', 'GET', body={}, headers={})
    auth_http.request('url', 'PUT', body={}, headers={})
    auth_http.request('url', 'POST', body={}, headers={})

    self.assertEqual(4, self.mock_get_token.call_count)
    self.assertEqual(6, self.mock_http.request.call_count)
    self.mock_get_token.assert_called_with(
        service_account_json=self.account_json, force_token_renew=True)

  def testAuthorize3(self):
    self.mock_resp.status = 500
    self.mock_http.request.return_value = self.mock_resp, 'content'

    auth_http = auth.AuthorizedHttp(
        auth.GetAccessToken, self.mock_http)
    auth_http.request('url', 'GET', body={}, headers={})
    auth_http.request('url', 'PUT', body={}, headers={})
    auth_http.request('url', 'POST', body={}, headers={})

    self.assertEqual(1, self.mock_get_token.call_count)
    self.mock_get_token.assert_called_with()
