blob: 1790f796c8e13c277ea8781cfebae0fb6b7da106 [file] [log] [blame]
#!/usr/bin/python
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import Queue
import array
import collections
import os
import shutil
import tempfile
import threading
import unittest
from contextlib import contextmanager
from multiprocessing import connection
import common
from autotest_lib.site_utils import lxc
from autotest_lib.site_utils.lxc import unittest_setup
from autotest_lib.site_utils.lxc.container_pool import message
from autotest_lib.site_utils.lxc.container_pool import service
from autotest_lib.site_utils.lxc.container_pool import unittest_client
FakeHostDir = collections.namedtuple('FakeHostDir', ['path'])
class ServiceTests(unittest.TestCase):
"""Unit tests for the Service class."""
@classmethod
def setUpClass(cls):
"""Creates a directory for running the unit tests. """
# Explicitly use /tmp as the tmpdir. Board specific TMPDIRs inside of
# the chroot are set to a path that causes the socket address to exceed
# the maximum allowable length.
cls.test_dir = tempfile.mkdtemp(prefix='service_unittest_', dir='/tmp')
@classmethod
def tearDownClass(cls):
"""Deletes the test directory. """
shutil.rmtree(cls.test_dir)
def setUp(self):
"""Per-test setup."""
# Put each test in its own test dir, so it's hermetic.
self.test_dir = tempfile.mkdtemp(dir=ServiceTests.test_dir)
self.host_dir = FakeHostDir(self.test_dir)
self.address = os.path.join(self.test_dir,
lxc.DEFAULT_CONTAINER_POOL_SOCKET)
def testConnection(self):
"""Tests a simple connection to the pool service."""
with self.run_service():
self.assertTrue(self._pool_is_healthy())
def testAbortedConnection(self):
"""Tests that a closed connection doesn't crash the service."""
with self.run_service():
client = connection.Client(self.address)
client.close()
self.assertTrue(self._pool_is_healthy())
def testCorruptedMessage(self):
"""Tests that corrupted messages don't crash the service."""
with self.run_service(), self.create_client() as client:
# Send a raw array of bytes. This will cause an unpickling error.
client.send_bytes(array.array('i', range(1, 10)))
# Verify that the container pool closed the connection.
with self.assertRaises(EOFError):
client.recv()
# Verify that the main container pool service is still alive.
self.assertTrue(self._pool_is_healthy())
def testInvalidMessageClass(self):
"""Tests that bad messages don't crash the service."""
with self.run_service(), self.create_client() as client:
# Send a valid object but not of the right Message class.
client.send('foo')
# Verify that the container pool closed the connection.
with self.assertRaises(EOFError):
client.recv()
# Verify that the main container pool service is still alive.
self.assertTrue(self._pool_is_healthy())
def testInvalidMessageType(self):
"""Tests that messages with a bad type don't crash the service."""
with self.run_service(), self.create_client() as client:
# Send a valid object but not of the right Message class.
client.send(message.Message('foo', None))
# Verify that the container pool closed the connection.
with self.assertRaises(EOFError):
client.recv()
# Verify that the main container pool service is still alive.
self.assertTrue(self._pool_is_healthy())
def testStop(self):
"""Tests stopping the service."""
with self.run_service() as svc, self.create_client() as client:
self.assertTrue(svc.is_running())
client.send(message.shutdown())
client.recv() # wait for ack
self.assertFalse(svc.is_running())
def testStatus(self):
"""Tests querying service status."""
pool = MockPool()
with self.run_service(pool) as svc, self.create_client() as client:
client.send(message.status())
status = client.recv()
self.assertTrue(status['running'])
self.assertEqual(self.address, status['socket_path'])
self.assertEqual(pool.capacity, status['pool capacity'])
self.assertEqual(pool.size, status['pool size'])
self.assertEqual(pool.worker_count, status['pool worker count'])
self.assertEqual(pool.errors.qsize(), status['pool errors'])
# Change some values, ensure the changes are reflected.
pool.capacity = 42
pool.size = 19
pool.worker_count = 3
error_count = 8
for e in range(error_count):
pool.errors.put(e)
client.send(message.status())
status = client.recv()
self.assertTrue(status['running'])
self.assertEqual(self.address, status['socket_path'])
self.assertEqual(pool.capacity, status['pool capacity'])
self.assertEqual(pool.size, status['pool size'])
self.assertEqual(pool.worker_count, status['pool worker count'])
self.assertEqual(pool.errors.qsize(), status['pool errors'])
def testGet(self):
"""Tests getting a container from the pool."""
test_pool = MockPool()
fake_container = MockContainer()
test_id = lxc.ContainerId.create(42)
test_pool.containers.put(fake_container)
with self.run_service(test_pool):
with self.create_client() as client:
client.send(message.get(test_id))
test_container = client.recv()
self.assertEqual(test_id, test_container.id)
def testGet_timeoutImmediate(self):
"""Tests getting a container with timeouts."""
test_id = lxc.ContainerId.create(42)
with self.run_service():
with self.create_client() as client:
client.send(message.get(test_id))
test_container = client.recv()
self.assertIsNone(test_container)
def testGet_timeoutDelayed(self):
"""Tests getting a container with timeouts."""
test_id = lxc.ContainerId.create(42)
with self.run_service():
with self.create_client() as client:
client.send(message.get(test_id, timeout=1))
test_container = client.recv()
self.assertIsNone(test_container)
def testMultipleClients(self):
"""Tests multiple simultaneous connections."""
with self.run_service():
with self.create_client() as client0:
with self.create_client() as client1:
msg0 = 'two driven jocks help fax my big quiz'
msg1 = 'how quickly daft jumping zebras vex'
client0.send(message.echo(msg0))
client1.send(message.echo(msg1))
echo0 = client0.recv()
echo1 = client1.recv()
self.assertEqual(msg0, echo0)
self.assertEqual(msg1, echo1)
def _pool_is_healthy(self):
"""Verifies that the pool service is still functioning.
Sends an echo message and tests for a response. This is a stronger
signal of aliveness than checking Service.is_running, but a False return
value does not necessarily indicate that the pool service shut down
cleanly. Use Service.is_running to check that.
"""
with self.create_client() as client:
msg = 'foobar'
client.send(message.echo(msg))
return client.recv() == msg
@contextmanager
def run_service(self, pool=None):
"""Creates and cleans up a Service instance."""
if pool is None:
pool = MockPool()
svc = service.Service(self.host_dir, pool)
thread = threading.Thread(name='service', target=svc.start)
thread.start()
try:
yield svc
finally:
svc.stop()
thread.join(1)
@contextmanager
def create_client(self):
"""Creates and cleans up a client connection."""
client = unittest_client.connect(self.address)
try:
yield client
finally:
client.close()
class MockPool(object):
"""A mock pool class for testing the service."""
def __init__(self):
"""Initializes a mock empty pool."""
self.capacity = 0
self.size = 0
self.worker_count = 0
self.errors = Queue.Queue()
self.containers = Queue.Queue()
def cleanup(self):
"""Required by pool interface. Does nothing."""
pass
def get(self, timeout=0):
"""Required by pool interface.
@return: A pool from the containers queue.
"""
try:
return self.containers.get(block=(timeout > 0), timeout=timeout)
except Queue.Empty:
return None
class MockContainer(object):
"""A mock container class for testing the service."""
def __init__(self):
"""Initializes a mock container."""
self.id = None
self.name = 'test_container'
if __name__ == '__main__':
unittest_setup.setup(require_sudo=False)
unittest.main()