# Copyright 2012 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
# Adapted from portage/getbinpkg.py -- Portage binary-package helper functions
# Copyright 2003-2004 Gentoo Foundation
# Distributed under the terms of the GNU General Public License v2

"""Helpers for binpkg Package index files and managing binhosts."""

import collections
import io
import logging
import math
import operator
import os
import tempfile
import time
from typing import Any, Dict, List, Optional
import urllib.error
import urllib.request

from chromite.cbuildbot import cbuildbot_alerts
from chromite.lib import build_target_lib
from chromite.lib import cros_build_lib
from chromite.lib import gerrit
from chromite.lib import git
from chromite.lib import gs
from chromite.lib import osutils
from chromite.lib import parallel
from chromite.lib import sysroot_lib
from chromite.utils import key_value_store


TWO_WEEKS = 60 * 60 * 24 * 7 * 2
HTTP_FORBIDDEN_CODES = (401, 403)
HTTP_NOT_FOUND_CODES = (404, 410)

_Package = collections.namedtuple("_Package", ["mtime", "uri", "debug_symbols"])


class PackageIndex(object):
    """A parser for the Portage Packages index file.

    The Portage Packages index file serves to keep track of what packages are
    included in a tree. It contains the following sections:
        1) The header. The header tracks general key/value pairs that don't
            apply to any specific package. E.g., it tracks the base URL of the
            packages file, and the number of packages included in the file. The
            header is terminated by a blank line.
        2) The body. The body is a list of packages. Each package contains a
            list of key/value pairs. Packages are either terminated by a blank
            line or by the end of the file. Every package has a CPV entry, which
            serves as a unique identifier for the package.
    """

    def __init__(self):
        """Constructor."""

        # The header tracks general key/value pairs that don't apply to any
        # specific package. E.g., it tracks the base URL of the packages.
        self.header = {}

        # A list of packages (stored as a list of dictionaries).
        self.packages = []

        # Whether the PackageIndex has been modified since the last time it
        # was written.
        self.modified = False

    def _PopulateDuplicateDB(self, db, expires):
        """Populate db with SHA1 -> URL mapping for packages.

        Args:
            db: Dictionary to populate with SHA1 -> URL mapping for packages.
            expires: The time at which prebuilts expire from the binhost.
        """

        uri = gs.CanonicalizeURL(self.header["URI"])
        for pkg in self.packages:
            cpv, sha1, mtime = pkg["CPV"], pkg.get("SHA1"), pkg.get("MTIME")
            oldpkg = db.get(sha1, _Package(0, None, False))
            if sha1 and mtime and int(mtime) > max(expires, oldpkg.mtime):
                path = pkg.get("PATH", cpv + ".tbz2")
                db[sha1] = _Package(
                    int(mtime),
                    "%s/%s" % (uri.rstrip("/"), path),
                    pkg.get("DEBUG_SYMBOLS") == "yes",
                )

    def _ReadPkgIndex(self, pkgfile):
        """Read a list of key/value pairs from the Packages file into a dict.

        Both header entries and package entries are lists of key/value pairs, so
        they can both be read by this function. Entries can be terminated by
        empty lines or by the end of the file.

        This function will read lines from the specified file until it
        encounters a blank line or the end of the file.

        Keys and values in the Packages file are separated by a colon and a
        space. Keys may contain capital letters, numbers, and underscores, but
        may not contain colons. Values may contain any character except a
        newline. In particular, it is normal for values to contain colons.

        Lines that have content, and do not contain a valid key/value pair, are
        ignored. This is for compatibility with the Portage package parser, and
        to allow for future extensions to the Packages file format.

        All entries must contain at least one key/value pair. If the end of the
        file is reached, an empty dictionary is returned.

        Args:
            pkgfile: A python file object.

        Returns:
            The dictionary of key-value pairs that was read from the file.
        """
        d = {}
        for line in pkgfile:
            line = line.rstrip("\n")
            if not line:
                assert (
                    d
                ), "Packages entry must contain at least one key/value pair"
                break
            line = line.split(": ", 1)
            if len(line) == 2:
                k, v = line
                d[k] = v
        return d

    def _FormatPkgIndex(self, entry: dict) -> List[str]:
        """Formats lines for _WritePkgIndex.

        Args:
            entry: The key/value pairs to write.
        """
        lines = ["%s: %s" % (k, v) for k, v in sorted(entry.items()) if v]
        if lines:
            # Temporary means of simplifying formatting and ensuring a blank
            # line after each entry. This whole system needs to be cleaned up.
            lines.extend([""])

        return lines

    def _WritePkgIndex(self, pkgfile, entry):
        """Write header entry or package entry to packages file.

        The keys and values will be separated by a colon and a space. The entry
        will be terminated by a blank line.

        Args:
            pkgfile: A python file object.
            entry: A dictionary of the key/value pairs to write.
        """
        lines = self._FormatPkgIndex(entry)
        pkgfile.write("%s\n" % "\n".join(lines))

    def _ReadHeader(self, pkgfile):
        """Read header of packages file.

        Args:
            pkgfile: A python file object.
        """
        assert not self.header, "Should only read header once."
        self.header = self._ReadPkgIndex(pkgfile)

    def _ReadBody(self, pkgfile):
        """Read body of packages file.

        Before calling this function, you must first read the header (using
        _ReadHeader).

        Args:
            pkgfile: A python file object.
        """
        assert self.header, "Should read header first."
        assert not self.packages, "Should only read body once."

        # Read all of the sections in the body by looping until we reach the end
        # of the file.
        while True:
            d = self._ReadPkgIndex(pkgfile)
            if not d:
                break
            if "CPV" in d:
                self.packages.append(d)

    def Read(self, pkgfile):
        """Read the entire packages file.

        Args:
            pkgfile: A python file object.
        """
        self._ReadHeader(pkgfile)
        self._ReadBody(pkgfile)

    def ReadFilePath(self, pkgfile_path: str):
        """Read the packages file path.

        Args:
            pkgfile_path: The path to the file.
        """
        with open(pkgfile_path, encoding="utf-8") as f:
            self.Read(f)

    def RemoveFilteredPackages(self, filter_fn):
        """Remove packages which match filter_fn.

        Args:
            filter_fn: A function which operates on packages. If it returns
                True, the package should be removed.
        """

        filtered = [p for p in self.packages if not filter_fn(p)]
        if filtered != self.packages:
            self.modified = True
            self.packages = filtered

    def ResolveDuplicateUploads(self, pkgindexes):
        """Point packages at files that have already been uploaded.

        For each package in our index, check if there is an existing package
        that has already been uploaded to the same base URI, and that is no
        older than two weeks. If so, point that package at the existing file, so
        that we don't have to upload the file.

        Args:
            pkgindexes: A list of PackageIndex objects containing info about
                packages that have already been uploaded.

        Returns:
            A list of the packages that still need to be uploaded.
        """
        db = {}
        now = int(time.time())
        expires = now - TWO_WEEKS
        base_uri = gs.CanonicalizeURL(self.header["URI"])
        for pkgindex in pkgindexes:
            if gs.CanonicalizeURL(pkgindex.header["URI"]) == base_uri:
                # pylint: disable=protected-access
                pkgindex._PopulateDuplicateDB(db, expires)

        uploads = []
        base_uri = self.header["URI"]
        for pkg in self.packages:
            sha1 = pkg.get("SHA1")
            dup = db.get(sha1)

            # If the debug symbols are available locally but are not available
            # in the remote binhost, re-upload them.
            # Note: this should never happen as we would have pulled the debug
            # symbols from said binhost.
            if (
                sha1
                and dup
                and dup.uri.startswith(base_uri)
                and (pkg.get("DEBUG_SYMBOLS") != "yes" or dup.debug_symbols)
            ):
                pkg["PATH"] = dup.uri[len(base_uri) :].lstrip("/")
                pkg["MTIME"] = str(dup.mtime)

                if dup.debug_symbols:
                    pkg["DEBUG_SYMBOLS"] = "yes"
            else:
                pkg["MTIME"] = str(now)
                uploads.append(pkg)
        return uploads

    def SetUploadLocation(self, base_uri, path_prefix):
        """Set upload location to base_uri + path_prefix.

        Args:
            base_uri: Base URI for all packages in the file. We set
                self.header['URI'] to this value, so all packages must live
                under this directory.
            path_prefix: Path prefix to use for all current packages in the
                file. This will be added to the beginning of the path for every
                package.
        """
        self.header["URI"] = base_uri.rstrip("/")
        for pkg in self.packages:
            path = pkg["CPV"] + ".tbz2"
            pkg["PATH"] = "%s/%s" % (path_prefix.rstrip("/"), path)

    def Write(self, pkgfile):
        """Write a packages file to disk.

        If 'modified' flag is set, the TIMESTAMP and PACKAGES fields in the
        header will be updated before writing to disk.

        Args:
            pkgfile: A python file object.
        """
        self._ModifiedHeaderUpdate()
        self._WritePkgIndex(pkgfile, self.header)
        for metadata in sorted(self.packages, key=operator.itemgetter("CPV")):
            self._WritePkgIndex(pkgfile, metadata)

    def WriteToNamedTemporaryFile(self):
        """Write pkgindex to a temporary file.

        Returns:
            A temporary file containing the packages from pkgindex.
        """
        # pylint: disable=R1732
        # This method returns an open file, so we cannot use a 'with' here
        # (without changing the behavior and breaking the unittest).
        f = tempfile.NamedTemporaryFile(
            prefix="chromite.binpkg.pkgidx.", mode="w+"
        )
        self.Write(f)
        f.flush()
        f.seek(0)
        return f

    def WriteFile(self, file_path, sudo=False):
        """Like Write, but takes a file path."""
        self._ModifiedHeaderUpdate()
        lines = self._FormatPkgIndex(self.header)
        for metadata in sorted(self.packages, key=operator.itemgetter("CPV")):
            lines.extend(self._FormatPkgIndex(metadata))

        # Adding trailing \n to force this method to produce the same output as
        # the other write method. This is unnecessary and can be removed when
        # this is refactored and simplified.
        osutils.WriteFile(file_path, "%s\n" % "\n".join(lines), sudo=sudo)

    def _ModifiedHeaderUpdate(self):
        if self.modified:
            self.header["TIMESTAMP"] = str(math.trunc(time.time()))
            self.header["PACKAGES"] = str(len(self.packages))
            self.modified = False


class PackageIndexInfo(object):
    """A parser for PackageIndex metadata.

    Attributes:
        snapshot_sha (str): The git SHA of the manifest snapshot.
        snapshot_number (int): The snapshot number.
        build_target (build_target_lib.BuildTarget): The build_target.
        profile (Profile): The build_target.
        location (str): The GS path for the prebuilts directory.
    """

    def __init__(
        self,
        snapshot_sha="",
        snapshot_number=0,
        build_target=None,
        profile=None,
        location="",
    ):
        self.snapshot_sha = snapshot_sha
        self.snapshot_number = snapshot_number
        self.build_target = build_target or build_target_lib.BuildTarget(
            name=""
        )
        self.profile = profile or sysroot_lib.Profile()
        self.location = location

    def __eq__(self, other):
        """Check equality."""
        # BuildTarget is in the process of dropping Profile and root (which
        # properly belong to the Sysroot, not the BuildTarget. As such, they are
        # handled separately here.
        return (
            self.snapshot_sha == other.snapshot_sha
            and self.snapshot_number == other.snapshot_number
            and self.build_target.name == other.build_target.name
            and self.profile == other.profile
            and self.location == other.location
        )


def _RetryUrlOpen(url, tries=3):
    """Open the specified url, retrying if we run into temporary errors.

    We retry for both network errors and 5xx Server Errors. We do not retry
    for HTTP errors with a non-5xx code.

    Args:
        url: The specified url.
        tries: The number of times to try.

    Returns:
        The result of urllib.request.urlopen(url).
    """
    for i in range(tries):
        try:
            with urllib.request.urlopen(url) as opened_url:
                return opened_url
        except urllib.error.HTTPError as e:
            if i + 1 >= tries or e.code < 500:
                e.msg += "\nwhile processing %s" % url
                raise
            else:
                print("Cannot GET %s: %s" % (url, str(e)))
        except urllib.error.URLError as e:
            if i + 1 >= tries:
                raise
            else:
                print("Cannot GET %s: %s" % (url, str(e)))
        print("Sleeping for 10 seconds before retrying...")
        time.sleep(10)


def GrabRemotePackageIndex(binhost_url, **kwargs):
    """Grab the latest binary package database from the specified URL.

    Args:
        binhost_url: Base URL of remote packages (PORTAGE_BINHOST).
        **kwargs: Additional RunCommand parameters.

    Returns:
        A PackageIndex object, if the Packages file can be retrieved. If the
        packages file cannot be retrieved, then None is returned.
    """
    url = "%s/Packages" % binhost_url.rstrip("/")
    pkgindex = PackageIndex()
    if binhost_url.startswith("http"):
        try:
            f = _RetryUrlOpen(url)
        except urllib.error.HTTPError as e:
            if e.code in HTTP_FORBIDDEN_CODES:
                cbuildbot_alerts.PrintBuildbotStepWarnings()
                logging.error("Cannot GET %s: %s", url, e)
                return None
            # Not found errors are normal if old prebuilts were cleaned out.
            if e.code in HTTP_NOT_FOUND_CODES:
                return None
            raise
    elif binhost_url.startswith("gs://"):
        try:
            gs_context = gs.GSContext()
            output = gs_context.Cat(url, encoding="utf-8", **kwargs)
        except (cros_build_lib.RunCommandError, gs.GSNoSuchKey) as e:
            cbuildbot_alerts.PrintBuildbotStepWarnings()
            logging.error("Cannot GET %s: %s", url, e)
            return None
        f = io.StringIO(output)
    else:
        return None
    pkgindex.Read(f)
    pkgindex.header.setdefault("URI", binhost_url)
    f.close()
    return pkgindex


def GrabLocalPackageIndex(package_path):
    """Read a local packages file from disk into a PackageIndex() object.

    Args:
        package_path: Directory containing Packages file.

    Returns:
        A PackageIndex object.
    """
    with open(os.path.join(package_path, "Packages"), encoding="utf-8") as f:
        pkgindex = PackageIndex()
        pkgindex.Read(f)

    # List all debug symbols available in package_path.
    symbols = set()
    for root, _, files in os.walk(package_path):
        for f in files:
            if f.endswith(".debug.tbz2"):
                full_path = os.path.join(root, f)
                f = os.path.relpath(full_path, package_path)[
                    : -len(".debug.tbz2")
                ]
                symbols.add(f)

    for p in pkgindex.packages:
        # If the Packages file has DEBUG_SYMBOLS set but no debug symbols are
        # found, unset it.
        p.pop("DEBUG_SYMBOLS", None)
        if p["CPV"] in symbols:
            p["DEBUG_SYMBOLS"] = "yes"

    return pkgindex


def _DownloadURLs(urls, dest_dir):
    """Copy URLs into the specified |dest_dir|.

    Args:
        urls: List of URLs to fetch.
        dest_dir: Destination directory.
    """
    gs_ctx = gs.GSContext()
    cmd = ["cp"] + urls + [dest_dir]
    gs_ctx.DoCommand(cmd, parallel=len(urls) > 1)


def FetchTarballs(binhost_urls, pkgdir):
    """Prefetch the specified |binhost_urls| to the specified |pkgdir|.

    This function fetches the tarballs from the specified list of binhost
    URLs to disk. It does not populate the Packages file -- we leave that
    to Portage.

    Args:
        binhost_urls: List of binhost URLs to fetch.
        pkgdir: Location to store the fetched packages.
    """
    categories = {}
    for binhost_url in binhost_urls:
        pkgindex = GrabRemotePackageIndex(binhost_url)
        base_uri = pkgindex.header["URI"]
        for pkg in pkgindex.packages:
            cpv = pkg["CPV"]
            path = pkg.get("PATH", "%s.tbz2" % cpv)
            uri = "/".join([base_uri, path])
            category = cpv.partition("/")[0]
            fetches = categories.setdefault(category, {})
            fetches[cpv] = uri

    with parallel.BackgroundTaskRunner(_DownloadURLs) as queue:
        for category, urls in categories.items():
            category_dir = os.path.join(pkgdir, category)
            if not os.path.exists(category_dir):
                os.makedirs(category_dir)
            queue.put((urls.values(), category_dir))


def UpdateAndSubmitKeyValueFile(
    filename: str,
    data: Dict[str, str],
    report: Optional[Dict[str, Any]] = None,
    dryrun: bool = False,
) -> None:
    """Update a key/value file, commit it, and submit the change.

    Args:
      filename: file to modify that is in a git repo already
      data: A dict of key/values to update in |filename|
      report: Dict in which to collect information to report to the user.
      dryrun: If True, do not actually commit the change.
    """
    if report is None:
        report = {}
    prebuilt_branch = "prebuilt_branch"
    cwd = os.path.abspath(os.path.dirname(filename))
    remote_name = git.RunGit(cwd, ["remote"]).stdout.strip()
    gerrit_helper = gerrit.GetGerritHelper(remote_name)
    remote_url = git.RunGit(
        cwd, ["config", "--get", f"remote.{remote_name}.url"]
    ).stdout.strip()
    description = "%s: updating %s" % (
        os.path.basename(filename),
        ", ".join(data.keys()),
    )
    # UpdateKeyInLocalFile will print out the keys/values for us.
    print("Revving git file %s" % filename)
    git.CreatePushBranch(prebuilt_branch, cwd)
    for key, value in data.items():
        key_value_store.UpdateKeyInLocalFile(filename, key, value)
    git.RunGit(cwd, ["add", filename])
    git.RunGit(cwd, ["commit", "-m", description])

    tracking_info = git.GetTrackingBranch(
        cwd, prebuilt_branch, for_push=True, for_checkout=False
    )
    gpatch = gerrit_helper.CreateGerritPatch(
        cwd, remote_url, ref=tracking_info.ref, notify="NONE"
    )
    report.setdefault("created_cls", []).append(gpatch.PatchLink())
    gerrit_helper.SetReview(
        gpatch, labels={"Bot-Commit": 1}, dryrun=dryrun, notify="NONE"
    )
    gerrit_helper.SubmitChange(gpatch, dryrun=dryrun, notify="NONE")
