blob: fc8ef344b238618ef261b5eba66ab0adde9f128b [file] [log] [blame]
#!/usr/bin/python
# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import os
import shutil
import socket
import tempfile
import unittest
from multiprocessing import connection
import common
from autotest_lib.client.common_lib import error
from autotest_lib.site_utils.lxc.container_pool import async_listener
from autotest_lib.site_utils.lxc.container_pool import unittest_client
from autotest_lib.site_utils.lxc import utils
# Namespace object for parsing cmd line options.
options = None
class AsyncListenerTests(utils.LXCTests):
"""Unit tests for the AsyncListener class."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.address = os.path.join(self.tmpdir, 'socket')
self.authkey = 'foo'
self.listener = async_listener.AsyncListener(self.address)
def tearDown(self):
self.listener.close()
shutil.rmtree(self.tmpdir)
def testInvalidAddresses(self):
"""Verifies that invalid socket paths raise errors."""
with self.assertRaises(socket.error):
invalid_address = os.path.join(self.tmpdir, 'foo', 'socket')
async_listener.AsyncListener(invalid_address)
def testAddressConflict(self):
"""Verifies that address conflicts raise errors."""
with self.assertRaises(socket.error):
async_listener.AsyncListener(self.address)
def testInetAddresses(self):
"""Verifies that inet addresses raise errors."""
with self.assertRaises(TypeError):
async_listener.AsyncListener(('127.0.0.1', 0))
def testStartStop(self):
"""Tests starting and stopping the async listener."""
try:
self.assertFalse(self.listener.is_running())
self.listener.start()
self.assertTrue(self.listener.is_running())
# Establish a connection to verify that the listener is actually
# alive.
host, client = self._make_connection()
client.close()
host.close()
self.assertTrue(self.listener.stop())
self.assertFalse(self.listener.is_running())
except:
self.fail(error.format_error())
def testStop_notRunning(self):
"""Tests stopping the async listener when it's not running."""
try:
self.assertFalse(self.listener.is_running())
# Verify that calling stop just returns False when the listener
# isn't running.
self.assertFalse(self.listener.stop())
except:
self.fail(error.format_error())
def testConnection(self):
"""Tests starting a connection."""
self.listener.start()
# Verify no pending connections.
self.assertIsNone(self.listener.get_connection(timeout=1))
# Start a client connection, verify that the listener now returns a
# pending connection. Keep a reference to the client, so it doesn't get
# GC'd while the test is in progress.
_unused = connection.Client(self.address)
self.assertIsNotNone(self.listener.get_connection(timeout=1))
def testHostToClientComms(self):
"""Tests that client/host connections can exchange data."""
self.listener.start()
host, client = self._make_connection()
# Send a message on the host side socket, check that the client receives
# it.
msg = 'quickly moving jedi zap a brown sith fox'
host.send(msg)
received_msg = client.recv()
self.assertEqual(msg, received_msg)
def testMultipleConnections(self):
"""Tests for correct pairing of multiple host/client connections."""
self.listener.start()
# Create two sets of host/client connections. Verify that messages sent
# on each client are received by the correct respective host.
host0, client0 = self._make_connection()
host1, client1 = self._make_connection()
msg0 = 'a large fawn jumped quickly over white zinc boxes'
msg1 = 'six big devils from japan quickly forgot how to waltz'
client0.send(msg0)
client1.send(msg1)
received_msg0 = host0.recv()
received_msg1 = host1.recv()
self.assertEquals(msg0, received_msg0)
self.assertEquals(msg1, received_msg1)
def _make_connection(self):
"""Creates pairs of host/client connections.
@return: A host, client connection pair.
"""
client = unittest_client.connect(self.address)
host = self.listener.get_connection(timeout=1)
return host, client
if __name__ == '__main__':
unittest.main()