#!/usr/bin/env python3

import errno
import os
import sys
import tempfile

from posix_parity import cleanup_dir
from posix_parity import fail
from posix_parity import join
from posix_parity import mergerfs_mount
from posix_parity import temp_dir
from posix_parity import touch


def errno_name(err):
    if err is None:
        return "None"
    return errno.errorcode.get(err, str(err))


def invoke(callable_):
    try:
        return True, callable_(), 0
    except OSError as exc:
        return False, None, exc.errno


def compare(name, m_call, n_call, value_cmp=None):
    m_ok, m_val, m_errno = invoke(m_call)
    n_ok, n_val, n_errno = invoke(n_call)

    if m_ok != n_ok:
        return (
            f"{name}: success mismatch mergerfs={m_ok} native={n_ok} "
            f"(mergerfs_errno={m_errno}:{errno_name(m_errno)} native_errno={n_errno}:{errno_name(n_errno)})"
        )
    if m_errno != n_errno:
        return (
            f"{name}: errno mismatch mergerfs={m_errno}:{errno_name(m_errno)} "
            f"native={n_errno}:{errno_name(n_errno)}"
        )
    if m_ok and value_cmp is not None and not value_cmp(m_val, n_val):
        return f"{name}: value mismatch mergerfs={m_val!r} native={n_val!r}"

    return None


def getxattr_cmp(path_m, path_n, xname):
    return compare(
        f"getxattr {xname} {os.path.basename(path_m)}",
        lambda: os.getxattr(path_m, xname),
        lambda: os.getxattr(path_n, xname),
        lambda a, b: a == b,
    )


def list_has_cmp(path_m, path_n, xname):
    return compare(
        f"listxattr has {xname} {os.path.basename(path_m)}",
        lambda: os.listxattr(path_m),
        lambda: os.listxattr(path_n),
        lambda a, b: ((xname in a) == (xname in b)),
    )


def lsetxattr(path, name, value):
    if hasattr(os, "lsetxattr"):
        return os.lsetxattr(path, name, value)
    return os.setxattr(path, name, value, follow_symlinks=False)


def lgetxattr(path, name):
    if hasattr(os, "lgetxattr"):
        return os.lgetxattr(path, name)
    return os.getxattr(path, name, follow_symlinks=False)


def llistxattr(path):
    if hasattr(os, "llistxattr"):
        return os.llistxattr(path)
    return os.listxattr(path, follow_symlinks=False)


def lremovexattr(path, name):
    if hasattr(os, "lremovexattr"):
        return os.lremovexattr(path, name)
    return os.removexattr(path, name, follow_symlinks=False)


def main():
    try:
        with mergerfs_mount() as (mount, _):
            with tempfile.TemporaryDirectory() as native:
                merge_base = temp_dir(mount)
                try:
                    native_base = join(native, os.path.basename(merge_base))
                    os.makedirs(native_base, exist_ok=True)

                    root_m = merge_base
                    root_n = native_base

                    file_m = join(root_m, "file")
                    file_n = join(root_n, "file")
                    dir_m = join(root_m, "dir")
                    dir_n = join(root_n, "dir")
                    link_m = join(root_m, "link")
                    link_n = join(root_n, "link")

                    touch(file_m, b"x")
                    touch(file_n, b"x")
                    os.makedirs(dir_m, exist_ok=True)
                    os.makedirs(dir_n, exist_ok=True)
                    try:
                        os.unlink(link_m)
                    except FileNotFoundError:
                        pass
                    try:
                        os.unlink(link_n)
                    except FileNotFoundError:
                        pass
                    os.symlink("file", link_m)
                    os.symlink("file", link_n)

                    objs = [
                        ("file", file_m, file_n),
                        ("dir", dir_m, dir_n),
                    ]

                    xbase = "user.xattr_matrix"

                    for obj_name, path_m, path_n in objs:
                        xname = f"{xbase}.{obj_name}"

                        err = compare(
                            f"setxattr default {obj_name}",
                            lambda p=path_m, n=xname: os.setxattr(p, n, b"v1"),
                            lambda p=path_n, n=xname: os.setxattr(p, n, b"v1"),
                        )
                        if err:
                            return fail(err)

                        err = getxattr_cmp(path_m, path_n, xname)
                        if err:
                            return fail(err)

                        err = list_has_cmp(path_m, path_n, xname)
                        if err:
                            return fail(err)

                        if hasattr(os, "XATTR_CREATE"):
                            err = compare(
                                f"setxattr CREATE existing {obj_name}",
                                lambda p=path_m, n=xname: os.setxattr(p, n, b"v2", os.XATTR_CREATE),
                                lambda p=path_n, n=xname: os.setxattr(p, n, b"v2", os.XATTR_CREATE),
                            )
                            if err:
                                return fail(err)

                        if hasattr(os, "XATTR_REPLACE"):
                            err = compare(
                                f"setxattr REPLACE existing {obj_name}",
                                lambda p=path_m, n=xname: os.setxattr(p, n, b"v3", os.XATTR_REPLACE),
                                lambda p=path_n, n=xname: os.setxattr(p, n, b"v3", os.XATTR_REPLACE),
                            )
                            if err:
                                return fail(err)

                            err = getxattr_cmp(path_m, path_n, xname)
                            if err:
                                return fail(err)

                        missing_name = f"{xname}.missing"
                        if hasattr(os, "XATTR_REPLACE"):
                            err = compare(
                                f"setxattr REPLACE missing {obj_name}",
                                lambda p=path_m, n=missing_name: os.setxattr(p, n, b"v", os.XATTR_REPLACE),
                                lambda p=path_n, n=missing_name: os.setxattr(p, n, b"v", os.XATTR_REPLACE),
                            )
                            if err:
                                return fail(err)

                        err = compare(
                            f"removexattr existing {obj_name}",
                            lambda p=path_m, n=xname: os.removexattr(p, n),
                            lambda p=path_n, n=xname: os.removexattr(p, n),
                        )
                        if err:
                            return fail(err)

                        err = compare(
                            f"removexattr missing {obj_name}",
                            lambda p=path_m, n=xname: os.removexattr(p, n),
                            lambda p=path_n, n=xname: os.removexattr(p, n),
                        )
                        if err:
                            return fail(err)

                    lxname = f"{xbase}.link"
                    err = compare(
                        "lsetxattr symlink default",
                        lambda: lsetxattr(link_m, lxname, b"lv1"),
                        lambda: lsetxattr(link_n, lxname, b"lv1"),
                    )
                    if err:
                        return fail(err)

                    err = compare(
                        "lgetxattr symlink",
                        lambda: lgetxattr(link_m, lxname),
                        lambda: lgetxattr(link_n, lxname),
                        lambda a, b: a == b,
                    )
                    if err:
                        return fail(err)

                    err = compare(
                        "llistxattr has symlink key",
                        lambda: llistxattr(link_m),
                        lambda: llistxattr(link_n),
                        lambda a, b: ((lxname in a) == (lxname in b)),
                    )
                    if err:
                        return fail(err)

                    err = compare(
                        "lremovexattr symlink",
                        lambda: lremovexattr(link_m, lxname),
                        lambda: lremovexattr(link_n, lxname),
                    )
                    if err:
                        return fail(err)

                    return 0
                finally:
                    cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
