|  | #!/usr/bin/python -u | 
|  |  | 
|  | import os, sys, re, tempfile | 
|  | from optparse import OptionParser | 
|  | import common | 
|  | from autotest_lib.client.common_lib import utils | 
|  | from autotest_lib.database import database_connection | 
|  |  | 
|  | MIGRATE_TABLE = 'migrate_info' | 
|  |  | 
|  | _AUTODIR = os.path.join(os.path.dirname(__file__), '..') | 
|  | _MIGRATIONS_DIRS = { | 
|  | 'AUTOTEST_WEB' : os.path.join(_AUTODIR, 'frontend', 'migrations'), | 
|  | 'TKO' : os.path.join(_AUTODIR, 'tko', 'migrations'), | 
|  | } | 
|  | _DEFAULT_MIGRATIONS_DIR = 'migrations' # use CWD | 
|  |  | 
|  | class Migration(object): | 
|  | """Represents a database migration.""" | 
|  | _UP_ATTRIBUTES = ('migrate_up', 'UP_SQL') | 
|  | _DOWN_ATTRIBUTES = ('migrate_down', 'DOWN_SQL') | 
|  |  | 
|  | def __init__(self, name, version, module): | 
|  | self.name = name | 
|  | self.version = version | 
|  | self.module = module | 
|  | self._check_attributes(self._UP_ATTRIBUTES) | 
|  | self._check_attributes(self._DOWN_ATTRIBUTES) | 
|  |  | 
|  |  | 
|  | @classmethod | 
|  | def from_file(cls, filename): | 
|  | """Instantiates a Migration from a file. | 
|  |  | 
|  | @param filename: Name of a migration file. | 
|  |  | 
|  | @return An instantiated Migration object. | 
|  |  | 
|  | """ | 
|  | version = int(filename[:3]) | 
|  | name = filename[:-3] | 
|  | module = __import__(name, globals(), locals(), []) | 
|  | return cls(name, version, module) | 
|  |  | 
|  |  | 
|  | def _check_attributes(self, attributes): | 
|  | method_name, sql_name = attributes | 
|  | assert (hasattr(self.module, method_name) or | 
|  | hasattr(self.module, sql_name)) | 
|  |  | 
|  |  | 
|  | def _execute_migration(self, attributes, manager): | 
|  | method_name, sql_name = attributes | 
|  | method = getattr(self.module, method_name, None) | 
|  | if method: | 
|  | assert callable(method) | 
|  | method(manager) | 
|  | else: | 
|  | sql = getattr(self.module, sql_name) | 
|  | assert isinstance(sql, basestring) | 
|  | manager.execute_script(sql) | 
|  |  | 
|  |  | 
|  | def migrate_up(self, manager): | 
|  | """Performs an up migration (to a newer version). | 
|  |  | 
|  | @param manager: A MigrationManager object. | 
|  |  | 
|  | """ | 
|  | self._execute_migration(self._UP_ATTRIBUTES, manager) | 
|  |  | 
|  |  | 
|  | def migrate_down(self, manager): | 
|  | """Performs a down migration (to an older version). | 
|  |  | 
|  | @param manager: A MigrationManager object. | 
|  |  | 
|  | """ | 
|  | self._execute_migration(self._DOWN_ATTRIBUTES, manager) | 
|  |  | 
|  |  | 
|  | class MigrationManager(object): | 
|  | """Managest database migrations.""" | 
|  | connection = None | 
|  | cursor = None | 
|  | migrations_dir = None | 
|  |  | 
|  | def __init__(self, database_connection, migrations_dir=None, force=False): | 
|  | self._database = database_connection | 
|  | self.force = force | 
|  | # A boolean, this will only be set to True if this migration should be | 
|  | # simulated rather than actually taken. For use with migrations that | 
|  | # may make destructive queries | 
|  | self.simulate = False | 
|  | self._set_migrations_dir(migrations_dir) | 
|  |  | 
|  |  | 
|  | def _set_migrations_dir(self, migrations_dir=None): | 
|  | config_section = self._config_section() | 
|  | if migrations_dir is None: | 
|  | migrations_dir = os.path.abspath( | 
|  | _MIGRATIONS_DIRS.get(config_section, _DEFAULT_MIGRATIONS_DIR)) | 
|  | self.migrations_dir = migrations_dir | 
|  | sys.path.append(migrations_dir) | 
|  | assert os.path.exists(migrations_dir), migrations_dir + " doesn't exist" | 
|  |  | 
|  |  | 
|  | def _config_section(self): | 
|  | return self._database.global_config_section | 
|  |  | 
|  |  | 
|  | def get_db_name(self): | 
|  | """Gets the database name.""" | 
|  | return self._database.get_database_info()['db_name'] | 
|  |  | 
|  |  | 
|  | def execute(self, query, *parameters): | 
|  | """Executes a database query. | 
|  |  | 
|  | @param query: The query to execute. | 
|  | @param parameters: Associated parameters for the query. | 
|  |  | 
|  | @return The result of the query. | 
|  |  | 
|  | """ | 
|  | return self._database.execute(query, parameters) | 
|  |  | 
|  |  | 
|  | def execute_script(self, script): | 
|  | """Executes a set of database queries. | 
|  |  | 
|  | @param script: A string of semicolon-separated queries. | 
|  |  | 
|  | """ | 
|  | sql_statements = [statement.strip() | 
|  | for statement in script.split(';') | 
|  | if statement.strip()] | 
|  | for statement in sql_statements: | 
|  | self.execute(statement) | 
|  |  | 
|  |  | 
|  | def check_migrate_table_exists(self): | 
|  | """Checks whether the migration table exists.""" | 
|  | try: | 
|  | self.execute("SELECT * FROM %s" % MIGRATE_TABLE) | 
|  | return True | 
|  | except self._database.DatabaseError, exc: | 
|  | # we can't check for more specifics due to differences between DB | 
|  | # backends (we can't even check for a subclass of DatabaseError) | 
|  | return False | 
|  |  | 
|  |  | 
|  | def create_migrate_table(self): | 
|  | """Creates the migration table.""" | 
|  | if not self.check_migrate_table_exists(): | 
|  | self.execute("CREATE TABLE %s (`version` integer)" % | 
|  | MIGRATE_TABLE) | 
|  | else: | 
|  | self.execute("DELETE FROM %s" % MIGRATE_TABLE) | 
|  | self.execute("INSERT INTO %s VALUES (0)" % MIGRATE_TABLE) | 
|  | assert self._database.rowcount == 1 | 
|  |  | 
|  |  | 
|  | def set_db_version(self, version): | 
|  | """Sets the database version. | 
|  |  | 
|  | @param version: The version to which to set the database. | 
|  |  | 
|  | """ | 
|  | assert isinstance(version, int) | 
|  | self.execute("UPDATE %s SET version=%%s" % MIGRATE_TABLE, | 
|  | version) | 
|  | assert self._database.rowcount == 1 | 
|  |  | 
|  |  | 
|  | def get_db_version(self): | 
|  | """Gets the database version. | 
|  |  | 
|  | @return The database version. | 
|  |  | 
|  | """ | 
|  | if not self.check_migrate_table_exists(): | 
|  | return 0 | 
|  | rows = self.execute("SELECT * FROM %s" % MIGRATE_TABLE) | 
|  | if len(rows) == 0: | 
|  | return 0 | 
|  | assert len(rows) == 1 and len(rows[0]) == 1 | 
|  | return rows[0][0] | 
|  |  | 
|  |  | 
|  | def get_migrations(self, minimum_version=None, maximum_version=None): | 
|  | """Gets the list of migrations to perform. | 
|  |  | 
|  | @param minimum_version: The minimum database version. | 
|  | @param maximum_version: The maximum database version. | 
|  |  | 
|  | @return A list of Migration objects. | 
|  |  | 
|  | """ | 
|  | migrate_files = [filename for filename | 
|  | in os.listdir(self.migrations_dir) | 
|  | if re.match(r'^\d\d\d_.*\.py$', filename)] | 
|  | migrate_files.sort() | 
|  | migrations = [Migration.from_file(filename) | 
|  | for filename in migrate_files] | 
|  | if minimum_version is not None: | 
|  | migrations = [migration for migration in migrations | 
|  | if migration.version >= minimum_version] | 
|  | if maximum_version is not None: | 
|  | migrations = [migration for migration in migrations | 
|  | if migration.version <= maximum_version] | 
|  | return migrations | 
|  |  | 
|  |  | 
|  | def do_migration(self, migration, migrate_up=True): | 
|  | """Performs a migration. | 
|  |  | 
|  | @param migration: The Migration to perform. | 
|  | @param migrate_up: Whether to migrate up (if not, then migrates down). | 
|  |  | 
|  | """ | 
|  | print 'Applying migration %s' % migration.name, # no newline | 
|  | if migrate_up: | 
|  | print 'up' | 
|  | assert self.get_db_version() == migration.version - 1 | 
|  | migration.migrate_up(self) | 
|  | new_version = migration.version | 
|  | else: | 
|  | print 'down' | 
|  | assert self.get_db_version() == migration.version | 
|  | migration.migrate_down(self) | 
|  | new_version = migration.version - 1 | 
|  | self.set_db_version(new_version) | 
|  |  | 
|  |  | 
|  | def migrate_to_version(self, version): | 
|  | """Performs a migration to a specified version. | 
|  |  | 
|  | @param version: The version to which to migrate the database. | 
|  |  | 
|  | """ | 
|  | current_version = self.get_db_version() | 
|  | if current_version == 0 and self._config_section() == 'AUTOTEST_WEB': | 
|  | self._migrate_from_base() | 
|  | current_version = self.get_db_version() | 
|  |  | 
|  | if current_version < version: | 
|  | lower, upper = current_version, version | 
|  | migrate_up = True | 
|  | else: | 
|  | lower, upper = version, current_version | 
|  | migrate_up = False | 
|  |  | 
|  | migrations = self.get_migrations(lower + 1, upper) | 
|  | if not migrate_up: | 
|  | migrations.reverse() | 
|  | for migration in migrations: | 
|  | self.do_migration(migration, migrate_up) | 
|  |  | 
|  | assert self.get_db_version() == version | 
|  | print 'At version', version | 
|  |  | 
|  |  | 
|  | def _migrate_from_base(self): | 
|  | self.confirm_initialization() | 
|  |  | 
|  | migration_script = utils.read_file( | 
|  | os.path.join(os.path.dirname(__file__), 'schema_051.sql')) | 
|  | migration_script = migration_script % ( | 
|  | dict(username=self._database.get_database_info()['username'])) | 
|  | self.execute_script(migration_script) | 
|  |  | 
|  | self.create_migrate_table() | 
|  | self.set_db_version(51) | 
|  |  | 
|  |  | 
|  | def confirm_initialization(self): | 
|  | """Confirms with the user that we should initialize the database. | 
|  |  | 
|  | @raises Exception, if the user chooses to abort the migration. | 
|  |  | 
|  | """ | 
|  | if not self.force: | 
|  | response = raw_input( | 
|  | 'Your %s database does not appear to be initialized.  Do you ' | 
|  | 'want to recreate it (this will result in loss of any existing ' | 
|  | 'data) (yes/No)? ' % self.get_db_name()) | 
|  | if response != 'yes': | 
|  | raise Exception('User has chosen to abort migration') | 
|  |  | 
|  |  | 
|  | def get_latest_version(self): | 
|  | """Gets the latest database version.""" | 
|  | migrations = self.get_migrations() | 
|  | return migrations[-1].version | 
|  |  | 
|  |  | 
|  | def migrate_to_latest(self): | 
|  | """Migrates the database to the latest version.""" | 
|  | latest_version = self.get_latest_version() | 
|  | self.migrate_to_version(latest_version) | 
|  |  | 
|  |  | 
|  | def initialize_test_db(self): | 
|  | """Initializes a test database.""" | 
|  | db_name = self.get_db_name() | 
|  | test_db_name = 'test_' + db_name | 
|  | # first, connect to no DB so we can create a test DB | 
|  | self._database.connect(db_name='') | 
|  | print 'Creating test DB', test_db_name | 
|  | self.execute('CREATE DATABASE ' + test_db_name) | 
|  | self._database.disconnect() | 
|  | # now connect to the test DB | 
|  | self._database.connect(db_name=test_db_name) | 
|  |  | 
|  |  | 
|  | def remove_test_db(self): | 
|  | """Removes a test database.""" | 
|  | print 'Removing test DB' | 
|  | self.execute('DROP DATABASE ' + self.get_db_name()) | 
|  | # reset connection back to real DB | 
|  | self._database.disconnect() | 
|  | self._database.connect() | 
|  |  | 
|  |  | 
|  | def get_mysql_args(self): | 
|  | """Returns the mysql arguments as a string.""" | 
|  | return ('-u %(username)s -p%(password)s -h %(host)s %(db_name)s' % | 
|  | self._database.get_database_info()) | 
|  |  | 
|  |  | 
|  | def migrate_to_version_or_latest(self, version): | 
|  | """Migrates to either a specified version, or the latest version. | 
|  |  | 
|  | @param version: The version to which to migrate the database, | 
|  | or None in order to migrate to the latest version. | 
|  |  | 
|  | """ | 
|  | if version is None: | 
|  | self.migrate_to_latest() | 
|  | else: | 
|  | self.migrate_to_version(version) | 
|  |  | 
|  |  | 
|  | def do_sync_db(self, version=None): | 
|  | """Migrates the database. | 
|  |  | 
|  | @param version: The version to which to migrate the database. | 
|  |  | 
|  | """ | 
|  | print 'Migration starting for database', self.get_db_name() | 
|  | self.migrate_to_version_or_latest(version) | 
|  | print 'Migration complete' | 
|  |  | 
|  |  | 
|  | def test_sync_db(self, version=None): | 
|  | """Create a fresh database and run all migrations on it. | 
|  |  | 
|  | @param version: The version to which to migrate the database. | 
|  |  | 
|  | """ | 
|  | self.initialize_test_db() | 
|  | try: | 
|  | print 'Starting migration test on DB', self.get_db_name() | 
|  | self.migrate_to_version_or_latest(version) | 
|  | # show schema to the user | 
|  | os.system('mysqldump %s --no-data=true ' | 
|  | '--add-drop-table=false' % | 
|  | self.get_mysql_args()) | 
|  | finally: | 
|  | self.remove_test_db() | 
|  | print 'Test finished successfully' | 
|  |  | 
|  |  | 
|  | def simulate_sync_db(self, version=None): | 
|  | """Creates a fresh DB, copies existing DB to it, then synchronizes it. | 
|  |  | 
|  | @param version: The version to which to migrate the database. | 
|  |  | 
|  | """ | 
|  | db_version = self.get_db_version() | 
|  | # don't do anything if we're already at the latest version | 
|  | if db_version == self.get_latest_version(): | 
|  | print 'Skipping simulation, already at latest version' | 
|  | return | 
|  | # get existing data | 
|  | self.initialize_and_fill_test_db() | 
|  | try: | 
|  | print 'Starting migration test on DB', self.get_db_name() | 
|  | self.migrate_to_version_or_latest(version) | 
|  | finally: | 
|  | self.remove_test_db() | 
|  | print 'Test finished successfully' | 
|  |  | 
|  |  | 
|  | def initialize_and_fill_test_db(self): | 
|  | """Initializes and fills up a test database.""" | 
|  | print 'Dumping existing data' | 
|  | dump_fd, dump_file = tempfile.mkstemp('.migrate_dump') | 
|  | os.system('mysqldump %s >%s' % | 
|  | (self.get_mysql_args(), dump_file)) | 
|  | # fill in test DB | 
|  | self.initialize_test_db() | 
|  | print 'Filling in test DB' | 
|  | os.system('mysql %s <%s' % (self.get_mysql_args(), dump_file)) | 
|  | os.close(dump_fd) | 
|  | os.remove(dump_file) | 
|  |  | 
|  |  | 
|  | USAGE = """\ | 
|  | %s [options] sync|test|simulate|safesync [version] | 
|  | Options: | 
|  | -d --database   Which database to act on | 
|  | -f --force      Don't ask for confirmation | 
|  | --debug         Print all DB queries"""\ | 
|  | % sys.argv[0] | 
|  |  | 
|  |  | 
|  | def main(): | 
|  | """Main function for the migration script.""" | 
|  | parser = OptionParser() | 
|  | parser.add_option("-d", "--database", | 
|  | help="which database to act on", | 
|  | dest="database", | 
|  | default="AUTOTEST_WEB") | 
|  | parser.add_option("-f", "--force", help="don't ask for confirmation", | 
|  | action="store_true") | 
|  | parser.add_option('--debug', help='print all DB queries', | 
|  | action='store_true') | 
|  | (options, args) = parser.parse_args() | 
|  | manager = get_migration_manager(db_name=options.database, | 
|  | debug=options.debug, force=options.force) | 
|  |  | 
|  | if len(args) > 0: | 
|  | if len(args) > 1: | 
|  | version = int(args[1]) | 
|  | else: | 
|  | version = None | 
|  | if args[0] == 'sync': | 
|  | manager.do_sync_db(version) | 
|  | elif args[0] == 'test': | 
|  | manager.simulate=True | 
|  | manager.test_sync_db(version) | 
|  | elif args[0] == 'simulate': | 
|  | manager.simulate=True | 
|  | manager.simulate_sync_db(version) | 
|  | elif args[0] == 'safesync': | 
|  | print 'Simluating migration' | 
|  | manager.simulate=True | 
|  | manager.simulate_sync_db(version) | 
|  | print 'Performing real migration' | 
|  | manager.simulate=False | 
|  | manager.do_sync_db(version) | 
|  | else: | 
|  | print USAGE | 
|  | return | 
|  |  | 
|  | print USAGE | 
|  |  | 
|  |  | 
|  | def get_migration_manager(db_name, debug, force): | 
|  | """Creates a MigrationManager object. | 
|  |  | 
|  | @param db_name: The database name. | 
|  | @param debug: Whether to print debug messages. | 
|  | @param force: Whether to force migration without asking for confirmation. | 
|  |  | 
|  | @return A created MigrationManager object. | 
|  |  | 
|  | """ | 
|  | database = database_connection.DatabaseConnection(db_name) | 
|  | database.debug = debug | 
|  | database.reconnect_enabled = False | 
|  | database.connect() | 
|  | return MigrationManager(database, force=force) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | main() |