#!/usr/bin/python3

import abc
import zipfile
import sys
import re
import xml.dom.minidom as xmd
import xml.dom as xdom
import argparse
import logging
from contextlib import ExitStack
from collections import defaultdict
from pathlib import Path
from typing import NamedTuple, Self, Generator, Any, IO, Iterable

log = logging.getLogger("wrep-importtable")

BUFR_UNIT_MAP: dict[str, str] = {"CCITT IA5": "CCITTIA5"}
CREX_UNIT_MAP: dict[str, str] = {}


class Fail(BaseException):
    pass


class EntryB(NamedTuple):
    """B table entry."""

    fxy: str
    name: str
    bufr_unit: str
    bufr_scale: int
    bufr_reference: int
    bufr_bits: int
    crex_unit: str | None
    crex_scale: int | None
    crex_chars: int | None

    def has_crex(self) -> bool:
        """Check if this entry has CREX data."""
        return not (
            self.crex_unit is None
            or self.crex_scale is None
            or self.crex_chars is None
        )

    def wreport_entry(self) -> str:
        """Make a wreport table entry line."""
        if not self.has_crex():
            return " %-6.6s %-64.64s %-24.24s %3d %12d %3d" % (
                self.fxy,
                self.name,
                self.bufr_unit,
                self.bufr_scale,
                self.bufr_reference,
                self.bufr_bits,
            )
        else:
            return " %-6.6s %-64.64s %-24.24s %3d %12d %3d %-24.24s %2d %9d" % (
                self.fxy,
                self.name,
                self.bufr_unit,
                self.bufr_scale,
                self.bufr_reference,
                self.bufr_bits,
                self.crex_unit,
                self.crex_scale,
                self.crex_chars,
            )


class EntryD(NamedTuple):
    """D table entry."""

    fxy: str
    expanded: list[str]


def normalise_bufr_unit(unit):
    """
    Normalise BUFR unit
    """
    unit = unit.upper()
    return BUFR_UNIT_MAP.get(unit, unit)


def normalise_crex_unit(unit):
    """
    Normalise CREX unit
    """
    unit = unit.upper()
    return CREX_UNIT_MAP.get(unit, unit)


def read_text(el):
    """
    Return the concatenation of all text nodes that are direct children of el
    """
    res = []
    for t in el.childNodes:
        if t.nodeType != xdom.Node.TEXT_NODE:
            continue
        if not t.data:
            continue
        res.append(t.data)
    return "".join(res)


def read_mapping(node):
    """
    Convert a DOM node in the form <foo>bar</foo> into a (foo, bar) couple
    """
    return node.nodeName, read_text(node)


def parse_xml_b(text: str) -> Generator[EntryB, None, None]:
    """Parse a B table from XML."""
    dom = xmd.parseString(text)
    for n in dom.documentElement.childNodes:
        if n.nodeType != xdom.Node.ELEMENT_NODE:
            continue
        if "TableB" not in n.nodeName:
            continue
        data: dict[str, str] = {}
        for x in n.childNodes:
            if x.nodeType != xdom.Node.ELEMENT_NODE:
                continue
            k, v = read_mapping(x)
            data[k] = v

        el_name = data.get("ElementName_en", None)
        if el_name is None:
            el_name = data["ElementName_E"]

        kwargs: dict[str, Any] = {
            "fxy": data["FXY"],
            "name": el_name,
            "bufr_unit": normalise_bufr_unit(data["BUFR_Unit"]),
            "bufr_scale": int(data["BUFR_Scale"]),
            "bufr_reference": int(data["BUFR_ReferenceValue"]),
            "bufr_bits": int(data["BUFR_DataWidth_Bits"]),
        }

        if (
            "CREX_Unit" not in data
            or "CREX_Scale" not in data
            or "CREX_DataWidth_Char" not in data
        ):
            kwargs["crex_unit"] = None
            kwargs["crex_scale"] = None
            kwargs["crex_chars"] = None
        else:
            kwargs["crex_unit"] = normalise_crex_unit(data["CREX_Unit"])
            kwargs["crex_scale"] = int(data["CREX_Scale"])
            kwargs["crex_chars"] = int(data["CREX_DataWidth_Char"])

        yield EntryB(**kwargs)


def parse_xml_d(text: str) -> Generator[EntryD, None, None]:
    """Parse a D table from XML."""
    dom = xmd.parseString(text)
    res: dict[str, list[str]] = defaultdict(list)
    for n in dom.documentElement.childNodes:
        if n.nodeType != xdom.Node.ELEMENT_NODE:
            continue
        if "TableD" not in n.nodeName:
            continue
        # This is an odd way to encode a D table: each element maps a D
        # code to one element of its expansion, instead of mapping a D code
        # to a list of expanded codes
        fxy1 = None
        fxy2 = None
        for x in n.childNodes:
            if x.nodeType != xdom.Node.ELEMENT_NODE:
                continue
            if x.nodeName == "FXY1":
                fxy1 = read_text(x).strip()
            if x.nodeName == "FXY2":
                fxy2 = read_text(x).strip()
        if fxy1 is not None and fxy2 is not None:
            res[fxy1].append(fxy2)
    for k, v in res.items():
        yield EntryD(k, v)


class ZipSource(ExitStack):
    """
    Tables downloaded as ZIP file.
    """

    re_version = re.compile(
        r"(?:va_)?BUFR(?:CREX|4)"
        r"(?:-v?|_)(?P<v>\d+)"
        r"(?:[._](?P<sv>\d+))?"
        r"(?:[._](?P<lv>\d+))?"
    )

    def __init__(self, path: Path) -> None:
        super().__init__()
        self.path = path
        self.version: int
        self.subversion: int
        self.localversion: int
        self.zip = self.enter_context(zipfile.ZipFile(self.path, "r"))

        # Detect zip layout and table versions
        self.root: str | None
        if root := self.find_root():
            if mo := self.re_version.match(root):
                pass
            else:
                raise Fail("Cannot detect version from zipfile root path")
        else:
            if mo := self.re_version.match(path.name):
                root = None
            else:
                raise Fail(
                    "Cannot detect version from zipfile root path or filename"
                )
        self.root = root
        self.version = int(mo["v"])
        self.subversion = int(mo["sv"]) if mo["sv"] else 0
        self.localversion = int(mo["lv"]) if mo["lv"] else 0

    def find_root(self) -> str | None:
        """Get the name of the zip root directory."""
        roots: set[str] = set()
        for name in self.zip.namelist():
            if "/" in name:
                root = name.split("/")[0]
            else:
                root = name
            if not self.re_version.match(name):
                continue
            roots.add(root)
        if len(roots) > 1:
            raise Fail(
                "Found multiple BUFR* root paths in zipfile:"
                + ", ".join(sorted(roots))
            )
        if not roots:
            return None
        return next(iter(roots))

    def find_csv(self, name_regex: str) -> list[Path]:
        """Look for CSV tables inside the file."""
        if self.root is None:
            return []

        pattern = re.compile(re.escape(self.root) + f"/{name_regex}$")
        found: list[Path] = []
        for name in self.zip.namelist():
            if pattern.match(name):
                found.append(Path(name))
        return found

    def find_xml(self, name: str) -> Path | None:
        """Look for an xml file."""
        paths: list[str] = []

        path = self.root + "/" if self.root else ""
        path += "xml/" + name
        paths.append(path)

        path = self.root + "/" if self.root else ""
        path += name
        paths.append(path)

        for candidate in paths:
            try:
                self.zip.getinfo(candidate)
            except KeyError:
                pass
            else:
                return Path(candidate)
        return None

    def read(self, path: Path) -> str:
        """Read the contents of a file."""
        with self.zip.open(path.as_posix(), "r") as fd:
            return fd.read().decode()


class Importer(abc.ABC):
    source: ZipSource

    @abc.abstractmethod
    def read_b(self) -> Generator[EntryB, None, None]:
        """Parse B table entries."""

    @abc.abstractmethod
    def read_bufr_d(self) -> Generator[EntryD, None, None]:
        """Parse BUFR D table entries."""

    @abc.abstractmethod
    def read_crex_d(self) -> Generator[EntryD, None, None]:
        """Parse CREX D table entries."""

    @classmethod
    def from_path(cls, path: Path) -> "Importer":
        """Instantiate an Importer from a path."""
        src = ZipSource(path)
        # print(src.root, src.version, src.subversion, src.localversion)
        csv_importer = CSVImporter.from_zip(src)
        xml_importer = XMLImporter.from_zip(src)
        # print(csv_importer, xml_importer)
        # TODO: prioritize CSV?
        if xml_importer:
            return xml_importer
        raise NotImplementedError()


class CSVImporter(Importer):
    """Import tables from CSV data inside a .zip file."""

    def __init__(
        self,
        source: ZipSource,
        b_tables: list[Path],
        bufr_d_tables: list[Path],
        crex_d_tables: list[Path],
    ) -> None:
        self.source = source
        self.b_tables = b_tables
        self.bufr_d_tables = bufr_d_tables
        self.crex_d_tables = crex_d_tables

    def read_b(self) -> Generator[EntryB, None, None]:
        raise NotImplementedError()

    def read_bufr_d(self) -> Generator[EntryD, None, None]:
        raise NotImplementedError()

    def read_crex_d(self) -> Generator[EntryD, None, None]:
        raise NotImplementedError()

    @classmethod
    def from_zip(cls, source: ZipSource) -> Self | None:
        """
        Instantiate the importer from a zip source.

        :returns: None if csv tables are not present
        """
        b_tables = source.find_csv(r"BUFRCREX_TableB_en_\d+.csv")
        if not b_tables:
            return None
        bufr_d_tables = source.find_csv(r"BUFR_TableD_en_\d+.csv")
        if not bufr_d_tables:
            raise Fail("zip file contains CSV B tables but no BUFR D tables")
        crex_d_tables = source.find_csv(r"CREX_TableD_en_\d+.csv")
        if not crex_d_tables:
            raise Fail("zip file contains CSV B tables but no CREX D tables")
        return cls(source, b_tables, bufr_d_tables, crex_d_tables)


class XMLImporter(Importer):
    """Import tables from XML data inside a .zip file."""

    def __init__(
        self,
        source: ZipSource,
        b_table: Path,
        bufr_d_table: Path,
        crex_d_table: Path,
    ) -> None:
        self.source = source
        self.b_table = b_table
        self.bufr_d_table = bufr_d_table
        self.crex_d_table = crex_d_table

    def read_b(self) -> Generator[EntryB, None, None]:
        """
        Read an XML B table from the given file descriptor and yield DICTs with
        all the <key>value</key> elements of B table records
        """
        yield from parse_xml_b(self.source.read(self.b_table))

    def read_bufr_d(self) -> Generator[EntryD, None, None]:
        yield from parse_xml_d(self.source.read(self.bufr_d_table))

    def read_crex_d(self) -> Generator[EntryD, None, None]:
        yield from parse_xml_d(self.source.read(self.crex_d_table))

    @classmethod
    def from_zip1(cls, source: ZipSource) -> Self | None:
        b_table = source.find_xml("BUFRCREX_TableB_en.xml")
        if not b_table:
            return None
        bufr_d_table = source.find_xml("BUFR_TableD_en.xml")
        if not bufr_d_table:
            raise Fail("zip file contains an XML B table but no BUFR D table")
        crex_d_table = source.find_xml("CREX_TableD_en.xml")
        if not crex_d_table:
            raise Fail("zip file contains an XML B table but no CREX D table")
        return cls(source, b_table, bufr_d_table, crex_d_table)

    @classmethod
    def from_zip2(cls, source: ZipSource) -> Self | None:
        b_table = source.find_xml("va_BUFRCREX_34_0_0_TableB_en.xml")
        if not b_table:
            return None
        bufr_d_table = source.find_xml("va_BUFR_34_0_0_TableD_en.xml")
        if not bufr_d_table:
            raise Fail("zip file contains an XML B table but no BUFR D table")
        crex_d_table = source.find_xml("va_CREX_34_0_0_TableD_en.xml")
        if not crex_d_table:
            raise Fail("zip file contains an XML B table but no CREX D table")
        return cls(source, b_table, bufr_d_table, crex_d_table)

    @classmethod
    def from_zip(cls, source: ZipSource) -> Self | None:
        """
        Instantiate the importer from a zip source.

        :returns: None if csv tables are not present
        """
        if res := cls.from_zip1(source):
            return res
        return cls.from_zip2(source)


# class Table:
#    def __init__(self, v, sv, lv, text=None):
#        # Version
#        self.v = v
#        # Subversion
#        self.sv = sv
#        # Local version
#        self.lv = lv
#        if text is not None:
#            self.parse(text)
#
#    @classmethod
#    def from_zip(cls, source: ZipSource) -> Self | None:
#        """
#        Instantiate the importer from a zip source.
#
#        :returns: None if csv tables are not present
#        """
#        b_table = source.find_xml("BUFRCREX_TableB_en.xml")
#        if not b_table:
#            return None
#        bufr_d_table = source.find_xml("BUFR_TableD_en.xml")
#        if not bufr_d_tables:
#            raise Fail("zip file contains CSV B tables but no BUFR D tables")
#        crex_d_tables = source.find_csv(r"CREX_TableD_en_\d+.csv")
#        if not crex_d_tables:
#            raise Fail("zip file contains CSV B tables but no CREX D tables")
#        return cls(source, b_tables, bufr_d_tables, crex_d_tables)
#
#    def find_xml_crex_d(self) -> Path | None:
#        """Enumerate XML tables inside the file."""
#        return self._find_xml("CREX_TableD_en.xml")


def write_table_b(importer: Importer) -> None:
    """Output a B table for wreport from an importer."""
    src = importer.source
    path_bufr = f"B0000000000000{src.version:03d}{src.subversion:03d}.txt"
    path_crex = f"B00{src.version:02d}{src.subversion:02d}.txt"
    count_bufr = 0
    count_crex = 0
    with (
        open(path_bufr, "wt") as out1,
        open(path_crex, "wt") as out2,
    ):
        for entry in sorted(importer.read_b(), key=lambda e: e.fxy):
            line = entry.wreport_entry()
            print(line, file=out1)
            count_bufr += 1
            if entry.has_crex():
                print(line, file=out2)
                count_crex += 1
    log.info("%s: %d entries", path_bufr, count_bufr)
    log.info("%s: %d entries", path_crex, count_crex)


def write_table_d(entries: Iterable[EntryD], out: IO[str]) -> None:
    """Output a D table from a list of entries."""
    count = 0
    for entry in sorted(entries, key=lambda e: e.fxy):
        print(
            " %-6.6s %2d %-6.6s"
            % (entry.fxy, len(entry.expanded), entry.expanded[0]),
            file=out,
        )
        for fxy in entry.expanded[1:]:
            print("           %-6.6s" % fxy, file=out)
        count += 1
    log.info("%s: %d entries", out.name, count)


def write_bufr_table_d(importer: Importer) -> None:
    src = importer.source
    with open(
        f"D0000000000000{src.version:03d}{src.subversion:03d}.txt", "wt"
    ) as out:
        write_table_d(importer.read_bufr_d(), out)


def write_crex_table_d(importer: Importer) -> None:
    src = importer.source
    with open(f"D00{src.version:02d}{src.subversion:02d}.txt", "wt") as out:
        write_table_d(importer.read_bufr_d(), out)


def main():
    parser = argparse.ArgumentParser(
        description=(
            "builds ECMWF-style BUFR/CREX table files from zipped WMO XML table"
            " information."
            " Files are written in the current directory."
            " New files can be found at "
            " https://community.wmo.int/activity-areas/wis/latest-version"
        )
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true", help="verbose output"
    )
    parser.add_argument("--debug", action="store_true", help="debug output")
    parser.add_argument(
        "--zipfile", action="store", type=Path, help="zip file to open"
    )
    # parser.add_argument(
    #     "--btable",
    #     action="store",
    #     type=Path,
    #     help="xml file to open for B table",
    # )
    # parser.add_argument(
    #     "--dtable",
    #     action="store",
    #     type=Path,
    #     help="xml file to open for D table",
    # )
    args = parser.parse_args()

    log_format = "%(asctime)-15s %(levelname)s %(message)s"
    level = logging.WARN
    if args.debug:
        level = logging.DEBUG
    elif args.verbose:
        level = logging.INFO
    logging.basicConfig(level=level, stream=sys.stderr, format=log_format)

    if args.zipfile:
        importer = Importer.from_path(args.zipfile)
        write_table_b(importer)
        write_bufr_table_d(importer)
        write_crex_table_d(importer)
    # elif args.btable:
    #     with open(args.btable, "rt") as fd:
    #         table = TableB(255, 255, 255, text=fd.read())
    #         table.write_files()
    # elif args.dtable:
    #     with open(args.dtable, "rt") as fd:
    #         table = TableDBUFR(255, 255, 255, text=fd.read())
    #         table.write_files()

    log.info(
        "Done. You can copy the table files to the table directory"
        " (see wrep --info for its location)."
    )


if __name__ == "__main__":
    try:
        main()
    except Fail as e:
        print(e, file=sys.stderr)
        sys.exit(1)
    except Exception:
        log.exception("uncaught exception")
