# Copyright 2017 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Virtualenv management"""

from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import collections
import distutils.util
import errno
import functools
import hashlib
import itertools
import json
import os
import platform
import re
import shutil
import subprocess
import sys
import sysconfig
import tempfile
import warnings

from cros_venv import constants
from cros_venv import flock

_PACKAGE_DIR = os.path.join(constants.REPO_DIR, 'pip_packages')
_VIRTUALENV_COMMAND = 'virtualenv'

# BASE_DEPENDENCIES are pip requirements automatically included in every
# virtualenv so we have control over which versions of bootstrap
# packages are installed.
_BASE_DEPENDENCIES = ('setuptools==44.0.0', 'pip==20.0.2')


class Error(Exception):
    """Base exception for all errors in this module."""


class _VenvPaths(object):

    """Wrapper defining paths inside a versioned virtualenv."""

    def __init__(self, venvdir):
        """Initialize instance.

        venvdir is the absolute path to a virtualenv.
        """
        self.venvdir = venvdir

    def __repr__(self):
        return '{cls}({this.venvdir!r})'.format(
            cls=type(self).__name__,
            this=self,
        )

    @property
    def python(self):
        """Path to the virtualenv's Python binary."""
        return os.path.join(self.venvdir, 'bin', 'python')

    @property
    def lockfile(self):
        """Path to lock file for changing virtualenv."""
        return os.path.join(self.venvdir, 'change.lock')

    @property
    def logfile(self):
        """Path to log file for creating virtualenv."""
        return os.path.join(self.venvdir, 'create.log')

    @property
    def spec(self):
        """Path to spec file inside virtualenv directory."""
        return os.path.join(self.venvdir, 'spec.json')


class VersionedVenv(object):

    """Versioned virtualenv, specified by a _VenvSpec.

    This class provides a method for ensuring the versioned virtualenv
    is created.
    """

    def __init__(self, spec):
        """Initialize instance.

        spec is a _VenvSpec.
        """
        self._spec = spec
        self._paths = _VenvPaths(_get_venvdir(spec))

    def __repr__(self):
        return '{cls}({this._spec!r})'.format(
            cls=type(self).__name__,
            this=self,
        )

    def ensure(self):
        """Ensure that the virtualenv exists."""
        _makedirs_exist_ok(self._paths.venvdir)
        with flock.FileLock(self._paths.lockfile):
            self._check_or_create()
        return self._paths.venvdir

    @property
    def logfile(self):
        """Show path to internal log file."""
        return self._paths.logfile

    @property
    def logdata(self):
        """Get any internal logged data."""
        try:
            with open(self._paths.logfile) as fp:
                return fp.read()
        except IOError as e:
            if e.errno == errno.ENOENT:
                return '<log does not exist>'
            raise

    def _log_env(self, logfile):
        """Log details about the active runtime for debugging."""
        def get_var(var):
            return sysconfig.get_config_var(var)
        logfile.writelines([
            'venv spec: %s\n' % (self._spec,),
            'Distutils platform tag: %s\n' % (distutils.util.get_platform(),),
            'Python implementation: %s\n' % (platform.python_implementation(),),
        ])
        for var in ('py_version_nodot', 'SOABI', 'Py_DEBUG', 'WITH_PYMALLOC',
                    'Py_UNICODE_SIZE'):
            logfile.write('sysconfig %s: %s\n' % (var, get_var(var)))

    def _check_or_create(self):
        """Check virtualenv, creating it if it is not created."""
        try:
            existing_spec = self._load_spec()
        except IOError:
            self._create()
        else:
            self._check(existing_spec)

    def _create(self):
        """Create virtualenv."""
        with open(self._paths.logfile, 'w') as logfile, \
             _make_reqs_file(self._spec) as reqs_file:
            self._log_env(logfile)
            _create_venv(venvdir=self._paths.venvdir,
                         logfile=logfile)
            _install_reqs_file(python_path=self._paths.python,
                               reqs_path=reqs_file.name,
                               logfile=logfile)
        self._dump_spec()

    def _check(self, spec):
        """Check if the given spec matches our spec.

        Raise SpecMismatchError if check fails.
        """
        if spec != self._spec:
            raise SpecMismatchError

    def _dump_spec(self):
        """Save the _VenvSpec to the virtualenv on disk."""
        with open(self._paths.spec, 'w') as f:
            return _dump_spec(self._spec, f)

    def _load_spec(self):
        """Return the _VenvSpec for the virtualenv on disk."""
        with open(self._paths.spec, 'r') as f:
            return _load_spec(f)


class SpecMismatchError(Error):
    """Versioned virtualenv specs do not match."""


_VenvSpec = collections.namedtuple('_VenvSpec', 'py_version,reqs')


def make_spec(f):
    """Make _VenvSpec from a requirements file object."""
    return _VenvSpec(_get_python_version(), f.read())


def _get_reqs_hash(spec):
    """Return hash string for _VenvSpec requirements.

    Make sure to check for collisions.
    """
    hasher = hashlib.md5()
    hasher.update(spec.reqs.encode('utf-8'))
    return hasher.hexdigest()


def _get_venvdir(spec):
    """Return the virtualenv directory to use for the _VenvSpec.

    Returns absolute path.
    """
    cache_dir = _get_cache_dir()
    return os.path.join(
        cache_dir, 'venv-%s-%s' % (spec.py_version, _get_reqs_hash(spec)))


def _dump_spec(spec, f):
    """Dump _VenvSpec to a file."""
    json.dump(spec, f)


def _load_spec(f):
    """Load _VenvSpec from a file."""
    return _VenvSpec._make(json.load(f))


def _make_reqs_file(spec):
    """Return a temporary reqs file for the virtualenv spec.

    The return value is a tempfile.NamedTemporaryFile, which cleans
    up on close.  The filename is accessible via the name attribute.
    """
    f = tempfile.NamedTemporaryFile('w')
    f.writelines(req + '\n' for req in _BASE_DEPENDENCIES)
    f.write(spec.reqs)
    f.flush()
    return f


def _virtualenv_version():
    """Return a version tuple for virtualenv."""
    # Versions before 20.0.0 returned just the version.
    # The 20.0.0 series switched to:
    # virtualenv 20.0.13 from .../virtualenv/__init__.py
    try:
      result = subprocess.check_output([_VIRTUALENV_COMMAND, '--version'],
                                       stderr=subprocess.STDOUT)
    except OSError as e:
      if e.errno == errno.ENOENT:
          raise VirtualenvMissingError(e)
      else:
          raise

    # This strips off "rc" and related tags, but we don't care for our use.
    m = re.match(br'^(virtualenv *)?([0-9.]+)', result)
    assert m, 'Could not find version in "%s"' % (result,)
    return tuple(int(x) for x in m.group(2).split(b'.'))


def _create_venv(venvdir, logfile):
    """Create a virtualenv at the given path."""
    ver = _virtualenv_version()
    MIN_VER = (1, 10)
    assert ver >= MIN_VER, ('virtualenv %s is required, but %s was found' %
                            (ver, MIN_VER))

    env = os.environ.copy()
    if ver < (20, 0):
        # TODO(crbug.com/808434): upstream virtualenv has a bug when this is
        # set.  See also https://github.com/pypa/virtualenv/issues/565
        env.pop('VIRTUALENV_ALWAYS_COPY', None)

    command = [_VIRTUALENV_COMMAND, venvdir, '-p', sys.executable,
               '--extra-search-dir', _PACKAGE_DIR, '--clear',
               '--never-download', '-vvvvv']
    try:
        _log_check_call(command, logfile=logfile, env=env)
    except OSError as e:
        raise VirtualenvMissingError(e)


class VirtualenvMissingError(Error):
    """Virtualenv is missing."""

    def __init__(self, cause):
        """Initialize instance.

        cause is an object that describes the underlying cause of the
        error.  This is usually an exception, but can be any object.
        The object's string representation is used in this exception's
        string representation.
        """
        super(VirtualenvMissingError, self).__init__(cause)
        self.cause = cause

    def __str__(self):
        return 'virtualenv is not installed (caused by %s)' % (self.cause)

    def __repr__(self):
        return '%s(%r)' % (type(self).__name__, self.cause)


def _install_reqs_file(python_path, reqs_path, logfile):
    """Install reqs file using pip."""
    command = [python_path, '-m', 'pip', 'install', '-vvv',
               '--no-index', '-f', 'file://' + _PACKAGE_DIR, '-r', reqs_path]
    _log_check_call(command, logfile=logfile)


def _add_call_logging(call_func):
    """Wrap a subprocess-style call with logging."""
    @functools.wraps(call_func)
    def wrapped_command(args, logfile, **kwargs):
        """Logging-wrapped call.

        Arguments are similar to subprocess.Popen, depending on the
        underlying call.  There is an extra keyword-only parameter
        logfile, which takes a file object.
        """
        logfile.write('Running %r\n' % (args,))
        logfile.flush()
        call_func(args, stdout=logfile, **kwargs)
    return wrapped_command


_log_check_call = _add_call_logging(subprocess.check_call)


def _get_python_version():
    """Return the version string for the current Python."""
    return '.'.join(str(part) for part in sys.version_info[:3])


def _get_cache_dir():
    """Get cache dir to use for cros_venv.

    Returns absolute path.
    """
    return os.environ.get('CROS_VENV_CACHE',
                          os.path.expanduser('~/.cache/cros_venv'))


def _makedirs_exist_ok(path):
    """Make directories recursively, ignoring if directory already exists."""
    try:
        os.makedirs(path)
    except OSError:
        if not os.path.isdir(path):  # pragma: no cover
            raise
