#!/usr/bin/python3

import re
import os
import sys
import tarfile
import shutil
import logging
import argparse

log = logging.getLogger()
re_data = re.compile(r"(?:\.grib|\.grib1|\.grib2|\.vm2|\.odimh5|\.bufr|.repack)$")


def filter_year_limit(value: str, limit) -> bool:
    """
    Returns True if value is allowed by the given year limit
    """
    if not value.isdigit():
        return True
    if limit is None:
        return True
    vmin, vmax = limit
    if vmin is not None and value < vmin:
        return False
    if vmax is not None and value > vmax:
        return False
    return True


def do_archive(src, rootpath=None, skip_archives=False, year_limit=None):
    dsname = src
    if rootpath is not None:
        src = os.path.join(rootpath, src)

    with tarfile.open(mode="w:gz", format=tarfile.PAX_FORMAT, fileobj=sys.stdout.buffer) as out:
        for root, dirnames, fnames, rootfd in os.fwalk(src):
            if skip_archives and ".archive" in dirnames:
                dirnames[::] = [d for d in dirnames if d != ".archive"]

            if rootpath is not None:
                root = os.path.relpath(root, rootpath)

            if year_limit != (None, None) and root == dsname:
                dirnames[::] = [d for d in dirnames if filter_year_limit(d, year_limit)]

            for fname in fnames:
                hfd = os.open(fname, os.O_RDONLY, dir_fd=rootfd)
                with os.fdopen(hfd, "rb") as fd:
                    info = out.gettarinfo(
                            arcname=os.path.join(root, fname),
                            fileobj=fd)

                    if re_data.search(fname):
                        info.pax_headers["hole"] = str(info.size)
                        info.size = 0
                        log.info("h %s", info.name)
                    else:
                        log.info("a %s", info.name)

                    out.addfile(info, fd)


def do_extract(src, rootpath=None):
    with tarfile.open(src, mode="r:*") as tf:
        while True:
            info = tf.next()
            if info is None:
                break
            dest = info.name
            if rootpath is not None:
                dest = os.path.join(rootpath, dest)
            os.makedirs(os.path.dirname(dest), exist_ok=True)

            hole = info.pax_headers.get("hole", None)
            if hole is not None:
                log.info("h %s", dest)
                fd = os.open(dest, os.O_WRONLY | os.O_EXCL | os.O_CREAT)
                os.ftruncate(fd, int(hole))
                os.close(fd)
            else:
                log.info("a %s", dest)
                srcfd = tf.extractfile(info)
                if srcfd is None:
                    continue
                with open(dest, "wb") as dstfd:
                    shutil.copyfileobj(srcfd, dstfd)

            os.chmod(dest, info.mode)
            os.utime(dest, times=(info.mtime, info.mtime))


def main():
    parser = argparse.ArgumentParser(
        description="Create a .tar archive with a mock version of a dataset")
    parser.add_argument("src", action="store", help="dataset to archive or archive to extract")
    parser.add_argument("-x", "--extract", action="store_true", help="extract an archive")
    parser.add_argument("-c", "--create", action="store_true", help="create archive (default)")
    parser.add_argument("-C", dest="root", action="store", help="root directory to use (default: current directory)")
    parser.add_argument("--debug", action="store_true", help="debugging output")
    parser.add_argument("-v", "--verbose", action="store_true", help="verbose output")
    parser.add_argument("--skip-archives", action="store_true", help="skip .archive contents")
    parser.add_argument("--years", action="store",
                        help="limit data outside .archive to the given years"
                             " (can be yyyy, yyyy:, :yyyy, or yyyy:yyyy, extremes included)")
    args = parser.parse_args()

    log_format = "%(asctime)-15s %(levelname)s %(message)s"
    level = logging.WARN
    if args.debug:
        level = logging.DEBUG
    elif args.verbose:
        level = logging.INFO
    logging.basicConfig(level=level, stream=sys.stderr, format=log_format)

    if args.extract:
        do_extract(args.src, args.root)
    else:
        if args.years:
            if args.years.startswith(":"):
                year_limit = (None, args.years[1:])
            elif args.years.endswith(":"):
                year_limit = (args.years[:-1], None)
            elif ":" in args.years:
                year_limit = tuple(args.years.split('-', 1))
            else:
                year_limit = (args.years, args.years)
        else:
            year_limit = None
        do_archive(args.src, args.root, skip_archives=args.skip_archives, year_limit=year_limit)


if __name__ == "__main__":
    main()
