#!/usr/bin/python3
# PYTHON_ARGCOMPLETE_OK
import argparse
import configparser
import json
import locale
import os.path
import re
import subprocess
import sys
import urllib.request


class RemoteException(Exception):
    pass

class RemoteUserException(RemoteException):
    pass

class RemoteSystemException(RemoteException):
    pass

class ExternalException(RemoteException):
    pass

class Rpc:
    def __init__(self, url):
        self.url = url
        if not self.url.endswith('/'):
            self.url += '/'
        self.info = {}
        self._dump = None

    def __getattr__(self, name):
        def rpc_call(**kwargs):
            data = json.dumps(kwargs).encode('utf-8')
            req = urllib.request.Request('{}{}/'.format(self.url, name),
                                         data=data, method='POST')
            req.add_header('Content-Type', 'application/json')
            try:
                resp = urllib.request.urlopen(req)
            except urllib.error.HTTPError as e:
                if e.code != 404:
                    raise e
                raise RemoteException('Remote procedure {} not found.'.format(name)) from None
            except urllib.error.URLError as e:
                raise ExternalException(
                    'Unable to connect to the Kerneloscope server ({}). Please verify your network and VPN connection.'.format(e)) from None

            resp = json.loads(resp.read().decode('utf-8'))
            result = None
            for k in resp:
                if k == 'error':
                    remote_exc_type = RemoteUserException if resp['type'] == 'user' else RemoteSystemException
                    raise remote_exc_type(resp['error'])
                elif k == 'warnings':
                    for r in resp[k]:
                        print('Warning: {}'.format(r), file=sys.stderr)
                elif k == 'type':
                    pass
                elif k == 'result':
                    result = resp['result']
                else:
                    self.info[k] = resp[k]
            if result is None:
                raise RemoteSystemException('Could not understand server response')
            if self._dump:
                dump = self._dump
                if dump[0][0] is None:
                    dump = map(lambda idx: (idx, self._dump[0][1]), range(len(result)))
                for idx, keys in dump:
                    str_idx = '{}: '.format(idx)
                    str_cont = ' ' * len(str_idx)
                    r = result[idx]
                    if not keys:
                        keys = r.keys()
                    for i, key in enumerate(keys):
                        print('{}{}: {}'.format(str_idx if i == 0 else str_cont, key, r[key]))
            return result

        return rpc_call


class Git:
    def _exec(self, *args):
        try:
            data = subprocess.check_output(args)
        except subprocess.CalledProcessError:
            raise ExternalException('cannot run git')
        return data.decode(errors='replace')

    def range(self, spec):
        return self._exec('git', 'log', '--pretty=format:%H%n', '--reverse', spec).split()

    def message(self, oid):
        return self._exec('git', 'log', '-1', '--pretty=format:%B', oid)


encoding = locale.getpreferredencoding()
def draw_str(s, fallback):
    try:
        s.encode(encoding)
    except UnicodeEncodeError:
        return fallback
    return s


class BaseCommand:
    overview = ''
    details = ''
    cmd = ''

    def __init__(self, subparsers):
        self.name = type(self).__name__
        if self.name.startswith('Command'):
            self.name = self.name[7:]
        self.name = re.sub(r'([a-z])([A-Z]+)', r'\1-\2', self.name).lower()
        desc = self.overview
        if self.details:
            desc = '{} {}'.format(self.overview, self.details)
        self.parser = subparsers.add_parser(self.name, help=self.overview,
                                            description=desc)
        self.parser.set_defaults(func=self.handle)
        self.arg_map = {}
        self.add_arguments()
        self.format_options = {}

    def add_rpc_argument(self, *args, group=None, rpc_values=None, **kwargs):
        """Adds the given argument to the parser. The argument will be
        automatically passed to rpc. The parameters to this function are
        identical to argparse.ArgumentParser.add_argument. To ease use of
        rpc, it is possible to specify the rpc argument name as the first
        parameter. This is not necessary if the rpc argument is the same as
        the command line argument.
        The 'rpc_values' parameter, if specified, must be a tuple of two
        values. The first value will be used as the rpc value if the
        argument was specified, the second value will be used if the
        argument was not specified. 'None' is accepted in both cases,
        causing the key to not be sent."""
        dest = None
        key = None
        if 'dest' in kwargs:
            dest = kwargs['dest']
        ldest = None
        sdest = None
        if not args[0].startswith('-') and len(args) > 1:
            key = args[0]
            args = args[1:]
        for a in args:
            if a.startswith('--'):
                if not ldest:
                    ldest = a[2:].replace('-', '_')
            elif a.startswith('-'):
                if not sdest:
                    sdest = a[1:]
            else:
                if not dest:
                    dest = a
        if not dest:
            dest = ldest
        if not dest:
            dest = sdest
        if not key:
            key = dest

        if not group:
            group = self.parser
        group.add_argument(*args, **kwargs)

        self.arg_map[dest] = { 'key': key, 'values': rpc_values }

    def add_local_argument(self, *args, group=None, **kwargs):
        """Adds the given argument to the parser. It won't be passed to
        rpc."""
        if not group:
            group = self.parser
        group.add_argument(*args, **kwargs)

    def add_arguments(self):
        pass

    def preprocess_arguments(self):
        pass

    def postprocess_arguments(self):
        pass

    def set_format_options(self):
        pass

    def print(self, *args):
        """Prints the arguments to the requested output stream (usually stdout)."""
        print(*args, file=output_stream)

    def format(self):
        self.print(self.response)

    def handle_rpc(self):
        f = getattr(rpc, self.cmd)
        self.response = f(**self.rpc_args)
        self.format()

    def handle(self, args):
        self.args = args
        self.preprocess_arguments()
        # translate the args to keys from self.arg_map; any argument not in
        # self.arg_map is ignored
        args = vars(args)
        self.rpc_args = {}
        for args_key, rpc_data in self.arg_map.items():
            val = None
            if args_key in args and args[args_key]:
                # present on the command line
                if rpc_data['values'] is None:
                    val = args[args_key]
                else:
                    val = rpc_data['values'][0]
            else:
                # not present on the command line
                if rpc_data['values'] is not None:
                    val = rpc_data['values'][1]
            if val:
                self.rpc_args[rpc_data['key']] = val
        self.postprocess_arguments()
        self.set_format_options()
        self.handle_rpc()


class BaseCommandCommitArg(BaseCommand):
    def add_arguments(self):
        super().add_arguments()
        self.add_rpc_argument('commit', help='Git commit id')


class BaseCommandTreeArg(BaseCommand):
    required_tree = False
    help_tree_arg = 'tree to look for backported commits (or URL#branch)'

    def add_arguments(self):
        super().add_arguments()
        self.add_rpc_argument('--tree', '-t', required=self.required_tree,
                              help=self.help_tree_arg)

    def postprocess_arguments(self):
        super().postprocess_arguments()
        if self.args.tree and '#' in self.args.tree:
            # this is a URL
            self.rpc_args['tree'] = ['url', self.args.tree]


class BaseCommandUpstreamCommitListArg(BaseCommandTreeArg):
    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--git', '-g', action='store_true',
                                help='get and parse the upstream commits from a git tree in the current directory')
        self.add_rpc_argument('commits', 'commit', nargs='+',
                              help='upstream commit id(s) or (with --git) a revision range. ' +
                                   'If - is specified, the commit ids are read from stdin, ' +
                                   'one commit id per line and the git --oneline format is accepted ' +
                                   'as input')

    def postprocess_arguments(self):
        super().postprocess_arguments()
        if self.args.git:
            git = Git()
            commits = []
            for r in self.rpc_args['commits']:
                for oid in git.range(r):
                    commits.append({ 'id': oid, 'message': git.message(oid) })
            self.rpc_args['commits'] = commits
        elif len(self.rpc_args['commits']) == 1 and self.rpc_args['commits'][0] == '-':
            commits = []
            for c in sys.stdin.readlines():
                c = c.split()
                if not c:
                    continue
                commits.append(c[0])
            self.rpc_args['commits'] = commits


class BaseCommandCommitListOutput(BaseCommand):
    class SkipCommit(Exception):
        pass

    def add_arguments(self):
        super().add_arguments()
        self.add_rpc_argument('commit_with_subject', '--id-only', '-i', action='store_true',
                              help='print only commit ids, unabbreviated',
                              rpc_values=(None, True))

    def get_format_flags(self, commit, index):
        return None

    def format_commit(self, commit, flags):
        cid = commit['commit']
        trees = ''
        if not self.args.id_only:
            tree_format = self.format_options.get('trees', 'auto')
            if tree_format != 'never':
                trees = [t for t in commit['trees']
                         if tree_format == 'force' or t != rpc.info['vanilla']] or ''
                if trees:
                    trees = ' (in {})'.format(', '.join(trees))
            subject = ' ' + commit['subject']
            cid = cid[:12]
            if flags is not None:
                flags = ' ' + flags
            else:
                flags = ''
        else:
            subject = ''
            flags = ''
        self.print('{}{}{}{}'.format(cid, trees, flags, subject))


    def format(self):
        limit = self.format_options.get('limit')
        index = 0
        for d in self.response:
            if limit and index >= limit:
                break
            try:
                self.format_commit(d, self.get_format_flags(d, index))
                index += 1
            except self.SkipCommit:
                pass


class BaseCommandContinuousOutput(BaseCommand):
    next_field = 'commit'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--limit', '-n', type=int, default=0,
                                help='limit the output to the given number of items; 0 means no limit (the default)')

    def handle_rpc(self):
        cnt = 0
        while True:
            if self.args.limit:
                self.format_options['limit'] = self.args.limit - cnt
            super().handle_rpc()
            if not self.response:
                break
            cnt += len(self.response)
            if self.args.limit and cnt >= self.args.limit:
                break
            self.rpc_args['next'] = self.response[-1][self.next_field]


class CommandTrees(BaseCommand):
    overview = 'Get list of available trees.'
    cmd = 'get_trees'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--name-only', '-N', action='store_true',
                                help='print only tree names')
        group = self.parser.add_mutually_exclusive_group()
        self.add_local_argument('--downstream', '-d', action='store_true',
                                help='show only downstream (RHEL) trees',
                                group=group)
        self.add_local_argument('--upstream', '-u', action='store_true',
                                help='show only upstream trees',
                                group=group)

    def format(self):
        for d in self.response:
            if self.args.downstream and d['type'] != 'RHEL':
                continue
            if self.args.upstream and d['type'] != 'upstream':
                continue
            self.print(d['name'])
            if not self.args.name_only:
                extra = ''
                if 'origin' in d:
                    extra = ' based on {}'.format(d['origin'])
                self.print('\t{} tree{}'.format(d['type'], extra))
                self.print('\t{}'.format(d['url']))


class CommandTree(BaseCommandCommitArg):
    overview = 'Get trees a given commit is in.'
    cmd = 'get_tree'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--name-only', '-N', action='store_true',
                                help='print only tree names')
        self.add_local_argument('--url-only', '-u', action='store_true',
                                help='print only tree URLs')

    def format(self):
        for d in self.response:
            if self.args.url_only:
                self.print(d['url'])
            else:
                self.print(d['name'])
                if not self.args.name_only:
                    self.print('\t{} tree'.format(d['type']))
                    self.print('\t{}'.format(d['url']))


class CommandUpstream(BaseCommandCommitListOutput, BaseCommandCommitArg):
    overview = 'Get upstream commit for a given RHEL commit.'
    cmd = 'get_upstream'


class CommandDownstream(BaseCommandCommitListOutput, BaseCommandTreeArg, BaseCommandCommitArg):
    overview = 'Get downstream (RHEL) commit for a given upstream commit.'
    cmd = 'get_downstream'


class CommandDiff(BaseCommandCommitArg):
    overview = 'Get the diff for a given commit.'
    cmd = 'get_diff'


class CommandMR(BaseCommandCommitArg):
    overview = 'Get the merge request info for a given commit.'
    cmd = 'get_series'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--url-only', '-u', action='store_true',
                                help='print only the merge request URL')

    def format(self):
        if not self.args.url_only:
            self.print(self.response['name'])
        self.print(self.response['link'])

class CommandFilterBackported(BaseCommandCommitListOutput, BaseCommandUpstreamCommitListArg):
    overview = 'Filter the given list of upstream commits, keeping only those that are not yet applied.'
    cmd = 'filter_backported'
    required_tree = True

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--include-partial', '-p', action='store_true',
                                help='include partially backported commits')

    def postprocess_arguments(self):
        super().postprocess_arguments()
        if self.args.include_partial:
            self.rpc_args['extended'] = True

    def get_format_flags(self, commit, index):
        if not self.args.include_partial:
            return ''
        if not commit['downstream']:
            return ' '
        if commit['is_partial']:
            return 'p'
        raise self.SkipCommit

class CommandFixes(BaseCommandCommitListOutput, BaseCommandUpstreamCommitListArg):
    overview = 'Find missing upstream fixes for the given list of upstream commits, recursively.'
    details = 'Fixes are marked with +, merges with M, mentions with ?'
    cmd = 'get_missing_fixes'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--missing-only', '-m', action='store_true',
                                help='print only the found missing commits')
        self.add_local_argument('--explicit-fixes', '-f', action='store_true',
                                help='print only commits with explicit Fixes: tag')
        self.add_local_argument('--no-merges', '-M', action='store_true',
                                help='do not print merge commits')

    def get_format_flags(self, commit, index):
        if self.args.missing_only and not commit['added']:
            raise self.SkipCommit
        if self.args.no_merges and commit['merge']:
            raise self.SkipCommit
        if self.args.explicit_fixes and not commit['kind']:
            raise self.SkipCommit
        if commit['added']:
            if commit['merge']:
                return 'M'
            if commit['kind']:
                return '+'
            return '?'
        return ' '


class CommandFixedCommits(BaseCommandCommitListOutput, BaseCommandUpstreamCommitListArg):
    overview = 'Find commits that are being fixed by the given list of upstream commits, yet were not backported.'
    details = 'This allows spotting fixes that are fixing something that is not backported. ' + \
              'Found commits are marked with +.'
    cmd = 'get_missing_fixed_commits'

    def add_arguments(self):
        super().add_arguments()
        self.add_local_argument('--missing-only', '-m', action='store_true',
                                help='print only the found missing commits')

    def get_format_flags(self, commit, index):
        if self.args.missing_only and not commit['added']:
            raise self.SkipCommit
        return '+' if commit['added'] else ' '


class CommandSeries(BaseCommandCommitListOutput, BaseCommandUpstreamCommitListArg):
    overview = 'Find commits belonging to the same series (patchset, merge request) as the given commits.'
    details = 'New commits are marked with + and already included commits with -. ' + \
              'Note that for already included commits to be detected, -t/--tree should be specified.'
    cmd = 'get_missing_series'

    def add_arguments(self):
        super().add_arguments()
        group = self.parser.add_mutually_exclusive_group()
        self.add_local_argument('--missing-only', '-m', action='store_true',
                                help='print only the found missing commits',
                                group=group)
        self.add_local_argument('--all', '-a', action='store_true',
                                help='print also commits that have been already applied',
                                group=group)
        self.add_local_argument('--no-merges', '-M', action='store_true',
                                help='do not print merge commits')

    def get_format_flags(self, commit, index):
        start = index == 0 or self.response[index - 1]['series'] != commit['series']
        end = index == len(self.response) - 1 or self.response[index + 1]['series'] != commit['series']
        if start and end:
            flags = ' '
        elif start:
            flags = draw_str('┐', '\\')
        elif end:
            flags = draw_str('┘', '/')
        else:
            flags = draw_str('│', '|')
        if commit['included']:
            flags += ' -'
        elif commit['merge']:
            flags += ' M'
        elif commit['added']:
            flags += ' +'
        else:
            flags += '  '
        return flags

    def format(self):
        # This cannot be moved to get_format_flags and converted to raise
        # SkipCommit, as the line drawing logic depends on only the
        # displayed commits being present in self.response.
        if self.args.missing_only:
            self.response = [d for d in self.response
                             if d['added'] and not d['included'] and not d['merge']]
        elif not self.args.all:
            self.response = [d for d in self.response if not d['included']]
        if self.args.no_merges:
            self.response = [d for d in self.response if not d['merge']]
        super().format()


class CommandLog(BaseCommandTreeArg, BaseCommandCommitListOutput, BaseCommandContinuousOutput):
    overview = 'Show commit list.'
    details = 'Patches that are part of a series are marked with "s", ' + \
              'unclean backports with "u" and partial backports with "p", ' + \
              'patches with kABI workarounds with "k", ' + \
              'missing fixes with "!" and patches reverted upstream with "R".'
    cmd = 'commit_list'
    help_tree_arg = 'limit the commits to a particular tree'
    options = ('partial', 'p', 'unclean', 'u', 'kabi', 'k', 'unfixed', '!', 'fixes')

    def add_arguments(self):
        super().add_arguments()
        self.add_rpc_argument('tree_special', '--upstream', '-u', action='store_true',
                              help='limit to commits from upstream trees (incompatible with -t)',
                              rpc_values=('UPSTREAM', None))
        self.add_rpc_argument('notin', '--exclude-tree', '-T', metavar='TREE',
                              help='do not include commits from this tree')
        self.add_rpc_argument('notin_special', '--exclude-upstream', '-U', action='store_true',
                              help='do not include commits from upstream trees (incompatible with -T)',
                              rpc_values=('UPSTREAM', None))
        self.add_rpc_argument('--author', '-A', metavar='EMAIL_REGEX',
                              help='limit commits to those authored by a person with email ' +
                                   'matching the regex')
        self.add_rpc_argument('--path', '-p', metavar='REGEX', action='append',
                              help='limit commits to those touching the given path. ' +
                                   'Can be specified multiple times; the parameters will be ORed')
        self.add_rpc_argument('excl', '--exclude-path', '-P', metavar='REGEX', action='append',
                              help='"subtract" this regex from the one given in --path. ' +
                                   'May be also used without specifying --path')
        self.add_rpc_argument('options', '--option', '-o', action='append',
                              choices=self.options,
                              help='limit to commits with this flag set')
        self.add_rpc_argument('top', 'commit', nargs='?',
                              help='top most commit. If specified, the output is filtered ' +
                                   'to trees this commit belongs to')

    def preprocess_arguments(self):
        super().preprocess_arguments()
        if self.args.option is not None:
            opts = set()
            for o in self.args.option:
                if len(o) == 1:
                    opts.add(self.options[self.options.index(o) - 1])
                else:
                    opts.add(o)
            self.args.option = list(opts)
        if self.args.path is not None:
            self.args.path = '|'.join(self.args.path)
        if self.args.exclude_path is not None:
            self.args.exclude_path = '|'.join(self.args.exclude_path)

    def set_format_options(self):
        super().set_format_options()
        self.format_options['trees'] = 'never' if self.args.tree else 'force'

    def get_format_flags(self, commit, index):
        flags = []
        for k, flag in (('in_series', 's'), ('is_unclean', 'u'), ('kabi', 'k'),
                        ('missing_fixes', '!'), ('is_reverted', 'R')):
            if k == 'is_unclean' and commit.get('is_partial'):
                # partial is always unclean; merge those two flags into one
                flags.append('p')
                continue
            flags.append(flag if commit.get(k) else ' ')
        trees = []
        for k in 'upstream', 'downstream':
            if commit[k]:
                trees.append('{{{}}}'.format(', '.join(commit[k])))
        return '{} {}'.format(''.join(flags), ' '.join(trees) if trees else ' ')


def read_config(filename):
    global config, rpc

    defconfig = {
        'server': 'https://kerneloscope.engineering.redhat.com/rpc/',
    }
    parser = configparser.ConfigParser(defaults=defconfig)

    if filename:
        # If a particular config file is specified, require it to exist. No global
        # config files are read in such case.
        with open(filename, 'r') as f:
            parser.read_string('[config]\n' + f.read())
    else:
        # Read the default, machine specific and user config file in this order. If a
        # config option appears in multiple files, the last one is used.
        xdg_conf = os.environ.get('XDG_CONFIG_HOME') or os.path.expanduser('~/.config')
        paths = (
            '/usr/lib/kerneloscope/kerneloscope.conf',
            '/etc/kerneloscope.conf',
            os.path.join(xdg_conf, 'kerneloscope.conf'),
            os.path.expanduser('~/.kerneloscope'),
        )
        parser.read_string('[config]\n')
        for path in paths:
            try:
                with open(path, 'r') as f:
                    parser.read_string('[config]\n' + f.read())
            except FileNotFoundError:
                pass
    config = parser['config']

    rpc = Rpc(config['server'])


def parse_dump_format(dump):
    parts = dump.split(':')
    if len(parts) > 1:
        parsed = [[int(parts[0].strip()), []]]
        for p in parts[1:-1]:
            keys, idx = p.rsplit(',', 1)
            parsed[-1][1] = keys.split(',')
            parsed.append([int(idx.strip()), []])
        parsed[-1][1] = parts[-1].split(',')
    else:
        parsed = [[None, parts[0].split(',')]]
    for p in parsed:
        for i, key in enumerate(p[1]):
            p[1][i] = key.strip()
        if p[1] == ['']:
            p[1] = []
    return parsed


def get_commands():
    return [g for n, g in globals().items() if n.startswith('Command')]


def handle_cmd_line():
    parser = argparse.ArgumentParser(description='A tool to ease kernel backports and review.')
    parser.add_argument('-c', '--config', action='store',
                        help='specify a config file to use')
    # Do not print the result; instead, dump the selected parts of the response. The format
    # of the --dump value is NOT STABLE and is currently:
    #     [INDEX:][FIELD][,FIELD]*[,INDEX:FIELD[,FIELD]*]
    # where INDEX is the index of the returned item and FIELD is the key in that item.
    # If INDEX is omitted, FIELDs from all items are dumped. If FIELDs are omitted, all
    # fields from the item are dumped.
    # If the value is empty, all fields from all items are dumped.
    # This is used by the Kerneloscope test suite.
    parser.add_argument('--dump', action='store', help=argparse.SUPPRESS)

    subparsers = parser.add_subparsers(dest='command', title='available commands', metavar='COMMAND')
    for c in get_commands():
        c(subparsers)

    try:
        import argcomplete
        argcomplete.autocomplete(parser)
    except ImportError:
        pass

    args = parser.parse_args()
    if not args.command:
        parser.print_help()
        sys.exit(0)
    try:
        read_config(args.config)
    except (FileNotFoundError, configparser.Error) as e:
        print('Error reading config: {}'.format(str(e)), file=sys.stderr)
        return
    if args.dump is not None:
        global output_stream, rpc
        output_stream = open(os.path.devnull, 'w')
        try:
            rpc._dump = parse_dump_format(args.dump)
        except (ValueError, IndexError):
            print('Invalid dump format: {}'.format(args.dump), file=sys.stderr)
            return
    try:
        args.func(args)
    except (RemoteUserException, ExternalException) as e:
        print('Error: {}'.format(str(e)), file=sys.stderr)
    except BrokenPipeError:
        # prevent Python printing an error message (this happens even though
        # the exception is caught)
        sys.stderr.close()
    except KeyboardInterrupt:
        pass


output_stream = sys.stdout
if __name__ == '__main__':
    handle_cmd_line()
