#!/usr/bin/env python3

import ctypes
import errno
import os
import sys

from posix_parity import cleanup_dir
from posix_parity import fail
from posix_parity import join
from posix_parity import mergerfs_fullpath
from posix_parity import mergerfs_mount
from posix_parity import should_compare_inode
from posix_parity import temp_dir
from posix_parity import touch


AT_FDCWD = -100
AT_SYMLINK_NOFOLLOW = 0x100
STATX_TYPE = 0x0001
STATX_MODE = 0x0002
STATX_NLINK = 0x0004
STATX_UID = 0x0008
STATX_GID = 0x0010
STATX_SIZE = 0x0200
STATX_BASIC_STATS = STATX_TYPE | STATX_MODE | STATX_NLINK | STATX_UID | STATX_GID | STATX_SIZE


class StatxTimestamp(ctypes.Structure):
    _fields_ = [
        ("tv_sec", ctypes.c_longlong),
        ("tv_nsec", ctypes.c_uint),
        ("__reserved", ctypes.c_int),
    ]


class Statx(ctypes.Structure):
    _fields_ = [
        ("stx_mask", ctypes.c_uint),
        ("stx_blksize", ctypes.c_uint),
        ("stx_attributes", ctypes.c_ulonglong),
        ("stx_nlink", ctypes.c_uint),
        ("stx_uid", ctypes.c_uint),
        ("stx_gid", ctypes.c_uint),
        ("stx_mode", ctypes.c_ushort),
        ("__spare0", ctypes.c_ushort),
        ("stx_ino", ctypes.c_ulonglong),
        ("stx_size", ctypes.c_ulonglong),
        ("stx_blocks", ctypes.c_ulonglong),
        ("stx_attributes_mask", ctypes.c_ulonglong),
        ("stx_atime", StatxTimestamp),
        ("stx_btime", StatxTimestamp),
        ("stx_ctime", StatxTimestamp),
        ("stx_mtime", StatxTimestamp),
        ("stx_rdev_major", ctypes.c_uint),
        ("stx_rdev_minor", ctypes.c_uint),
        ("stx_dev_major", ctypes.c_uint),
        ("stx_dev_minor", ctypes.c_uint),
        ("stx_mnt_id", ctypes.c_ulonglong),
        ("stx_dio_mem_align", ctypes.c_uint),
        ("stx_dio_offset_align", ctypes.c_uint),
        ("stx_subvol", ctypes.c_ulonglong),
        ("stx_atomic_write_unit_min", ctypes.c_uint),
        ("stx_atomic_write_unit_max", ctypes.c_uint),
        ("stx_atomic_write_segments_max", ctypes.c_uint),
        ("stx_dio_read_offset_align", ctypes.c_uint),
        ("__spare3", ctypes.c_ulonglong * 9),
    ]


libc = ctypes.CDLL(None, use_errno=True)
if not hasattr(libc, "statx"):
    raise SystemExit(0)

libc.statx.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_uint, ctypes.POINTER(Statx)]
libc.statx.restype = ctypes.c_int


def statx_call(path, flags=0, mask=STATX_BASIC_STATS):
    st = Statx()
    ctypes.set_errno(0)
    rv = libc.statx(AT_FDCWD, path.encode(), flags, mask, ctypes.byref(st))
    err = ctypes.get_errno()
    if rv < 0:
        raise OSError(err, os.strerror(err), path)
    return st


def errno_name(err):
    if err is None:
        return "None"
    return errno.errorcode.get(err, str(err))


def statx_summary(st):
    return (
        "{"
        f"mask=0x{st.stx_mask:x},"
        f"mode={st.stx_mode:o},"
        f"type=0x{(st.stx_mode & 0xF000):x},"
        f"perm=0o{(st.stx_mode & 0x0FFF):o},"
        f"nlink={st.stx_nlink},"
        f"uid={st.stx_uid},"
        f"gid={st.stx_gid},"
        f"size={st.stx_size},"
        f"ino={st.stx_ino},"
        f"blocks={st.stx_blocks},"
        f"blksize={st.stx_blksize}"
        "}"
    )


def call_pair(name, merge_call, native_call):
    try:
        m_val = merge_call()
        m_err = None
    except OSError as exc:
        m_val = None
        m_err = exc.errno

    try:
        n_val = native_call()
        n_err = None
    except OSError as exc:
        n_val = None
        n_err = exc.errno

    if m_err != n_err:
        return None, None, (
            f"{name}: errno mismatch\n"
            f"  mergerfs errno: {m_err} ({errno_name(m_err)})\n"
            f"  native errno:   {n_err} ({errno_name(n_err)})"
        )
    if m_err is not None:
        return None, None, None

    return m_val, n_val, None


def cmp_statx_basic(a, b):
    return (
        (a.stx_mode & 0xF000) == (b.stx_mode & 0xF000)
        and (a.stx_mode & 0x0FFF) == (b.stx_mode & 0x0FFF)
        and a.stx_nlink == b.stx_nlink
        and a.stx_uid == b.stx_uid
        and a.stx_gid == b.stx_gid
        and a.stx_size == b.stx_size
    )


def cmp_statx_basic_with_inode(a, b):
    return cmp_statx_basic(a, b) and a.stx_ino == b.stx_ino


def expect_same_errno(name, mcall, ncall):
    m_err = None
    n_err = None
    try:
        mcall()
    except OSError as exc:
        m_err = exc.errno
    try:
        ncall()
    except OSError as exc:
        n_err = exc.errno
    if m_err != n_err:
        return (
            f"{name}: errno mismatch\n"
            f"  mergerfs errno: {m_err} ({errno_name(m_err)})\n"
            f"  native errno:   {n_err} ({errno_name(n_err)})"
        )
    return None


def main():
    try:
        with mergerfs_mount() as (mount, _):
            stcmp = cmp_statx_basic_with_inode if should_compare_inode(mount) else cmp_statx_basic

            merge_base = temp_dir(mount)

            try:
                merge_file = join(merge_base, "file")
                merge_link = join(merge_base, "link")

                touch(merge_file, b"hello")

                try:
                    native_file = mergerfs_fullpath(merge_file)
                except OSError:
                    return 0

                try:
                    os.unlink(merge_link)
                except FileNotFoundError:
                    pass
                os.symlink("file", merge_link)

                try:
                    native_link = mergerfs_fullpath(merge_link)
                except OSError:
                    return 0

                expected_native_link = os.path.join(os.path.dirname(native_file), "link")
                if not os.path.islink(native_link):
                    if os.path.islink(expected_native_link):
                        native_link = expected_native_link
                    else:
                        return 0

                native_missing = os.path.join(os.path.dirname(native_file), "missing")

                mst, nst, err = call_pair(
                    "statx regular",
                    lambda: statx_call(merge_file),
                    lambda: statx_call(native_file),
                )
                if err:
                    return fail(err)
                if not stcmp(mst, nst):
                    return fail(
                        "statx basic mismatch\n"
                        f"  mergerfs path: {merge_file}\n"
                        f"  native path:   {native_file}\n"
                        f"  mergerfs statx: {statx_summary(mst)}\n"
                        f"  native statx:   {statx_summary(nst)}"
                    )

                mst, nst, err = call_pair(
                    "statx symlink nofollow",
                    lambda: statx_call(merge_link, flags=AT_SYMLINK_NOFOLLOW),
                    lambda: statx_call(native_link, flags=AT_SYMLINK_NOFOLLOW),
                )
                if err:
                    return fail(err)
                if (mst.stx_mode & 0xF000) != (nst.stx_mode & 0xF000):
                    return fail(
                        "statx symlink type mismatch\n"
                        f"  mergerfs path: {merge_link}\n"
                        f"  native path:   {native_link}\n"
                        f"  mergerfs statx: {statx_summary(mst)}\n"
                        f"  native statx:   {statx_summary(nst)}"
                    )

                err = expect_same_errno(
                    "statx ENOENT",
                    lambda: statx_call(join(merge_base, "missing")),
                    lambda: statx_call(native_missing),
                )
                if err:
                    return fail(err)

                err = expect_same_errno(
                    "statx ENOTDIR",
                    lambda: statx_call(join(merge_file, "child")),
                    lambda: statx_call(join(native_file, "child")),
                )
                if err:
                    return fail(err)

                return 0
            finally:
                cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
