# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Copyright (C) 2016-2021 GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake. If not, see <http://www.gnu.org/licenses/>.
import os
import sys
import re
import time
import runpy
import urllib.request
import logging
import importlib
import sqlite3
[docs]class DuplicatedVersion(RuntimeError):
    pass 
[docs]class VersionTooSmall(RuntimeError):
    pass 
[docs]class VersioningNotInstalled(RuntimeError):
    pass 
CREATE_VERSIONING = '''\
CREATE TABLE %s(
version TEXT PRIMARY KEY,
scriptname TEXT NOT NULL,
executed TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);
'''
[docs]class WrappedConnection(object):
    """
    This is an utility class that wraps a DB API-2 connection
    providing a couple of convenient features.
    1) it is possible to set a debug flag to print on stdout
       the executed queries;
    2) there is a .run method to run a query with a dedicated
       cursor; it returns the cursor, which can be iterated over
    :param conn: a DB API2-compatible connection
    """
    def __init__(self, conn, debug=False):
        self._conn = conn
        self.debug = debug
    def __getattr__(self, name):
        return getattr(self._conn, name)
[docs]    def run(self, templ, *args):
        """
        A simple utility to run SQL queries.
        :param templ: a query or query template
        :param args: the arguments (or the empty tuple)
        :returns: the DB API 2 cursor used to run the query
        """
        curs = self._conn.cursor()
        query = curs.mogrify(templ, args)
        if self.debug:
            print(query)
        curs.execute(query)
        return curs  
# not used right now
[docs]def check_script(upgrade, conn, dry_run=True, debug=True):
    """
    An utility to debug upgrade scripts written in Python
    :param upgrade: upgrade procedure
    :param conn: a DB API 2 connection
    :param dry_run: if True, do not change the database
    :param debug: if True, print the queries which are executed
    """
    conn = WrappedConnection(conn, debug=debug)
    try:
        upgrade(conn)
    except Exception:
        conn.rollback()
        raise
    else:
        if dry_run:
            conn.rollback()
        else:
            conn.commit() 
[docs]def apply_sql_script(conn, fname):
    """
    Apply the given SQL script to the database
    :param conn: a DB API 2 connection
    :param fname: full path to the creation script
    """
    sql = open(fname).read()
    try:
        # we cannot use conn.executescript which is non transactional
        for query in sql.split('\n\n'):
            conn.execute(query)
    except Exception:
        logging.error('Error executing %s' % fname)
        raise 
# errors are not trapped on purpose, since transactions should be managed
# in client code
[docs]class UpgradeManager(object):
    r"""
    The package containing the upgrade scripts should contain an instance
    of the UpgradeManager called `upgrader` in the __init__.py file. It
    should also specify the initializations parameters
    :param upgrade_dir:
        the directory were the upgrade script reside
    :param version_table:
        the name of the versioning table (default revision_info)
    :param version_pattern:
        a regulation expression for the script version number (\d\d\d\d)
    """
    ENGINE_URL = 'https://github.com/gem/oq-engine/tree/master/'
    UPGRADES = 'openquake/server/db/schema/upgrades/'
    def __init__(self, upgrade_dir, version_table='revision_info',
                 version_pattern=r'\d\d\d\d', flag_pattern='(-slow|-danger)?'):
        self.upgrade_dir = upgrade_dir
        self.version_table = version_table
        self.version_pattern = version_pattern
        self.flag_pattern = flag_pattern
        self.pattern = r'^(%s)%s-([\w\-_]+)\.(sql|py)$' % (
            version_pattern, flag_pattern)
        self.upgrades_url = self.ENGINE_URL + self.UPGRADES
        if re.match(r'[\w_\.]+', version_table) is None:
            raise ValueError(version_table)
        self.starting_version = None  # will be updated after the run
    def _insert_script(self, script, conn):
        conn.cursor().execute(
            'INSERT INTO {} (version, scriptname) VALUES (?, ?)'.format(
                self.version_table),
            (script['version'], script['name']))
[docs]    def install_versioning(self, conn):
        """
        Create the version table into an already populated database
        and insert the base script.
        :param conn: a DB API 2 connection
        """
        logging.info('Creating the versioning table %s', self.version_table)
        conn.executescript(CREATE_VERSIONING % self.version_table)
        self._insert_script(self.read_scripts()[0], conn) 
[docs]    def init(self, conn):
        """
        Create the version table and run the base script on an empty database.
        :param conn: a DB API 2 connection
        """
        base = self.read_scripts()[0]['fname']
        logging.info('Creating the initial schema from %s',  base)
        apply_sql_script(conn, os.path.join(self.upgrade_dir, base))
        self.install_versioning(conn) 
[docs]    def upgrade(self, conn, skip_versions=()):
        '''
        Upgrade the database from the current version to the maximum
        version in the upgrade scripts.
        :param conn: a DBAPI 2 connection
        :param skip_versions: the versions to skip
        '''
        db_versions = self.get_db_versions(conn)
        self.starting_version = max(db_versions)
        to_skip = sorted(db_versions | set(skip_versions))
        scripts = self.read_scripts(None, None, to_skip)
        if not scripts:  # no new scripts to apply
            return []
        self.ending_version = max(s['version'] for s in scripts)
        return self._upgrade(conn, scripts) 
    def _upgrade(self, conn, scripts):
        conn = WrappedConnection(conn)
        versions_applied = []
        for script in scripts:  # script is a dictionary
            fullname = os.path.join(self.upgrade_dir, script['fname'])
            logging.info('Executing %s', fullname)
            if script['ext'] == 'py':  # Python script with a upgrade(conn)
                globs = runpy.run_path(fullname)
                globs['upgrade'](conn)
                self._insert_script(script, conn)
            else:  # SQL script
                # notice that this prints the file name in case of error
                apply_sql_script(conn, fullname)
                self._insert_script(script, conn)
            versions_applied.append(script['version'])
        return versions_applied
[docs]    def check_versions(self, conn):
        """
        :param conn: a DB API 2 connection
        :returns: a message with the versions that will be applied or None
        """
        scripts = self.read_scripts(skip_versions=self.get_db_versions(conn))
        versions = [s['version'] for s in scripts]
        if versions:
            return ('Your database is not updated. You can update it by '
                    'running oq engine --upgrade-db which will process the '
                    'following new versions: %s' % versions) 
[docs]    def get_db_versions(self, conn):
        """
        Get all the versions stored in the database as a set.
        :param conn: a DB API 2 connection
        """
        curs = conn.cursor()
        query = 'select version from {}'.format(self.version_table)
        try:
            curs.execute(query)
            return set(version for version, in curs.fetchall())
        except:
            raise VersioningNotInstalled('Run oq engine --upgrade-db') 
[docs]    def parse_script_name(self, script_name):
        '''
        Parse a script name and return a dictionary with fields
        fname, name, version and ext (or None if the name does not match).
        :param name: name of the script
        '''
        match = re.match(self.pattern, script_name)
        if not match:
            return
        version, flag, name, ext = match.groups()
        return dict(fname=script_name, version=version, name=name,
                    flag=flag, ext=ext, url=self.upgrades_url + script_name) 
[docs]    def read_scripts(self, minversion=None, maxversion=None, skip_versions=()):
        """
        Extract the upgrade scripts from a directory as a list of
        dictionaries, ordered by version.
        :param minversion: the minimum version to consider
        :param maxversion: the maximum version to consider
        :param skipversions: the versions to skip
        """
        scripts = []
        versions = {}  # a script is unique per version
        for scriptname in sorted(os.listdir(self.upgrade_dir)):
            match = self.parse_script_name(scriptname)
            if match:
                version = match['version']
                if version in skip_versions:
                    continue  # do not collect scripts already applied
                elif minversion and version <= minversion:
                    continue  # do not collect versions too old
                elif maxversion and version > maxversion:
                    continue  # do not collect versions too new
                try:
                    previousname = versions[version]
                except KeyError:  # no previous script with the same version
                    scripts.append(match)
                    versions[version] = scriptname
                else:
                    raise DuplicatedVersion(
                        'Duplicated versions {%s,%s}' %
                        (scriptname, previousname))
        return scripts 
[docs]    @classmethod
    def instance(cls, conn, pkg_name='openquake.server.db.schema.upgrades'):
        """
        Return an :class:`UpgradeManager` instance.
        :param conn: a DB API 2 connection
        :param str pkg_name: the name of the package with the upgrade scripts
        """
        try:
            # upgrader is an UpgradeManager instance defined in the __init__.py
            upgrader = importlib.import_module(pkg_name).upgrader
        except ImportError:
            raise SystemExit(
                'Could not import %s (not in the PYTHONPATH?)' % pkg_name)
        if not upgrader.read_scripts():
            raise SystemExit(
                'The upgrade_dir does not contain scripts matching '
                'the pattern %s' % upgrader.pattern)
        curs = conn.cursor()
        # check if there is already a versioning table
        curs.execute("SELECT name FROM sqlite_master "
                     "WHERE name=%r" % upgrader.version_table)
        versioning_table = curs.fetchall()
        # if not, run the base script and create the versioning table
        if not versioning_table:
            upgrader.init(conn)
            conn.commit()
        return upgrader  
[docs]def upgrade_db(conn, pkg_name='openquake.server.db.schema.upgrades',
               skip_versions=()):
    """
    Upgrade a database by running several scripts in a single transaction.
    :param conn: a DB API 2 connection
    :param str pkg_name: the name of the package with the upgrade scripts
    :param list skip_versions: the versions to skip
    :returns: the version numbers of the new scripts applied the database
    """
    upgrader = UpgradeManager.instance(conn, pkg_name)
    t0 = time.time()
    # run the upgrade scripts
    try:
        versions_applied = upgrader.upgrade(conn, skip_versions)
    except:
        conn.rollback()
        raise
    else:
        conn.commit()
    dt = time.time() - t0
    logging.info('Upgrade completed in %s seconds', dt)
    return versions_applied 
[docs]def db_version(conn, pkg_name='openquake.server.db.schema.upgrades'):
    """
    :param conn: a DB API 2 connection
    :param str pkg_name: the name of the package with the upgrade scripts
    :returns: the current version of the database
    """
    upgrader = UpgradeManager.instance(conn, pkg_name)
    return max(upgrader.get_db_versions(conn)) 
[docs]def what_if_I_upgrade(conn, pkg_name='openquake.server.db.schema.upgrades',
                      extract_scripts='extract_upgrade_scripts'):
    """
    :param conn:
        a DB API 2 connection
    :param str pkg_name:
        the name of the package with the upgrade scripts
    :param extract_scripts:
        name of the method to extract the scripts
    """
    msg_safe_ = '\nThe following scripts can be applied safely:\n%s'
    msg_slow_ = '\nPlease note that the following scripts could be slow:\n%s'
    msg_danger_ = ('\nPlease note that the following scripts are potentially '
                   'dangerous and could destroy your data:\n%s')
    upgrader = UpgradeManager.instance(conn, pkg_name)
    applied_versions = upgrader.get_db_versions(conn)
    current_version = max(applied_versions)
    slow = []
    danger = []
    safe = []
    for script in getattr(upgrader, extract_scripts)():
        url = script['url']
        if script['version'] in applied_versions:
            continue
        elif script['version'] <= current_version:
            # you cannot apply a script with a version number lower than the
            # current db version: ensure that upgrades are strictly incremental
            raise VersionTooSmall(
                'Your database is at version %s but you want to apply %s??'
                % (current_version, script['fname']))
        elif script['flag'] == '-slow':
            slow.append(url)
        elif script['flag'] == '-danger':
            danger.append(url)
        else:  # safe script
            safe.append(url)
    if not safe and not slow and not danger:
        return 'Your database is already updated at version %s.' % \
            
current_version
    header = 'Your database is at version %s.' % current_version
    msg_safe = msg_safe_ % '\n'.join(safe)
    msg_slow = msg_slow_ % '\n'.join(slow)
    msg_danger = msg_danger_ % '\n'.join(danger)
    msg = header + (msg_safe if safe else '') + (msg_slow if slow else '') \
        
+ (msg_danger if danger else '')
    msg += ('\nClick on the links if you want to know what exactly the '
            'scripts are doing.')
    if slow:
        msg += ('\nEven slow script can be fast if your database is small or'
                ' the upgrade affects tables that are empty.')
    if danger:
        msg += ('\nEven dangerous scripts are fine if they '
                'affect empty tables or data you are not interested in.')
    return msg 
if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    path = sys.argv[1]
    dirname = os.path.dirname(path)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    conn = sqlite3.connect(path)
    upgrade_db(conn)