| # -*- coding: utf-8 -*- |
| # Copyright (c) 2012 The Chromium OS Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| """Test the commandline module.""" |
| |
| from __future__ import print_function |
| |
| import argparse |
| import pickle |
| import signal |
| import os |
| import sys |
| |
| from chromite.cli import command |
| from chromite.lib import build_target_lib |
| from chromite.lib import commandline |
| from chromite.lib import cros_build_lib |
| from chromite.lib import cros_logging as logging |
| from chromite.lib import cros_test_lib |
| from chromite.lib import gs |
| from chromite.lib import path_util |
| |
| |
| class TestShutDownException(cros_test_lib.TestCase): |
| """Test that ShutDownException can be pickled.""" |
| |
| def testShutDownException(self): |
| """Test that ShutDownException can be pickled.""" |
| # pylint: disable=protected-access |
| ex = commandline._ShutDownException(signal.SIGTERM, 'Received SIGTERM') |
| ex2 = pickle.loads(pickle.dumps(ex)) |
| self.assertEqual(ex.signal, ex2.signal) |
| self.assertEqual(str(ex), str(ex2)) |
| |
| |
| class GSPathTest(cros_test_lib.OutputTestCase): |
| """Test type=gs_path normalization functionality.""" |
| |
| GS_REL_PATH = 'bucket/path/to/artifacts' |
| |
| @staticmethod |
| def _ParseCommandLine(argv): |
| parser = commandline.ArgumentParser() |
| parser.add_argument('-g', '--gs-path', type='gs_path', |
| help='GS path that contains the chrome to deploy.') |
| return parser.parse_args(argv) |
| |
| def _RunGSPathTestCase(self, raw, parsed): |
| options = self._ParseCommandLine(['--gs-path', raw]) |
| self.assertEqual(options.gs_path, parsed) |
| |
| def testNoGSPathCorrectionNeeded(self): |
| """Test case where GS path correction is not needed.""" |
| gs_path = '%s/%s' % (gs.BASE_GS_URL, self.GS_REL_PATH) |
| self._RunGSPathTestCase(gs_path, gs_path) |
| |
| def testTrailingSlashRemoval(self): |
| """Test case where GS path ends with /.""" |
| gs_path = '%s/%s/' % (gs.BASE_GS_URL, self.GS_REL_PATH) |
| self._RunGSPathTestCase(gs_path, gs_path.rstrip('/')) |
| |
| def testDuplicateSlashesRemoved(self): |
| """Test case where GS path contains many / in a row.""" |
| self._RunGSPathTestCase( |
| '%s/a/dir/with//////////slashes' % gs.BASE_GS_URL, |
| '%s/a/dir/with/slashes' % gs.BASE_GS_URL) |
| |
| def testRelativePathsRemoved(self): |
| """Test case where GS path contain /../ logic.""" |
| self._RunGSPathTestCase( |
| '%s/a/dir/up/here/.././../now/down/there' % gs.BASE_GS_URL, |
| '%s/a/dir/now/down/there' % gs.BASE_GS_URL) |
| |
| def testCorrectionNeeded(self): |
| """Test case where GS path correction is needed.""" |
| self._RunGSPathTestCase( |
| '%s/%s/' % (gs.PRIVATE_BASE_HTTPS_URL, self.GS_REL_PATH), |
| '%s/%s' % (gs.BASE_GS_URL, self.GS_REL_PATH)) |
| |
| def testInvalidPath(self): |
| """Path cannot be normalized.""" |
| with self.OutputCapturer(): |
| self.assertRaises2( |
| SystemExit, self._RunGSPathTestCase, 'http://badhost.com/path', '', |
| check_attrs={'code': 2}) |
| |
| |
| class BoolTest(cros_test_lib.TestCase): |
| """Test type='bool' functionality.""" |
| |
| @staticmethod |
| def _ParseCommandLine(argv): |
| parser = commandline.ArgumentParser() |
| parser.add_argument('-e', '--enable', type='bool', |
| help='Boolean Argument.') |
| return parser.parse_args(argv) |
| |
| def _RunBoolTestCase(self, enable, expected): |
| options = self._ParseCommandLine(['--enable', enable]) |
| self.assertEqual(options.enable, expected) |
| |
| def testBoolTrue(self): |
| """Test case setting the value to true.""" |
| self._RunBoolTestCase('True', True) |
| self._RunBoolTestCase('1', True) |
| self._RunBoolTestCase('true', True) |
| self._RunBoolTestCase('yes', True) |
| self._RunBoolTestCase('TrUe', True) |
| |
| def testBoolFalse(self): |
| """Test case setting the value to false.""" |
| self._RunBoolTestCase('False', False) |
| self._RunBoolTestCase('0', False) |
| self._RunBoolTestCase('false', False) |
| self._RunBoolTestCase('no', False) |
| self._RunBoolTestCase('FaLse', False) |
| |
| |
| class BuildTargetTest(cros_test_lib.TestCase): |
| """Test type='build_target' functionality.""" |
| |
| @staticmethod |
| def _ParseCommandLine(argv): |
| parser = commandline.ArgumentParser() |
| parser.add_argument('--build-target', type='build_target', |
| help='Build Target Argument.') |
| return parser.parse_args(argv) |
| |
| def _RunBuildTargetTestCase(self, build_target_name, expected): |
| options = self._ParseCommandLine(['--build-target', build_target_name]) |
| self.assertEqual(options.build_target, expected) |
| |
| def _CheckBuildTargetParseFails(self, build_target_name): |
| """Checks that parsing a device input fails.""" |
| with self.assertRaises(SystemExit): |
| self._ParseCommandLine(['--build-target', build_target_name]) |
| |
| def test_valid_build_target(self): |
| """Test valid build target name.""" |
| self._RunBuildTargetTestCase( |
| 'build-target-name', build_target_lib.BuildTarget('build-target-name')) |
| |
| def test_invalid_build_target(self): |
| """Test invalid build target name.""" |
| self._CheckBuildTargetParseFails('invalid-name-!@#$%^&*()+=') |
| |
| |
| class DeviceParseTest(cros_test_lib.OutputTestCase): |
| """Test device parsing functionality.""" |
| |
| _ALL_SCHEMES = (commandline.DEVICE_SCHEME_FILE, |
| commandline.DEVICE_SCHEME_SERVO, |
| commandline.DEVICE_SCHEME_SSH, |
| commandline.DEVICE_SCHEME_USB) |
| |
| def _CheckDeviceParse(self, device_input, scheme, username=None, |
| hostname=None, port=None, path=None, serial=None): |
| """Checks that parsing a device input gives the expected result. |
| |
| Args: |
| device_input (str): Input specifying a device. |
| scheme (str): Expected scheme. |
| username (str|None): Expected username. |
| hostname (str|None): Expected hostname. |
| port (int|None): Expected port. |
| path (str|None): Expected path. |
| serial (str|None): Expected serial number. |
| """ |
| parser = commandline.ArgumentParser() |
| parser.add_argument('device', type=commandline.DeviceParser(scheme)) |
| device = parser.parse_args([device_input]).device |
| self.assertEqual(device.scheme, scheme) |
| self.assertEqual(device.username, username) |
| self.assertEqual(device.hostname, hostname) |
| self.assertEqual(device.port, port) |
| self.assertEqual(device.path, path) |
| self.assertEqual(device.serial_number, serial) |
| |
| def _CheckDeviceParseFails(self, device_input, schemes=_ALL_SCHEMES): |
| """Checks that parsing a device input fails. |
| |
| Args: |
| device_input: String input specifying a device. |
| schemes: A scheme or list of allowed schemes, by default allows all. |
| """ |
| parser = commandline.ArgumentParser() |
| parser.add_argument('device', type=commandline.DeviceParser(schemes)) |
| with self.OutputCapturer(): |
| self.assertRaises2(SystemExit, parser.parse_args, [device_input]) |
| |
| def testNoDevice(self): |
| """Verify that an empty device specification fails.""" |
| self._CheckDeviceParseFails('') |
| |
| def testSshScheme(self): |
| """Verify that SSH scheme-only device specification fails.""" |
| self._CheckDeviceParseFails('ssh://') |
| |
| def testInvalidSshScheme(self): |
| """Verify that invalid ssh specification fails.""" |
| self._CheckDeviceParseFails('sssssh://localhost:22') |
| |
| def testSshHostname(self): |
| """Test SSH hostname-only device specification.""" |
| self._CheckDeviceParse('192.168.1.200', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| hostname='192.168.1.200') |
| |
| def testSshHostnamePort(self): |
| """Test SSH hostname and port device specification.""" |
| self._CheckDeviceParse('192.168.1.200:9999', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| hostname='192.168.1.200', |
| port=9999) |
| |
| def testSshUsernameHostname(self): |
| """Test SSH username and hostname device specification.""" |
| self._CheckDeviceParse('me@foo_host', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| username='me', |
| hostname='foo_host') |
| |
| def testSshUsernameHostnamePort(self): |
| """Test SSH username, hostname, and port device specification.""" |
| self._CheckDeviceParse('me@foo_host:4500', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| username='me', |
| hostname='foo_host', |
| port=4500) |
| |
| def testSshSchemeUsernameHostnamePort(self): |
| """Test SSH scheme, username, hostname, and port device specification.""" |
| self._CheckDeviceParse('ssh://me@foo_host:4500', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| username='me', |
| hostname='foo_host', |
| port=4500) |
| |
| def testEmptyServoScheme(self): |
| """Test empty servo scheme.""" |
| # Everything should be None so the underlying programs (e.g. dut-control) |
| # can use their defaults. |
| self._CheckDeviceParseFails('servo:') |
| |
| def testServoPort(self): |
| """Test valid servo port values.""" |
| self._CheckDeviceParse('servo:port', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| port=None) |
| self._CheckDeviceParse('servo:port:1', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| port=1) |
| self._CheckDeviceParse('servo:port:12345', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| port=12345) |
| self._CheckDeviceParse('servo:port:65535', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| port=65535) |
| |
| def testInvalidServoPort(self): |
| """Invalid port provided.""" |
| self._CheckDeviceParseFails('servo:port:0') |
| self._CheckDeviceParseFails('servo:port:65536') |
| # Some serial numbers. |
| self._CheckDeviceParseFails('servo:port:C1234567890') |
| self._CheckDeviceParseFails('servo:port:123456-12345') |
| |
| def testServoSerialNumber(self): |
| """Test servo serial number.""" |
| # Some known serial number formats. |
| self._CheckDeviceParse('servo:serial:C1234567890', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| serial='C1234567890') |
| self._CheckDeviceParse('servo:serial:123456-12345', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| serial='123456-12345') |
| # Make sure we don't fall back to a port when it looks like one. |
| self._CheckDeviceParse('servo:serial:12345', |
| scheme=commandline.DEVICE_SCHEME_SERVO, |
| serial='12345') |
| |
| def testInvalidServoSerialNumber(self): |
| """Invalid serial number value provided.""" |
| self._CheckDeviceParseFails('servo:serial:') |
| |
| def testUsbScheme(self): |
| """Test USB scheme-only device specification.""" |
| self._CheckDeviceParse('usb://', scheme=commandline.DEVICE_SCHEME_USB) |
| |
| def testUsbSchemePath(self): |
| """Test USB scheme and path device specification.""" |
| self._CheckDeviceParse('usb://path/to/my/device', |
| scheme=commandline.DEVICE_SCHEME_USB, |
| path='path/to/my/device') |
| |
| def testFileScheme(self): |
| """Verify that file scheme-only device specification fails.""" |
| self._CheckDeviceParseFails('file://') |
| |
| def testFileSchemePath(self): |
| """Test file scheme and path device specification.""" |
| self._CheckDeviceParse('file://foo/bar', |
| scheme=commandline.DEVICE_SCHEME_FILE, |
| path='foo/bar') |
| |
| def testAbsolutePath(self): |
| """Verify that an absolute path defaults to file scheme.""" |
| self._CheckDeviceParse('/path/to/my/device', |
| scheme=commandline.DEVICE_SCHEME_FILE, |
| path='/path/to/my/device') |
| |
| def testUnsupportedScheme(self): |
| """Verify that an unsupported scheme fails.""" |
| self._CheckDeviceParseFails('ssh://192.168.1.200', |
| schemes=commandline.DEVICE_SCHEME_USB) |
| self._CheckDeviceParseFails('usb://path/to/my/device', |
| schemes=[commandline.DEVICE_SCHEME_SSH, |
| commandline.DEVICE_SCHEME_FILE]) |
| |
| def testUnknownScheme(self): |
| """Verify that an unknown scheme fails.""" |
| self._CheckDeviceParseFails('ftp://192.168.1.200') |
| |
| def testSchemeCaseInsensitive(self): |
| """Verify that schemes are case-insensitive.""" |
| self._CheckDeviceParse('SSH://foo_host', |
| scheme=commandline.DEVICE_SCHEME_SSH, |
| hostname='foo_host') |
| |
| |
| class AppendOptionTest(cros_test_lib.TestCase): |
| """Verify append_option/append_option_value actions.""" |
| |
| def setUp(self): |
| """Create a standard parser for the tests.""" |
| self.parser = commandline.ArgumentParser() |
| self.parser.add_argument('--flag', action='append_option') |
| self.parser.add_argument('--value', action='append_option_value') |
| self.parser.add_argument('-x', '--shared_flag', dest='shared', |
| action='append_option') |
| self.parser.add_argument('-y', '--shared_value', dest='shared', |
| action='append_option_value') |
| |
| def testNone(self): |
| """Test results when no arguments are passed in.""" |
| result = self.parser.parse_args([]) |
| self.assertDictContainsSubset( |
| {'flag': None, 'value': None, 'shared': None}, |
| vars(result), |
| ) |
| |
| def testSingles(self): |
| """Test results when no argument is used more than once.""" |
| result = self.parser.parse_args( |
| ['--flag', '--value', 'foo', '--shared_flag', '--shared_value', 'bar'] |
| ) |
| |
| self.assertDictContainsSubset( |
| { |
| 'flag': ['--flag'], |
| 'value': ['--value', 'foo'], |
| 'shared': ['--shared_flag', '--shared_value', 'bar'], |
| }, |
| vars(result), |
| ) |
| |
| def testMultiples(self): |
| """Test results when no arguments are used more than once.""" |
| result = self.parser.parse_args([ |
| '--flag', '--value', 'v1', |
| '-x', '-y', 's1', |
| '--shared_flag', '--shared_value', 's2', |
| '--flag', '--value', 'v2', |
| ]) |
| |
| self.assertDictContainsSubset( |
| { |
| 'flag': ['--flag', '--flag'], |
| 'value': ['--value', 'v1', '--value', 'v2'], |
| 'shared': ['-x', '-y', 's1', '--shared_flag', |
| '--shared_value', 's2'], |
| }, |
| vars(result), |
| ) |
| |
| |
| class SplitExtendActionTest(cros_test_lib.TestCase): |
| """Verify _SplitExtendAction/split_extend action.""" |
| |
| def _CheckArgs(self, cliargs, expected): |
| """Check |cliargs| produces |expected|.""" |
| parser = commandline.ArgumentParser() |
| parser.add_argument('-x', action='split_extend', default=[]) |
| opts = parser.parse_args( |
| cros_build_lib.iflatten_instance(['-x', x] for x in cliargs)) |
| self.assertEqual(opts.x, expected) |
| |
| def testDefaultNone(self): |
| """Verify default=None works.""" |
| parser = commandline.ArgumentParser() |
| parser.add_argument('-x', action='split_extend', default=None) |
| |
| opts = parser.parse_args([]) |
| self.assertIs(opts.x, None) |
| |
| opts = parser.parse_args(['-x', '']) |
| self.assertEqual(opts.x, []) |
| |
| opts = parser.parse_args(['-x', 'f']) |
| self.assertEqual(opts.x, ['f']) |
| |
| def testNoArgs(self): |
| """This is more of a sanity check for resting state.""" |
| self._CheckArgs([], []) |
| |
| def testEmptyArg(self): |
| """Make sure '' produces nothing.""" |
| self._CheckArgs(['', ''], []) |
| |
| def testEmptyWhitespaceArg(self): |
| """Make sure whitespace produces nothing.""" |
| self._CheckArgs([' ', '\t', ' \t '], []) |
| |
| def testSingleSingleArg(self): |
| """Verify splitting one arg works.""" |
| self._CheckArgs(['a'], ['a']) |
| |
| def testMultipleSingleArg(self): |
| """Verify splitting one arg works.""" |
| self._CheckArgs(['a b c\td '], ['a', 'b', 'c', 'd']) |
| |
| def testMultipleMultipleArgs(self): |
| """Verify splitting multiple args works.""" |
| self._CheckArgs(['a b c', '', 'x', ' k '], ['a', 'b', 'c', 'x', 'k']) |
| |
| |
| class CacheTest(cros_test_lib.MockTempDirTestCase): |
| """Test cache dir default / override functionality.""" |
| |
| CACHE_DIR = '/fake/cache/dir' |
| |
| def setUp(self): |
| self.PatchObject(commandline.ArgumentParser, 'ConfigureCacheDir') |
| dir_struct = [ |
| 'repo/.repo/', |
| ] |
| cros_test_lib.CreateOnDiskHierarchy(self.tempdir, dir_struct) |
| self.repo_root = os.path.join(self.tempdir, 'repo') |
| self.cwd_mock = self.PatchObject(os, 'getcwd') |
| self.parser = commandline.ArgumentParser(caching=True) |
| |
| def _CheckCall(self, cwd_retval, args_to_parse, expected, assert_func): |
| self.cwd_mock.return_value = cwd_retval |
| self.parser.parse_args(args_to_parse) |
| cache_dir_mock = self.parser.ConfigureCacheDir |
| self.assertEqual(1, cache_dir_mock.call_count) |
| assert_func(cache_dir_mock.call_args[0][0], expected) |
| |
| def testRepoRootNoOverride(self): |
| """Test default cache location when in a repo checkout.""" |
| self._CheckCall(self.repo_root, [], self.repo_root, self.assertStartsWith) |
| |
| def testRepoRootWithOverride(self): |
| """User provided cache location overrides repo checkout default.""" |
| self._CheckCall(self.repo_root, ['--cache-dir', self.CACHE_DIR], |
| self.CACHE_DIR, self.assertEqual) |
| |
| |
| class ParseArgsTest(cros_test_lib.TestCase): |
| """Test parse_args behavior of our custom argument parsing classes.""" |
| |
| def _CreateOptionParser(self, cls): |
| """Create a class of optparse.OptionParser with prepared config. |
| |
| Args: |
| cls: Some subclass of optparse.OptionParser. |
| |
| Returns: |
| The created OptionParser object. |
| """ |
| usage = 'usage: some usage' |
| parser = cls(usage=usage) |
| |
| # Add some options. |
| parser.add_option('-x', '--xxx', action='store_true', default=False, |
| help='Gimme an X') |
| parser.add_option('-y', '--yyy', action='store_true', default=False, |
| help='Gimme a Y') |
| parser.add_option('-a', '--aaa', type='string', default='Allan', |
| help='Gimme an A') |
| parser.add_option('-b', '--bbb', type='string', default='Barry', |
| help='Gimme a B') |
| parser.add_option('-c', '--ccc', type='string', default='Connor', |
| help='Gimme a C') |
| |
| return parser |
| |
| def _CreateArgumentParser(self, cls): |
| """Create a class of argparse.ArgumentParser with prepared config. |
| |
| Args: |
| cls: Some subclass of argparse.ArgumentParser. |
| |
| Returns: |
| The created ArgumentParser object. |
| """ |
| usage = 'usage: some usage' |
| parser = cls(usage=usage) |
| |
| # Add some options. |
| parser.add_argument('-x', '--xxx', action='store_true', default=False, |
| help='Gimme an X') |
| parser.add_argument('-y', '--yyy', action='store_true', default=False, |
| help='Gimme a Y') |
| parser.add_argument('-a', '--aaa', type=str, default='Allan', |
| help='Gimme an A') |
| parser.add_argument('-b', '--bbb', type=str, default='Barry', |
| help='Gimme a B') |
| parser.add_argument('-c', '--ccc', type=str, default='Connor', |
| help='Gimme a C') |
| parser.add_argument('args', type=str, nargs='*', help='args') |
| |
| return parser |
| |
| def _TestParser(self, parser): |
| """Test the given parser with a prepared argv.""" |
| argv = ['-x', '--bbb', 'Bobby', '-c', 'Connor', 'foobar'] |
| |
| parsed = parser.parse_args(argv) |
| |
| if isinstance(parser, commandline.FilteringParser): |
| # optparse returns options and args separately. |
| options, args = parsed |
| self.assertEqual(['foobar'], args) |
| else: |
| # argparse returns just options. Options configured above to have the |
| # args stored at option "args". |
| options = parsed |
| self.assertEqual(['foobar'], parsed.args) |
| |
| self.assertTrue(options.xxx) |
| self.assertFalse(options.yyy) |
| |
| self.assertEqual('Allan', options.aaa) |
| self.assertEqual('Bobby', options.bbb) |
| self.assertEqual('Connor', options.ccc) |
| |
| self.assertRaises(AttributeError, getattr, options, 'xyz') |
| |
| # Now try altering option values. |
| options.aaa = 'Arick' |
| self.assertEqual('Arick', options.aaa) |
| |
| # Now freeze the options and try altering again. |
| options.Freeze() |
| self.assertRaises(commandline.attrs_freezer.Error, |
| setattr, options, 'aaa', 'Arnold') |
| self.assertEqual('Arick', options.aaa) |
| |
| def testFilterParser(self): |
| self._TestParser(self._CreateOptionParser(commandline.FilteringParser)) |
| |
| def testArgumentParser(self): |
| self._TestParser(self._CreateArgumentParser(commandline.ArgumentParser)) |
| |
| def testDisableCommonLogging(self): |
| """Verify we can elide common logging options.""" |
| parser = commandline.ArgumentParser(logging=False) |
| |
| # Sanity check it first. |
| opts = parser.parse_args([]) |
| self.assertFalse(hasattr(opts, 'log_level')) |
| |
| # Now add our own logging options. If the options were added, |
| # argparse would throw duplicate flag errors for us. |
| parser.add_argument('--log-level') |
| parser.add_argument('--nocolor') |
| |
| def testCommonBaseDefaults(self): |
| """Make sure common options work with just a base parser.""" |
| parser = commandline.ArgumentParser(logging=True, default_log_level='info') |
| |
| # Make sure the default works. |
| opts = parser.parse_args([]) |
| self.assertEqual(opts.log_level, 'info') |
| self.assertEqual(opts.color, None) |
| |
| # Then we can set up our own values. |
| opts = parser.parse_args(['--nocolor', '--log-level=notice']) |
| self.assertEqual(opts.log_level, 'notice') |
| self.assertEqual(opts.color, False) |
| |
| def testCommonBaseAndSubDefaults(self): |
| """Make sure common options work between base & sub parsers.""" |
| parser = commandline.ArgumentParser(logging=True, default_log_level='info') |
| |
| sub_parsers = parser.add_subparsers(title='Subs') |
| sub_parsers.add_parser('cmd1') |
| sub_parsers.add_parser('cmd2') |
| |
| # Make sure the default works. |
| opts = parser.parse_args(['cmd1']) |
| self.assertEqual(opts.log_level, 'info') |
| self.assertEqual(opts.color, None) |
| |
| # Make sure options passed to base parser work. |
| opts = parser.parse_args(['--nocolor', '--log-level=notice', 'cmd2']) |
| self.assertEqual(opts.log_level, 'notice') |
| self.assertEqual(opts.color, False) |
| |
| # Make sure options passed to sub parser work. |
| opts = parser.parse_args(['cmd2', '--nocolor', '--log-level=notice']) |
| self.assertEqual(opts.log_level, 'notice') |
| self.assertEqual(opts.color, False) |
| |
| |
| class ScriptWrapperMainTest(cros_test_lib.MockTestCase): |
| """Test the behavior of the ScriptWrapperMain function.""" |
| |
| def setUp(self): |
| self.PatchObject(sys, 'exit') |
| self.lastTargetFound = None |
| |
| SYS_ARGV = ['/cmd', '/cmd', 'arg1', 'arg2'] |
| CMD_ARGS = ['/cmd', 'arg1', 'arg2'] |
| # The exact flags here don't matter as we don't invoke the underlying script. |
| # Lets pick something specifically invalid just in case we do. |
| CHROOT_ARGS = ['--some-option', 'foo'] |
| |
| def testRestartInChrootPreserveArgs(self): |
| """Verify args to ScriptWrapperMain are passed through to chroot..""" |
| # Setup Mocks/Fakes |
| rc = self.StartPatcher(cros_test_lib.RunCommandMock()) |
| rc.SetDefaultCmdResult() |
| |
| def findTarget(target): |
| """ScriptWrapperMain needs a function to find a function to run.""" |
| def raiseChrootRequiredError(args): |
| raise commandline.ChrootRequiredError(args) |
| |
| self.lastTargetFound = target |
| return raiseChrootRequiredError |
| |
| # Run Test |
| commandline.ScriptWrapperMain(findTarget, self.SYS_ARGV) |
| |
| # Verify Results |
| rc.assertCommandContains(enter_chroot=True) |
| rc.assertCommandContains(self.CMD_ARGS) |
| self.assertEqual('/cmd', self.lastTargetFound) |
| |
| def testRestartInChrootWithChrootArgs(self): |
| """Verify args and chroot args from exception are used.""" |
| # Setup Mocks/Fakes |
| rc = self.StartPatcher(cros_test_lib.RunCommandMock()) |
| rc.SetDefaultCmdResult() |
| |
| def findTarget(_): |
| """ScriptWrapperMain needs a function to find a function to run.""" |
| def raiseChrootRequiredError(_args): |
| raise commandline.ChrootRequiredError(self.CMD_ARGS, self.CHROOT_ARGS) |
| |
| return raiseChrootRequiredError |
| |
| # Run Test |
| commandline.ScriptWrapperMain(findTarget, ['unrelated']) |
| |
| # Verify Results |
| rc.assertCommandContains(enter_chroot=True) |
| rc.assertCommandContains(self.CMD_ARGS) |
| rc.assertCommandContains(chroot_args=self.CHROOT_ARGS) |
| |
| |
| class TestRunInsideChroot(cros_test_lib.MockTestCase): |
| """Test commandline.RunInsideChroot().""" |
| |
| def setUp(self): |
| self.orig_argv = sys.argv |
| sys.argv = ['/cmd', 'arg1', 'arg2'] |
| |
| self.mockFromHostToChrootPath = self.PatchObject( |
| path_util, 'ToChrootPath', return_value='/inside/cmd') |
| |
| # Return values for these two should be set by each test. |
| self.mock_inside_chroot = self.PatchObject(cros_build_lib, 'IsInsideChroot') |
| |
| # Mocked CliCommand object to pass to RunInsideChroot. |
| self.cmd = command.CliCommand(argparse.Namespace()) |
| self.cmd.options.log_level = 'info' |
| |
| def teardown(self): |
| sys.argv = self.orig_argv |
| |
| def _VerifyRunInsideChroot(self, expected_cmd, expected_chroot_args=None, |
| log_level_args=None, **kwargs): |
| """Run RunInsideChroot, and verify it raises with expected values. |
| |
| Args: |
| expected_cmd: Command that should be executed inside the chroot. |
| expected_chroot_args: Args that should be passed as chroot args. |
| log_level_args: Args that set the log level of cros_sdk. |
| kwargs: Additional args to pass to RunInsideChroot(). |
| """ |
| with self.assertRaises(commandline.ChrootRequiredError) as cm: |
| commandline.RunInsideChroot(self.cmd, **kwargs) |
| |
| if log_level_args is None: |
| if self.cmd is not None: |
| log_level_args = ['--log-level', self.cmd.options.log_level] |
| else: |
| log_level_args = [] |
| |
| if expected_chroot_args is not None: |
| log_level_args.extend(expected_chroot_args) |
| expected_chroot_args = log_level_args |
| else: |
| expected_chroot_args = log_level_args |
| |
| self.assertEqual(expected_cmd, cm.exception.cmd) |
| self.assertEqual(expected_chroot_args, cm.exception.chroot_args) |
| |
| def testRunInsideChroot(self): |
| """Test we can restart inside the chroot.""" |
| self.mock_inside_chroot.return_value = False |
| self._VerifyRunInsideChroot(['/inside/cmd', 'arg1', 'arg2']) |
| |
| def testRunInsideChrootWithoutCommand(self): |
| """Test that RunInsideChroot can get by without the |command| parameter.""" |
| self.mock_inside_chroot.return_value = False |
| self.cmd = None |
| self._VerifyRunInsideChroot(['/inside/cmd', 'arg1', 'arg2']) |
| |
| def testRunInsideChrootLogLevel(self): |
| """Test chroot restart with properly inherited log-level.""" |
| self.cmd.options.log_level = 'notice' |
| self.mock_inside_chroot.return_value = False |
| self._VerifyRunInsideChroot(['/inside/cmd', 'arg1', 'arg2'], |
| log_level_args=['--log-level', 'notice']) |
| |
| def testRunInsideChrootAlreadyInside(self): |
| """Test we don't restart inside the chroot if we are already there.""" |
| self.mock_inside_chroot.return_value = True |
| |
| # Since we are in the chroot, it should return, doing nothing. |
| commandline.RunInsideChroot(self.cmd) |
| |
| |
| class DeprecatedActionTest(cros_test_lib.MockTestCase): |
| """Test the _DeprecatedAction integration.""" |
| |
| def setUp(self): |
| self.warning_patch = self.PatchObject(logging, 'warning') |
| |
| # Setup arguments for a handful of actions. |
| self.argument_parser = commandline.ArgumentParser() |
| self.argument_parser.add_argument('--store') |
| self.argument_parser.add_argument('--store-true', action='store_true') |
| self.argument_parser.add_argument('--append', action='append', type=int) |
| self.argument_parser.add_argument('--dep-store', |
| deprecated='Deprecated store') |
| self.argument_parser.add_argument('--dep-store-true', action='store_true', |
| deprecated='Deprecated store true') |
| self.argument_parser.add_argument('--dep-append', action='append', type=int, |
| deprecated='Deprecated append') |
| |
| self.not_deprecated = ['--store', 'a', '--store-true', '--append', '1', |
| '--append', '2'] |
| self.deprecated = ['--dep-store', 'b', '--dep-store-true', |
| '--dep-append', '3', '--dep-append', '4'] |
| self.mixed = self.not_deprecated + self.deprecated |
| |
| self.store_expected = 'a' |
| self.append_expected = [1, 2] |
| self.dep_store_expected = 'b' |
| self.dep_append_expected = [3, 4] |
| |
| def testNonDeprecatedParsing(self): |
| """Test normal parsing is not affected.""" |
| opts = self.argument_parser.parse_args(self.not_deprecated) |
| |
| self.assertFalse(self.warning_patch.called) |
| |
| self.assertEqual(self.store_expected, opts.store) |
| self.assertTrue(opts.store_true) |
| self.assertEqual(self.append_expected, opts.append) |
| |
| self.assertIsNone(opts.dep_store) |
| self.assertFalse(opts.dep_store_true) |
| self.assertIsNone(opts.dep_append) |
| |
| def testDeprecatedParsing(self): |
| """Test deprecated parsing logs the warning but parses normally.""" |
| opts = self.argument_parser.parse_args(self.deprecated) |
| |
| self.assertTrue(self.warning_patch.called) |
| |
| self.assertIsNone(opts.store) |
| self.assertFalse(opts.store_true) |
| self.assertIsNone(opts.append) |
| |
| self.assertEqual(self.dep_store_expected, opts.dep_store) |
| self.assertTrue(opts.dep_store_true) |
| self.assertEqual(self.dep_append_expected, opts.dep_append) |
| |
| def testMixedParsing(self): |
| """Test parsing a mix of arguments.""" |
| opts = self.argument_parser.parse_args(self.mixed) |
| |
| self.assertTrue(self.warning_patch.called) |
| |
| self.assertEqual(self.store_expected, opts.store) |
| self.assertTrue(opts.store_true) |
| self.assertEqual(self.append_expected, opts.append) |
| |
| self.assertEqual(self.dep_store_expected, opts.dep_store) |
| self.assertTrue(opts.dep_store_true) |
| self.assertEqual(self.dep_append_expected, opts.dep_append) |