#!/usr/bin/env python3

import errno
import os
import sys
import tempfile

from posix_parity import cleanup_dir
from posix_parity import fail
from posix_parity import mergerfs_mount
from posix_parity import join
from posix_parity import temp_dir


def open_tmpfile(dirpath):
    return os.open(dirpath, os.O_TMPFILE | os.O_RDWR, 0o600)


def errno_of(call):
    try:
        fd = call()
        os.close(fd)
        return 0
    except OSError as exc:
        return exc.errno


def main():
    if not hasattr(os, "O_TMPFILE"):
        return 0

    try:
        with mergerfs_mount() as (mount, _):
            with tempfile.TemporaryDirectory() as native:
                merge_base = temp_dir(mount)
                try:
                    native_base = join(native, os.path.basename(merge_base))
                    os.makedirs(native_base, exist_ok=True)

                    merge_dir = merge_base
                    native_dir = native_base

                    m_err = errno_of(lambda: open_tmpfile(merge_dir))
                    n_err = errno_of(lambda: open_tmpfile(native_dir))
                    if m_err == errno.EOPNOTSUPP:
                        return 0
                    if m_err != n_err:
                        return fail(f"O_TMPFILE support mismatch mergerfs_errno={m_err} native_errno={n_err}")

                    if m_err in (errno.EOPNOTSUPP, errno.EISDIR, errno.EINVAL, errno.ENOSYS):
                        return 0

                    mfd = open_tmpfile(merge_dir)
                    try:
                        nfd = open_tmpfile(native_dir)
                    except OSError:
                        os.close(mfd)
                        return fail("O_TMPFILE: native open failed after merge succeeded")
                    try:
                        mw = os.write(mfd, b"tmpfile-data")
                        nw = os.write(nfd, b"tmpfile-data")
                        if mw != nw:
                            return fail(f"O_TMPFILE write count mismatch mergerfs={mw} native={nw}")
                    finally:
                        os.close(mfd)
                        os.close(nfd)

                    return 0
                finally:
                    cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
