gce_au_worker: Allow custom tests and flags

One major advantage of running tests on GCE is the potential to run tests with
different GCE instance property permutations, for example, some tests may
require an instance to start with a certain metadata, or a benchmark may desire
to run tests with different types of machines, e.g., n1-standard-1 or
n1-standard-8. This CL enables this ability by allowing board overlays to
declare their requirement of instance properties in a JSON file, and to group
tests based on their environmental request.

Board overlays will have a JSON file, i.e., <overlay>/scripts/gce_tests.json,
that define the list of tests that they would like to run to verify the built
image, in the format of,

{
  tests: [
    {
      name: test_or_suite_name,
      flags: {}
    },

  ]
}

, where each test object can be any valid Autotest test case, or a suite of
tests, e.g., "suite:smoke". Each test object will have two members - name
and flags. flags is a JSON object, containing key-value pairs that follow
the exact schemas of GCE Instance Resource properties as listed here:
https://cloud.google.com/compute/docs/reference/latest/instances#resource. Any
mis-formed key-value pair will result in an instance creation failure.

On the absence of such a JSON file, the smoke suite will be run on an instance
created with default flags.

As all tests share a GCP project thats managed by cros-infras team, and
allowing arbitrary instance properties may introduce undesired surprises, we
only whitelist trusted boards to use this feature.

This CL also updates lib/gce.py with a thread-safe service client, which adds to
memory footprint, so the original non-thread-safe one is kept as default.

BUG=brillo:1285
TEST=au_test_harness/gce_au_worker_unittest.py and trybot run against
lakitu-pre-cq

Change-Id: I1b412a28df9bfe4558fc53b85d6ccf991c29d638
Reviewed-on: https://chromium-review.googlesource.com/308536
Commit-Ready: Daniel Wang <wonderfly@google.com>
Tested-by: Daniel Wang <wonderfly@google.com>
Reviewed-by: Aditya Kali <adityakali@google.com>
Reviewed-by: Daniel Wang <wonderfly@google.com>
diff --git a/au_test_harness/au_worker.py b/au_test_harness/au_worker.py
index b2062cb..88ec805 100644
--- a/au_test_harness/au_worker.py
+++ b/au_test_harness/au_worker.py
@@ -231,7 +231,7 @@
     Returns:
       percent that passed.
     """
-    percent_passed = self._ParseGenerateTestReportOutput(output)
+    percent_passed = self.ParseGeneratedTestOutput(output)
     self.TestInfo('Percent passed: %d vs. Percent required: %d' % (
         percent_passed, percent_required_to_pass))
     if percent_passed < percent_required_to_pass:
@@ -290,9 +290,7 @@
 
     return results_dir, fail_dir
 
-  # --- PRIVATE HELPER FUNCTIONS ---
-
-  def _ParseGenerateTestReportOutput(self, output):
+  def ParseGeneratedTestOutput(self, output):
     """Returns the percentage of tests that passed based on output.
 
     Args:
diff --git a/au_test_harness/gce_au_worker.py b/au_test_harness/gce_au_worker.py
index ed9c82d..f2441cd 100644
--- a/au_test_harness/gce_au_worker.py
+++ b/au_test_harness/gce_au_worker.py
@@ -2,21 +2,81 @@
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
-"""Module containing class that implements an au_worker for GCE instances."""
+"""Module containing a class that implements an au_worker for GCE instances.
+
+By default GCEAUWorker creates a GCE instance with 'Default Instance Properties'
+(detailed below), and runs the gce-smoke suite to verify an image. However it
+allows customized test/suite list and instance properties, through an overlay
+specific JSON file.
+
+Default Instance Properties:
+  project: constants.GCE_PROJECT
+  zone: constants.GCE_DEFAULT_ZONE
+  machine_type: n1-standard-8
+  network: constants.GCE_DEFAULT_NETWORK
+  other properties: GCE default.
+    https://cloud.google.com/compute/docs/reference/latest/instances/insert
+
+To run tests/suites other than the gce-smoke suite, and to specify the instance
+properties, add gce_tests.json under <overlay>/scripts. Refer to _LoadTests for
+the exact requirement of this file, but here is a short example:
+  {
+    "tests": [
+      {
+        "name": "suite:suite1",
+        "flags": {
+          "metadata": {
+            "items": [
+              {
+                "key": "key1",
+                "value": "value1"
+              }
+            ]
+          }
+        }
+      },
+      {
+        "name": "foo_Test",
+        "flags": {}
+      }
+    ]
+  }
+
+"flags" must strictly follow the schema of the Instance Resource
+(https://cloud.google.com/compute/docs/reference/latest/instances#resource).
+
+GCEAUWorker respects most of the properties except instance name, boot_disk,
+network and zone. The enforced values of these special properties are:
+  instance_name: managed name
+  boot_disk: a disk with the image being verified
+  network: the network that has required firewall set up
+  zone: project selected default zone
+
+Some of the properties of the Instance Resource are set by the GCE
+backend so trying to set them at the client may result in noops or GCE errors,
+which will be wrapped into an UpdateException.
+
+Note that some properties like 'disks' that depend on the existence of other
+resources are not supported yet.
+"""
 
 from __future__ import print_function
 
 import datetime
+import json
 import os
 import shutil
 import time
 
+from functools import partial
 from multiprocessing import Process
 
 from chromite.lib import cros_build_lib
 from chromite.lib import cros_logging as logging
 from chromite.lib import gs
+from chromite.lib import parallel
 from chromite.lib import path_util
+from chromite.lib import portage_util
 from crostestutils.au_test_harness import au_worker
 from crostestutils.au_test_harness import constants
 from crostestutils.au_test_harness import update_exception
@@ -30,41 +90,48 @@
     gce_context: An utility for GCE operations.
     gscontext: An utility for GCS operations.
     gcs_bucket: The GCS bucket to upload image tarballs to.
-    instance: A single VM instance associated with a worker.
-    image: A single GCE image associated with a worker.
     tarball_local: Local path to the tarball of test image.
     tarball_remote: GCS path to the tarball of test image.
+    image: A single GCE image associated with a worker.
+    image_link: The URL to the image created.
+    instances: GCE VM instances associated with a worker.
     bg_delete_processes:
       Background processes that delete stale instances and images.
   """
 
-  INSTANCE_PREFIX = 'test-instance-'
-  IMAGE_PREFIX = 'test-image-'
   GS_PATH_COMMON_PREFIX = 'gs://'
   GS_URL_COMMON_PREFIX = 'https://storage.googleapis.com/'
+  IMAGE_PREFIX = 'test-image-'
+  INSTANCE_PREFIX = 'test-instance-'
 
   def __init__(self, options, test_results_root,
                project=constants.GCE_PROJECT,
                zone=constants.GCE_DEFAULT_ZONE,
                network=constants.GCE_DEFAULT_NETWORK,
+               machine_type=constants.GCE_DEFAULT_MACHINE_TYPE,
                json_key_file=constants.GCE_JSON_KEY,
                gcs_bucket=constants.GCS_BUCKET):
     """Processes GCE-specific options."""
     super(GCEAUWorker, self).__init__(options, test_results_root)
-    self.gce_context = gce.GceContext.ForServiceAccount(
-        project, zone, network, json_key_file=json_key_file)
+    self.gce_context = gce.GceContext.ForServiceAccountThreadSafe(
+        project, zone, network, machine_type, json_key_file=json_key_file)
     self.gscontext = gs.GSContext()
     self.gcs_bucket = gcs_bucket
     self.tarball_local = None
     self.tarball_remote = None
-    self.instance = None
     self.image = None
+    self.image_link = None
+    # One instance per test.
+    self.instances = {}
 
     # Background processes that delete throw-away instances.
     self.bg_delete_processes = []
 
+    # Load test specifications from <overlay>/scripts/gce_tests.json, if any.
+    self._LoadTests()
+
   def CleanUp(self):
-    """Deletes throw-away instances and images"""
+    """Deletes throw-away instances and images."""
     logging.info('Waiting for all instances and images to be deleted.')
 
     def _WaitForBackgroundDeleteProcesses():
@@ -74,46 +141,211 @@
 
     _WaitForBackgroundDeleteProcesses()
     # Delete the instance/image created by the last call to UpdateImage.
-    self._DeleteInstanceIfExists()
+    self._DeleteInstancesIfExist()
     _WaitForBackgroundDeleteProcesses()
     logging.info('All instances/images are deleted.')
 
-  def _DeleteInstanceIfExists(self):
-    """Deletes existing instances if any."""
-    def _DeleteInstanceAndImage():
-      self.gscontext.DoCommand(['rm', self.tarball_remote])
-      self.gce_context.DeleteInstance(self.instance)
-      self.gce_context.DeleteImage(self.image)
-
-    if self.instance:
-      logging.info('Existing instance %s found. Deleting...', self.instance)
-      bg_delete = Process(target=_DeleteInstanceAndImage)
-      bg_delete.start()
-      self.bg_delete_processes.append(bg_delete)
-
   def PrepareBase(self, image_path, signed_base=False):
     """Auto-update to base image to prepare for test."""
     return self.PrepareRealBase(image_path, signed_base)
 
   def UpdateImage(self, image_path, src_image_path='', stateful_change='old',
                   proxy_port=None, payload_signing_key=None):
-    """Updates the image on a GCE instance.
+    """Updates the image on all GCE instances.
 
-    Unlike real_au_worker, this method always creates a new instance.
+    There may be multiple instances created with different gcloud flags that
+    will be used by different tests or suites.
     """
+    self._CreateImage(image_path)
+    self._CreateInstances()
+
+  def VerifyImage(self, unittest, percent_required_to_pass=100, test=''):
+    """Verifies the image by running all the required tests.
+
+    Run the test targets as specified in <overlay>/scripts/gce_gce_tests.json or
+    the default 'gce-smoke' suite if none. Multiple test targets are run in
+    parallel. Test results are joined and printed after all tests finish. Note
+    that a dedicated instance has been created for each test target.
+
+    Args:
+      unittest: (unittest.TestCase) The test case to report results back to.
+      percent_required_to_pass: (int) The required minimum pass rate. Not used.
+      test: (str) The specific test to run. Not used.
+
+    Returns:
+      True if all tests pass, or False otherwise.
+    """
+    log_directory_base, fail_directory_base = self.GetNextResultsPath(
+        'autotest_tests')
+
+    steps = []
+    for test in self.tests:
+      remote = self.gce_context.GetInstanceIP(self.instances[test['name']])
+      # Prefer partial to lambda because of Python's late binding.
+      steps.append(partial(self._RunTest, test['name'], remote,
+                           log_directory_base, fail_directory_base))
+    return_values = parallel.RunParallelSteps(steps, return_values=True)
+
+    passed = True
+    outputs = {}
+    for test, percent_passed, output in return_values:
+      passed &= (percent_passed == 100)
+      outputs[test] = output
+
+    if not passed:
+      self._HandleFail(log_directory_base, fail_directory_base)
+      if unittest is not None:
+        unittest.fail('Not all tests passed')
+      for test, output in outputs.iteritems():
+        print ('\nTest: %s\n' % test)
+        print (output)
+    return passed
+
+  # --- PRIVATE HELPER FUNCTIONS ---
+
+  def _RunTest(self, test, remote, log_directory_base, fail_directory_base):
+    """Runs a test or a suite of tests on a given remote.
+
+    Runs a test target, whether an individual test or a suite of tests, with
+    'test_that'.
+
+    Args:
+      test: (str) The test or suite to run.
+      remote: (str) The hostname of the remote DUT.
+      log_directory_base:
+          (str) The base directory to store test logs. A sub directory specific
+          to this test will be created there.
+      fail_directory_base:
+          (str) The base directory to store test logs in case of a test failure.
+
+    Returns:
+      test:
+          (str) Same as |test|. This is useful when the caller wants to
+          correlate results to the test name.
+      percent_passed: (int) Pass rate.
+      output: (str): Original test output.
+    """
+    log_directory, _ = self._GetResultsDirectoryForTest(
+        test, log_directory_base, fail_directory_base)
+    log_directory_in_chroot = log_directory.rpartition('chroot')[2]
+
+    cmd = ['test_that', '-b', self.board, '--no-quickmerge',
+           '--results_dir=%s' % log_directory_in_chroot, remote, test]
+    if self.ssh_private_key is not None:
+      cmd.append('--ssh_private_key=%s' %
+                 path_util.ToChrootPath(self.ssh_private_key))
+
+      result = cros_build_lib.RunCommand(cmd, error_code_ok=True,
+                                         enter_chroot=True,
+                                         redirect_stdout=True,
+                                         cwd=constants.CROSUTILS_DIR)
+      percent_passed = self.ParseGeneratedTestOutput(result.output)
+    return test, percent_passed, result.output
+
+  def _GetResultsDirectoryForTest(self, test, log_directory_base,
+                                  fail_directory_base):
+    """Gets the log and fail directories for a particular test.
+
+    Args:
+      test: (str) The test or suite to get directories for.
+      log_directory_base:
+          (str) The base directory where all test results are saved.
+      fail_directory_base:
+          (str) The base directory where all test failures are recorded.
+    """
+    log_directory = os.path.join(log_directory_base, test)
+    fail_directory = os.path.join(fail_directory_base, test)
+
+    if not os.path.exists(log_directory):
+      os.makedirs(log_directory)
+    return log_directory, fail_directory
+
+  def _LoadTests(self):
+    """Loads the tests to run from <overlay>/scripts/gce_tests.json.
+
+    If the JSON file exists, loads the tests and flags to create instance for
+    each test with. The JSON file should contain a "tests" object, which is an
+    array of objects, each of which has only two keys: "name" and "flags".
+
+    "name" could be any valid Autotest test name, or a suite name, in the form
+    of "suite:<suite_name>", e.g., "suite:gce-smoke".
+
+    "flags" is a JSON object whose members must be valid proterties of the GCE
+    Instance Resource, as specificed at:
+    https://cloud.google.com/compute/docs/reference/latest/instances#resource.
+
+    These flags will be used to create instances. Each flag must strictly follow
+    the property schema as defined in the Instance Resource. Failure to do so
+    will result in instance creation failures.
+
+    Note that a dedicated instance will be created for every test object
+    specified in scripts/gce_tests.json. So group test cases that require
+    similar instance properties together as suites whenever possible.
+
+    An example scripts/gce_tests.json may look like:
+    {
+      "tests": [
+        {
+          "name": "suite:gce-smoke",
+          "flags": []
+        },
+        {
+          "name": "suite:cloud-init",
+          "flags": {
+              "description": "Test instance",
+              "metadata": {
+                "items": [
+                  {
+                    "key": "fake_key",
+                    "value": "fake_value"
+                  }
+                ]
+              }
+          }
+        }
+      ]
+    }
+
+    If the JSON file does not exist, the 'gce-smoke' suite will be used to
+    verify the image.
+    """
+    # Defaults to run the gce-smoke suite if no custom tests are given.
+    tests = [dict(name="suite:gce-smoke", flags=dict())]
+
+    custom_tests = None
+    try:
+      custom_tests = portage_util.ReadOverlayFile(
+          'scripts/gce_tests.json', board=self.board)
+    except portage_util.MissingOverlayException as e:
+      logging.warn('Board overlay not found. Error: %r', e)
+
+    if custom_tests is not None:
+      if self.board not in constants.TRUSTED_BOARDS:
+        logging.warn('Custom tests and flags are not allowed for this board '
+                     '(%s)!', self.board)
+      else:
+        # Read the list of tests.
+        try:
+          json_file = json.loads(custom_tests)
+          tests = json_file.get('tests')
+        except ValueError as e:
+          logging.warn('scripts/gce_tests.json contains invalid JSON content. '
+                       'Default tests will be run and default flags will be '
+                       'used to create instances. Error: %r', e)
+    self.tests = tests
+
+  def _CreateImage(self, image_path):
+    """Uploads the gce tarball and creates an image with it."""
     self.tarball_local = image_path
     log_directory, fail_directory = self.GetNextResultsPath('update')
-    self._DeleteInstanceIfExists()
+    self._DeleteInstancesIfExist()
     ts = datetime.datetime.fromtimestamp(time.time()).strftime(
         '%Y-%m-%d-%H-%M-%S')
     image = '%s%s' % (self.IMAGE_PREFIX, ts)
-    instance = '%s%s' % (self.INSTANCE_PREFIX, ts)
     gs_directory = ('gs://%s/%s' % (self.gcs_bucket, ts))
 
     # Upload the GCE tarball to Google Cloud Storage.
     try:
-      logging.info('Uploading GCE tarball %s to %s ...' , self.tarball_local,
-                   gs_directory)
       self.gscontext.CopyInto(self.tarball_local, gs_directory)
       self.tarball_remote = '%s/%s' % (gs_directory,
                                        os.path.basename(self.tarball_local))
@@ -124,38 +356,43 @@
 
     # Create an image from |image_path| and an instance from the image.
     try:
-      image_link = self.gce_context.CreateImage(
+      self.image_link = self.gce_context.CreateImage(
           image, self._GsPathToUrl(self.tarball_remote))
-      self.gce_context.CreateInstance(instance, image_link)
+      self.image = image
     except gce.Error as e:
       self._HandleFail(log_directory, fail_directory)
-      raise update_exception.UpdateException(1, 'Update failed. Error: %s' % e)
-    self.instance = instance
-    self.image = image
+      raise update_exception.UpdateException(1, 'Update failed. Error: %r' % e)
 
-  def VerifyImage(self, unittest, percent_required_to_pass=100, test=''):
-    """Verifies an image using test_that with verification suite."""
-    log_directory, fail_directory = self.GetNextResultsPath('autotest_tests')
-    log_directory_in_chroot = log_directory.rpartition('chroot')[2]
-    instance_ip = self.gce_context.GetInstanceIP(self.instance)
-    test_suite = test or self.verify_suite
+  def _CreateInstances(self):
+    """Creates instances with custom flags as specificed in |self.tests|."""
+    steps = []
+    for test in self.tests:
+      ts = datetime.datetime.fromtimestamp(time.time()).strftime(
+          '%Y-%m-%d-%H-%M-%S')
+      instance = '%s%s' % (self.INSTANCE_PREFIX, ts)
+      kwargs = test['flags'].copy()
+      kwargs['description'] = 'For test %s' % test['name']
+      steps.append(partial(self.gce_context.CreateInstance, instance,
+                           self.image_link, **kwargs))
+      self.instances[test['name']] = instance
+    parallel.RunParallelSteps(steps)
 
-    cmd = ['test_that', '-b', self.board, '--no-quickmerge',
-           '--results_dir=%s' % log_directory_in_chroot, instance_ip,
-           test_suite]
-    if self.ssh_private_key is not None:
-      cmd.append('--ssh_private_key=%s' %
-                 path_util.ToChrootPath(self.ssh_private_key))
+  def _DeleteInstancesIfExist(self):
+    """Deletes existing instances if any."""
+    def _DeleteInstancesAndImage():
+      steps = [
+          lambda: self.gscontext.DoCommand(['rm', self.tarball_remote]),
+          lambda: self.gce_context.DeleteImage(self.image),
+      ]
+      for instance in self.instances.values():
+        steps.append(partial(self.gce_context.DeleteInstance, instance))
+      parallel.RunParallelSteps(steps)
 
-    result = cros_build_lib.RunCommand(cmd, error_code_ok=True,
-                                       enter_chroot=True, redirect_stdout=True,
-                                       cwd=constants.CROSUTILS_DIR)
-    ret = self.AssertEnoughTestsPassed(unittest, result.output,
-                                       percent_required_to_pass)
-    if not ret:
-      self._HandleFail(log_directory, fail_directory)
-
-    return ret
+    if self.instances:
+      logging.info('Deleting instances...')
+      bg_delete = Process(target=_DeleteInstancesAndImage)
+      bg_delete.start()
+      self.bg_delete_processes.append(bg_delete)
 
   def _HandleFail(self, log_directory, fail_directory):
     """Handles test failures.
@@ -187,7 +424,7 @@
     except shutil.Error as e:
       logging.warning('Ignoring errors while copying GCE tarball: %s', e)
 
-    self._DeleteInstanceIfExists()
+    self._DeleteInstancesIfExist()
 
   def _GsPathToUrl(self, gs_path):
     """Converts a gs:// path to a URL.
diff --git a/au_test_harness/gce_au_worker_unittest.py b/au_test_harness/gce_au_worker_unittest.py
index 76a7280..0ea7531 100755
--- a/au_test_harness/gce_au_worker_unittest.py
+++ b/au_test_harness/gce_au_worker_unittest.py
@@ -8,6 +8,7 @@
 
 from __future__ import print_function
 
+import mock
 import os
 import sys
 import unittest
@@ -19,19 +20,23 @@
 from chromite.lib import cros_build_lib
 from chromite.lib import cros_test_lib
 from chromite.lib import osutils
+from chromite.lib import parallel
 from chromite.lib import path_util
+from chromite.lib import portage_util
+from crostestutils.au_test_harness.au_worker import AUWorker
 from crostestutils.au_test_harness.gce_au_worker import GCEAUWorker
 from crostestutils.lib.gce import GceContext
 
+
 class Options(object):
   """A fake class to hold command line options."""
 
   def __init__(self):
-    self.board = 'fake-board'
+    self.board = 'lakitu'
     self.delta = False
     self.verbose = False
     self.quick_test = False
-    self.verify_suite_name = 'smoke'
+    self.verify_suite_name = 'gce-smoke'
 
 
 class GceAuWorkerTest(cros_test_lib.MockTempDirTestCase):
@@ -48,13 +53,14 @@
     # Fake out environment.
     options = Options()
     options.ssh_private_key = os.path.join(self.tempdir, 'ssh-private-key')
-    self.ssh_private_key = options.ssh_private_key
-    osutils.Touch(self.ssh_private_key)
+    osutils.Touch(options.ssh_private_key)
+    self.options = options
 
     test_results_root = os.path.join(self.tempdir, 'test-results')
     self.test_results_all = os.path.join(test_results_root, 'all')
     self.test_results_failed = os.path.join(test_results_root, 'failed')
     osutils.SafeMakedirs(self.test_results_all)
+    self.test_results_root = test_results_root
 
     self.json_key_file = os.path.join(self.tempdir, 'service_account.json')
     osutils.Touch(self.json_key_file)
@@ -62,90 +68,220 @@
     self.image_path = os.path.join(self.tempdir, self.GCE_TARBALL)
     osutils.Touch(self.image_path)
 
-    self.PatchObject(GceContext, 'ForServiceAccount', autospec=True)
-    self.worker = GCEAUWorker(options, test_results_root, project=self.PROJECT,
-                              zone=self.ZONE, network=self.NETWORK,
-                              gcs_bucket=self.BUCKET,
-                              json_key_file=self.json_key_file)
-
-    # Mock out methods.
-    for cmd in ['CreateInstance', 'CreateImage', 'GetInstanceIP',
-                'DeleteInstance', 'DeleteImage', 'ListInstances', 'ListImages']:
-      self.PatchObject(self.worker.gce_context, cmd, autospec=True)
-
-    for cmd in ['CopyInto', 'DoCommand']:
-      self.PatchObject(self.worker.gscontext, cmd, autospec=True)
-
-    self.PatchObject(self.worker, 'GetNextResultsPath', autospec=True,
+    # Mock out model or class level methods.
+    self.PatchObject(AUWorker, 'GetNextResultsPath', autospec=True,
                      return_value=(self.test_results_all,
                                    self.test_results_failed))
+    self.PatchObject(GceContext, 'ForServiceAccountThreadSafe',
+                     spec=GceContext.ForServiceAccountThreadSafe)
 
-  def testUpdateImage(self):
-    """Tests that UpdateImage creates a GCE VM using the given tarball."""
+  def testUpdateImageWithoutCustomTests(self):
+    """Tests UpdateImage's behavior when no custom tests are specified.
 
-    def _CopyInto(src, _):
-      self.assertEqual(self.image_path, src)
+    This test verifies that when no custom gce_tests.json is found, the
+    gce-smoke suite will be used as verification test and no special flags will
+    be used at instance creation time.
+    """
+    # Fake an empty gce_tests.json.
+    self.PatchObject(portage_util, 'ReadOverlayFile', autospec=True,
+                     return_value=None)
 
-    self.PatchObject(self.worker.gscontext, 'CopyInto', autospec=True,
-                     side_effect=_CopyInto)
-    self.PatchObject(self.worker, '_DeleteInstanceIfExists', autospec=True)
-    self.PatchObject(self.worker, 'GetNextResultsPath', autospec=True,
-                     return_value=('test-resultsi-all', 'test-results-failed'))
-    self.worker.UpdateImage(self.image_path)
+    # Initialize GCEAUWorker. gce_tests.json will be loaded.
+    worker = GCEAUWorker(self.options, self.test_results_root,
+                         project=self.PROJECT, zone=self.ZONE,
+                         network=self.NETWORK, gcs_bucket=self.BUCKET,
+                         json_key_file=self.json_key_file)
 
-    #pylint: disable=protected-access
-    self.worker._DeleteInstanceIfExists.assert_called_once_with()
-    #pylint: enable=protected-access
-    self.assertNotEqual(self.worker.instance, '')
-    self.assertNotEqual(self.worker.image, '')
-    self.assertTrue(self.worker.gscontext.CopyInto.called)
+    # There are no custom tests specified. The gce-smoke suite will be run, and
+    # no special flags will be used at instance creation.
+    self.assertListEqual([dict(name="suite:gce-smoke", flags=dict())],
+                         worker.tests)
+
+    # Call UpdateImage.
+    self.PatchObject(worker.gce_context, 'CreateInstance', autospec=True)
+    self.PatchObject(worker, '_CreateImage', autospec=True)
+    worker.UpdateImage(self.image_path)
+
+    # Verify that only one instance is created and no additional kwargs are
+    # passed to CreateInstance.
+    worker.gce_context.CreateInstance.assert_called_once_with(
+        mock.ANY, mock.ANY, mock.ANY, network=self.NETWORK, zone=self.ZONE)
+
+  def testUpdateImageWithCustomTests(self):
+    """Tests UpdateImage's behavior with custom tests.
+
+    This tests verifies that when a custom gce_tests.json is provided, tests
+    specified in it will be used to verify the image, and instances will be
+    created for each test target as specificed, with specificed GCE flags.
+    """
+    # Fake gce_tests.json.
+    tests_json = """
+    {
+        "tests": [
+            {
+              "name": "suite:suite1",
+              "flags": {
+                  "foo": "bar"
+              }
+            },
+            {
+              "name": "suite:suite2",
+              "flags": {
+                  "bar": "foo"
+              }
+            },
+            {
+              "name": "foo_test",
+              "flags": {}
+            }
+        ]
+    }
+    """
+    self.PatchObject(portage_util, 'ReadOverlayFile', autospec=True,
+                     return_value=tests_json)
+
+    # Initialize GCEAUWorker. It should load gce_tests.json.
+    worker = GCEAUWorker(self.options, self.test_results_root,
+                         project=self.PROJECT, zone=self.ZONE,
+                         network=self.NETWORK, gcs_bucket=self.BUCKET,
+                         json_key_file=self.json_key_file)
+
+    # Assert that tests specificed in gce_tests.json are loaded and will be run
+    # later to verify the image.
+    self.assertSetEqual(
+        set([test['name'] for test in worker.tests]),
+        set(['suite:suite1', 'suite:suite2', 'foo_test'])
+    )
+
+    # UpdateImage is expected to create instances for each test with correct
+    # flags.
+    self.PatchObject(worker.gce_context, 'CreateInstance', autospec=True)
+    self.PatchObject(worker, '_CreateImage', autospec=True)
+    worker.UpdateImage(self.image_path)
+
+    # Assert that instances are created for each test.
+    self.assertSetEqual(
+        set(worker.instances.keys()),
+        set(['suite:suite1', 'suite:suite2', 'foo_test'])
+    )
+
+    # Assert that correct flags are applied.
+    worker.gce_context.CreateInstance.assert_called_with(
+        mock.ANY, mock.ANY, mock.ANY, network=self.NETWORK, zone=self.ZONE,
+        foo='bar')
+    worker.gce_context.CreateInstance.assert_called_with(
+        mock.ANY, mock.ANY, mock.ANY, network=self.NETWORK, zone=self.ZONE,
+        bar='foo')
+    worker.gce_context.CreateInstance.assert_called_with(
+        mock.ANY, mock.ANY, mock.ANY, network=self.NETWORK, zone=self.ZONE)
+
+  def testVerifyImage(self):
+    """Verifies that VerifyImage runs required tests on correct instances."""
+    worker = GCEAUWorker(self.options, self.test_results_root,
+                         project=self.PROJECT, zone=self.ZONE,
+                         network=self.NETWORK, gcs_bucket=self.BUCKET,
+                         json_key_file=self.json_key_file)
+    # Fake tests and instances.
+    worker.tests = [
+        dict(name='suite:suite1', flags=dict(foo='bar')),
+        dict(name='suite:suite2', flags=dict(bar='foo')),
+        dict(name='foo_test', flags=dict()),
+    ]
+    worker.instances = {
+        'suite:suite1': 'instance_1',
+        'suite:suite2': 'instance_2',
+        'foo_test': 'instance_3',
+    }
+
+    expected_tests_run = [
+        dict(remote='1.1.1.1', test='suite:suite1'),
+        dict(remote='2.2.2.2', test='suite:suite2'),
+        dict(remote='3.3.3.3', test='foo_test'),
+    ]
+    actual_tests_run = []
+
+    def _OverrideGetInstanceIP(instance, *unused_args, **unused_kwargs):
+      if instance == 'instance_1':
+        return '1.1.1.1'
+      elif instance == 'instance_2':
+        return '2.2.2.2'
+      else:
+        return '3.3.3.3'
+
+    def _OverrideRunCommand(cmd, *unused_args, **unused_kwargs):
+      remote = cmd[-3]
+      test = cmd[-2]
+      actual_tests_run.append(dict(remote=remote, test=test))
+      return cros_build_lib.CommandResult()
+
+    def _OverrideRunParallelSteps(steps, *unused_args, **unused_kwargs):
+      """Run steps sequentially."""
+      return_values = []
+      for step in steps:
+        ret = step()
+        return_values.append(ret)
+      return return_values
+
+    self.PatchObject(worker.gce_context, 'CreateInstance', autospec=True)
+    self.PatchObject(path_util, 'ToChrootPath', autospec=True,
+                     return_value='x/y/z')
+    self.PatchObject(worker.gce_context, 'GetInstanceIP',
+                     autospec=True,
+                     side_effect=_OverrideGetInstanceIP)
+    self.PatchObject(cros_build_lib, 'RunCommand',
+                     autospec=True,
+                     side_effect=_OverrideRunCommand)
+    self.PatchObject(AUWorker, 'ParseGeneratedTestOutput', autospec=True,
+                     return_value=100)
+    self.PatchObject(parallel, 'RunParallelSteps', autospec=True,
+                     side_effect=_OverrideRunParallelSteps)
+
+    # VerifyImage should run all expected tests.
+    worker.VerifyImage(None)
+
+    # Assert that expected and only expected tests are run.
+    self.assertEqual(len(expected_tests_run), len(actual_tests_run))
+    for test in expected_tests_run:
+      self.assertIn(test, actual_tests_run)
 
   def testCleanUp(self):
     """Tests that CleanUp deletes all instances and doesn't leak processes."""
+    worker = GCEAUWorker(self.options, self.test_results_root,
+                         project=self.PROJECT, zone=self.ZONE,
+                         network=self.NETWORK, gcs_bucket=self.BUCKET,
+                         json_key_file=self.json_key_file)
+    for cmd in ['CopyInto', 'DoCommand']:
+      self.PatchObject(worker.gscontext, cmd, autospec=True)
+
+    self.PatchObject(worker.gce_context, 'DeleteInstance', autospec=True)
+
     for _ in range(3):
-      self.worker.UpdateImage(self.image_path)
-    self.assertEqual(len(self.worker.bg_delete_processes), 2)
+      worker.UpdateImage(self.image_path)
+    self.assertEqual(len(worker.bg_delete_processes), 2)
 
-    self.worker.CleanUp()
-    self.assertEqual(len(self.worker.bg_delete_processes), 0)
-
-  def testVerifyImage(self):
-    """Tests that VerifyImage calls out to test_that with correct args."""
-
-    def _RunCommand(cmd, *args, **kwargs):
-      expected_cmd = ['test_that', '-b', 'fake-board', '--no-quickmerge',
-                      '--results_dir=%s' % self.test_results_all, '1.2.3.4',
-                      'suite:smoke']
-      for i, arg in enumerate(expected_cmd):
-        self.assertEqual(arg, cmd[i])
-
-      return cros_build_lib.CommandResult()
-
-    self.PatchObject(cros_build_lib, 'RunCommand', autospec=True,
-                     side_effect=_RunCommand)
-    self.PatchObject(self.worker, 'AssertEnoughTestsPassed', autospec=True)
-    self.PatchObject(self.worker, '_DeleteInstanceIfExists', autospec=True)
-    self.PatchObject(self.worker.gce_context, 'GetInstanceIP', autospec=True,
-                     return_value='1.2.3.4')
-    self.PatchObject(path_util, 'ToChrootPath', autospec=True,
-                     return_value='x/y/z')
-    self.worker.UpdateImage(self.image_path)
-    self.worker.VerifyImage(None)
-    self.assertTrue(cros_build_lib.RunCommand.called)
+    worker.CleanUp()
+    self.assertEqual(len(worker.bg_delete_processes), 0)
 
   def testHandleFail(self):
     """Tests that _HandleFail copies necessary files for repro."""
+    worker = GCEAUWorker(self.options, self.test_results_root,
+                         project=self.PROJECT, zone=self.ZONE,
+                         network=self.NETWORK, gcs_bucket=self.BUCKET,
+                         json_key_file=self.json_key_file)
+    for cmd in ['CopyInto', 'DoCommand']:
+      self.PatchObject(worker.gscontext, cmd, autospec=True)
     self.PatchObject(cros_build_lib, 'RunCommand', autospec=True)
-    self.PatchObject(self.worker, '_DeleteInstanceIfExists', autospec=True)
+    self.PatchObject(worker, '_DeleteInstancesIfExist', autospec=True)
     self.PatchObject(path_util, 'ToChrootPath', autospec=True,
                      return_value='x/y/z')
-    self.PatchObject(self.worker, 'AssertEnoughTestsPassed', autospec=True,
-                     return_value=False)
-    self.worker.UpdateImage(self.image_path)
-    self.worker.VerifyImage(None)
+    self.PatchObject(worker, '_RunTest', autospec=True,
+                     return_value=(0, None, None))
+    worker.UpdateImage(self.image_path)
+    worker.VerifyImage(None)
     self.assertExists(os.path.join(self.test_results_failed, self.GCE_TARBALL))
-    self.assertExists(os.path.join(self.test_results_failed,
-                                   os.path.basename(self.ssh_private_key)))
+    self.assertExists(os.path.join(
+        self.test_results_failed,
+        os.path.basename(self.options.ssh_private_key)))
 
 
 if __name__ == '__main__':
diff --git a/lib/constants.py b/lib/constants.py
index 66e9ac7..14e3063 100644
--- a/lib/constants.py
+++ b/lib/constants.py
@@ -23,5 +23,10 @@
 GCE_PROJECT = 'cros-autotest-bots'
 GCE_DEFAULT_ZONE = 'us-central1-a'
 GCE_DEFAULT_NETWORK = 'network-prod'
+GCE_DEFAULT_MACHINE_TYPE = 'n1-standard-8'
 GCE_JSON_KEY = '/creds/service_accounts/service-account-cros-autotest-bots.json'
 GCS_BUCKET = 'chromeos-test-gce-tarballs'
+
+TRUSTED_BOARDS = [
+    'lakitu'
+]
diff --git a/lib/gce.py b/lib/gce.py
index 609c90c..5955319 100644
--- a/lib/gce.py
+++ b/lib/gce.py
@@ -10,9 +10,12 @@
 
 from __future__ import print_function
 
+import httplib2
+
 from chromite.lib import cros_logging as logging
 from chromite.lib import timeout_util
 from googleapiclient.discovery import build
+from googleapiclient.http import HttpRequest
 from googleapiclient import errors
 from oauth2client.client import GoogleCredentials
 
@@ -54,22 +57,38 @@
   DEFAULT_MACHINE_TYPE = 'n1-standard-8'
   DEFAULT_TIMEOUT_SEC = 5 * 60
 
-  def __init__(self, project, zone, network, credentials):
+  def __init__(self, project, zone, network, machine_type, credentials,
+               thread_safe=False):
     """Initializes GceContext.
 
     Args:
       project: The GCP project to create instances in.
       zone: The default zone to create instances in.
       network: The default network to create instances in.
+      machine_type: The default machine type to use.
       credentials: The credentials used to call the GCE API.
+      thread_safe: Whether the client is expected to be thread safe.
     """
     self.project = project
     self.zone = zone
     self.network = network
-    self.gce_client = build('compute', 'v1', credentials=credentials)
+    self.machine_type = machine_type
+
+    def BuildRequest(_, *args, **kwargs):
+      """Create a new Http() object for every request."""
+      http = httplib2.Http()
+      http = credentials.authorize(http)
+      return HttpRequest(http, *args, **kwargs)
+
+    if thread_safe:
+      self.gce_client = build('compute', 'v1', credentials=credentials,
+                              requestBuilder=BuildRequest)
+    else:
+      self.gce_client = build('compute', 'v1', credentials=credentials)
 
   @classmethod
-  def ForServiceAccount(cls, project, zone, network, json_key_file):
+  def ForServiceAccount(cls, project, zone, network, machine_type,
+                        json_key_file):
     """Creates a GceContext using service account credentials.
 
     About service account:
@@ -79,6 +98,7 @@
       project: The GCP project to create images and instances in.
       zone: The default zone to create instances in.
       network: The default network to create instances in.
+      machine_type: The default machine type to use.
       json_key_file: Path to the service account JSON key.
 
     Returns:
@@ -86,10 +106,33 @@
     """
     credentials = GoogleCredentials.from_stream(json_key_file).create_scoped(
         cls.GCE_SCOPES)
-    return GceContext(project, zone, network, credentials)
+    return GceContext(project, zone, network, machine_type, credentials)
 
-  def CreateInstance(self, name, image, machine_type=DEFAULT_MACHINE_TYPE,
-                     network=None, zone=None):
+  @classmethod
+  def ForServiceAccountThreadSafe(cls, project, zone, network, machine_type,
+                                  json_key_file):
+    """Creates a thread-safe GceContext using service account credentials.
+
+    About service account:
+    https://developers.google.com/api-client-library/python/auth/service-accounts
+
+    Args:
+      project: The GCP project to create images and instances in.
+      zone: The default zone to create instances in.
+      network: The default network to create instances in.
+      machine_type: The default machine type to use.
+      json_key_file: Path to the service account JSON key.
+
+    Returns:
+      GceContext.
+    """
+    credentials = GoogleCredentials.from_stream(json_key_file).create_scoped(
+        cls.GCE_SCOPES)
+    return GceContext(project, zone, network, machine_type, credentials,
+                      thread_safe=True)
+
+  def CreateInstance(self, name, image, machine_type=None, network=None,
+                     zone=None, **kwargs):
     """Creates an instance with the given image and waits until it's ready.
 
     Args:
@@ -106,14 +149,22 @@
       zone:
         The zone to create the instance in. Default zone will be used if
         omitted.
+      kwargs:
+        Other possible Instance Resource properties.
+        https://cloud.google.com/compute/docs/reference/latest/instances#resource
 
     Returns:
       URL to the created instance.
     """
+    machine_type = 'zones/%s/machineTypes/%s' % (
+        zone or self.zone, machine_type or self.machine_type)
+    # Allow machineType overriding.
+    if 'machineType' in kwargs.keys():
+      machine_type = kwargs['machineType']
+
     config = {
         'name': name,
-        'machineType': 'zones/%s/machineTypes/%s' % (zone or self.zone,
-                                                     machine_type),
+        'machineType': machine_type,
         'disks': [
             {
                 'boot': True,
@@ -132,6 +183,7 @@
                 }
             ]
         }
+    config.update(**kwargs)
     operation = self.gce_client.instances().insert(
         project=self.project,
         zone=zone or self.zone,
@@ -227,7 +279,6 @@
     except (KeyError, IndexError):
       raise Error('Failed to get IP address for instance %s' % instance)
 
-
   def _WaitForZoneOperation(self, operation, zone=None, timeout_handler=None):
     get_request = self.gce_client.zoneOperations().get(
         project=self.project, zone=zone or self.zone, operation=operation)