#!/usr/bin/env python3

import errno
import hashlib
import os
import shutil
import sys
import tempfile

from posix_parity import cleanup_dir
from posix_parity import compare_calls
from posix_parity import fail
from posix_parity import join
from posix_parity import mergerfs_mount
from posix_parity import temp_dir
from posix_parity import touch


def _native_tmp():
    tests_dir = os.path.dirname(os.path.realpath(__file__))
    parent = os.path.join(tests_dir, ".test_tmp")
    os.makedirs(parent, exist_ok=True)
    return tempfile.mkdtemp(prefix="native_", dir=parent)


class NativeTmp:
    def __enter__(self):
        self.path = _native_tmp()
        return self.path

    def __exit__(self, *args):
        shutil.rmtree(self.path, ignore_errors=True)


def main():
    try:
        with mergerfs_mount() as (mount, _):
            if not hasattr(os, "copy_file_range"):
                return 0

            with NativeTmp() as native:
                merge_base = temp_dir(mount)
                try:
                    native_base = join(native, os.path.basename(merge_base))
                    os.makedirs(native_base, exist_ok=True)

                    merge_src = join(merge_base, "src")
                    merge_dst = join(merge_base, "dst")
                    native_src = join(native_base, "src")
                    native_dst = join(native_base, "dst")

                    payload = b"0123456789abcdefghijklmnopqrstuvwxyz"
                    touch(merge_src, payload)
                    touch(native_src, payload)
                    touch(merge_dst, b"")
                    touch(native_dst, b"")

                    msfd = os.open(merge_src, os.O_RDONLY)
                    mdfd = os.open(merge_dst, os.O_WRONLY)
                    nsfd = os.open(native_src, os.O_RDONLY)
                    ndfd = os.open(native_dst, os.O_WRONLY)
                    try:
                        err = compare_calls(
                            "copy_file_range success",
                            lambda: os.copy_file_range(msfd, mdfd, 16),
                            lambda: os.copy_file_range(nsfd, ndfd, 16),
                            lambda a, b: a == b,
                        )
                        if err:
                            return fail(err)
                    finally:
                        os.close(msfd)
                        os.close(mdfd)
                        os.close(nsfd)
                        os.close(ndfd)

                    with open(merge_dst, "rb") as mf, open(native_dst, "rb") as nf:
                        mdata = mf.read()
                        ndata = nf.read()
                    if mdata != ndata:
                        return fail(f"copy_file_range dst mismatch mergerfs={mdata!r} native={ndata!r}")

                    mdfd = os.open(merge_dst, os.O_WRONLY)
                    ndfd = os.open(native_dst, os.O_WRONLY)
                    try:
                        err = compare_calls(
                            "copy_file_range EBADF src",
                            lambda: os.copy_file_range(-1, mdfd, 1),
                            lambda: os.copy_file_range(-1, ndfd, 1),
                        )
                        if err:
                            return fail(err)
                    finally:
                        os.close(mdfd)
                        os.close(ndfd)

                    large_len = (5 << 30) + 4096
                    merge_large_src = join(merge_base, "src-large")
                    merge_large_dst = join(merge_base, "dst-large")
                    native_large_src = join(native_base, "src-large")
                    native_large_dst = join(native_base, "dst-large")
                    touch(merge_large_src, b"")
                    touch(merge_large_dst, b"")
                    touch(native_large_src, b"")
                    touch(native_large_dst, b"")

                    msfd = os.open(merge_large_src, os.O_RDWR)
                    mdfd = os.open(merge_large_dst, os.O_RDWR)
                    nsfd = os.open(native_large_src, os.O_RDWR)
                    ndfd = os.open(native_large_dst, os.O_RDWR)
                    try:
                        try:
                            os.ftruncate(msfd, large_len)
                            os.ftruncate(nsfd, large_len)
                            os.pwrite(msfd, b"z", large_len - 1)
                            os.pwrite(nsfd, b"z", large_len - 1)
                        except OSError as exc:
                            if exc.errno in (errno.ENOSPC, errno.EFBIG):
                                return 0
                            raise

                        total_m = 0
                        while True:
                            n = os.copy_file_range(msfd, mdfd, large_len)
                            if n == 0:
                                break
                            total_m += n
                            if total_m >= large_len:
                                break

                        total_n = 0
                        while True:
                            n = os.copy_file_range(nsfd, ndfd, large_len)
                            if n == 0:
                                break
                            total_n += n
                            if total_n >= large_len:
                                break

                        mstat = os.fstat(mdfd)
                        nstat = os.fstat(ndfd)
                        if mstat.st_size != nstat.st_size:
                            return fail(
                                "copy_file_range large sparse dst size mismatch "
                                f"mergerfs={mstat.st_size} native={nstat.st_size}"
                            )

                        def sha256_partial(fd, size):
                            h = hashlib.sha256()
                            pos = os.lseek(fd, 0, os.SEEK_SET)
                            remaining = size
                            while remaining > 0:
                                chunk = os.read(fd, min(remaining, 65536))
                                if not chunk:
                                    break
                                h.update(chunk)
                                remaining -= len(chunk)
                            os.lseek(fd, pos, os.SEEK_SET)
                            return h.digest()

                        if os.path.getsize(merge_large_dst) != os.path.getsize(native_large_dst):
                            return fail("copy_file_range large: final sizes differ")

                        mhash = sha256_partial(mdfd, mstat.st_size)
                        nhash = sha256_partial(ndfd, nstat.st_size)
                        if mhash != nhash:
                            return fail("copy_file_range large: content hash mismatch")
                    finally:
                        os.close(msfd)
                        os.close(mdfd)
                        os.close(nsfd)
                        os.close(ndfd)

                    return 0
                finally:
                    cleanup_dir(merge_base)
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())