#!/usr/bin/python
# Copyright 2016 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import mock
import unittest

import common
from autotest_lib.client.common_lib import error
from autotest_lib.server.hosts import base_label_unittest, factory


class MockHost(object):
    """Mock host object with no side effects."""
    def __init__(self, hostname, **args):
        self._init_args = args
        self._init_args['hostname'] = hostname


    def job_start(self):
        """Only method called by factory."""
        pass


class MockConnectivity(object):
    """Mock connectivity object with no side effects."""
    def __init__(self, hostname, **args):
        pass


    def close(self):
        """Only method called by factory."""
        pass


def _gen_mock_host(name, check_host=False):
    """Create an identifiable mock host closs.
    """
    return type('mock_host_%s' % name, (MockHost,), {
        '_host_cls_name': name,
        'check_host': staticmethod(lambda host, timeout=None: check_host)
    })


def _gen_mock_conn(name):
    """Create an identifiable mock connectivity class.
    """
    return type('mock_conn_%s' % name, (MockConnectivity,),
                {'_conn_cls_name': name})


def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
    """Generate a machine dictionary with the specified parameters.

    @param hostname: hostname of machine
    @param labels: list of host labels
    @param attributes: dict of host attributes

    @return: machine dict with mocked AFE Host object and fake AfeStore.
    """
    afe_host = base_label_unittest.MockAFEHost(labels, attributes)
    return {'hostname': hostname,
            'afe_host': afe_host,
            'host_info_store': mock.sentinel.dummy}


class CreateHostUnittests(unittest.TestCase):
    """Tests for create_host function."""

    def setUp(self):
        """Prevent use of real Host and connectivity objects due to potential
        side effects.
        """
        self._orig_ssh_engine = factory.SSH_ENGINE
        self._orig_types = factory.host_types
        self._orig_dict = factory.OS_HOST_DICT
        self._orig_cros_host = factory.cros_host.CrosHost
        self._orig_local_host = factory.local_host.LocalHost
        self._orig_ssh_host = factory.ssh_host.SSHHost

        self.host_types = factory.host_types = []
        self.os_host_dict = factory.OS_HOST_DICT = {}
        factory.cros_host.CrosHost = _gen_mock_host('cros_host')
        factory.local_host.LocalHost = _gen_mock_conn('local')
        factory.ssh_host.SSHHost = _gen_mock_conn('ssh')


    def tearDown(self):
        """Clean up mocks."""
        factory.SSH_ENGINE = self._orig_ssh_engine
        factory.host_types = self._orig_types
        factory.OS_HOST_DICT = self._orig_dict
        factory.cros_host.CrosHost = self._orig_cros_host
        factory.local_host.LocalHost = self._orig_local_host
        factory.ssh_host.SSHHost = self._orig_ssh_host


    def test_use_specified(self):
        """Confirm that the specified host and connectivity classes are used."""
        machine = _gen_machine_dict()
        host_obj = factory.create_host(
                machine,
                _gen_mock_host('specified'),
                _gen_mock_conn('specified')
        )
        self.assertEqual(host_obj._host_cls_name, 'specified')
        self.assertEqual(host_obj._conn_cls_name, 'specified')


    def test_detect_host_by_os_label(self):
        """Confirm that the host object is selected by the os label.
        """
        machine = _gen_machine_dict(labels=['os:foo'])
        self.os_host_dict['foo'] = _gen_mock_host('foo')
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._host_cls_name, 'foo')


    def test_detect_host_by_os_type_attribute(self):
        """Confirm that the host object is selected by the os_type attribute
        and that the os_type attribute is preferred over the os label.
        """
        machine = _gen_machine_dict(labels=['os:foo'],
                                         attributes={'os_type': 'bar'})
        self.os_host_dict['foo'] = _gen_mock_host('foo')
        self.os_host_dict['bar'] = _gen_mock_host('bar')
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._host_cls_name, 'bar')


    def test_detect_host_by_check_host(self):
        """Confirm check_host logic chooses a host object when label/attribute
        detection fails.
        """
        machine = _gen_machine_dict()
        self.host_types.append(_gen_mock_host('first', check_host=False))
        self.host_types.append(_gen_mock_host('second', check_host=True))
        self.host_types.append(_gen_mock_host('third', check_host=False))
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._host_cls_name, 'second')


    def test_detect_host_fallback_to_cros_host(self):
        """Confirm fallback to CrosHost when all other detection fails.
        """
        machine = _gen_machine_dict()
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._host_cls_name, 'cros_host')


    def test_choose_connectivity_local(self):
        """Confirm local connectivity class used when hostname is localhost.
        """
        machine = _gen_machine_dict(hostname='localhost')
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._conn_cls_name, 'local')


    def test_choose_connectivity_ssh(self):
        """Confirm ssh connectivity class used when configured and hostname
        is not localhost.
        """
        factory.SSH_ENGINE = 'raw_ssh'
        machine = _gen_machine_dict(hostname='somehost')
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._conn_cls_name, 'ssh')


    def test_choose_connectivity_unsupported(self):
        """Confirm exception when configured for unsupported ssh engine.
        """
        factory.SSH_ENGINE = 'unsupported'
        machine = _gen_machine_dict(hostname='somehost')
        with self.assertRaises(error.AutoservError):
            factory.create_host(machine)


    def test_argument_passthrough(self):
        """Confirm that detected and specified arguments are passed through to
        the host object.
        """
        machine = _gen_machine_dict(hostname='localhost')
        host_obj = factory.create_host(machine, foo='bar')
        self.assertEqual(host_obj._init_args['hostname'], 'localhost')
        self.assertTrue('afe_host' in host_obj._init_args)
        self.assertTrue('host_info_store' in host_obj._init_args)
        self.assertEqual(host_obj._init_args['foo'], 'bar')


    def test_global_ssh_params(self):
        """Confirm passing of ssh parameters set as globals.
        """
        factory.ssh_user = 'foo'
        factory.ssh_pass = 'bar'
        factory.ssh_port = 1
        factory.ssh_verbosity_flag = 'baz'
        factory.ssh_options = 'zip'
        machine = _gen_machine_dict()
        try:
            host_obj = factory.create_host(machine)
            self.assertEqual(host_obj._init_args['user'], 'foo')
            self.assertEqual(host_obj._init_args['password'], 'bar')
            self.assertEqual(host_obj._init_args['port'], 1)
            self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
            self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
        finally:
            del factory.ssh_user
            del factory.ssh_pass
            del factory.ssh_port
            del factory.ssh_verbosity_flag
            del factory.ssh_options


    def test_host_attribute_ssh_params(self):
        """Confirm passing of ssh parameters from host attributes.
        """
        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
                                                'ssh_port': 100,
                                                'ssh_verbosity_flag': 'verb',
                                                'ssh_options': 'options'})
        host_obj = factory.create_host(machine)
        self.assertEqual(host_obj._init_args['user'], 'somebody')
        self.assertEqual(host_obj._init_args['port'], 100)
        self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
        self.assertEqual(host_obj._init_args['ssh_options'], 'options')


class CreateTestbedUnittests(unittest.TestCase):
    """Tests for create_testbed function."""

    def setUp(self):
        """Mock out TestBed class to eliminate side effects.
        """
        self._orig_testbed = factory.testbed.TestBed
        factory.testbed.TestBed = _gen_mock_host('testbed')


    def tearDown(self):
        """Clean up mock.
        """
        factory.testbed.TestBed = self._orig_testbed


    def test_argument_passthrough(self):
        """Confirm that detected and specified arguments are passed through to
        the testbed object.
        """
        machine = _gen_machine_dict(hostname='localhost')
        testbed_obj = factory.create_testbed(machine, foo='bar')
        self.assertEqual(testbed_obj._init_args['hostname'], 'localhost')
        self.assertTrue('afe_host' in testbed_obj._init_args)
        self.assertTrue('host_info_store' in testbed_obj._init_args)
        self.assertEqual(testbed_obj._init_args['foo'], 'bar')


    def test_global_ssh_params(self):
        """Confirm passing of ssh parameters set as globals.
        """
        factory.ssh_user = 'foo'
        factory.ssh_pass = 'bar'
        factory.ssh_port = 1
        factory.ssh_verbosity_flag = 'baz'
        factory.ssh_options = 'zip'
        machine = _gen_machine_dict()
        try:
            testbed_obj = factory.create_testbed(machine)
            self.assertEqual(testbed_obj._init_args['user'], 'foo')
            self.assertEqual(testbed_obj._init_args['password'], 'bar')
            self.assertEqual(testbed_obj._init_args['port'], 1)
            self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'],
                             'baz')
            self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip')
        finally:
            del factory.ssh_user
            del factory.ssh_pass
            del factory.ssh_port
            del factory.ssh_verbosity_flag
            del factory.ssh_options


    def test_host_attribute_ssh_params(self):
        """Confirm passing of ssh parameters from host attributes.
        """
        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
                                                'ssh_port': 100,
                                                'ssh_verbosity_flag': 'verb',
                                                'ssh_options': 'options'})
        testbed_obj = factory.create_testbed(machine)
        self.assertEqual(testbed_obj._init_args['user'], 'somebody')
        self.assertEqual(testbed_obj._init_args['port'], 100)
        self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb')
        self.assertEqual(testbed_obj._init_args['ssh_options'], 'options')


if __name__ == '__main__':
    unittest.main()

