#!/usr/bin/env python3

import tempfile
import ctypes
import mmap
import os
import resource
import sys

from posix_parity import mergerfs_mount


libc = ctypes.CDLL(None, use_errno=True)
libc.read.argtypes = [ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t]
libc.read.restype = ctypes.c_ssize_t


def aligned_read(fd, size):
    buf = mmap.mmap(-1, size)
    ptr = ctypes.addressof(ctypes.c_char.from_buffer(buf))
    ctypes.set_errno(0)
    rv = libc.read(fd, ctypes.c_void_p(ptr), size)
    if rv < 0:
        err = ctypes.get_errno()
        buf.close()
        raise OSError(err, os.strerror(err))
    data = buf[:rv]
    buf.close()
    return data


def main():
    try:
        with mergerfs_mount() as (mount, _):
            (fd, filepath) = tempfile.mkstemp(dir=mount)

            os.close(fd)

            fd = os.open(filepath, os.O_RDWR|os.O_DIRECT|os.O_TRUNC)

            os.unlink(filepath)

            size = resource.getpagesize()
            pattern = (b"mergerfs-o-direct" * ((size // 16) + 1))[:size]
            buf = mmap.mmap(-1, size)
            buf.write(pattern)
            buf.seek(0)

            try:
                written = os.write(fd, buf)
                if written != size:
                    print(f"O_DIRECT short write: wrote {written}, expected {size}", end="")
                    return 1
                os.lseek(fd, 0, os.SEEK_SET)
                data = aligned_read(fd, size)
                if len(data) != size:
                    print(f"O_DIRECT short read: read {len(data)}, expected {size}", end="")
                    return 1
                if data != pattern:
                    print("O_DIRECT data mismatch", end="")
                    return 1
            finally:
                buf.close()
                os.close(fd)

            return 0
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
