#!/usr/bin/env python3

import itertools
import os
import re
import sys

re_import = re.compile(r'^import (.*?)$', re.M)
re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)

class Import(object):
    defstyle = {'arrowhead': 'open', 'arrowtail':'none'}

    def __init__(self, importing, imported, **kwargs):
        self.importing = importing
        self.imported = imported
        self.style = self.defstyle.copy()
        self.style.update(kwargs)

    def __str__(self):
        return '{}"{}" -> "{}" [{}];'.format(
            ('//' if self.commented else ''),
            self.importing,
            self.imported,
            ', '.join('{}="{}"'.format(*i) for i in self.style.items()))

    def __eq__(self, other):
        return (self.importing.name, self.imported.name) \
            == (other.importing.name, other.imported.name)

    def __hash__(self):
        return hash((self.importing.name, self.imported.name))

    @property
    def commented(self):
        if self.style.get('color', '') != 'red':
            return True
#       for i in (self.importing, self.imported):
#           if i.name.startswith('qubes.tests'): return True
#           if i.name.startswith('qubes.tools'): return True


class Module(set):
    def __init__(self, package, path):
        self.package = package
        self.path = path

    def process(self):
        with open(os.path.join(self.package.root, self.path)) as fh:
            data = fh.read()
        data.replace('\\\n', ' ')

        for imported in re_import.findall(data):
            try:
                imported = self.package[imported]
            except KeyError:
                continue
            self.add(Import(self, imported))

        for imported in re_import_from.findall(data):
            try:
                imported = self.package[imported]
            except KeyError:
                continue
            self.add(Import(self, imported, style='dotted'))

    def __getitem__(self, key):
        for i in self:
            if i.imported == key:
                return i
        raise KeyError(key)

    @property
    def name(self):
        names = os.path.splitext(self.path)[0].split('/')
        names.insert(0, self.package.name)
        if names[-1] == '__init__':
            del names[-1]
        return '.'.join(names)

    def __hash__(self):
        return hash(self.name)

    def __str__(self):
        return self.name

    def __repr__(self):
        return '<{} {!r}>'.format(self.__class__.__name__, self.name)

    def __lt__(self, other):
        return self.name < other.name

    def __eq__(self, other):
        return self.name == other.name


class Cycle(tuple):
    def __new__(cls, modules):
        i = modules.index(sorted(modules)[0])
#       sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
        return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])

#   def __lt__(self, other):
#       if len(self) < len(other):
#           return True
#       elif len(self) > len(other):
#           return False
#
#       return super(Cycle, self).__lt__(other)


class Package(dict):
    def __init__(self, root):
        super(Package, self).__init__()
        self.root = root

        for dirpath, dirnames, filenames in os.walk(self.root):
            for filename in filenames:
                if not os.path.splitext(filename)[1] == '.py':
                    continue
                module = Module(self,
                    os.path.relpath(os.path.join(dirpath, filename), self.root))
                self[module.name] = module

        for name, module in self.items():
            module.process()

    @property
    def name(self):
        return os.path.basename(self.root.rstrip(os.path.sep))

    def _find_cycles(self):
        # stolen from codereview.stackexchange.com/questions/86021 and hacked
        path = []
        visited = set()

        def visit(module):
#           if module in visited:
#               return
#           visited.add(module)
            path.append(module)
            for i in module:
                if i.imported in path:
                    yield Cycle(path[path.index(i.imported):])
                else:
                    yield from visit(i.imported)
            path.pop()

        for v in self.values():
            yield from visit(v)

    def find_cycles(self):
        return list(sorted(set(self._find_cycles())))

    def get_all_imports(self):
        for module in self.values():
            yield from module

    def __str__(self):
        return '''\n
digraph "import" {{
charset="utf-8"
rankdir=BT
{}
}}
'''.format('\n'.join(str(i) for i in self.get_all_imports()))

def main():
    package = Package(sys.argv[1])

    for cycle in package.find_cycles():
        for i in range(len(cycle) - 1):
            edge = cycle[i][cycle[i+1]]
            edge.style['color'] = 'red'
        sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')

    sys.stdout.write(str(package))

if __name__ == '__main__':
    main()