Accept an optional testing ssh private key

BUG=brillo:1200
TEST=unit tests

Change-Id: Ifbba4eb334ce919a96b5ee5dfce1eefba3edcc03
Reviewed-on: https://chromium-review.googlesource.com/286934
Reviewed-by: Prathmesh Prabhu <pprabhu@chromium.org>
Commit-Queue: Daniel Wang <wonderfly@google.com>
Tested-by: Daniel Wang <wonderfly@google.com>
diff --git a/au_test_harness/au_test.py b/au_test_harness/au_test.py
index a04a6e7..a4c5758 100644
--- a/au_test_harness/au_test.py
+++ b/au_test_harness/au_test.py
@@ -35,7 +35,7 @@
       options: options class to be parsed from main class.
     """
     cls.base_image_path = options.base_image
-    cls.private_key = options.private_key
+    cls.payload_signing_key = options.payload_signing_key
     cls.target_image_path = options.target_image
     cls.test_results_root = options.test_results_root
     if options.type == 'vm':
@@ -173,10 +173,10 @@
     self.worker.Initialize(9226)
     signed_target_image_path = self.worker.PrepareBase(self.target_image_path,
                                                        signed_base=True)
-    if self.private_key:
-      self.worker.PerformUpdate(self.target_image_path,
-                                signed_target_image_path,
-                                private_key_path=self.private_key)
+    if self.payload_signing_key:
+      self.worker.PerformUpdate(
+          self.target_image_path, signed_target_image_path,
+          payload_signing_key=self.payload_signing_key)
     else:
       logging.info('No key found to use for signed testing.')
 
diff --git a/au_test_harness/au_worker.py b/au_test_harness/au_worker.py
index d137b70..26b7278 100644
--- a/au_test_harness/au_worker.py
+++ b/au_test_harness/au_worker.py
@@ -39,6 +39,9 @@
     else:
       self.verify_suite = 'suite:%s' % (options.verify_suite_name or 'smoke')
 
+    # An optional ssh private key for testing.
+    self.ssh_private_key = options.ssh_private_key
+
   def CleanUp(self):
     """Called at the end of every test."""
 
@@ -68,7 +71,7 @@
     """
 
   def UpdateImage(self, image_path, src_image_path='', stateful_change='old',
-                  proxy_port=None, private_key_path=None):
+                  proxy_port=None, payload_signing_key=None):
     """Implementation of an actual update.
 
     Subclasses must override this method with the correct update procedure for
@@ -113,7 +116,7 @@
   # --- INTERFACE TO AU_TEST ---
 
   def PerformUpdate(self, image_path, src_image_path='', stateful_change='old',
-                    proxy_port=None, private_key_path=None):
+                    proxy_port=None, payload_signing_key=None):
     """Performs an update using  _UpdateImage and reports any error.
 
     Subclasses should not override this method but override _UpdateImage
@@ -130,11 +133,11 @@
             exception of code needed for ssh.
       proxy_port:  Port to have the client connect to. For use with
         CrosTestProxy.
-      private_key_path:  Path to a private key to use with update payload.
+      payload_signing_key: Path to the private key to use to sign payloads.
     Raises an update_exception.UpdateException if _UpdateImage returns an error.
     """
     if not self.use_delta_updates: src_image_path = ''
-    key_to_use = private_key_path
+    key_to_use = payload_signing_key
 
     self.UpdateImage(image_path, src_image_path, stateful_change, proxy_port,
                      key_to_use)
@@ -176,7 +179,7 @@
     return stateful_change_flag
 
   def AppendUpdateFlags(self, cmd, image_path, src_image_path, proxy_port,
-                        private_key_path, for_vm=False):
+                        payload_signing_key, for_vm=False):
     """Appends common args to an update cmd defined by an array.
 
     Modifies cmd in places by appending appropriate items given args.
@@ -188,7 +191,7 @@
     """
     if proxy_port: cmd.append('--proxy_port=%s' % proxy_port)
     update_id = dev_server_wrapper.GenerateUpdateId(
-        image_path, src_image_path, private_key_path,
+        image_path, src_image_path, payload_signing_key,
         for_vm=for_vm)
     cache_path = self.update_cache.get(update_id)
     if cache_path:
diff --git a/au_test_harness/cros_au_test_harness.py b/au_test_harness/cros_au_test_harness.py
index c7d189d..8dd3ce3 100755
--- a/au_test_harness/cros_au_test_harness.py
+++ b/au_test_harness/cros_au_test_harness.py
@@ -125,10 +125,9 @@
     leftover_args: Args left after parsing.
   """
 
-  _IMAGE_PATH_REQUIREMENT = """
-  For vm and real types, the image should be a local file and for gce, the image
-  path has to be a valid Google Cloud Storage path.
-  """
+  _IMAGE_PATH_REQUIREMENT = ('For vm and real types, the image must be a local '
+                             'file. For gce, the image path has to be a valid '
+                             'Google Cloud Storage URI.')
 
   if leftover_args: parser.error('Found unsupported flags ' + leftover_args)
   if not options.type in ['real', 'vm', 'gce']:
@@ -142,20 +141,28 @@
             os.path.isfile(image))
 
   if not _IsValidImage(options.target_image):
-    parser.error('Testing requires a valid target image. Given %s. %s' %
-                 (options.target_image, _IMAGE_PATH_REQUIREMENT))
+    parser.error('Testing requires a valid target image.\n'
+                 '%s\n'
+                 'Given: type=%s, target_image=%s.' %
+                 (_IMAGE_PATH_REQUIREMENT, options.type, options.target_image))
 
   if not options.base_image:
     logging.info('No base image supplied.  Using target as base image.')
     options.base_image = options.target_image
 
   if not _IsValidImage(options.base_image):
-    parser.error('Testing requires a valid base image. Given: %s. %s' %
-                 (options.base_image, _IMAGE_PATH_REQUIREMENT))
+    parser.error('Testing requires a valid base image.\n'
+                 '%s\n'
+                 'Given: type=%s, base_image=%s.' %
+                 (_IMAGE_PATH_REQUIREMENT, options.type, options.base_image))
 
-  if options.private_key and not os.path.isfile(options.private_key):
+  if (options.payload_signing_key and not
+      os.path.isfile(options.payload_signing_key)):
     parser.error('Testing requires a valid path to the private key.')
 
+  if options.ssh_private_key and not os.path.isfile(options.ssh_private_key):
+    parser.error('Testing requires a valid path to the ssh private key.')
+
   if options.test_results_root:
     if not 'chroot/tmp' in options.test_results_root:
       parser.error('Must specify a test results root inside tmp in a chroot.')
@@ -183,7 +190,7 @@
                     help='Disable graphics for the vm test.')
   parser.add_option('-j', '--jobs', default=test_helper.CalculateDefaultJobs(),
                     type=int, help='Number of simultaneous jobs')
-  parser.add_option('--private_key', default=None,
+  parser.add_option('--payload_signing_key', default=None,
                     help='Path to the private key used to sign payloads with.')
   parser.add_option('-q', '--quick_test', default=False, action='store_true',
                     help='Use a basic test to verify image.')
@@ -211,6 +218,9 @@
                     action='store_true',
                     help='Run multiple test stages in parallel (applies only '
                          'to vm tests). Default: False')
+  parser.add_option('--ssh_private_key', default=None,
+                    help='Path to the private key to use to ssh into the image '
+                    'as the root user.')
   (options, leftover_args) = parser.parse_args()
 
   CheckOptions(parser, options, leftover_args)
diff --git a/au_test_harness/cros_au_test_harness_unittest.py b/au_test_harness/cros_au_test_harness_unittest.py
index 9b06e4b..c16bcd0 100755
--- a/au_test_harness/cros_au_test_harness_unittest.py
+++ b/au_test_harness/cros_au_test_harness_unittest.py
@@ -64,7 +64,7 @@
     self.assertIn(self.INVALID_IMAGE_PATH, cm.exception.result.error)
 
     cmd = [os.path.join(constants.CROSUTILS_DIR, 'bin', 'cros_au_test_harness'),
-           '--type=vm',
+           '--type=gce',
            '--target_image=%s' % gs_path
           ]
     with self.assertRaises(cros_build_lib.RunCommandError) as cm:
diff --git a/au_test_harness/gce_au_worker.py b/au_test_harness/gce_au_worker.py
index 1cc0ccb..74ac082 100644
--- a/au_test_harness/gce_au_worker.py
+++ b/au_test_harness/gce_au_worker.py
@@ -113,14 +113,15 @@
       test = self.verify_suite
 
     self.TestInfo('Running test %s to verify image.' % test)
-    result = cros_build_lib.RunCommand(
-        ['test_that',
-         '--no-quickmerge',
-         '--results_dir=%s' % test_directory,
-         self.instance_ip,
-         test
-        ], error_code_ok=True, enter_chroot=True, redirect_stdout=True,
-        cwd=constants.CROSUTILS_DIR)
+
+    cmd = ['test_that', '--no-quickmerge', '--results_dir=%s' % test_directory,
+           self.instance_ip, test]
+    if self.ssh_private_key is not None:
+      cmd.append('--ssh_private_key=%s' % self.ssh_private_key)
+
+    result = cros_build_lib.RunCommand(cmd, error_code_ok=True,
+                                       enter_chroot=True, redirect_stdout=True,
+                                       cwd=constants.CROSUTILS_DIR)
     ret = self.AssertEnoughTestsPassed(unittest, result.output,
                                        percent_required_to_pass)
     if not ret:
diff --git a/au_test_harness/real_au_worker.py b/au_test_harness/real_au_worker.py
index df5e698..ba7aa04 100644
--- a/au_test_harness/real_au_worker.py
+++ b/au_test_harness/real_au_worker.py
@@ -24,7 +24,7 @@
     return self.PrepareRealBase(image_path, signed_base)
 
   def UpdateImage(self, image_path, src_image_path='', stateful_change='old',
-                  proxy_port=None, private_key_path=None):
+                  proxy_port=None, payload_signing_key=None):
     """Updates a remote image using image_to_live.sh."""
     stateful_change_flag = self.GetStatefulChangeFlag(stateful_change)
     cmd = ['%s/image_to_live.sh' % constants.CROSUTILS_DIR,
@@ -33,7 +33,7 @@
            '--verify',
           ]
     self.AppendUpdateFlags(cmd, image_path, src_image_path, proxy_port,
-                           private_key_path)
+                           payload_signing_key)
     self.RunUpdateCmd(cmd)
 
   def UpdateUsingPayload(self, update_path, stateful_change='old',
@@ -54,14 +54,15 @@
     test_directory, _ = self.GetNextResultsPath('autotest_tests')
     if not test: test = self.verify_suite
 
-    result = cros_build_lib.RunCommand(
-        ['test_that',
-         '--no-quickmerge',
-         '--results_dir=%s' % test_directory,
-         self.remote,
-         test
-        ], error_code_ok=True, enter_chroot=True, redirect_stdout=True,
-        cwd=constants.CROSUTILS_DIR)
+    cmd = ['test_that', '--no-quickmerge', '--results_dir=%s' % test_directory,
+           self.remote, test]
+    if self.ssh_private_key is not None:
+      cmd.append('--ssh_private_key=%s' % self.ssh_private_key)
+
+    result = cros_build_lib.RunCommand(cmd, error_code_ok=True,
+                                       enter_chroot=True, redirect_stdout=True,
+                                       cwd=constants.CROSUTILS_DIR)
+
     return self.AssertEnoughTestsPassed(unittest, result.output,
                                         percent_required_to_pass)
 
diff --git a/au_test_harness/vm_au_worker.py b/au_test_harness/vm_au_worker.py
index c093070..7d2427f 100644
--- a/au_test_harness/vm_au_worker.py
+++ b/au_test_harness/vm_au_worker.py
@@ -85,7 +85,7 @@
       logging.warning('Ignoring errors while copying VM files: %s', e)
 
   def UpdateImage(self, image_path, src_image_path='', stateful_change='old',
-                  proxy_port='', private_key_path=None):
+                  proxy_port='', payload_signing_key=None):
     """Updates VM image with image_path."""
     log_directory, fail_directory = self.GetNextResultsPath('update')
     stateful_change_flag = self.GetStatefulChangeFlag(stateful_change)
@@ -100,7 +100,7 @@
            stateful_change_flag,
           ]
     self.AppendUpdateFlags(cmd, image_path, src_image_path, proxy_port,
-                           private_key_path)
+                           payload_signing_key)
     self.TestInfo(self.GetUpdateMessage(image_path, src_image_path, True,
                                         proxy_port))
     try:
@@ -134,7 +134,7 @@
       raise
 
   def AppendUpdateFlags(self, cmd, image_path, src_image_path, proxy_port,
-                        private_key_path, for_vm=False):
+                        payload_signing_key, for_vm=False):
     """Appends common args to an update cmd defined by an array.
 
     Calls super function with for_vm set to True.
@@ -143,7 +143,7 @@
       See AppendUpdateFlags for description of args.
     """
     super(VMAUWorker, self).AppendUpdateFlags(
-        cmd, image_path, src_image_path, proxy_port, private_key_path,
+        cmd, image_path, src_image_path, proxy_port, payload_signing_key,
         for_vm=True)
 
   def VerifyImage(self, unittest, percent_required_to_pass=100, test=''):
@@ -175,6 +175,9 @@
     if self.graphics_flag: command.append(self.graphics_flag)
     if self.whitelist_chrome_crashes:
       command.append('--whitelist_chrome_crashes')
+    if self.ssh_private_key is not None:
+      logging.warning('Flag "--ssh_private_key" set but not yet supported for '
+                      '"VMAUWorker". Default test key will be used.')
     self.TestInfo('Running smoke suite to verify image.')
     result = cros_build_lib.RunCommand(
         command, print_cmd=False, combine_stdout_stderr=True,
diff --git a/ctest/ctest.py b/ctest/ctest.py
index e7454c4..940077f 100755
--- a/ctest/ctest.py
+++ b/ctest/ctest.py
@@ -36,7 +36,7 @@
     jobs: Numbers of threads to run in parallel.
     no_graphics: boolean: If True, disable graphics during vm test.
     nplus1_archive_dir: Archive directory to store nplus1 payloads.
-    private_key: Signs payloads with this key.
+    payload_signing_key: Signs payloads with this key.
     public_key: Loads key to verify signed payloads.
     remote: ip address for real test harness run.
     sign_payloads: Build some payloads with signed keys.
@@ -68,22 +68,25 @@
 
     self.public_key = None
     if self.sign_payloads:
-      self.private_key = os.path.realpath(
+      self.payload_signing_key = os.path.realpath(
           os.path.join(self.crosutils_root, '..', 'platform', 'update_engine',
                        'unittest_key.pem'))
     else:
-      self.private_key = None
+      self.payload_signing_key = None
 
     self.jobs = options.jobs
     self.nplus1_archive_dir = options.nplus1_archive_dir
 
+    # An optional ssh private key used for testing.
+    self.ssh_private_key = options.ssh_private_key
+
   def GeneratePublicKey(self):
     """Returns the path to a generated public key from the UE private key."""
     # Just output to local directory.
     public_key_path = 'public_key.pem'
     logging.info('Generating public key from private key.')
     cros_build_lib.RunCommand(
-        ['openssl', 'rsa', '-in', self.private_key, '-pubout',
+        ['openssl', 'rsa', '-in', self.payload_signing_key, '-pubout',
          '-out', public_key_path], print_cmd=False)
     self.public_key = public_key_path
 
@@ -136,7 +139,7 @@
       # This only is compatible with payload signing.
       if self.sign_payloads:
         cmd.append('--public_key=%s' % self.public_key)
-        cmd.append('--private_key=%s' % self.private_key)
+        cmd.append('--private_key=%s' % self.payload_signing_key)
     else:
       cmd.append('--basic_suite')
 
@@ -175,6 +178,9 @@
            '--jobs=%d' % self.jobs,
           ]
 
+    if self.ssh_private_key is not None:
+      cmd.append('--ssh_private_key=%s' % self.ssh_private_key)
+
     if suite:
       cmd.append('--verify_suite_name=%s' % suite)
 
@@ -190,7 +196,7 @@
 
     # We did not generate signed payloads if this is a |quick_update| test.
     if not quick_update and self.sign_payloads:
-      cmd.append('--private_key=%s' % self.private_key)
+      cmd.append('--payload_signing_key=%s' % self.payload_signing_key)
 
     res = cros_build_lib.RunCommand(cmd, cwd=self.crosutils_root,
                                     error_code_ok=True)
@@ -237,6 +243,9 @@
   parser.add_option('--whitelist_chrome_crashes', default=False,
                     dest='whitelist_chrome_crashes', action='store_true',
                     help='Treat Chrome crashes as non-fatal.')
+  parser.add_option('--ssh_private_key', default=None,
+                    help='Path to the private key to use to ssh into the image '
+                    'as the root user')
 
   # Set the usage to include flags.
   def _ParserError(msg):