#!/bin/bash
# freesurfer-synth-setup — one-shot bootstrap of the FreeSurfer Python
# environment used by SynthSeg / SynthSR / SynthStrip / SynthMorph / etc.
#
# Detects the host's accelerator (NVIDIA / AMD / CPU+AVX2 / CPU) and
# builds a uv-managed venv with the matching wheels. Idempotent.
#
# Defaults to a system-wide venv under /var/lib/freesurfer/synth-venv;
# set FREESURFER_SYNTH_VENV to install per-user instead.

set -euo pipefail

PROGNAME=$(basename "$0")
DATADIR=/usr/lib/freesurfer/synth
VENV=${FREESURFER_SYNTH_VENV:-/var/lib/freesurfer/synth-venv}
PYTHON_VERSION_PIN=3.12
FORCE_ACCEL=""
REBUILD=0
CHECK=0
DRYRUN=0

usage() {
    cat <<EOF
$PROGNAME — set up the FreeSurfer synth-tools Python environment.

Usage:
    $PROGNAME [--accel auto|cpu|cpu-avx2|cuda121|cuda124|rocm60]
                  [--venv PATH] [--python X.Y] [--rebuild] [--check] [--dry-run]

Options:
    --accel TYPE     Force a specific accelerator profile (default: auto-detect).
    --venv PATH      Override the venv location (default: $VENV).
    --python X.Y     Python version inside the venv (default: $PYTHON_VERSION_PIN).
    --rebuild        Discard the existing venv and recreate it.
    --check          Verify the venv matches detected hardware; print drift.
    --dry-run        Show what would happen without changing anything.

Environment:
    FREESURFER_SYNTH_VENV   default venv path (overridden by --venv).

Examples:
    sudo $PROGNAME                                  # system-wide, auto-detect
    sudo $PROGNAME --accel cuda124                  # force CUDA 12.4 wheels
    FREESURFER_SYNTH_VENV=\$HOME/.fs-venv $PROGNAME   # per-user
EOF
}

die() { echo "$PROGNAME: error: $*" >&2; exit 1; }
info() { echo "$PROGNAME: $*"; }

while [ $# -gt 0 ]; do
    case "$1" in
        --accel)    FORCE_ACCEL="$2"; shift 2 ;;
        --venv)     VENV="$2"; shift 2 ;;
        --python)   PYTHON_VERSION_PIN="$2"; shift 2 ;;
        --rebuild)  REBUILD=1; shift ;;
        --check)    CHECK=1; shift ;;
        --dry-run)  DRYRUN=1; shift ;;
        -h|--help)  usage; exit 0 ;;
        *)          die "unknown option: $1 (try --help)" ;;
    esac
done

command -v uv >/dev/null || die "uv is required but not installed (dnf install uv)"

detect_accel() {
    # NVIDIA: device node + driver
    if [ -e /dev/nvidia0 ] && command -v nvidia-smi >/dev/null 2>&1; then
        local drv major
        drv=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader 2>/dev/null | head -1)
        major=${drv%%.*}
        if   (( major >= 560 )); then echo cuda124; return
        elif (( major >= 525 )); then echo cuda121; return
        else                          echo cuda121; return  # safest fallback
        fi
    fi
    # AMD ROCm
    if command -v rocminfo >/dev/null 2>&1; then
        echo rocm60; return
    fi
    # CPU with AVX2 capability gets the perf-tuned wheels where available
    if grep -q '\bavx2\b' /proc/cpuinfo 2>/dev/null; then
        echo cpu-avx2; return
    fi
    echo cpu
}

if [ -n "$FORCE_ACCEL" ]; then
    ACCEL="$FORCE_ACCEL"
else
    ACCEL=$(detect_accel)
fi
info "accelerator: $ACCEL"

REQS="$DATADIR/requirements-${ACCEL}.txt"
[ -r "$REQS" ] || die "no requirements file for accel='$ACCEL' at $REQS"

# Per-accel extra index URL (PyTorch wheel variants ship from pytorch.org)
case "$ACCEL" in
    cuda124)   EXTRA_INDEX="https://download.pytorch.org/whl/cu124" ;;
    cuda121)   EXTRA_INDEX="https://download.pytorch.org/whl/cu121" ;;
    rocm60)    EXTRA_INDEX="https://download.pytorch.org/whl/rocm6.0" ;;
    *)         EXTRA_INDEX="https://download.pytorch.org/whl/cpu" ;;
esac

# --check: compare current venv stamp against detected accel, exit non-zero if drifted
if [ "$CHECK" = 1 ]; then
    if [ ! -f "$VENV/.synth-setup-stamp" ]; then
        echo "no venv at $VENV — run without --check to build it"
        exit 2
    fi
    cur=$(grep '^accel=' "$VENV/.synth-setup-stamp" 2>/dev/null | cut -d= -f2)
    if [ "$cur" = "$ACCEL" ]; then
        info "venv accel ($cur) matches detected hardware ($ACCEL) — ok"
        exit 0
    fi
    echo "drift: venv was built for accel=$cur, hardware is now $ACCEL"
    echo "  run: $PROGNAME --rebuild --accel $ACCEL"
    exit 3
fi

# --rebuild or fresh: nuke any existing venv
if [ "$REBUILD" = 1 ] && [ -d "$VENV" ]; then
    info "removing existing venv at $VENV"
    [ "$DRYRUN" = 1 ] || rm -rf "$VENV"
fi

if [ -f "$VENV/.synth-setup-stamp" ] && [ "$REBUILD" = 0 ]; then
    info "venv already initialised at $VENV (use --rebuild to recreate)"
    exit 0
fi

info "creating venv at $VENV (python $PYTHON_VERSION_PIN)"
[ "$DRYRUN" = 1 ] || uv venv --python "$PYTHON_VERSION_PIN" "$VENV"

info "installing pinned requirements from $REQS"
info "  primary index: pypi.org"
info "  extra index:   $EXTRA_INDEX"
if [ "$DRYRUN" = 0 ]; then
    uv pip install --python "$VENV/bin/python" \
        --index-url https://pypi.org/simple \
        --extra-index-url "$EXTRA_INDEX" \
        -r "$REQS"
fi

# Write a stamp so fspython can refuse to launch against a half-built venv
if [ "$DRYRUN" = 0 ]; then
    {
        echo "accel=$ACCEL"
        echo "python=$PYTHON_VERSION_PIN"
        echo "requirements=$REQS"
        echo "built=$(date -uIs)"
        echo "uv_version=$(uv --version 2>/dev/null || echo unknown)"
    } > "$VENV/.synth-setup-stamp"
fi

info "done. fspython will resolve to $VENV/bin/python"
