#!/usr/bin/env python3

import argparse
from e3.fs import sync_tree
from e3.gaia.api import GAIA
from e3.testsuite.report.rewriting import BaseBaselineRewriter, RewritingError
from e3.testsuite.report.index import ReportIndex
from e3.testsuite.utils import ColorConfig
from adacore_gitlab_scripts.util.e3_baseline_updates import (
    Gitlab,
    GitLabProjectDescription,
)
import os.path
import tarfile
import tempfile
import yaml

descr = """This script updates the outputs from a previous run of the
testsuite."""


def parse_options():
    args = None
    parser = argparse.ArgumentParser(description=descr)
    parser.add_argument(
        "-o",
        "--output-dir",
        dest="outputdir",
        help="Specify the output dir to read outputs from",
    )
    parser.add_argument(
        "--from-gaia",
        dest="gaia_id",
        help="Specify the GAIA event id to read outputs from",
    )
    parser.add_argument(
        "--from-gitlab",
        dest="gitlab_id",
        help="Specify the Gitlab pipeline id to read outputs from",
    )
    parser.add_argument(
        "--compute-time", default=False, action=argparse.BooleanOptionalAction
    )
    args = parser.parse_args()
    if (
        (args.outputdir and args.gitlab_id)
        or (args.outputdir and args.gaia_id)
        or (args.gitlab_id and args.gaia_id)
    ):
        print("please provide only one of (--output-dir, --from-gaia, --from-gitlab)")
        exit(1)
    return args


class BaselineRewriter(BaseBaselineRewriter):
    def baseline_filename(self, test_name: str) -> str:
        testdir = os.path.join("tests", test_name)
        if not os.path.exists(testdir):
            testdir = os.path.join("internal", test_name)
        if not os.path.exists(testdir):
            testdir = os.path.join("sparklib", test_name)
        if not os.path.exists(testdir):
            raise RewritingError(f"cannot find test {test_name!r}")
        return os.path.join(testdir, "test.out")


def get_single_subdir(d):
    entry = os.listdir(d)[0]
    return os.path.join(d, entry)


def download_from_gaia(gaia_id):
    g = GAIA()
    r = g.request("GET", f"event/attachment/{gaia_id}/tests_log")
    with tempfile.NamedTemporaryFile("wb", suffix=".tgz", delete=False) as f:
        for chunk in r:
            f.write(chunk)
        name = f.name
    tmpdir = tempfile.mkdtemp()
    with tarfile.open(name, mode="r:gz") as tar:
        tar.extractall(tmpdir)
    return get_single_subdir(tmpdir)


def update_sessions_from_gaia(gaia_id):
    try:
        g = GAIA()
        r = g.request("GET", f"event/attachment/{gaia_id}/sessions.tgz")
        with tempfile.NamedTemporaryFile("wb", suffix=".tgz", delete=False) as f:
            for chunk in r:
                f.write(chunk)
            name = f.name
        tmpdir = tempfile.mkdtemp()
        with tarfile.open(name, mode="r:gz") as tar:
            tar.extractall(tmpdir)
        updated, deleted = sync_tree(tmpdir, "../..", delete=False)
        assert deleted == []
        for fn in updated:
            if os.path.isfile(fn):
                print(f"updated {fn}")
    except Exception as e:
        print(
            "warning: failed to update sessions, "
            f"possibly GAIA event doesn't have session files: {str(e)}"
        )


def find_test_by_name(name):
    for d in ["tests", "internal", "sparklib"]:
        testdir = os.path.join(d, name)
        if os.path.exists(testdir):
            return testdir
    return None


def is_large(name):
    testdir = find_test_by_name(name)
    if testdir is not None:
        yamlfile = os.path.join(testdir, "test.yaml")
        if not os.path.exists(yamlfile):
            return False
        with open(yamlfile) as f:
            data = yaml.safe_load(f)
        if "large" not in data:
            return False
        if data["large"]:
            return True
        else:
            return False
        return False
    return None


def read_write_yaml(testname, fn):
    testdir = find_test_by_name(testname)
    if testdir is None:
        print(f"Unable to find test dir to modify yaml file for {testname}")
        return
    yamlfile = os.path.join(testdir, "test.yaml")
    if not os.path.exists(yamlfile):
        data = {}
    else:
        with open(yamlfile, "r") as f:
            data = yaml.safe_load(f)
    data = fn(data)
    if not data:
        os.remove(yamlfile)
    else:
        with open(yamlfile, "w") as f:
            yaml.dump(data, f)


def make_large(testname):
    def set_large(d):
        result = dict(d)
        result["large"] = True
        return result

    print(f"Moving {testname} to large")
    read_write_yaml(testname, set_large)


def make_nonlarge(testname):
    def erase_large(d):
        result = dict(d)
        del result["large"]
        return result

    print(f"Removing {testname} from large")
    read_write_yaml(testname, erase_large)


def compute_time(outputdir):
    report = ReportIndex.read(outputdir)
    for k, v in report.entries.items():
        if v.time is not None:
            large = is_large(k)
            if v.time > 60:
                if large:
                    pass
                else:
                    make_large(k)
            else:
                if large:
                    make_nonlarge(k)
                else:
                    pass
        else:
            print(f"No time information for test {k}")


def main():
    args = parse_options()
    if args.compute_time:
        if args.gaia_id:
            download_from_gaia(args.gaia_id)
            tmpdir = download_from_gaia(args.gaia_id)
            print(tmpdir)
            compute_time(tmpdir)
        elif args.outputdir:
            compute_time(args.outputdir)
    if args.gaia_id:
        update_sessions_from_gaia(args.gaia_id)
        tmpdir = download_from_gaia(args.gaia_id)
        outputdirs = [tmpdir]
    elif args.gitlab_id:
        p = GitLabProjectDescription("168", ["spark2014"])
        gl = Gitlab()
        outputdirs = gl.download_test_results(p, args.gitlab_id)
        # artifact contains a coverage subdir which should be ignored
        outputdirs = [item for item in outputdirs if "coverage" not in item]
    elif args.outputdir:
        outputdirs = [args.outputdir]
    else:
        outputdirs = [os.path.join("out", "new")]
    BR = BaselineRewriter(ColorConfig())
    for d in outputdirs:
        summary = BR.rewrite(d)
        for elt in summary.new_baselines:
            print(f"updated baseline for {elt}")
        for elt in summary.deleted_baselines:
            print(f"deleted empty baseline for {elt}")
        for elt in summary.errors:
            print(f"error updating baseline for {elt}")


main()
