| #!/usr/bin/python -u |
| |
| import os, sys, re, tempfile |
| from optparse import OptionParser |
| import common |
| from autotest_lib.client.common_lib import utils |
| from autotest_lib.database import database_connection |
| |
| MIGRATE_TABLE = 'migrate_info' |
| |
| _AUTODIR = os.path.join(os.path.dirname(__file__), '..') |
| _MIGRATIONS_DIRS = { |
| 'AUTOTEST_WEB' : os.path.join(_AUTODIR, 'frontend', 'migrations'), |
| 'TKO' : os.path.join(_AUTODIR, 'tko', 'migrations'), |
| } |
| _DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD |
| |
| class Migration(object): |
| """Represents a database migration.""" |
| _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL') |
| _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL') |
| |
| def __init__(self, name, version, module): |
| self.name = name |
| self.version = version |
| self.module = module |
| self._check_attributes(self._UP_ATTRIBUTES) |
| self._check_attributes(self._DOWN_ATTRIBUTES) |
| |
| |
| @classmethod |
| def from_file(cls, filename): |
| """Instantiates a Migration from a file. |
| |
| @param filename: Name of a migration file. |
| |
| @return An instantiated Migration object. |
| |
| """ |
| version = int(filename[:3]) |
| name = filename[:-3] |
| module = __import__(name, globals(), locals(), []) |
| return cls(name, version, module) |
| |
| |
| def _check_attributes(self, attributes): |
| method_name, sql_name = attributes |
| assert (hasattr(self.module, method_name) or |
| hasattr(self.module, sql_name)) |
| |
| |
| def _execute_migration(self, attributes, manager): |
| method_name, sql_name = attributes |
| method = getattr(self.module, method_name, None) |
| if method: |
| assert callable(method) |
| method(manager) |
| else: |
| sql = getattr(self.module, sql_name) |
| assert isinstance(sql, basestring) |
| manager.execute_script(sql) |
| |
| |
| def migrate_up(self, manager): |
| """Performs an up migration (to a newer version). |
| |
| @param manager: A MigrationManager object. |
| |
| """ |
| self._execute_migration(self._UP_ATTRIBUTES, manager) |
| |
| |
| def migrate_down(self, manager): |
| """Performs a down migration (to an older version). |
| |
| @param manager: A MigrationManager object. |
| |
| """ |
| self._execute_migration(self._DOWN_ATTRIBUTES, manager) |
| |
| |
| class MigrationManager(object): |
| """Managest database migrations.""" |
| connection = None |
| cursor = None |
| migrations_dir = None |
| |
| def __init__(self, database_connection, migrations_dir=None, force=False): |
| self._database = database_connection |
| self.force = force |
| # A boolean, this will only be set to True if this migration should be |
| # simulated rather than actually taken. For use with migrations that |
| # may make destructive queries |
| self.simulate = False |
| self._set_migrations_dir(migrations_dir) |
| |
| |
| def _set_migrations_dir(self, migrations_dir=None): |
| config_section = self._config_section() |
| if migrations_dir is None: |
| migrations_dir = os.path.abspath( |
| _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR)) |
| self.migrations_dir = migrations_dir |
| sys.path.append(migrations_dir) |
| assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist" |
| |
| |
| def _config_section(self): |
| return self._database.global_config_section |
| |
| |
| def get_db_name(self): |
| """Gets the database name.""" |
| return self._database.get_database_info()['db_name'] |
| |
| |
| def execute(self, query, *parameters): |
| """Executes a database query. |
| |
| @param query: The query to execute. |
| @param parameters: Associated parameters for the query. |
| |
| @return The result of the query. |
| |
| """ |
| return self._database.execute(query, parameters) |
| |
| |
| def execute_script(self, script): |
| """Executes a set of database queries. |
| |
| @param script: A string of semicolon-separated queries. |
| |
| """ |
| sql_statements = [statement.strip() |
| for statement in script.split(';') |
| if statement.strip()] |
| for statement in sql_statements: |
| self.execute(statement) |
| |
| |
| def check_migrate_table_exists(self): |
| """Checks whether the migration table exists.""" |
| try: |
| self.execute("SELECT * FROM %s" % MIGRATE_TABLE) |
| return True |
| except self._database.DatabaseError, exc: |
| # we can't check for more specifics due to differences between DB |
| # backends (we can't even check for a subclass of DatabaseError) |
| return False |
| |
| |
| def create_migrate_table(self): |
| """Creates the migration table.""" |
| if not self.check_migrate_table_exists(): |
| self.execute("CREATE TABLE %s (`version` integer)" % |
| MIGRATE_TABLE) |
| else: |
| self.execute("DELETE FROM %s" % MIGRATE_TABLE) |
| self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE) |
| assert self._database.rowcount == 1 |
| |
| |
| def set_db_version(self, version): |
| """Sets the database version. |
| |
| @param version: The version to which to set the database. |
| |
| """ |
| assert isinstance(version, int) |
| self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE, |
| version) |
| assert self._database.rowcount == 1 |
| |
| |
| def get_db_version(self): |
| """Gets the database version. |
| |
| @return The database version. |
| |
| """ |
| if not self.check_migrate_table_exists(): |
| return 0 |
| rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE) |
| if len(rows) == 0: |
| return 0 |
| assert len(rows) == 1 and len(rows[0]) == 1 |
| return rows[0][0] |
| |
| |
| def get_migrations(self, minimum_version=None, maximum_version=None): |
| """Gets the list of migrations to perform. |
| |
| @param minimum_version: The minimum database version. |
| @param maximum_version: The maximum database version. |
| |
| @return A list of Migration objects. |
| |
| """ |
| migrate_files = [filename for filename |
| in os.listdir(self.migrations_dir) |
| if re.match(r'^\d\d\d_.*\.py$', filename)] |
| migrate_files.sort() |
| migrations = [Migration.from_file(filename) |
| for filename in migrate_files] |
| if minimum_version is not None: |
| migrations = [migration for migration in migrations |
| if migration.version >= minimum_version] |
| if maximum_version is not None: |
| migrations = [migration for migration in migrations |
| if migration.version <= maximum_version] |
| return migrations |
| |
| |
| def do_migration(self, migration, migrate_up=True): |
| """Performs a migration. |
| |
| @param migration: The Migration to perform. |
| @param migrate_up: Whether to migrate up (if not, then migrates down). |
| |
| """ |
| print 'Applying migration %s' % migration.name, # no newline |
| if migrate_up: |
| print 'up' |
| assert self.get_db_version() == migration.version - 1 |
| migration.migrate_up(self) |
| new_version = migration.version |
| else: |
| print 'down' |
| assert self.get_db_version() == migration.version |
| migration.migrate_down(self) |
| new_version = migration.version - 1 |
| self.set_db_version(new_version) |
| |
| |
| def migrate_to_version(self, version): |
| """Performs a migration to a specified version. |
| |
| @param version: The version to which to migrate the database. |
| |
| """ |
| current_version = self.get_db_version() |
| if current_version == 0 and self._config_section() == 'AUTOTEST_WEB': |
| self._migrate_from_base() |
| current_version = self.get_db_version() |
| |
| if current_version < version: |
| lower, upper = current_version, version |
| migrate_up = True |
| else: |
| lower, upper = version, current_version |
| migrate_up = False |
| |
| migrations = self.get_migrations(lower + 1, upper) |
| if not migrate_up: |
| migrations.reverse() |
| for migration in migrations: |
| self.do_migration(migration, migrate_up) |
| |
| assert self.get_db_version() == version |
| print 'At version', version |
| |
| |
| def _migrate_from_base(self): |
| self.confirm_initialization() |
| |
| migration_script = utils.read_file( |
| os.path.join(os.path.dirname(__file__), 'schema_051.sql')) |
| migration_script = migration_script % ( |
| dict(username=self._database.get_database_info()['username'])) |
| self.execute_script(migration_script) |
| |
| self.create_migrate_table() |
| self.set_db_version(51) |
| |
| |
| def confirm_initialization(self): |
| """Confirms with the user that we should initialize the database. |
| |
| @raises Exception, if the user chooses to abort the migration. |
| |
| """ |
| if not self.force: |
| response = raw_input( |
| 'Your %s database does not appear to be initialized. Do you ' |
| 'want to recreate it (this will result in loss of any existing ' |
| 'data) (yes/No)? ' % self.get_db_name()) |
| if response != 'yes': |
| raise Exception('User has chosen to abort migration') |
| |
| |
| def get_latest_version(self): |
| """Gets the latest database version.""" |
| migrations = self.get_migrations() |
| return migrations[-1].version |
| |
| |
| def migrate_to_latest(self): |
| """Migrates the database to the latest version.""" |
| latest_version = self.get_latest_version() |
| self.migrate_to_version(latest_version) |
| |
| |
| def initialize_test_db(self): |
| """Initializes a test database.""" |
| db_name = self.get_db_name() |
| test_db_name = 'test_' + db_name |
| # first, connect to no DB so we can create a test DB |
| self._database.connect(db_name='') |
| print 'Creating test DB', test_db_name |
| self.execute('CREATE DATABASE ' + test_db_name) |
| self._database.disconnect() |
| # now connect to the test DB |
| self._database.connect(db_name=test_db_name) |
| |
| |
| def remove_test_db(self): |
| """Removes a test database.""" |
| print 'Removing test DB' |
| self.execute('DROP DATABASE ' + self.get_db_name()) |
| # reset connection back to real DB |
| self._database.disconnect() |
| self._database.connect() |
| |
| |
| def get_mysql_args(self): |
| """Returns the mysql arguments as a string.""" |
| return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' % |
| self._database.get_database_info()) |
| |
| |
| def migrate_to_version_or_latest(self, version): |
| """Migrates to either a specified version, or the latest version. |
| |
| @param version: The version to which to migrate the database, |
| or None in order to migrate to the latest version. |
| |
| """ |
| if version is None: |
| self.migrate_to_latest() |
| else: |
| self.migrate_to_version(version) |
| |
| |
| def do_sync_db(self, version=None): |
| """Migrates the database. |
| |
| @param version: The version to which to migrate the database. |
| |
| """ |
| print 'Migration starting for database', self.get_db_name() |
| self.migrate_to_version_or_latest(version) |
| print 'Migration complete' |
| |
| |
| def test_sync_db(self, version=None): |
| """Create a fresh database and run all migrations on it. |
| |
| @param version: The version to which to migrate the database. |
| |
| """ |
| self.initialize_test_db() |
| try: |
| print 'Starting migration test on DB', self.get_db_name() |
| self.migrate_to_version_or_latest(version) |
| # show schema to the user |
| os.system('mysqldump %s --no-data=true ' |
| '--add-drop-table=false' % |
| self.get_mysql_args()) |
| finally: |
| self.remove_test_db() |
| print 'Test finished successfully' |
| |
| |
| def simulate_sync_db(self, version=None): |
| """Creates a fresh DB, copies existing DB to it, then synchronizes it. |
| |
| @param version: The version to which to migrate the database. |
| |
| """ |
| db_version = self.get_db_version() |
| # don't do anything if we're already at the latest version |
| if db_version == self.get_latest_version(): |
| print 'Skipping simulation, already at latest version' |
| return |
| # get existing data |
| self.initialize_and_fill_test_db() |
| try: |
| print 'Starting migration test on DB', self.get_db_name() |
| self.migrate_to_version_or_latest(version) |
| finally: |
| self.remove_test_db() |
| print 'Test finished successfully' |
| |
| |
| def initialize_and_fill_test_db(self): |
| """Initializes and fills up a test database.""" |
| print 'Dumping existing data' |
| dump_fd, dump_file = tempfile.mkstemp('.migrate_dump') |
| os.system('mysqldump %s >%s' % |
| (self.get_mysql_args(), dump_file)) |
| # fill in test DB |
| self.initialize_test_db() |
| print 'Filling in test DB' |
| os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file)) |
| os.close(dump_fd) |
| os.remove(dump_file) |
| |
| |
| USAGE = """\ |
| %s [options] sync|test|simulate|safesync [version] |
| Options: |
| -d --database Which database to act on |
| -f --force Don't ask for confirmation |
| --debug Print all DB queries"""\ |
| % sys.argv[0] |
| |
| |
| def main(): |
| """Main function for the migration script.""" |
| parser = OptionParser() |
| parser.add_option("-d", "--database", |
| help="which database to act on", |
| dest="database", |
| default="AUTOTEST_WEB") |
| parser.add_option("-f", "--force", help="don't ask for confirmation", |
| action="store_true") |
| parser.add_option('--debug', help='print all DB queries', |
| action='store_true') |
| (options, args) = parser.parse_args() |
| manager = get_migration_manager(db_name=options.database, |
| debug=options.debug, force=options.force) |
| |
| if len(args) > 0: |
| if len(args) > 1: |
| version = int(args[1]) |
| else: |
| version = None |
| if args[0] == 'sync': |
| manager.do_sync_db(version) |
| elif args[0] == 'test': |
| manager.simulate=True |
| manager.test_sync_db(version) |
| elif args[0] == 'simulate': |
| manager.simulate=True |
| manager.simulate_sync_db(version) |
| elif args[0] == 'safesync': |
| print 'Simluating migration' |
| manager.simulate=True |
| manager.simulate_sync_db(version) |
| print 'Performing real migration' |
| manager.simulate=False |
| manager.do_sync_db(version) |
| else: |
| print USAGE |
| return |
| |
| print USAGE |
| |
| |
| def get_migration_manager(db_name, debug, force): |
| """Creates a MigrationManager object. |
| |
| @param db_name: The database name. |
| @param debug: Whether to print debug messages. |
| @param force: Whether to force migration without asking for confirmation. |
| |
| @return A created MigrationManager object. |
| |
| """ |
| database = database_connection.DatabaseConnection(db_name) |
| database.debug = debug |
| database.reconnect_enabled = False |
| database.connect() |
| return MigrationManager(database, force=force) |
| |
| |
| if __name__ == '__main__': |
| main() |