#!/usr/bin/env python3

import ctypes
import errno
import os
import sys
import tempfile

from posix_parity import cleanup_dir
from posix_parity import join
from posix_parity import mergerfs_mount
from posix_parity import temp_dir


libc = ctypes.CDLL(None, use_errno=True)
libc.readlink.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_size_t]
libc.readlink.restype = ctypes.c_ssize_t


def readlink_raw(path, bufsiz):
    buf = ctypes.create_string_buffer(bufsiz)
    ctypes.set_errno(0)
    rv = libc.readlink(path.encode(), buf, bufsiz)
    err = ctypes.get_errno()

    return rv, err, bytes(buf)


def compare_case(name,
                 merge_path,
                 native_path,
                 bufsiz,
                 expect_errno=None):
    m_rv, m_errno, m_buf = readlink_raw(merge_path, bufsiz)
    n_rv, n_errno, n_buf = readlink_raw(native_path, bufsiz)

    if m_rv != n_rv:
        return f"{name}: return mismatch mergerfs={m_rv} native={n_rv}"
    if m_errno != n_errno:
        return f"{name}: errno mismatch mergerfs={m_errno} native={n_errno}"
    if m_rv >= 0 and m_buf != n_buf:
        return f"{name}: buffer mismatch mergerfs={m_buf!r} native={n_buf!r}"
    if expect_errno is not None and n_errno != expect_errno:
        return f"{name}: expected errno={expect_errno}, got errno={n_errno}"

    return None


def main():
    try:
        with mergerfs_mount() as (mount, _):
            target = "target-abcdefghijklmnopqrstuvwxyz"

            with tempfile.TemporaryDirectory() as native_dir:
                merge_base = temp_dir(mount)
                try:
                    native_base = join(native_dir, os.path.basename(merge_base))
                    os.makedirs(native_base, exist_ok=True)

                    paths = {
                        "merge_link": join(merge_base, "readlink-semantics-link"),
                        "native_link": join(native_base, "readlink-semantics-link"),
                        "merge_regular": join(merge_base, "readlink-semantics-regular"),
                        "native_regular": join(native_base, "readlink-semantics-regular"),
                        "merge_notdir": join(merge_base, "readlink-semantics-notdir"),
                        "native_notdir": join(native_base, "readlink-semantics-notdir"),
                        "merge_loop": join(merge_base, "readlink-semantics-loop"),
                        "native_loop": join(native_base, "readlink-semantics-loop"),
                        "merge_private_dir": join(merge_base, "readlink-semantics-private"),
                        "native_private_dir": join(native_base, "readlink-semantics-private"),
                    }
                    paths["merge_private_link"] = os.path.join(paths["merge_private_dir"], "link")
                    paths["native_private_link"] = os.path.join(paths["native_private_dir"], "link")

                    for p in (
                        paths["merge_private_link"],
                        paths["merge_link"],
                        paths["merge_regular"],
                        paths["merge_notdir"],
                        paths["merge_loop"],
                    ):
                        try:
                            os.unlink(p)
                        except FileNotFoundError:
                            pass

                    try:
                        os.rmdir(paths["merge_private_dir"])
                    except FileNotFoundError:
                        pass
                    except OSError:
                        pass

                    os.symlink(target, paths["merge_link"])
                    os.symlink(target, paths["native_link"])

                    with open(paths["merge_regular"], "w", encoding="ascii"):
                        pass
                    with open(paths["native_regular"], "w", encoding="ascii"):
                        pass

                    with open(paths["merge_notdir"], "w", encoding="ascii"):
                        pass
                    with open(paths["native_notdir"], "w", encoding="ascii"):
                        pass

                    os.symlink("readlink-semantics-loop", paths["merge_loop"])
                    os.symlink("readlink-semantics-loop", paths["native_loop"])

                    os.makedirs(paths["merge_private_dir"], exist_ok=True)
                    os.makedirs(paths["native_private_dir"], exist_ok=True)
                    os.symlink(target, paths["merge_private_link"])
                    os.symlink(target, paths["native_private_link"])

                    cases = []

                    for bufsiz in (0, 1, 2, 5, 8, len(target), len(target) + 1):
                        cases.append((
                            f"success/truncation bufsiz={bufsiz}",
                            paths["merge_link"],
                            paths["native_link"],
                            bufsiz,
                            errno.EINVAL if bufsiz == 0 else None,
                        ))

                    cases.extend([
                        (
                            "EINVAL non-symlink",
                            paths["merge_regular"],
                            paths["native_regular"],
                            128,
                            errno.EINVAL,
                        ),
                        (
                            "ENOENT missing path",
                            join(merge_base, "readlink-semantics-missing"),
                            join(native_base, "readlink-semantics-missing"),
                            128,
                            errno.ENOENT,
                        ),
                        (
                            "ENOTDIR non-directory prefix",
                            os.path.join(paths["merge_notdir"], "child"),
                            os.path.join(paths["native_notdir"], "child"),
                            128,
                            errno.ENOTDIR,
                        ),
                        (
                            "ELOOP prefix symlink loop",
                            os.path.join(paths["merge_loop"], "child"),
                            os.path.join(paths["native_loop"], "child"),
                            128,
                            errno.ELOOP,
                        ),
                        (
                            "ENAMETOOLONG long pathname",
                            join(merge_base, "a" * 8192),
                            join(native_base, "a" * 8192),
                            128,
                            errno.ENAMETOOLONG,
                        ),
                    ])

                    os.chmod(paths["merge_private_dir"], 0o600)
                    os.chmod(paths["native_private_dir"], 0o600)
                    cases.append((
                        "EACCES unreadable path prefix",
                        paths["merge_private_link"],
                        paths["native_private_link"],
                        128,
                        errno.EACCES if os.geteuid() != 0 else None,
                    ))

                    for case in cases:
                        err = compare_case(*case)
                        if err is not None:
                            print(err, end="")
                            return 1

                    rv, err, _ = readlink_raw(paths["merge_link"], 1)
                    if rv != 1 or err != 0:
                        print(f"bufsiz=1 expected rv=1 errno=0, got rv={rv} errno={err}", end="")
                        return 1

                    return 0
                finally:
                    try:
                        os.chmod(paths["merge_private_dir"], 0o700)
                    except (FileNotFoundError, PermissionError):
                        pass

                    try:
                        os.chmod(paths["native_private_dir"], 0o700)
                    except (FileNotFoundError, PermissionError):
                        pass

                    cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
