#!/usr/bin/env python2

from __future__ import print_function
from pprint import pprint

import argparse
import ast
import os
import sys

SOMETHING = '<something>'

parser = argparse.ArgumentParser()

parser.add_argument('--never-handled',
    action='store_true', dest='never_handled',
    help='mark never handled events')

parser.add_argument('--no-never-handled',
    action='store_false', dest='never_handled',
    help='do not mark never handled events')

parser.add_argument('directory', metavar='DIRECTORY',
    help='directory to search for .py files')

class Event(object):
    def __init__(self, events, name):
        self.events = events
        self.name = name
        self.fired = []
        self.handled = []

    def fire(self, filename, lineno):
        self.fired.append((filename, lineno))

    def handle(self, filename, lineno):
        self.handled.append((filename, lineno))

    def print_summary_one(self, stream, attr, colour, never=True):
        lines = getattr(self, attr)
        if lines:
            for filename, lineno in lines:
                stream.write('  \033[{}m{}\033[0m {} +{}\n'.format(
                    colour, attr[0], filename, lineno))

        elif never:
            stream.write('  \033[1;33mnever {}\033[0m\n'.format(attr))

    def print_summary(self, stream, never_handled):
        stream.write('\033[1m{}\033[0m\n'.format(self.name))

        self.print_summary_one(stream, 'fired', '1;31')
        self.print_summary_one(stream, 'handled', '1;32', never=never_handled)


class Events(dict):
    def __missing__(self, key):
        self[key] = Event(self, key)
        return self[key]


class EventVisitor(ast.NodeVisitor):
    def __init__(self, events, filename, *args, **kwargs):
        super(EventVisitor, self).__init__(*args, **kwargs)
        self.events = events
        self.filename = filename

    def resolve_attr(self, node):
        if isinstance(node, ast.Name):
            return node.id
        if isinstance(node, ast.Attribute):
            return '{}.{}'.format(self.resolve_attr(node.value), node.attr)
        raise TypeError('resolve_attr() does not support {!r}'.format(node))

    def visit_Call(self, node):
        try:
            name = self.resolve_attr(node.func)
        except TypeError:
            # name got something else than identifier in the attribute path;
            # this may have been 'str'.format() for example; we can't call
            # events this way
            return

        if name.endswith('.fire_event') or name.endswith('.fire_event_pre'):
            # here we throw events; event name is the first argument; sometimes
            # it is expressed as 'event-stem:' + some_variable
            eventnode = node.args[0]
            if isinstance(eventnode, ast.Str):
                event = eventnode.s
            elif isinstance(eventnode, ast.BinOp) \
                    and isinstance(eventnode.left, ast.Str):
                event = eventnode.left.s
            else:
                raise AssertionError('fishy event {!r} in {} +{}'.format(
                    eventnode, self.filename, node.lineno))

            if ':' in event:
                event = ':'.join((event.split(':', 1)[0], SOMETHING))

            self.events[event].fire(self.filename, node.lineno)
            return

        if name in ('qubes.events.handler', 'qubes.ext.handler'):
            # here we handle; event names (there may be more than one) are all
            # positional arguments
            if node.starargs is not None:
                raise AssertionError(
                    'event handler with *args in {} +{}'.format(
                        self.filename, node.lineno))

            for arg in node.args:
                if not isinstance(arg, ast.Str):
                    raise AssertionError(
                        'event handler with non-string arg in {} +{}'.format(
                            self.filename, node.lineno))

                event = arg.s
                if ':' in event:
                    event = ':'.join((event.split(':', 1)[0], SOMETHING))

                self.events[event].handle(self.filename, node.lineno)

            return

        self.generic_visit(node)
        return


def main():
    args = parser.parse_args()

    events = Events()

    for dirpath, dirnames, filenames in os.walk(args.directory):
        for filename in filenames:
            if not filename.endswith('.py'):
                continue
            filepath = os.path.join(dirpath, filename)
            EventVisitor(events, filepath).visit(
                ast.parse(open(filepath).read(), filepath))

    for event in sorted(events):
        events[event].print_summary(
            sys.stdout, never_handled=args.never_handled)

if __name__ == '__main__':
    main()