#!/usr/bin/fspython

import os
import sys
import pathlib
import argparse
import textwrap
import surfa as sf
from synthmorph import utils


# Argument settings.
default = {
    'model': 'joint',
    'hyper': 0.5,
    'extent': 256,
    'steps': 7,
    'method': 'linear',
    'type': 'float32',
    'fill': 0,
}
choices = {
    'model': ('joint', 'deform', 'affine', 'rigid'),
    'extent': (192, 256),
    'method': ('linear', 'nearest'),
    'type': ('uint8', 'uint16', 'int16', 'int32', 'float32'),
}
limits = {
    'steps': 5,
}
resolve = ('model', 'method')


# Documentation.
n = '\033[0m' if sys.stdout.isatty() else ''
b = '\033[1m' if sys.stdout.isatty() else ''
u = '\033[4m' if sys.stdout.isatty() else ''
prog = os.path.basename(sys.argv[0])


# References.
ref = f'''
SynthMorph: learning contrast-invariant registration without acquired images\t
Hoffmann M, Billot B, Greve DN, Iglesias JE, Fischl B, Dalca AV\t
IEEE Transactions on Medical Imaging, 41 (3), 543-558, 2022\t
https://doi.org/10.1109/TMI.2021.3116879

Anatomy-specific acquisition-agnostic {u}affine{n} registration learned from fictitious images\t
Hoffmann M, Hoopes A, Fischl B*, Dalca AV* (*equal contribution)\t
SPIE Medical Imaging: Image Processing, 12464, 1246402, 2023\t
https://doi.org/10.1117/12.2653251\t
https://synthmorph.io/#papers (PDF)

Anatomy-aware and acquisition-agnostic {u}joint{n} registration with SynthMorph\t
Hoffmann M, Hoopes A, Greve DN, Fischl B*, Dalca AV* (*equal contribution)\t
Imaging Neuroscience, 2, 1-33, 2024\t
https://doi.org/10.1162/imag_a_00197

Website: https://synthmorph.io
'''

help_general = f'''{prog}

{b}NAME{n}
        {b}{prog}{n} - register 3D brain images without preprocessing

{b}SYNOPSIS{n}
        {b}{prog}{n} [-h] {u}command{n} [options]

{b}DESCRIPTION{n}
        SynthMorph is a deep-learning tool for symmetric, acquisition-agnostic
        registration of single-frame brain MRI of any geometry. The
        registration is anatomy-aware, removing the need for skull-stripping,
        and you can control the warp smoothness.

        Pass an option or {u}command{n} from the following list. You can omit
        trailing characters, as long as there is no ambiguity.

        {b}register{n}
                Register 3D brain images without preprocessing.

        {b}apply{n}
                Apply an existing transform to another 3D image or label map.

        {b}-h{n}
                Print this help text and exit.

{b}IMAGE FORMAT{n}
        The registration supports single-frame image volumes of any size,
        resolution, and orientation. The moving and the fixed image geometries
        can differ. The accepted image file formats are: MGH (.mgz) and NIfTI
        (.nii.gz, .nii).

        Internally, the registration converts image buffers to: isotropic 1-mm
        voxels, intensities min-max normalized into the interval [0, 1], and
        left-inferior-anterior (LIA) axes. This conversion requires intact
        image-to-world matrices. That is, the head must have the correct
        anatomical orientation in a viewer like {b}freeview{n}.

{b}TRANSFORM FORMAT{n}
        SynthMorph transformations operate in physical RAS space. We save
        matrix transforms as text in LTA format (.lta) and displacement fields
        as images with three frames indicating shifts in RAS direction.

{b}ENVIRONMENT{n}
        The following environment variables affect {b}{prog}{n}:

        SUBJECTS_DIR
                Ignored unless {b}{prog}{n} runs inside a container. Mounts the
                host directory SUBJECTS_DIR to {u}/mnt{n} inside the container.
                Defaults to the current working directory.

{b}SEE ALSO{n}
        For converting, composing, and applying transforms, consider FreeSurfer
        tools {b}lta_convert{n}, {b}mri_warp_convert{n},
        {b}mri_concatenate_lta{n}, {b}mri_concatenate_gcam{n}, and
        {b}mri_convert{n}.

{b}CONTACT{n}
        Reach out to freesurfer@nmr.mgh.harvard.edu or at
        https://voxelmorph.net.

{b}REFERENCES{n}
        If you use SynthMorph in a publication, please cite us!
''' + textwrap.indent(ref, prefix=' ' * 8)

help_register = f'''{prog}-register

{b}NAME{n}
        {b}{prog}-register{n} - register 3D brain images without preprocessing

{b}SYNOPSIS{n}
        {b}{prog} register{n} [options] {u}moving{n} {u}fixed{n}

{b}DESCRIPTION{n}
        SynthMorph is a deep-learning tool for symmetric, acquisition-agnostic
        registration of brain MRI with any volume size, resolution, and
        orientation. The registration is anatomy-aware, removing the need for
        skull-stripping, and you can control the warp smoothness.

        SynthMorph registers a {u}moving{n} (source) image to a {u}fixed{n}
        (target) image. Their geometries can differ. The options are as
        follows:

        {b}-m{n} {u}model{n}
                Transformation model ({', '.join(choices['model'])}). Defaults
                to {default['model']}. Joint includes affine and deformable but
                differs from running both in sequence in that it applies the
                deformable step in an affine mid-space to guarantee symmetric
                joint transforms. Deformable assumes prior affine alignment or
                initialization with {b}-i{n}.

        {b}-o{n} {u}image{n}
                Save {u}moving{n} registered to {u}fixed{n}.

        {b}-O{n} {u}image{n}
                Save {u}fixed{n} registered to {u}moving{n}.

        {b}-H{n}
                Update the voxel-to-world matrix instead of resampling when
                saving images with {b}-o{n} and {b}-O{n}. For matrix transforms
                only. Not all software supports headers with shear from affine
                registration.

        {b}-t{n} {u}trans{n}
                Save the transform from {u}moving{n} to {u}fixed{n}, including
                any initialization.

        {b}-T{n} {u}trans{n}
                Save the transform from {u}fixed{n} to {u}moving{n}, including
                any initialization.

        {b}-i{n} {u}trans{n}
                Apply an initial matrix transform to {u}moving{n} before the
                registration.

        {b}-M{n}
                Apply half the initial matrix transform to {u}moving{n} and
                (the inverse of) the other half to {u}fixed{n}, for symmetry.
                This will make running the deformable after an affine step
                equivalent to joint registration. Requires {b}-i{n}.

        {b}-j{n} {u}threads{n}
                Number of TensorFlow threads. System default if unspecified.

        {b}-g{n}
                Use the GPU in environment variable CUDA_VISIBLE_DEVICES or GPU
                0 if the variable is unset or empty.

        {b}-r{n} {u}lambda{n}
                Regularization parameter in the open interval (0, 1) for
                deformable registration. Higher values lead to smoother warps.
                Defaults to {default['hyper']}.

        {b}-n{n} {u}steps{n}
                Integration steps for deformable registration. Lower numbers
                improve speed and memory use but can lead to inaccuracies and
                folding voxels. Defaults to {default['steps']}. Should not be
                less than {limits['steps']}.

        {b}-e{n} {u}extent{n}
                Isotropic extent of the registration space in unit voxels
                {choices['extent']}. Lower values improve speed and memory use
                but may crop the anatomy of interest. Defaults to
                {default['extent']}.

        {b}-w{n} {u}weights{n}
                Use alternative model weights, exclusively. Repeat the flag
                to set affine and deformable weights for joint registration,
                or the result will disappoint.

        {b}-h{n}
                Print this help text and exit.

{b}ENVIRONMENT{n}
        The following environment variables affect {b}{prog}-register{n}:

        CUDA_VISIBLE_DEVICES
                Use a specific GPU. If unset or empty, passing {b}-g{n} will
                select GPU 0. Ignored without {b}-g{n}.

        FREESURFER_HOME
                Load model weights from directory {u}FREESURFER_HOME/models{n}.
                Ignored when specifying weights with {b}-w{n}.

{b}EXAMPLES{n}
        Joint affine-deformable registration, saving the moved image:
                # {prog} register -o out.nii mov.nii fix.nii

        Joint registration at 25% warp smoothness:
                # {prog} register -r 0.25 -o out.nii mov.nii fix.nii

        Affine registration saving the transform:
                # {prog} register -m affine -t aff.lta mov.nii.gz fix.nii.gz

        Deformable registration only, assuming prior affine alignment:
                # {prog} register -m deform -t def.mgz mov.mgz fix.mgz

        Deformable step initialized with an affine transform:
                # {prog} reg -m def -i aff.lta -o out.mgz mov.mgz fix.mgz

        Rigid registration, setting the output image header (no resampling):
                # {prog} register -m rigid -Ho out.mgz mov.mgz fix.mgz
'''

help_apply = f'''{prog}-apply

{b}NAME{n}
        {b}{prog}-apply{n} - apply an existing SynthMorph transform

{b}SYNOPSIS{n}
        {b}{prog} apply{n} [options] {u}trans{n} {u}image{n} {u}output{n}
        [{u}image{n} {u}output{n} ...]

{b}DESCRIPTION{n}
        Apply a spatial transform {u}trans{n} estimated by SynthMorph to a 3D
        {u}image{n} and write the result to {u}output{n}. You can pass any
        number of image-output pairs to be processed in the same way.

        The following options identically affect all input-output pairs.

        {b}-H{n}
                Update the voxel-to-world matrix of {u}output{n} instead of
                resampling. For matrix transforms only. Not all software and
                file formats support headers with shear from affine
                registration.

        {b}-m{n} {u}method{n}
                Interpolation method ({', '.join(choices['method'])}). Defaults
                to {default['method']}. Choose linear for images and nearest
                for label (segmentation) maps.

        {b}-t{n} {u}type{n}
                Output data type ({', '.join(choices['type'])}). Defaults to
                {default['type']}. Casting to a type other than
                {default['type']} after linear interpolation may result in
                information loss.

        {b}-f{n} {u}fill{n}
                Extrapolation fill value for areas outside the field-of-view of
                {u}image{n}. Defaults to {default['fill']}.

        {b}-h{n}
                Print this help text and exit.

{b}EXAMPLES{n}
        Apply an affine transform to an image:
                # {prog} apply affine.lta image.nii out.nii.gz

        Apply a warp to an image, saving the output in floating-point format:
                # {prog} apply -t float32 warp.nii image.nii out.nii

        Apply the same transform to two images:
                # {prog} app warp.mgz im_1.mgz out_1.mgz im_2.mgz out_2.mgz

        Transform a label map:
                # {prog} apply -m nearest warp.nii labels.nii out.nii
'''


# Command-line parsing.
p = argparse.ArgumentParser()
p.format_help = lambda: utils.rewrap_text(help_general, end='\n\n')

sub = p.add_subparsers(dest='command')
commands = {f: sub.add_parser(f) for f in ('register', 'apply')}


def add_flags(f):
    out = {}

    if f in choices:
        out.update(choices=choices[f])
        if f in resolve:
            out.update(type=lambda x: utils.resolve_abbrev(x, strings=choices[f]))

    if f in default:
        out.update(default=default[f])

    return out


# Registration arguments.
r = commands['register']
r.add_argument('moving')
r.add_argument('fixed')
r.add_argument('-m', dest='model', **add_flags('model'))
r.add_argument('-o', dest='out_moving', metavar='image')
r.add_argument('-O', dest='out_fixed', metavar='image')
r.add_argument('-H', dest='header_only', action='store_true')
r.add_argument('-t', dest='trans', metavar='trans')
r.add_argument('-T', dest='inverse', metavar='trans')
r.add_argument('-i', dest='init', metavar='trans')
r.add_argument('-M', dest='mid_space', action='store_true')
r.add_argument('-j', dest='threads', metavar='threads', type=int)
r.add_argument('-g', dest='gpu', action='store_true')
r.add_argument('-r', dest='hyper', metavar='lambda', type=float, **add_flags('hyper'))
r.add_argument('-n', dest='steps', metavar='steps', type=int, **add_flags('steps'))
r.add_argument('-e', dest='extent', type=int, **add_flags('extent'))
r.add_argument('-w', dest='weights', metavar='weights', action='append')
r.add_argument('-v', dest='verbose', action='store_true')
r.add_argument('-d', dest='out_dir', metavar='dir', type=pathlib.Path)
r.format_help = lambda: utils.rewrap_text(help_register, end='\n\n')

# Transformation arguments.
a = commands['apply']
a.add_argument('trans')
a.add_argument('pairs', metavar='image output', nargs='+')
a.add_argument('-H', dest='header_only', action='store_true')
a.add_argument('-m', dest='method', **add_flags('method'))
a.add_argument('-t', dest='type', **add_flags('type'))
a.add_argument('-f', dest='fill', metavar='fill', type=float, **add_flags('fill'))
a.format_help = lambda: utils.rewrap_text(help_apply, end='\n\n')


# Parse arguments.
if len(sys.argv) == 1:
    p.print_usage()
    exit(0)

# Command resolution.
c = sys.argv[1] = utils.resolve_abbrev(sys.argv[1], commands)
if len(sys.argv) == 2 and c in commands:
    commands[c].print_usage()
    exit(0)

# Default command.
if c not in (*commands, '-h'):
    sys.argv.insert(1, 'register')

arg = p.parse_args()


# Command.
if arg.command == 'register':

    # Argument checking.
    if arg.header_only and not arg.model in ('affine', 'rigid'):
        sf.system.fatal('-H is not compatible with deformable registration')

    if arg.mid_space and not arg.init:
        sf.system.fatal('-M requires matrix initialization')

    if not 0 < arg.hyper < 1:
        sf.system.fatal('regularization strength not in open interval (0, 1)')

    if arg.steps < limits['steps']:
        sf.system.fatal('too few integration steps')

    # TensorFlow setup.
    gpu = os.environ.get('CUDA_VISIBLE_DEVICES', '0')
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu if arg.gpu else ''
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if arg.verbose else '3'
    os.environ['NEURITE_BACKEND'] = 'tensorflow'
    os.environ['VXM_BACKEND'] = 'tensorflow'
    from synthmorph import registration

    registration.register(arg)


if arg.command == 'apply':

    # Argument checking.
    which = 'affine' if arg.trans.endswith('.lta') else 'warp'
    if arg.header_only and which == 'warp':
        sf.system.fatal('-H is not compatible with deformable transforms')

    if len(arg.pairs) % 2:
        sf.system.fatal('list of input-output pairs not of even length')

    # Transform.
    trans = getattr(sf, f'load_{which}')(arg.trans)
    prop = dict(method=arg.method, resample=not arg.header_only, fill=arg.fill)
    for inp, out in zip(arg.pairs[::2], arg.pairs[1::2]):
        sf.load_volume(inp).transform(trans, **prop).astype(arg.type).save(out)


print('Thank you for choosing SynthMorph. Please cite us!')
print(utils.rewrap_text(ref))
