# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Test the remote_access module."""

from __future__ import print_function

import os
import re
import socket

from chromite.lib import cros_build_lib
from chromite.lib import cros_build_lib_unittest
from chromite.lib import cros_test_lib
from chromite.lib import debug_link
from chromite.lib import mdns
from chromite.lib import mdns_unittest
from chromite.lib import osutils
from chromite.lib import partial_mock
from chromite.lib import remote_access


# pylint: disable=W0212


class TestNormalizePort(cros_test_lib.TestCase):
  """Verifies we normalize port."""

  def testNormalizePortStrOK(self):
    """Tests that string will be converted to integer."""
    self.assertEqual(remote_access.NormalizePort('123'), 123)

  def testNormalizePortStrNotOK(self):
    """Tests that error is raised if port is string and str_ok=False."""
    self.assertRaises(
        ValueError, remote_access.NormalizePort, '123', str_ok=False)

  def testNormalizePortOutOfRange(self):
    """Tests that error is rasied when port is out of range."""
    self.assertRaises(ValueError, remote_access.NormalizePort, '-1')
    self.assertRaises(ValueError, remote_access.NormalizePort, 99999)


class TestRemoveKnownHost(cros_test_lib.MockTempDirTestCase):
  """Verifies RemoveKnownHost() functionality."""

  # ssh-keygen doesn't check for a valid hostname so use something that won't
  # be in the user's known_hosts to avoid changing their file contents.
  _HOST = '0.0.0.0.0.0'

  _HOST_KEY = (
      _HOST + ' ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCjysPTaDAtRaxRaW1JjqzCHp2'
      '88gvlUgtJxd2Jt/v63fkqZ5zzLLoeoAMwv0oYSRU82qhLimXpHxXRkrMC5nrpz5zJch+ktql'
      '0rSRgo+dqc1GzmyOOAq5NkQsgBb3hefxMxCZRV8Dv0n7qaindZRxE8MnRJmVUoj8Wq8wryab'
      'p+fUBkesBwaJhPXa4WBJeI5d+rO5tEBSNkvIp0USU6Ku3Ct0q2sZbOkY5g1VFAUYm4wyshCf'
      'oWvU8ivMFp0pCezMISGstKpkIQApq2dLUb6EmeIgnhHzZXOn7doxIGD33JUfFmwNi0qfk3vV'
      '6vKRVDEZD68+ix6gjKpicY5upA/9P\n')

  def testRemoveKnownHostDefaultFile(self):
    """Tests RemoveKnownHost() on the default known_hosts file.

    `ssh-keygen -R` on its own fails when run from within the chroot
    since the default known_hosts is bind mounted.
    """
    # It doesn't matter if known_hosts actually has this host in it or not,
    # this test just makes sure the command doesn't fail. The default
    # known_hosts file always exists in the chroot due to the bind mount.
    remote_access.RemoveKnownHost(self._HOST)

  def testRemoveKnownHostCustomFile(self):
    """Tests RemoveKnownHost() on a custom known_hosts file."""
    path = os.path.join(self.tempdir, 'known_hosts')
    osutils.WriteFile(path, self._HOST_KEY)
    remote_access.RemoveKnownHost(self._HOST, known_hosts_path=path)
    self.assertEqual(osutils.ReadFile(path), '')

  def testRemoveKnownHostNonexistentFile(self):
    """Tests RemoveKnownHost() on a nonexistent known_hosts file."""
    path = os.path.join(self.tempdir, 'known_hosts')
    remote_access.RemoveKnownHost(self._HOST, known_hosts_path=path)


class TestCompileSSHConnectSettings(cros_test_lib.TestCase):
  """Verifies CompileSSHConnectSettings()."""

  def testCustomSettingIncluded(self):
    """Tests that a custom setting will be included in the output."""
    self.assertIn(
        '-oNumberOfPasswordPrompts=100',
        remote_access.CompileSSHConnectSettings(NumberOfPasswordPrompts=100))

  def testNoneSettingOmitted(self):
    """Tests that a None value will omit a default setting from the output."""
    self.assertIn('-oProtocol=2', remote_access.CompileSSHConnectSettings())
    self.assertNotIn(
        '-oProtocol=2',
        remote_access.CompileSSHConnectSettings(Protocol=None))


class RemoteShMock(partial_mock.PartialCmdMock):
  """Mocks the RemoteSh function."""
  TARGET = 'chromite.lib.remote_access.RemoteAccess'
  ATTRS = ('RemoteSh',)
  DEFAULT_ATTR = 'RemoteSh'

  def RemoteSh(self, inst, cmd, *args, **kwargs):
    """Simulates a RemoteSh invocation.

    Returns:
      A CommandResult object with an additional member |rc_mock| to
      enable examination of the underlying RunCommand() function call.
    """
    result = self._results['RemoteSh'].LookupResult(
        (cmd,), hook_args=(inst, cmd,) + args, hook_kwargs=kwargs)

    # Run the real RemoteSh with RunCommand mocked out.
    rc_mock = cros_build_lib_unittest.RunCommandMock()
    rc_mock.AddCmdResult(
        partial_mock.Ignore(), result.returncode, result.output, result.error)

    with rc_mock:
      result = self.backup['RemoteSh'](inst, cmd, *args, **kwargs)
    result.rc_mock = rc_mock
    return result


class RemoteDeviceMock(partial_mock.PartialMock):
  """Mocks the RemoteDevice function."""

  TARGET = 'chromite.lib.remote_access.RemoteDevice'
  ATTRS = ('Pingable',)

  def Pingable(self, _):
    return True


class RemoteAccessTest(cros_test_lib.MockTempDirTestCase):
  """Base class with RemoteSh mocked out for testing RemoteAccess."""
  def setUp(self):
    self.rsh_mock = self.StartPatcher(RemoteShMock())
    self.host = remote_access.RemoteAccess('foon', self.tempdir)


class RemoteShTest(RemoteAccessTest):
  """Tests of basic RemoteSh functions"""
  TEST_CMD = 'ls'
  RETURN_CODE = 0
  OUTPUT = 'witty'
  ERROR = 'error'

  def assertRemoteShRaises(self, **kwargs):
    """Asserts that RunCommandError is raised when running TEST_CMD."""
    self.assertRaises(cros_build_lib.RunCommandError, self.host.RemoteSh,
                      self.TEST_CMD, **kwargs)

  def assertRemoteShRaisesSSHConnectionError(self, **kwargs):
    """Asserts that SSHConnectionError is raised when running TEST_CMD."""
    self.assertRaises(remote_access.SSHConnectionError, self.host.RemoteSh,
                      self.TEST_CMD, **kwargs)

  def SetRemoteShResult(self, returncode=RETURN_CODE, output=OUTPUT,
                        error=ERROR):
    """Sets the RemoteSh command results."""
    self.rsh_mock.AddCmdResult(self.TEST_CMD, returncode=returncode,
                               output=output, error=error)

  def testNormal(self):
    """Test normal functionality."""
    self.SetRemoteShResult()
    result = self.host.RemoteSh(self.TEST_CMD)
    self.assertEquals(result.returncode, self.RETURN_CODE)
    self.assertEquals(result.output.strip(), self.OUTPUT)
    self.assertEquals(result.error.strip(), self.ERROR)

  def testRemoteCmdFailure(self):
    """Test failure in remote cmd."""
    self.SetRemoteShResult(returncode=1)
    self.assertRemoteShRaises()
    self.assertRemoteShRaises(ssh_error_ok=True)
    self.host.RemoteSh(self.TEST_CMD, error_code_ok=True)
    self.host.RemoteSh(self.TEST_CMD, ssh_error_ok=True, error_code_ok=True)

  def testSshFailure(self):
    """Test failure in ssh command."""
    self.SetRemoteShResult(returncode=remote_access.SSH_ERROR_CODE)
    self.assertRemoteShRaisesSSHConnectionError()
    self.assertRemoteShRaisesSSHConnectionError(error_code_ok=True)
    self.host.RemoteSh(self.TEST_CMD, ssh_error_ok=True)
    self.host.RemoteSh(self.TEST_CMD, ssh_error_ok=True, error_code_ok=True)

  def testEnvLcMessagesSet(self):
    """Test that LC_MESSAGES is set to 'C' for an SSH command."""
    self.SetRemoteShResult()
    result = self.host.RemoteSh(self.TEST_CMD)
    rc_kwargs = result.rc_mock.call_args_list[-1][1]
    self.assertEqual(rc_kwargs['extra_env']['LC_MESSAGES'], 'C')

  def testEnvLcMessagesOverride(self):
    """Test that LC_MESSAGES is overridden to 'C' for an SSH command."""
    self.SetRemoteShResult()
    result = self.host.RemoteSh(self.TEST_CMD, extra_env={'LC_MESSAGES': 'fr'})
    rc_kwargs = result.rc_mock.call_args_list[-1][1]
    self.assertEqual(rc_kwargs['extra_env']['LC_MESSAGES'], 'C')


class CheckIfRebootedTest(RemoteAccessTest):
  """Tests of the _CheckIfRebooted function."""

  def MockCheckReboot(self, returncode):
    self.rsh_mock.AddCmdResult(
        partial_mock.Regex('.*%s.*' % re.escape(remote_access.REBOOT_MARKER)),
        returncode)

  def testSuccess(self):
    """Test the case of successful reboot."""
    self.MockCheckReboot(0)
    self.assertTrue(self.host._CheckIfRebooted())

  def testRemoteFailure(self):
    """Test case of reboot pending."""
    self.MockCheckReboot(1)
    self.assertFalse(self.host._CheckIfRebooted())

  def testSshFailure(self):
    """Test case of connection down."""
    self.MockCheckReboot(remote_access.SSH_ERROR_CODE)
    self.assertFalse(self.host._CheckIfRebooted())

  def testInvalidErrorCode(self):
    """Test case of bad error code returned."""
    self.MockCheckReboot(2)
    self.assertRaises(Exception, self.host._CheckIfRebooted)


class USBDeviceTestCase(mdns_unittest.mDnsTestCase):
  """Base class for USB device related tests."""

  def setUp(self):
    self.StartPatcher(RemoteDeviceMock())
    self.initializedebuglink_mock = self.PatchObject(debug_link,
                                                     'InitializeDebugLink')


class TestGetUSBConnectedDevices(USBDeviceTestCase):
  """Tests of the GetUSBConnectedDevices() function."""

  def testDebugLinkInitialization(self):
    """Test case to make sure the Debug Link is initialized."""
    self.PatchObject(mdns, 'FindServices')
    remote_access.GetUSBConnectedDevices()
    self.initializedebuglink_mock.assert_called_once()

  def testEnumeration(self):
    """Test case to check correct enumeration results."""
    services = [
        mdns.Service('d1.local', '1.1.1.1', 0, 'd1.a.local', {'alias': 'd1'}),
        mdns.Service('d2.local', '2.2.2.2', 0, 'd2.a.local', {'alias': 'd2'})]
    self._MockNetworkResponse(services)
    devices = remote_access.GetUSBConnectedDevices()
    self.assertEqual(len(devices), len(services))
    for index in range(len(devices)):
      self.assertEqual(devices[index].hostname, services[index].ip)
      self.assertEqual(devices[index].alias, services[index].text['alias'])


class TestGetDefaultDevice(USBDeviceTestCase):
  """Tests _GetDefaultDevice() function."""

  SERVICE_1 = mdns.Service('d1.local', '1.1.1.1', 0, 'd1.a.local',
                           {'alias': 'd1'})
  SERVICE_2 = mdns.Service('d2.local', '2.2.2.2', 0, 'd2.a.local',
                           {'alias': 'd2'})

  def testNoDevices(self):
    """Tests when no devices are found."""
    self._MockNetworkResponse([])
    with self.assertRaises(remote_access.DefaultDeviceError):
      remote_access._GetDefaultService()

  def testOneDevice(self):
    """Tests when one device is found."""
    self._MockNetworkResponse([self.SERVICE_1])
    service = remote_access._GetDefaultService()
    self.assertEqual(self.SERVICE_1.ip, service.ip)
    self.assertEqual(
        self.SERVICE_1.text[remote_access.BRILLO_DEVICE_PROPERTY_ALIAS],
        service.text[remote_access.BRILLO_DEVICE_PROPERTY_ALIAS])

  def testMultipleDevices(self):
    """Tests when multiple devices are found."""
    self._MockNetworkResponse([self.SERVICE_1, self.SERVICE_2])
    with self.assertRaises(remote_access.DefaultDeviceError):
      remote_access._GetDefaultService()

  def testDefaultChromiumOSDevice(self):
    """Tests finding the default ChromiumOSDevice."""
    self._MockNetworkResponse([self.SERVICE_1])
    device = remote_access.ChromiumOSDevice(None, ping=False, connect=False)
    self.assertEqual(self.SERVICE_1.ip, device.hostname)
    self.assertEqual(
        self.SERVICE_1.text[remote_access.BRILLO_DEVICE_PROPERTY_ALIAS],
        device.alias)


class TestUSBDeviceIP(USBDeviceTestCase):
  """Tests of the GetUSBDeviceIP() function."""

  def testEmptyAlias(self):
    """Test empty alias should resolve to None."""
    ip = remote_access.GetUSBDeviceIP('')
    self.assertEqual(ip, None)

  def testDebugLinkInitialization(self):
    """Test case to make sure the Debug Link is initialized."""
    self.PatchObject(mdns, 'FindServices')
    remote_access.GetUSBDeviceIP('dut')
    self.initializedebuglink_mock.assert_called_once()

  def testSuccessfulResolution(self):
    """Test successful resolution of alias to IP."""
    services = [
        mdns.Service('d1.local', '1.1.1.1', 0, 'd1.a.local', {'alias': 'd1'}),
        mdns.Service('d2.local', '2.2.2.2', 0, 'd2.a.local', {'alias': 'd2'}),
        mdns.Service('d3.local', '3.3.3.3', 0, 'd3.a.local', {'alias': 'd3'})]
    self._MockNetworkResponse(services)
    ip = remote_access.GetUSBDeviceIP('d2')
    self.assertEqual(ip, '2.2.2.2')

  def testDuplicateAlias(self):
    """Test resolution of alias to IP when duplicate aliases exist."""
    services = [
        mdns.Service('d1.local', '1.1.1.1', 0, 'd1.a.local', {'alias': 'd1'}),
        mdns.Service('d2.local', '2.2.2.2', 0, 'd2.a.local', {'alias': 'd2'}),
        mdns.Service('d2.local', '3.3.3.3', 0, 'd2.a.local', {'alias': 'd2'})]
    self._MockNetworkResponse(services)
    ip = remote_access.GetUSBDeviceIP('d2')
    # Make sure the IP belongs to the first response that matches the alias.
    self.assertEqual(ip, '2.2.2.2')

  def testFailedResolution(self):
    """Test failed resolution of alias to IP."""
    services = [
        mdns.Service('d1.local', '1.1.1.1', 0, 'd1.a.local', {'alias': 'd1'}),
        mdns.Service('d2.local', '2.2.2.2', 0, 'd2.a.local', {'alias': 'd2'})]
    self._MockNetworkResponse(services)
    ip = remote_access.GetUSBDeviceIP('d3')
    self.assertEqual(ip, None)


class TestChromiumOSDeviceHostnameResolution(USBDeviceTestCase):
  """Tests hostname resolution in ChromiumOSDevice."""

  def testHostnameAsNetworkName(self):
    """Test resolving a valid network name."""
    self.PatchObject(socket, 'getaddrinfo')
    hostname = 'good-hostname'
    device = remote_access.ChromiumOSDevice(hostname, connect=False)
    self.assertEqual(device.hostname, hostname)

  def testHostnameAsAlias(self):
    """Test resolving when hostname is used as an alias."""
    hostname = 'good-alias'
    ip = '1.1.1.1'
    self.PatchObject(socket, 'getaddrinfo', side_effect=socket.gaierror)
    self.PatchObject(remote_access, 'GetUSBDeviceIP', return_value=ip)
    device = remote_access.ChromiumOSDevice(hostname, connect=False)
    self.assertEqual(device.hostname, ip)
    self.assertEqual(device._alias, hostname)

  def testInvalidHostname(self):
    """Test resolving a bad network name and bad alias."""
    hostname = 'bad'
    self.PatchObject(socket, 'getaddrinfo', side_effect=socket.gaierror)
    self.PatchObject(remote_access, 'GetUSBDeviceIP', return_value=None)
    device = remote_access.ChromiumOSDevice(hostname, connect=False)
    # Hostname should be left alone if it's not resolvable.
    self.assertEqual(device.hostname, hostname)

class TestChromiumOSDeviceSetAlias(USBDeviceTestCase):
  """Tests setting alias name in ChromiumOSDevice."""

  def setUp(self):
    """Set up test objects."""
    self.device = remote_access.ChromiumOSDevice('1.1.1.1', connect=False)

  def testAliasTooLong(self):
    """Tests that error is raised when alias is too long."""
    alias_too_long = 'a' * (remote_access.BRILLO_DEVICE_PROPERTY_MAX_LEN + 1)
    with self.assertRaises(remote_access.InvalidDevicePropertyError):
      self.device.SetAlias(alias_too_long)

  def testAliasWithBadChars(self):
    """Tests that error is raised when alias has bad chars."""
    alias_bad_chars = 'bad*'
    with self.assertRaises(remote_access.InvalidDevicePropertyError):
      self.device.SetAlias(alias_bad_chars)

  def testGoodAlias(self):
    """Tests that command runs when good alias is provided."""
    good_alias = 'good'
    run_mock = self.PatchObject(remote_access.ChromiumOSDevice, 'RunCommand')
    self.device.SetAlias(good_alias)
    self.assertTrue(run_mock.called)
