blob: eca776b5a4c40a574d17985cf07c8475cfa1cd9b [file] [log] [blame] [edit]
#!/usr/bin/python -u
import os, sys, re, tempfile
from optparse import OptionParser
import common
from autotest_lib.client.common_lib import utils
from autotest_lib.database import database_connection
MIGRATE_TABLE = 'migrate_info'
_AUTODIR = os.path.join(os.path.dirname(__file__), '..')
_MIGRATIONS_DIRS = {
'AUTOTEST_WEB' : os.path.join(_AUTODIR, 'frontend', 'migrations'),
'TKO' : os.path.join(_AUTODIR, 'tko', 'migrations'),
}
_DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD
class Migration(object):
"""Represents a database migration."""
_UP_ATTRIBUTES = ('migrate_up', 'UP_SQL')
_DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL')
def __init__(self, name, version, module):
self.name = name
self.version = version
self.module = module
self._check_attributes(self._UP_ATTRIBUTES)
self._check_attributes(self._DOWN_ATTRIBUTES)
@classmethod
def from_file(cls, filename):
"""Instantiates a Migration from a file.
@param filename: Name of a migration file.
@return An instantiated Migration object.
"""
version = int(filename[:3])
name = filename[:-3]
module = __import__(name, globals(), locals(), [])
return cls(name, version, module)
def _check_attributes(self, attributes):
method_name, sql_name = attributes
assert (hasattr(self.module, method_name) or
hasattr(self.module, sql_name))
def _execute_migration(self, attributes, manager):
method_name, sql_name = attributes
method = getattr(self.module, method_name, None)
if method:
assert callable(method)
method(manager)
else:
sql = getattr(self.module, sql_name)
assert isinstance(sql, basestring)
manager.execute_script(sql)
def migrate_up(self, manager):
"""Performs an up migration (to a newer version).
@param manager: A MigrationManager object.
"""
self._execute_migration(self._UP_ATTRIBUTES, manager)
def migrate_down(self, manager):
"""Performs a down migration (to an older version).
@param manager: A MigrationManager object.
"""
self._execute_migration(self._DOWN_ATTRIBUTES, manager)
class MigrationManager(object):
"""Managest database migrations."""
connection = None
cursor = None
migrations_dir = None
def __init__(self, database_connection, migrations_dir=None, force=False):
self._database = database_connection
self.force = force
# A boolean, this will only be set to True if this migration should be
# simulated rather than actually taken. For use with migrations that
# may make destructive queries
self.simulate = False
self._set_migrations_dir(migrations_dir)
def _set_migrations_dir(self, migrations_dir=None):
config_section = self._config_section()
if migrations_dir is None:
migrations_dir = os.path.abspath(
_MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR))
self.migrations_dir = migrations_dir
sys.path.append(migrations_dir)
assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist"
def _config_section(self):
return self._database.global_config_section
def get_db_name(self):
"""Gets the database name."""
return self._database.get_database_info()['db_name']
def execute(self, query, *parameters):
"""Executes a database query.
@param query: The query to execute.
@param parameters: Associated parameters for the query.
@return The result of the query.
"""
return self._database.execute(query, parameters)
def execute_script(self, script):
"""Executes a set of database queries.
@param script: A string of semicolon-separated queries.
"""
sql_statements = [statement.strip()
for statement in script.split(';')
if statement.strip()]
for statement in sql_statements:
self.execute(statement)
def check_migrate_table_exists(self):
"""Checks whether the migration table exists."""
try:
self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
return True
except self._database.DatabaseError, exc:
# we can't check for more specifics due to differences between DB
# backends (we can't even check for a subclass of DatabaseError)
return False
def create_migrate_table(self):
"""Creates the migration table."""
if not self.check_migrate_table_exists():
self.execute("CREATE TABLE %s (`version` integer)" %
MIGRATE_TABLE)
else:
self.execute("DELETE FROM %s" % MIGRATE_TABLE)
self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE)
assert self._database.rowcount == 1
def set_db_version(self, version):
"""Sets the database version.
@param version: The version to which to set the database.
"""
assert isinstance(version, int)
self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE,
version)
assert self._database.rowcount == 1
def get_db_version(self):
"""Gets the database version.
@return The database version.
"""
if not self.check_migrate_table_exists():
return 0
rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE)
if len(rows) == 0:
return 0
assert len(rows) == 1 and len(rows[0]) == 1
return rows[0][0]
def get_migrations(self, minimum_version=None, maximum_version=None):
"""Gets the list of migrations to perform.
@param minimum_version: The minimum database version.
@param maximum_version: The maximum database version.
@return A list of Migration objects.
"""
migrate_files = [filename for filename
in os.listdir(self.migrations_dir)
if re.match(r'^\d\d\d_.*\.py$', filename)]
migrate_files.sort()
migrations = [Migration.from_file(filename)
for filename in migrate_files]
if minimum_version is not None:
migrations = [migration for migration in migrations
if migration.version >= minimum_version]
if maximum_version is not None:
migrations = [migration for migration in migrations
if migration.version <= maximum_version]
return migrations
def do_migration(self, migration, migrate_up=True):
"""Performs a migration.
@param migration: The Migration to perform.
@param migrate_up: Whether to migrate up (if not, then migrates down).
"""
print 'Applying migration %s' % migration.name, # no newline
if migrate_up:
print 'up'
assert self.get_db_version() == migration.version - 1
migration.migrate_up(self)
new_version = migration.version
else:
print 'down'
assert self.get_db_version() == migration.version
migration.migrate_down(self)
new_version = migration.version - 1
self.set_db_version(new_version)
def migrate_to_version(self, version):
"""Performs a migration to a specified version.
@param version: The version to which to migrate the database.
"""
current_version = self.get_db_version()
if current_version == 0 and self._config_section() == 'AUTOTEST_WEB':
self._migrate_from_base()
current_version = self.get_db_version()
if current_version < version:
lower, upper = current_version, version
migrate_up = True
else:
lower, upper = version, current_version
migrate_up = False
migrations = self.get_migrations(lower + 1, upper)
if not migrate_up:
migrations.reverse()
for migration in migrations:
self.do_migration(migration, migrate_up)
assert self.get_db_version() == version
print 'At version', version
def _migrate_from_base(self):
self.confirm_initialization()
migration_script = utils.read_file(
os.path.join(os.path.dirname(__file__), 'schema_051.sql'))
migration_script = migration_script % (
dict(username=self._database.get_database_info()['username']))
self.execute_script(migration_script)
self.create_migrate_table()
self.set_db_version(51)
def confirm_initialization(self):
"""Confirms with the user that we should initialize the database.
@raises Exception, if the user chooses to abort the migration.
"""
if not self.force:
response = raw_input(
'Your %s database does not appear to be initialized. Do you '
'want to recreate it (this will result in loss of any existing '
'data) (yes/No)? ' % self.get_db_name())
if response != 'yes':
raise Exception('User has chosen to abort migration')
def get_latest_version(self):
"""Gets the latest database version."""
migrations = self.get_migrations()
return migrations[-1].version
def migrate_to_latest(self):
"""Migrates the database to the latest version."""
latest_version = self.get_latest_version()
self.migrate_to_version(latest_version)
def initialize_test_db(self):
"""Initializes a test database."""
db_name = self.get_db_name()
test_db_name = 'test_' + db_name
# first, connect to no DB so we can create a test DB
self._database.connect(db_name='')
print 'Creating test DB', test_db_name
self.execute('CREATE DATABASE ' + test_db_name)
self._database.disconnect()
# now connect to the test DB
self._database.connect(db_name=test_db_name)
def remove_test_db(self):
"""Removes a test database."""
print 'Removing test DB'
self.execute('DROP DATABASE ' + self.get_db_name())
# reset connection back to real DB
self._database.disconnect()
self._database.connect()
def get_mysql_args(self):
"""Returns the mysql arguments as a string."""
return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' %
self._database.get_database_info())
def migrate_to_version_or_latest(self, version):
"""Migrates to either a specified version, or the latest version.
@param version: The version to which to migrate the database,
or None in order to migrate to the latest version.
"""
if version is None:
self.migrate_to_latest()
else:
self.migrate_to_version(version)
def do_sync_db(self, version=None):
"""Migrates the database.
@param version: The version to which to migrate the database.
"""
print 'Migration starting for database', self.get_db_name()
self.migrate_to_version_or_latest(version)
print 'Migration complete'
def test_sync_db(self, version=None):
"""Create a fresh database and run all migrations on it.
@param version: The version to which to migrate the database.
"""
self.initialize_test_db()
try:
print 'Starting migration test on DB', self.get_db_name()
self.migrate_to_version_or_latest(version)
# show schema to the user
os.system('mysqldump %s --no-data=true '
'--add-drop-table=false' %
self.get_mysql_args())
finally:
self.remove_test_db()
print 'Test finished successfully'
def simulate_sync_db(self, version=None):
"""Creates a fresh DB, copies existing DB to it, then synchronizes it.
@param version: The version to which to migrate the database.
"""
db_version = self.get_db_version()
# don't do anything if we're already at the latest version
if db_version == self.get_latest_version():
print 'Skipping simulation, already at latest version'
return
# get existing data
self.initialize_and_fill_test_db()
try:
print 'Starting migration test on DB', self.get_db_name()
self.migrate_to_version_or_latest(version)
finally:
self.remove_test_db()
print 'Test finished successfully'
def initialize_and_fill_test_db(self):
"""Initializes and fills up a test database."""
print 'Dumping existing data'
dump_fd, dump_file = tempfile.mkstemp('.migrate_dump')
os.system('mysqldump %s >%s' %
(self.get_mysql_args(), dump_file))
# fill in test DB
self.initialize_test_db()
print 'Filling in test DB'
os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file))
os.close(dump_fd)
os.remove(dump_file)
USAGE = """\
%s [options] sync|test|simulate|safesync [version]
Options:
-d --database Which database to act on
-f --force Don't ask for confirmation
--debug Print all DB queries"""\
% sys.argv[0]
def main():
"""Main function for the migration script."""
parser = OptionParser()
parser.add_option("-d", "--database",
help="which database to act on",
dest="database",
default="AUTOTEST_WEB")
parser.add_option("-f", "--force", help="don't ask for confirmation",
action="store_true")
parser.add_option('--debug', help='print all DB queries',
action='store_true')
(options, args) = parser.parse_args()
manager = get_migration_manager(db_name=options.database,
debug=options.debug, force=options.force)
if len(args) > 0:
if len(args) > 1:
version = int(args[1])
else:
version = None
if args[0] == 'sync':
manager.do_sync_db(version)
elif args[0] == 'test':
manager.simulate=True
manager.test_sync_db(version)
elif args[0] == 'simulate':
manager.simulate=True
manager.simulate_sync_db(version)
elif args[0] == 'safesync':
print 'Simluating migration'
manager.simulate=True
manager.simulate_sync_db(version)
print 'Performing real migration'
manager.simulate=False
manager.do_sync_db(version)
else:
print USAGE
return
print USAGE
def get_migration_manager(db_name, debug, force):
"""Creates a MigrationManager object.
@param db_name: The database name.
@param debug: Whether to print debug messages.
@param force: Whether to force migration without asking for confirmation.
@return A created MigrationManager object.
"""
database = database_connection.DatabaseConnection(db_name)
database.debug = debug
database.reconnect_enabled = False
database.connect()
return MigrationManager(database, force=force)
if __name__ == '__main__':
main()