#!/usr/bin/env python3

import os
import stat
import sys
import tempfile

from posix_parity import cleanup_dir
from posix_parity import cleanup_paths
from posix_parity import compare_calls
from posix_parity import fail
from posix_parity import join
from posix_parity import mergerfs_mount
from posix_parity import should_compare_inode
from posix_parity import temp_dir
from posix_parity import touch


def st_cmp(lhs, rhs, compare_ino=False):
    base = (
        stat.S_IFMT(lhs.st_mode) == stat.S_IFMT(rhs.st_mode)
        and stat.S_IMODE(lhs.st_mode) == stat.S_IMODE(rhs.st_mode)
        and lhs.st_nlink == rhs.st_nlink
    )
    if stat.S_ISDIR(lhs.st_mode) or stat.S_ISDIR(rhs.st_mode):
        return base and (not compare_ino or lhs.st_ino == rhs.st_ino)
    return base and lhs.st_size == rhs.st_size and (not compare_ino or lhs.st_ino == rhs.st_ino)


def main():
    try:
        with mergerfs_mount() as (mount, _):
            compare_ino = should_compare_inode(mount)

            with tempfile.TemporaryDirectory() as native:
                merge_base = temp_dir(mount)
                try:
                    native_base = join(native, os.path.basename(merge_base))
                    os.makedirs(native_base, exist_ok=True)

                    merge_file = join(merge_base, "file")
                    native_file = join(native_base, "file")
                    merge_dir = join(merge_base, "dir")
                    native_dir = join(native_base, "dir")
                    merge_missing = join(merge_base, "missing")
                    native_missing = join(native_base, "missing")
                    merge_notdir = join(merge_base, "notdir")
                    native_notdir = join(native_base, "notdir")

                    cleanup_paths([merge_file, merge_dir, merge_notdir])

                    touch(merge_file, b"abc", 0o640)
                    touch(native_file, b"abc", 0o640)
                    os.makedirs(merge_dir, exist_ok=True)
                    os.makedirs(native_dir, exist_ok=True)
                    touch(merge_notdir, b"x", 0o644)
                    touch(native_notdir, b"x", 0o644)

                    err = compare_calls(
                        "lstat regular", lambda: os.lstat(merge_file), lambda: os.lstat(native_file), lambda l, r: st_cmp(l, r, compare_ino)
                    )
                    if err:
                        return fail(err)

                    err = compare_calls(
                        "lstat directory", lambda: os.lstat(merge_dir), lambda: os.lstat(native_dir), lambda l, r: st_cmp(l, r, compare_ino)
                    )
                    if err:
                        return fail(err)

                    mfd = os.open(merge_file, os.O_RDONLY)
                    nfd = os.open(native_file, os.O_RDONLY)
                    try:
                        err = compare_calls("fstat open fd", lambda: os.fstat(mfd), lambda: os.fstat(nfd), lambda l, r: st_cmp(l, r, compare_ino))
                        if err:
                            return fail(err)
                    finally:
                        os.close(mfd)
                        os.close(nfd)

                    err = compare_calls("lstat ENOENT", lambda: os.lstat(merge_missing), lambda: os.lstat(native_missing))
                    if err:
                        return fail(err)

                    err = compare_calls(
                        "lstat ENOTDIR",
                        lambda: os.lstat(join(merge_notdir, "child")),
                        lambda: os.lstat(join(native_notdir, "child")),
                    )
                    if err:
                        return fail(err)

                    return 0
                finally:
                    cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
