#!/usr/bin/python3
# debci job scheduler triggered by repository changes
#
# Monitors an APT repository for package changes, and schedules autopkgtests.
#
# Author: Christian Kastner <ckk@kvr.at>
# License: MIT


from collections import defaultdict
import dataclasses
import email.parser
import fcntl
import itertools
import logging
import os
import re
import sqlite3
import subprocess
import sys
import tempfile
from typing import Optional

import apt
import apt_pkg
import yaml


class ConflictedStateError(Exception):
    """Raised when APT files are in a conflicted state.

    For example, when a source package in Sources lists a number of binaries,
    but Binaries does not (yet) contain all these binaries.

    This has been frequently observed in experimental.
    """
    def __init__(self, message):
        super().__init__(message)


@dataclasses.dataclass(frozen=True)
class SourcePackage:
    """Debian source package state."""
    name: str
    dist: str
    arch: str
    version: Optional[str] = None


@dataclasses.dataclass(frozen=True)
class TestParams:
    """Additional debci parameters for a test."""
    # TODO: hard-coded until callers are cleaned up
    dist: str = "unstable"
    requestor: str = "rocm-debci"
    backend: str = "qemu+rocm"
    trigger: Optional[str] = None
    extra_apt_sources: tuple[str] = dataclasses.field(default_factory=tuple)
    pin_packages: tuple[tuple[str, str], ...] = dataclasses.field(default_factory=tuple)


@dataclasses.dataclass
class ScheduleParams:
    """Maps a deb822-like stanza with a schedule for a package."""
    source: str
    dists: list[str]
    limit_to_bins: list[str]
    also_trigger_on: list[str]
    skip_trigger_on: list[str]


class VersionDB:
    """Mini-ORM for a single table of package versions.

    This program is just a prototype, so there's no point in expending a lot of
    energy into doing this cleaner yet.

    Parameters
    ----------
    dbfile : str
        Path to the backing sqlite3 database file.
    """
    def __init__(self, dbfile: str = '/var/lib/debci/data/debci.sqlite3') -> None:
        self.dbfile = dbfile
        self.conn = sqlite3.connect(dbfile)
        self.conn.row_factory = sqlite3.Row
        csr = self.conn.cursor()
        csr.execute("""CREATE TABLE IF NOT EXISTS versions(
                            dist       text NOT NULL,
                            package    text NOT NULL,
                            arch       text NOT NULL,
                            version    text NOT NULL,
                            updated_at datetime DEFAULT CURRENT_TIMESTAMP,
                            PRIMARY KEY(dist, package, arch)
                    )""")
        self.conn.commit()

    def is_newer(self, package: SourcePackage) -> bool:
        """Checks if a package version is newer than the last seen one."""
        csr = self.conn.cursor()
        qparams = (package.dist, package.name, package.arch)
        csr.execute("SELECT * FROM versions "
                    "WHERE dist=? and package=? and arch=?", qparams)
        row = csr.fetchone()
        if not row or apt_pkg.version_compare(package.version, row['version']) > 0:
            return True
        return False

    @staticmethod
    def _normalize_pins(pin_packages: list[tuple[str, str]]) -> dict[str, str]:
        """Return package pins as sets, by distribution."""
        dists = {}
        for pins, dist in pin_packages:
            dists[dist] = set(pins.split(","))
        return dists

    def is_queued(self, package: SourcePackage, params: TestParams) -> bool:
        """Checks if a particular test has already been queued.

        This deliberately does not use the version, as queued jobs don't have
        one. Whatever is in the distribution at test time, will be tested.
        """
        csr = self.conn.cursor()
        qparams = (params.dist, package.name, package.arch)
        csr.execute("SELECT * FROM jobs j "
                    "JOIN packages p ON p.id = j.package_id "
                    "WHERE j.status IS NULL "
                      "AND j.suite=? "
                      "AND p.name=? "
                      "AND j.arch=? ", qparams)
        rows = csr.fetchall()
        pins_params = self._normalize_pins(params.pin_packages)
        # TODO: extra_apt_sources, but debci-enqueue seems to have issues
        for row in rows:
            # We're getting YAML with randomly ordered stuff, so to test for
            # equivalence, we need to normalize everything
            pins_db = self._normalize_pins(yaml.safe_load(row['pin_packages'] or "[]"))
            if pins_params == pins_db:
                return True
        return False

    def upsert(self, package: SourcePackage) -> None:
        """Updates a package version, creating it if it doesn't exist yet."""
        csr = self.conn.cursor()
        qparams = (package.dist, package.name, package.arch, package.version)
        csr.execute("INSERT INTO versions(dist, package, arch, version) "
                    "VALUES(?, ?, ?, ?) "
                    "ON CONFLICT(dist, package, arch) "
                    "DO UPDATE SET version=excluded.version, updated_at=CURRENT_TIMESTAMP",
                    qparams)
        self.conn.commit()
        logging.debug("Upserted: dist=%s package=%s arch=%s version=%s", *qparams)


class APTSingleDistCache:
    """APT cache for a single distribution, with custom root directory.

    This instantiates a custom APT cache, rather than using the system cache.
    For now, all sources are trusted, meaning no signature verification of
    Release files is performed. After all, this cache is only used to discover
    package relationships.

    Parameters
    ----------
    root : str
        Root directory in which the cache will be created.
    dist : str
        Distribution for which to fetch data.
    components: list[str]
        Repository components (main, contrib, non-free, non-free-firmware).
    """
    def __init__(
        self,
        root: str,
        uri: str = "http://deb.debian.org/debian",
        dist: str = "unstable",
        # pylint: disable=dangerous-default-value   # only read from, never written to
        components: list[str] = ["main"],
    ) -> None:
        self.root = root
        self.uri = uri
        self.dist = dist
        self.components = components[:]
        # First, create the files and directories needed by apt_pkg.Cache
        dirs = [
            "etc/apt/preferences.d",
            "etc/apt/sources.list.d",
            "var/cache/apt/archives/partial",
            "var/lib/apt/lists/partial",
            "var/lib/dpkg",
        ]
        for dir_ in dirs:
            os.makedirs(os.path.join(root, dir_))
        open(os.path.join(root, "var/lib/dpkg/status"), "w").close()

        # Then, initialize the config
        apt_pkg.init_config()
        apt_pkg.config.set("Dir", root)
        apt_pkg.config.set("Dir::State::status", os.path.join(root, "var/lib/dpkg/status"))
        apt_pkg.init_system()

        # Then, load the data into it
        with open(os.path.join(root, "etc/apt/sources.list"), "w") as fobj:
            fobj.write(f"deb      [allow-insecure=yes] {uri} {dist} {' '.join(components)}\n")
            fobj.write(f"deb-src  [allow-insecure=yes] {uri} {dist} {' '.join(components)}\n")
        # The python-apt API is... peculiar
        sourcelist = apt_pkg.SourceList()
        sourcelist.read_main_list()
        cache = apt_pkg.Cache(None)
        cache.update(apt.progress.base.AcquireProgress(), sourcelist)
        self.cache = apt_pkg.Cache(None)
        self.prec = apt_pkg.PackageRecords(self.cache)
        self.srec = apt_pkg.SourceRecords()
        self._memo_source_version = {}
        self._memo_source_rdeps = {}
        self._memo_source_deps = {}

    def get_source_version(self, name: str) -> str:
        """Return the version of a source package, or None if not found."""
        if name in self._memo_source_version:
            return self._memo_source_version[name]
        self.srec.restart()
        version = None
        while self.srec.lookup(name):
            if not version or apt_pkg.version_compare(self.srec.version, version) > 0:
                version = self.srec.version
        self._memo_source_version[name] = version
        return self._memo_source_version[name]

    def get_source_rdeps(self, name: str) -> list[str]:
        """Get the list of reverse dependencies for a source package.

        We define these as follows:
          (1) the list of all source packages, excluding <name> itself,
          (2) which build at least one binary package that depends on a binary
              built by <name>.

        Consequently,
          * src:rocrand is a reverse dependency of src:rocm-hipamd because
            bin:librocrand1 depends on bin:libamdhip64-5
          * src:rocrand is not a reverse dependency of itself, even though
            bin:librocrand1-tests depends on bin:librocrand1
        """
        if name in self._memo_source_rdeps:
            return self._memo_source_rdeps[name]
        rdeps = set()
        self.srec.restart()
        if not self.srec.lookup(name):
            return []
        rdls = []
        for binname in self.srec.binaries:
            try:
                rdls.append(self.cache[binname].rev_depends_list)
            except KeyError as exc:
                raise ConflictedStateError(
                    f"src:{name}: cache for {self.dist} is missing bin:{binname}"
                ) from exc
        for dep in itertools.chain.from_iterable(rdls):
            # Again... peculiar API
            parent_pkg = self.cache[dep.parent_pkg.name]
            if not self.prec.lookup(parent_pkg.version_list[0].file_list[0]):
                raise NotImplementedError("Failed source lookup. Can this happen?")
            source_pkg = self.prec.source_pkg or self.prec.name
            if source_pkg == name:
                continue
            rdeps.add(source_pkg)
        self._memo_source_rdeps[name] = sorted(rdeps)
        return self._memo_source_rdeps[name]

    def get_source_deps(
        self, name: str,
        limit_to_bins: Optional[set[str]] = None,
    ) -> list[str]:
        """Get the list of dependencies for a source package.

        We define these as follows:
          (1) the list of all source packages, excluding <package> itself,
          (2) which build at least one binary package upon which a binary
              ``BINNAME`` built by <name> depends,
          (3) as long as either ``limit_to_bins`` is None, or a set of strings
              and ``BINNAME`` is a member of that set.

        Consequently,
          * src:glibc is a dependency of src:rocm-hipamd because
            bin:libamdhip64-5 depends on bin:libc6
          * src:rocm-hipamd is a dependency of src:rocrand because
            bin:librocrand1 depends on bin:libamdhip64-5
          * src:rocrand is not a dependency of itself, even though
            bin:librocrand1-tests depends on bin:librocrand1
          * However, if limit_to_bins were set to {'librocrand1-tests',
            'libhiprand1-tests'}, then src:sphinx-rtd-theme would not
            be a dependency of src:rocrand because the binaries that
            depend on it, librocrand-doc and libhiprand-doc, are not
            in the ``limit_to_bins`` list.
        """
        binnmu_re = re.compile(r"'\+b\d+$")
        limit_to_bins = frozenset(limit_to_bins or [])
        if (name, limit_to_bins) in self._memo_source_deps:
            return self._memo_source_deps[(name, limit_to_bins)]
        depends = set()
        self.srec.restart()
        if not self.srec.lookup(name):
            return []
        for binname in self.srec.binaries:
            if limit_to_bins and binname not in limit_to_bins:
                continue
            try:
                pkg = self.cache[binname]
            except KeyError as exc:
                raise ConflictedStateError(
                    f"src:{name}: cache for {self.dist} is missing bin:{binname}"
                ) from exc
            try:
                ver = [v for v in pkg.version_list
                        if binnmu_re.sub("", v.ver_str) == self.srec.version][0]
            except IndexError as exc:
                raise ConflictedStateError(
                    f"src:{name}: cache for {self.dist} has bin:{binname} but not "
                    f"matching source version={self.srec.version}",
                ) from exc
            for dep in ver.depends_list.get('Depends', []):
                # For now, we always look at the first dependency, meaning:
                # Depends: foo | bar -> we look at foo
                dep = dep[0]
                target_pkg = self.cache[dep.target_pkg.name]
                # A binary package can be both real and virtual... we test both
                if target_pkg.has_versions:
                    if not self.prec.lookup(target_pkg.version_list[0].file_list[0]):
                        raise NotImplementedError("Failed source lookup. Can this happen?")
                    source_pkg = self.prec.source_pkg or self.prec.name
                    if source_pkg == name:
                        continue
                    depends.add(source_pkg)
                if target_pkg.has_provides:
                    # We trigger on change to any of the providing packages
                    for _, _, version in target_pkg.provides_list:
                        if not self.prec.lookup(version.file_list[0]):
                            raise NotImplementedError("Failed source lookup. Can this happen?")
                        source_pkg = self.prec.source_pkg or self.prec.name
                        if source_pkg == name:
                            continue
                        depends.add(source_pkg)
        self._memo_source_deps[(name, limit_to_bins)] = sorted(depends)
        return self._memo_source_deps[(name, limit_to_bins)]


class DebCIConfig:
    """debci config wrapper.

    Exposes some of the debci configuration variables as attributes.

    Parameters
    ----------
    config_name : str
        Name of the debci configuration, which is assumed to exist in
        ``/etc/debci/<config_name/``.
    """
    def __init__(self, config_name: str = None) -> None:
        self.env = os.environ.copy()
        if config_name:
            self.env["debci_config_dir"] = f"/etc/debci/{config_name}"
        self.config_dir = self.get_string("config_dir")
        self.arch_list = self.get_list("arch_list")
        self.suite_list = self.get_list("suite_list")

    def get_string(self, option: str) -> str:
        """Gets an option, interpreting its value as a string."""
        cmd = ["debci", "config", "-v", option]
        out = subprocess.check_output(cmd, text=True, env=self.env)
        return out.split("\n", maxsplit=1)[0]

    def get_list(self, option: str) -> list[str]:
        """Gets an option, interpreting its value as a list."""
        return self.get_string(option).split()


class DebCITestScheduler:
    """Check releases for package changes, and schedule debci test jobs.

    Given a list of packages, this monitors one or more APT repositories for
    changes to those packages or to their binary dependencies, and schedules
    debci tests for the package and its reverse dependencies.

    Parameters
    ----------
    schedule_conffile : str
        List of source package names to track.
    """
    def __init__(self, schedule_conffile: str = "/etc/debci/scheduler.conf") -> None:
        self.tmpdir = tempfile.TemporaryDirectory()
        self.config = DebCIConfig()
        self.vdb = VersionDB()
        self.vdb_update_needed = set()
        self.test_needed = set()
        self.archs = []
        self.dists = []
        self.schedules = {}

        for block in self._read_blocks(schedule_conffile):
            parser = email.parser.HeaderParser()
            hdrs = parser.parsestr(block)

            # Control stanza
            if set(("Architectures", "Distributions", "Addon-Distributions")) & set(hdrs.keys()):
                self.archs = hdrs.get("Architectures", "").split()
                self.dists = hdrs.get("Distributions", "").split()
                pairs = hdrs.get("Addon-Distributions", "").split()
                self.addon_dists = dict(p.split(">") for p in pairs)
                continue

            # Schedule stanza
            sched = ScheduleParams(
                source=hdrs["Source"],
                dists=hdrs.get("Distributions", "").split(),
                limit_to_bins=hdrs.get("Limit-To-Binaries", "").split(),
                also_trigger_on=hdrs.get("Also-Trigger-On", "").split(),
                skip_trigger_on=hdrs.get("Skip-Trigger-On", "").split(),
            )
            self.schedules[hdrs["Source"]] = sched

        self.caches = {}
        for dist in self.config.suite_list + list(self.addon_dists.keys()):
            self.caches[dist] = APTSingleDistCache(
                os.path.join(self.tmpdir.name, dist),
                dist=dist,
                components=["main", "non-free-firmware"],
            )

    @staticmethod
    def _read_blocks(file_path):
        """Read blocks of text (separated by blank lines) from a file."""
        with open(file_path, encoding="utf8-") as fobj:
            lines = []
            for line in fobj:
                if line.strip():
                    if line.startswith("#"):
                        continue
                    if line.startswith((" ", "\t")):
                        lines[-1] += line.strip()
                    else:
                        lines.append(line.strip())
                elif lines:
                    yield '\n'.join(lines)
                    lines = []
            if lines:
                yield '\n'.join(lines)

    def get_deps_to_check(self, package: SourcePackage) -> list[SourcePackage]:
        """Produce a list of all dependencies to check for changes."""
        names = self.caches[package.dist].get_source_deps(
            package.name,
            self.schedules[package.name].limit_to_bins,
        )
        if package.dist in self.addon_dists:
            base = self.caches[self.addon_dists[package.dist]].get_source_deps(
                package.name,
                self.schedules[package.name].limit_to_bins,
            )
            names = list(set(base) | set(names))
        names += [t for t in self.schedules[package.name].also_trigger_on if t not in names]
        names = [d for d in names if d not in self.schedules[package.name].skip_trigger_on]
        deps = []
        for name in names:
            dist = package.dist
            version = self.caches[dist].get_source_version(name)
            if dist in self.addon_dists:
                bversion = self.caches[self.addon_dists[dist]].get_source_version(name)
                # eg: A package in experimental can be stale (=older than in unstable)
                if bversion and apt_pkg.version_compare(bversion, version or '') > 0:
                    dist = self.addon_dists[dist]
                    version = bversion
            deps.append(SourcePackage(name, dist, package.arch, version))
        return deps

    def get_rdeps_to_test(self, package: SourcePackage) -> list[SourcePackage]:
        """Produce a list of all reverse dependencies to schedule tests for."""
        try:
            names = self.caches[package.dist].get_source_rdeps(package.name)
        except ConflictedStateError as err:
            logging.warning("Skipping %s", err)
            return []
        if package.dist in self.addon_dists:
            base = self.caches[self.addon_dists[package.dist]].get_source_rdeps(package.name)
            names = list(set(base) | set(names))
        rdeps = []
        for name in names:
            dist = package.dist
            version = self.caches[dist].get_source_version(name)
            if dist in self.addon_dists:
                bversion = self.caches[self.addon_dists[dist]].get_source_version(name)
                # eg: A package in experimental can be stale (=older than in unstable)
                if bversion and apt_pkg.version_compare(bversion, version or '') > 0:
                    dist = self.addon_dists[dist]
                    version = bversion
            rdeps.append(SourcePackage(name, dist, package.arch, version))
        return rdeps

    def make_testparams(self, package: SourcePackage, trigger: SourcePackage) -> TestParams:
        """Generate test parameters from a package and its test trigger."""
        pin_packages = []
        if package.dist in self.addon_dists:
            pin_packages.append((f"src:{package.name}", package.dist))
        if package != trigger and trigger.dist in self.addon_dists:
            pin_packages.append((f"src:{trigger.name}", trigger.dist))
        return TestParams(
            pin_packages=tuple(pin_packages),
            dist=self.addon_dists.get(package.dist, package.dist),
            trigger=f"{trigger.name}={trigger.version or ''}",
        )

    def collect(
        self,
        dist: Optional[str] = None,
        arch: Optional[str] = None,
        name: Optional[str] = None,
    ) -> None:
        """Identify all the wantlist packages that need testing.

        If dist, arch, or name (or any combination thereof) are given, then the
        results will be filtered by that value.
        """
        product = itertools.product(
            [dist] if dist else self.dists + list(self.addon_dists.keys()),
            [arch] if arch else self.archs,
            [name] if name else list(self.schedules.keys()),
        )

        candidates = []
        for pdist, parch, pname in product:
            if self.schedules[pname].dists and pdist not in self.schedules[pname].dists:
                continue
            pversion = self.caches[pdist].get_source_version(pname)
            if pversion:
                candidates.append(SourcePackage(pname, pdist, parch, pversion))

        triggered = {}
        for package in candidates:
            try:
                deps_to_check = self.get_deps_to_check(package)
            except ConflictedStateError as err:
                logging.warning("Skipping %s", err)
                continue

            # Case 1: test triggered because the package was updated
            if self.vdb.is_newer(package):
                self.vdb_update_needed.add(package)
                if package not in triggered:
                    triggered[package] = package
                    logging.debug("Triggered: %s by %s", package, package)
                # We can't skip the next step (Case 2), as we *must* process
                # all dependencies, to update their last seen versions

            # Case 2: test triggered because a dependency or other trigger changed
            for dep in deps_to_check:
                if self.vdb.is_newer(dep):
                    self.vdb_update_needed.add(dep)
                    if package not in triggered:
                        triggered[package] = dep
                        logging.debug("Triggered: %s by %s", package, dep)

        # Then, add all triggered packages to the needed set, plus their
        # reverse dependencies
        while triggered:
            package, trigger = triggered.popitem()
            self.test_needed.add((package, self.make_testparams(package, trigger)))
            for rdep in self.get_rdeps_to_test(package):
                # Transitively test all wanted packages
                if rdep.name in self.schedules:
                    # Don't re-trigger already processed package
                    if rdep not in triggered:
                        triggered[rdep] = trigger
                        logging.debug("Triggered: %s by %s", rdep, trigger)
                else:
                    self.test_needed.add((rdep, self.make_testparams(rdep, trigger)))

    def enqueue(self, package: SourcePackage, params: TestParams) -> int:
        """Calls `debci enqueue` for a package.

        For any omitted arguments, debci will use the values from its configuration
        in /etc/debci.conf.
        """
        if self.vdb.is_queued(package, params):
            return 0
        self.vdb_update_needed.add(package)

        cmd = ["debci", "enqueue"]
        cmd += ["--suite", params.dist]
        cmd += ["--arch", package.arch]
        if params.requestor:
            cmd += ["--requestor", params.requestor]
        if params.backend:
            cmd += ["--backend", params.backend]
        if params.trigger:
            cmd += ["--trigger", params.trigger]
        if params.extra_apt_sources:
            cmd += ["--extra-apt-sources", ','.join(params.extra_apt_sources)]
        if params.pin_packages:
            pins = defaultdict(list)
            # Collect packages from the same dist
            for pin, dist in params.pin_packages:
                pins[dist].append(pin)
            # Flatten list to a comma-separated string
            for dist in pins:
                pins[dist] = ",".join(pins[dist])
            # Flatten dists
            cmd += ["-p", ",".join(f"{dist}={pins[dist]}" for dist in pins)]
        cmd.append(package.name)

        try:
            subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT)
        except subprocess.CalledProcessError as err:
            params = (err.returncode, " ".join(cmd), err.output)
            logging.error("enqueue: FAIL (exitcode=%s): %s: %s", *params)
            return err.returncode
        logging.info("enqueue: OK: %s", " ".join(cmd))
        return 0

    def submit(self):
        """Submit all collected tests to their queues."""
        for package, params in self.test_needed:
            self.enqueue(package, params)
        for package in self.vdb_update_needed:
            self.vdb.upsert(package)


def main() -> None:
    """Main"""
    logging.basicConfig(
        filename="/var/log/debci/scheduler.log",
        level=logging.DEBUG,
        format="%(asctime)s %(levelname)s %(message)s",
    )
    logging.info("debci-scheduler started")

    lockfile = "/tmp/debci-scheduler.lock"
    try:
        lockfileobj = open(lockfile, "a")
    except OSError as err:
        logging.fatal("Could not open lockfile: %s", err)
        sys.exit(1)
    try:
        fcntl.flock(lockfileobj.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
    except OSError:
        logging.fatal("Another instance of debci-scheduler is already running.")
        sys.exit(1)

    scheduler = DebCITestScheduler()
    scheduler.collect()
    scheduler.submit()

    logging.info("debci-scheduler finished")
    fcntl.flock(lockfileobj.fileno(), fcntl.LOCK_UN)
    lockfileobj.close()
    os.remove(lockfile)


if __name__ == '__main__':
    main()
