blob: 73e48ac3b12cccf0dcc6c7501ea800aff262f94f [file] [log] [blame]
#!/usr/bin/python3
import unittest, time
import common
from autotest_lib.client.common_lib import global_config
from autotest_lib.client.common_lib.test_utils import mock
from autotest_lib.database import database_connection
_CONFIG_SECTION = 'AUTOTEST_WEB'
_HOST = 'myhost'
_USER = 'myuser'
_PASS = 'mypass'
_DB_NAME = 'mydb'
_DB_TYPE = 'mydbtype'
_CONNECT_KWARGS = dict(host=_HOST, username=_USER, password=_PASS,
db_name=_DB_NAME)
_RECONNECT_DELAY = 10
class FakeDatabaseError(Exception):
pass
class DatabaseConnectionTest(unittest.TestCase):
def setUp(self):
self.god = mock.mock_god()
self.god.stub_function(time, 'sleep')
def tearDown(self):
global_config.global_config.reset_config_values()
self.god.unstub_all()
def _get_database_connection(self, config_section=_CONFIG_SECTION):
if config_section == _CONFIG_SECTION:
self._override_config()
db = database_connection.DatabaseConnection(config_section)
self._fake_backend = self.god.create_mock_class(
database_connection._GenericBackend, 'fake_backend')
for exception in database_connection._DB_EXCEPTIONS:
setattr(self._fake_backend, exception, FakeDatabaseError)
self._fake_backend.rowcount = 0
def get_fake_backend(db_type):
self._db_type = db_type
return self._fake_backend
self.god.stub_with(db, '_get_backend', get_fake_backend)
db.reconnect_delay_sec = _RECONNECT_DELAY
return db
def _override_config(self):
c = global_config.global_config
c.override_config_value(_CONFIG_SECTION, 'host', _HOST)
c.override_config_value(_CONFIG_SECTION, 'user', _USER)
c.override_config_value(_CONFIG_SECTION, 'password', _PASS)
c.override_config_value(_CONFIG_SECTION, 'database', _DB_NAME)
c.override_config_value(_CONFIG_SECTION, 'db_type', _DB_TYPE)
def test_connect(self):
db = self._get_database_connection(config_section=None)
self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
db.connect(db_type=_DB_TYPE, host=_HOST, username=_USER,
password=_PASS, db_name=_DB_NAME)
self.assertEquals(self._db_type, _DB_TYPE)
self.god.check_playback()
def test_global_config(self):
db = self._get_database_connection()
self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
db.connect()
self.assertEquals(self._db_type, _DB_TYPE)
self.god.check_playback()
def _expect_reconnect(self, fail=False):
self._fake_backend.disconnect.expect_call()
call = self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
if fail:
call.and_raises(FakeDatabaseError())
def _expect_fail_and_reconnect(self, num_reconnects, fail_last=False):
self._fake_backend.connect.expect_call(**_CONNECT_KWARGS).and_raises(
FakeDatabaseError())
for i in range(num_reconnects):
time.sleep.expect_call(_RECONNECT_DELAY)
if i < num_reconnects - 1:
self._expect_reconnect(fail=True)
else:
self._expect_reconnect(fail=fail_last)
def test_connect_retry(self):
db = self._get_database_connection()
self._expect_fail_and_reconnect(1)
db.connect()
self.god.check_playback()
self._fake_backend.disconnect.expect_call()
self._expect_fail_and_reconnect(0)
self.assertRaises(FakeDatabaseError, db.connect,
try_reconnecting=False)
self.god.check_playback()
db.reconnect_enabled = False
self._fake_backend.disconnect.expect_call()
self._expect_fail_and_reconnect(0)
self.assertRaises(FakeDatabaseError, db.connect)
self.god.check_playback()
def test_max_reconnect(self):
db = self._get_database_connection()
db.max_reconnect_attempts = 5
self._expect_fail_and_reconnect(5, fail_last=True)
self.assertRaises(FakeDatabaseError, db.connect)
self.god.check_playback()
def test_reconnect_forever(self):
db = self._get_database_connection()
db.max_reconnect_attempts = database_connection.RECONNECT_FOREVER
self._expect_fail_and_reconnect(30)
db.connect()
self.god.check_playback()
def _simple_connect(self, db):
self._fake_backend.connect.expect_call(**_CONNECT_KWARGS)
db.connect()
self.god.check_playback()
def test_disconnect(self):
db = self._get_database_connection()
self._simple_connect(db)
self._fake_backend.disconnect.expect_call()
db.disconnect()
self.god.check_playback()
def test_execute(self):
db = self._get_database_connection()
self._simple_connect(db)
params = object()
self._fake_backend.execute.expect_call('query', params)
db.execute('query', params)
self.god.check_playback()
def test_execute_retry(self):
db = self._get_database_connection()
self._simple_connect(db)
self._fake_backend.execute.expect_call('query', None).and_raises(
FakeDatabaseError())
self._expect_reconnect()
self._fake_backend.execute.expect_call('query', None)
db.execute('query')
self.god.check_playback()
self._fake_backend.execute.expect_call('query', None).and_raises(
FakeDatabaseError())
self.assertRaises(FakeDatabaseError, db.execute, 'query',
try_reconnecting=False)
if __name__ == '__main__':
unittest.main()