This CL updates the parallel job library in the au test harness to be more robust.

Specifically, we migrate to using multiprocessing.Process rather than threading.Thread
which allows us to terminate pesky threads in the event a job takes too long.  In
addition it allows us to remove some redundancy that was needed for threading.Thread.

I've also refactored the "waiting" code into the ParallelJob class and made it possible
to test with ctest --cache.

Change-Id: I7b691b676defb9d2140b006b5731874c7a08562c

BUG=chromium-os:11168, chromium-os:13027
TEST=Ran ctest on x86-generic using new CL logic that lets me do so.

Review URL: http://codereview.chromium.org/6815003
diff --git a/au_test_harness/au_test.py b/au_test_harness/au_test.py
index 5a6549e..a95fa74 100644
--- a/au_test_harness/au_test.py
+++ b/au_test_harness/au_test.py
@@ -131,7 +131,9 @@
     This test checks that we can update by updating the stateful partition
     rather than wiping it.
     """
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9222)
+    # Just make sure some tests pass on original image.  Some old images
+    # don't pass many tests.
     self.worker.PrepareBase(self.base_image_path)
 
     # Update to
@@ -148,7 +150,9 @@
     This test checks that we can update successfully after wiping the
     stateful partition.
     """
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9223)
+    # Just make sure some tests pass on original image.  Some old images
+    # don't pass many tests.
     self.worker.PrepareBase(self.base_image_path)
 
     # Update to
@@ -193,7 +197,7 @@
         self.data_size += len(data)
         return data
 
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9224)
     self.AttemptUpdateWithFilter(InterruptionFilter(), proxy_port=8082)
 
   def testDelayedUpdate(self):
@@ -224,7 +228,7 @@
         self.data_size += len(data)
         return data
 
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9225)
     self.AttemptUpdateWithFilter(DelayedFilter(), proxy_port=8083)
 
   def SimpleTest(self):
@@ -233,7 +237,7 @@
     We explicitly don't use test prefix so that isn't run by default.  Can be
     run using test_prefix option.
     """
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9226)
     self.worker.PrepareBase(self.base_image_path)
     self.worker.PerformUpdate(self.target_image_path, self.base_image_path)
     self.worker.VerifyImage(self)
@@ -243,7 +247,7 @@
   # TODO(sosa): Get test to work with verbose.
   def NotestPartialUpdate(self):
     """Tests what happens if we attempt to update with a truncated payload."""
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9227)
     # Preload with the version we are trying to test.
     self.worker.PrepareBase(self.target_image_path)
 
@@ -262,7 +266,7 @@
   # TODO(sosa): Get test to work with verbose.
   def NotestCorruptedUpdate(self):
     """Tests what happens if we attempt to update with a corrupted payload."""
-    self.worker.InitializeResultsDirectory()
+    self.worker.Initialize(9228)
     # Preload with the version we are trying to test.
     self.worker.PrepareBase(self.target_image_path)
 
diff --git a/au_test_harness/au_worker.py b/au_test_harness/au_worker.py
index 701a12a..a25de7f 100644
--- a/au_test_harness/au_worker.py
+++ b/au_test_harness/au_worker.py
@@ -232,9 +232,19 @@
     unittest.assertTrue(percent_passed >= percent_required_to_pass)
     return percent_passed
 
-  def InitializeResultsDirectory(self):
-    """Called by a test to initialize a results directory for this worker."""
-    # Use the name of the test.
+  def Initialize(self, port):
+    """Initializes test specific variables for each test.
+
+    Each test needs to specify a unique ssh port.
+
+    Args:
+      port:  Unique port for ssh access.
+    """
+    # Initialize port vars.
+    self._ssh_port = port
+    self._kvm_pid_file = '/tmp/kvm.%d' % port
+
+    # Initialize test results directory.
     test_name = inspect.stack()[1][3]
     self.results_directory = os.path.join(self.test_results_root, test_name)
     self.results_count = 0
diff --git a/au_test_harness/parallel_test_job.py b/au_test_harness/parallel_test_job.py
index 096ec17..89befff 100644
--- a/au_test_harness/parallel_test_job.py
+++ b/au_test_harness/parallel_test_job.py
@@ -4,58 +4,86 @@
 
 """Module containing methods/classes related to running parallel test jobs."""
 
+import multiprocessing
 import sys
-import threading
 import time
 
 import cros_build_lib as cros_lib
 
-class ParallelJob(threading.Thread):
-  """Small wrapper for threading.  Thread that releases a semaphores on exit."""
+class ParallelJobTimeoutError(Exception):
+  """Thrown when a job ran for longer than expected."""
+  pass
 
-  def __init__(self, starting_semaphore, ending_semaphore, target, args):
+
+class ParallelJob(multiprocessing.Process):
+  """Small wrapper for Process that stores output of its target method."""
+
+  MAX_TIMEOUT_SECONDS = 1800
+  SLEEP_TIMEOUT_SECONDS = 180
+
+  def __init__(self, starting_semaphore, target, args):
     """Initializes an instance of a job.
 
     Args:
       starting_semaphore: Semaphore used by caller to wait on such that
-        there isn't more than a certain number of threads running.  Should
-        be initialized to a value for the number of threads wanting to be run
-        at a time.
-      ending_semaphore:  Semaphore is released every time a job ends.  Should be
-        initialized to 0 before starting first job.  Should be acquired once for
-        each job.  Threading.Thread.join() has a bug where if the run function
-        terminates too quickly join() will hang forever.
+        there isn't more than a certain number of parallel_jobs running.  Should
+        be initialized to a value for the number of parallel_jobs wanting to be
+        run at a time.
       target:  The func to run.
       args:  Args to pass to the fun.
     """
-    threading.Thread.__init__(self, target=target, args=args)
+    super(ParallelJob, self).__init__(target=target, args=args)
     self._target = target
     self._args = args
     self._starting_semaphore = starting_semaphore
-    self._ending_semaphore = ending_semaphore
-    self._output = None
-    self._completed = False
 
   def run(self):
     """Thread override.  Runs the method specified and sets output."""
     try:
-      self._output = self._target(*self._args)
+      self._target(*self._args)
     finally:
-      # Our own clean up.
-      self._Cleanup()
-      self._completed = True
-      # From threading.py to avoid a refcycle.
-      del self._target, self._args
+      self._starting_semaphore.release()
 
-  def GetOutput(self):
-    """Returns the output of the method run."""
-    assert self._completed, 'GetOutput called before thread was run.'
-    return self._output
+  @classmethod
+  def WaitUntilJobsComplete(cls, parallel_jobs):
+    """Waits until all parallel_jobs have completed before returning.
 
-  def _Cleanup(self):
-    """Releases semaphores for a waiting caller."""
-    self._starting_semaphore.release()
-    self._ending_semaphore.release()
+    Given an array of parallel_jobs, returns once all parallel_jobs have
+    completed or a max timeout is reached.
+
+    Raises:
+      ParallelJobTimeoutError:  if max timeout is reached.
+    """
+    def GetCurrentActiveCount():
+      """Returns the (number of active jobs, first active job)."""
+      active_count = 0
+      active_job = None
+      for parallel_job in parallel_jobs:
+        if parallel_job.is_alive():
+          active_count += 1
+          if not active_job:
+            active_job = parallel_job
+
+      return (active_count, parallel_job)
+
+    start_time = time.time()
+    while (time.time() - start_time) < cls.MAX_TIMEOUT_SECONDS:
+      (active_count, active_job) = GetCurrentActiveCount()
+      if active_count == 0:
+        return
+      else:
+        print >> sys.stderr, (
+            'Process Pool Active: Waiting on %d/%d jobs to complete' %
+            (active_count, len(parallel_jobs)))
+        active_job.join(cls.SLEEP_TIMEOUT_SECONDS)
+        time.sleep(5) # Prevents lots of printing out as job is ending.
+
+    for parallel_job in parallel_jobs:
+      if parallel_job.is_alive():
+        parallel_job.terminate()
+
+    raise ParallelJobTimeoutError('Exceeded max time of %d seconds to wait for '
+                                  'job completion.' % cls.MAX_TIMEOUT_SECONDS)
 
   def __str__(self):
     return '%s(%s)' % (self._target, self._args)
@@ -66,44 +94,44 @@
   """Runs set number of specified jobs in parallel.
 
   Args:
-    number_of_simultaneous_jobs:  Max number of threads to be run in parallel.
+    number_of_simultaneous_jobs:  Max number of parallel_jobs to be run in
+      parallel.
     jobs:  Array of methods to run.
     jobs_args:  Array of args associated with method calls.
     print_status: True if you'd like this to print out .'s as it runs jobs.
   Returns:
-    Returns an array of results corresponding to each thread.
+    Returns an array of results corresponding to each parallel_job.
   """
-  def _TwoTupleize(x, y):
-    return (x, y)
+  def ProcessOutputWrapper(func, args, output):
+    """Simple function wrapper that puts the output of a function in a queue."""
+    output.put(func(*args))
 
-  threads = []
-  job_start_semaphore = threading.Semaphore(number_of_simultaneous_jobs)
-  join_semaphore = threading.Semaphore(0)
   assert len(jobs) == len(jobs_args), 'Length of args array is wrong.'
-
-  # Create the parallel jobs.
-  for job, args in map(_TwoTupleize, jobs, jobs_args):
-    thread = ParallelJob(job_start_semaphore, join_semaphore, target=job,
-                         args=args)
-    threads.append(thread)
-
   # Cache sudo access.
   cros_lib.RunCommand(['sudo', 'echo', 'Caching sudo credentials'],
                       print_cmd=False, redirect_stdout=True,
                       redirect_stderr=True)
 
+  parallel_jobs = []
+  output_array = []
+
+  # Semaphore used to create a Process Pool.
+  job_start_semaphore = multiprocessing.Semaphore(number_of_simultaneous_jobs)
+
+  # Create the parallel jobs.
+  for job, args in map(lambda x, y: (x, y), jobs, jobs_args):
+    output = multiprocessing.Queue()
+    parallel_job = ParallelJob(job_start_semaphore,
+                               target=ProcessOutputWrapper,
+                               args=(job, args, output))
+    parallel_jobs.append(parallel_job)
+    output_array.append(output)
+
   # We use a semaphore to ensure we don't run more jobs than required.
-  # After each thread finishes, it releases (increments semaphore).
-  # Acquire blocks of num jobs reached and continues when a thread finishes.
-  for next_thread in threads:
-    job_start_semaphore.acquire(blocking=True)
-    next_thread.start()
+  # After each parallel_job finishes, it releases (increments semaphore).
+  for next_parallel_job in parallel_jobs:
+    job_start_semaphore.acquire(block=True)
+    next_parallel_job.start()
 
-  # Wait on the rest of the threads to finish.
-  for thread in threads:
-    while not join_semaphore.acquire(blocking=False):
-      time.sleep(5)
-      if print_status:
-        print >> sys.stderr, '.',
-
-  return [thread.GetOutput() for thread in threads]
+  ParallelJob.WaitUntilJobsComplete(parallel_jobs)
+  return [output.get() for output in output_array]
diff --git a/au_test_harness/vm_au_worker.py b/au_test_harness/vm_au_worker.py
index b5925cc..f64515e 100644
--- a/au_test_harness/vm_au_worker.py
+++ b/au_test_harness/vm_au_worker.py
@@ -5,7 +5,6 @@
 """Module containing implementation of an au_worker for virtual machines."""
 
 import os
-import threading
 import unittest
 
 import cros_build_lib as cros_lib
@@ -16,10 +15,6 @@
 class VMAUWorker(au_worker.AUWorker):
   """Test harness for updating virtual machines."""
 
-  # Class variables used to acquire individual VM variables per test.
-  _vm_lock = threading.Lock()
-  _next_port = 9222
-
   def __init__(self, options, test_results_root):
     """Processes vm-specific options."""
     au_worker.AUWorker.__init__(self, options, test_results_root)
@@ -27,9 +22,6 @@
     if options.no_graphics: self.graphics_flag = '--no_graphics'
     if not self.board: cros_lib.Die('Need board to convert base image to vm.')
 
-    self._AcquireUniquePortAndPidFile()
-    self._KillExistingVM(self._kvm_pid_file)
-
   def _KillExistingVM(self, pid_file):
     """Kills an existing VM specified by the pid_file."""
     if os.path.exists(pid_file):
@@ -40,13 +32,6 @@
 
     assert not os.path.exists(pid_file)
 
-  def _AcquireUniquePortAndPidFile(self):
-    """Acquires unique ssh port and pid file for VM."""
-    with VMAUWorker._vm_lock:
-      self._ssh_port = VMAUWorker._next_port
-      self._kvm_pid_file = '/tmp/kvm.%d' % self._ssh_port
-      VMAUWorker._next_port += 1
-
   def CleanUp(self):
     """Stop the vm after a test."""
     self._KillExistingVM(self._kvm_pid_file)
diff --git a/ctest/ctest.py b/ctest/ctest.py
index 7799a1d..ed11019 100755
--- a/ctest/ctest.py
+++ b/ctest/ctest.py
@@ -283,8 +283,9 @@
   update_engine_path = os.path.join(crosutils_root, '..', 'platform',
                                     'update_engine')
 
-  private_key_path = os.path.join(update_engine_path, 'unittest_key.pem')
-  public_key_path = GeneratePublicKey(private_key_path)
+  if clean:
+    private_key_path = os.path.join(update_engine_path, 'unittest_key.pem')
+    public_key_path = GeneratePublicKey(private_key_path)
 
   cmd = ['bin/cros_au_test_harness',
          '--base_image=%s' % base_image,
@@ -292,12 +293,14 @@
          '--board=%s' % board,
          '--type=%s' % type,
          '--remote=%s' % remote,
-         '--private_key=%s' % private_key_path,
-         '--public_key=%s' % public_key_path,
          ]
   if test_results_root: cmd.append('--test_results_root=%s' % test_results_root)
   if no_graphics: cmd.append('--no_graphics')
-  if clean: cmd.append('--clean')
+  # Using keys is only compatible with clean.
+  if clean:
+    cmd.append('--clean')
+    cmd.append('--private_key=%s' % private_key_path)
+    cmd.append('--public_key=%s' % public_key_path)
 
   cros_lib.RunCommand(cmd, cwd=crosutils_root)