#!/usr/bin/env python3

import os
import sys
import tempfile
import threading

from posix_parity import mergerfs_mount

NUM_THREADS = 50
TEST_DATA = b"race condition test data\n"


def open_and_operate(filepath, barrier, results, index):
    fd = None
    try:
        barrier.wait()

        fd = os.open(filepath, os.O_RDWR)

        offset = index * len(TEST_DATA)
        os.lseek(fd, offset, os.SEEK_SET)
        bytes_written = os.write(fd, TEST_DATA)
        if bytes_written != len(TEST_DATA):
            results[index] = ("write_error", "expected {} bytes, wrote {}".format(len(TEST_DATA), bytes_written))
            os.close(fd)
            fd = None
            return

        os.lseek(fd, offset, os.SEEK_SET)
        read_data = os.read(fd, len(TEST_DATA))

        if read_data != TEST_DATA:
            results[index] = ("read_error", "data mismatch at offset {}".format(offset))
            os.close(fd)
            fd = None
            return

        results[index] = ("success", fd)
        fd = None
    except Exception as e:
        if fd is not None:
            try:
                os.close(fd)
            except OSError:
                pass
        results[index] = ("exception", str(e))


def main():
    try:
        with mergerfs_mount() as (mount, _):
            (fd, filepath) = tempfile.mkstemp(dir=mount)

            total_size = NUM_THREADS * len(TEST_DATA)
            os.ftruncate(fd, total_size)
            os.close(fd)

            barrier = threading.Barrier(NUM_THREADS)
            results = [None] * NUM_THREADS
            threads = []

            for i in range(NUM_THREADS):
                t = threading.Thread(target=open_and_operate, args=(filepath, barrier, results, i))
                threads.append(t)

            for t in threads:
                t.start()

            for t in threads:
                t.join()

            failed = False
            fds_to_close = []

            for i, result in enumerate(results):
                if result is None:
                    print("thread {} returned no result".format(i))
                    failed = True
                elif result[0] == "success":
                    fds_to_close.append(result[1])
                else:
                    print("thread {} failed: {} - {}".format(i, result[0], result[1]))
                    failed = True

            if failed:
                for fd in fds_to_close:
                    os.close(fd)
                os.unlink(filepath)
                return 1

            if fds_to_close:
                verify_fd = fds_to_close[0]
                os.lseek(verify_fd, 0, os.SEEK_SET)
                all_data = os.read(verify_fd, total_size)

                expected_data = TEST_DATA * NUM_THREADS
                if all_data != expected_data:
                    print("final verification failed: data mismatch")
                    print("expected {} bytes, got {} bytes".format(len(expected_data), len(all_data)))
                    for fd in fds_to_close:
                        os.close(fd)
                    os.unlink(filepath)
                    return 1

            if len(fds_to_close) >= 2:
                fd_a = fds_to_close[0]
                fd_b = fds_to_close[1]

                new_data = b"cross-fd visibility test\n"
                os.lseek(fd_a, 0, os.SEEK_SET)
                os.write(fd_a, new_data)

                os.lseek(fd_b, 0, os.SEEK_SET)
                read_back = os.read(fd_b, len(new_data))

                if read_back != new_data:
                    print("cross-fd visibility failed: write from fd_a not visible on fd_b")
                    for fd in fds_to_close:
                        os.close(fd)
                    os.unlink(filepath)
                    return 1

            for fd in fds_to_close:
                os.close(fd)

            os.unlink(filepath)
            return 0
    except RuntimeError as exc:
        print(str(exc), end="")
        return 77


if __name__ == "__main__":
    raise SystemExit(main())
