#!/usr/bin/env python3 import itertools import os import re import sys re_import = re.compile(r'^import (.*?)$', re.M) re_import_from = re.compile(r'^from (.*?) import .*?$', re.M) class Import(object): defstyle = {'arrowhead': 'open', 'arrowtail':'none'} def __init__(self, importing, imported, **kwargs): self.importing = importing self.imported = imported self.style = self.defstyle.copy() self.style.update(kwargs) def __str__(self): return '{}"{}" -> "{}" [{}];'.format( ('//' if self.commented else ''), self.importing, self.imported, ', '.join('{}="{}"'.format(*i) for i in self.style.items())) def __eq__(self, other): return (self.importing.name, self.imported.name) \ == (other.importing.name, other.imported.name) def __hash__(self): return hash((self.importing.name, self.imported.name)) @property def commented(self): if self.style.get('color', '') != 'red': return True # for i in (self.importing, self.imported): # if i.name.startswith('qubes.tests'): return True # if i.name.startswith('qubes.tools'): return True class Module(set): def __init__(self, package, path): self.package = package self.path = path def process(self): with open(os.path.join(self.package.root, self.path)) as fh: data = fh.read() data.replace('\\\n', ' ') for imported in re_import.findall(data): try: imported = self.package[imported] except KeyError: continue self.add(Import(self, imported)) for imported in re_import_from.findall(data): try: imported = self.package[imported] except KeyError: continue self.add(Import(self, imported, style='dotted')) def __getitem__(self, key): for i in self: if i.imported == key: return i raise KeyError(key) @property def name(self): names = os.path.splitext(self.path)[0].split('/') names.insert(0, self.package.name) if names[-1] == '__init__': del names[-1] return '.'.join(names) def __hash__(self): return hash(self.name) def __str__(self): return self.name def __repr__(self): return '<{} {!r}>'.format(self.__class__.__name__, self.name) def __lt__(self, other): return self.name < other.name def __eq__(self, other): return self.name == other.name class Cycle(tuple): def __new__(cls, modules): i = modules.index(sorted(modules)[0]) # sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i)) return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1]) # def __lt__(self, other): # if len(self) < len(other): # return True # elif len(self) > len(other): # return False # # return super(Cycle, self).__lt__(other) class Package(dict): def __init__(self, root): super(Package, self).__init__() self.root = root for dirpath, dirnames, filenames in os.walk(self.root): for filename in filenames: if not os.path.splitext(filename)[1] == '.py': continue module = Module(self, os.path.relpath(os.path.join(dirpath, filename), self.root)) self[module.name] = module for name, module in self.items(): module.process() @property def name(self): return os.path.basename(self.root.rstrip(os.path.sep)) def _find_cycles(self): # stolen from codereview.stackexchange.com/questions/86021 and hacked path = [] visited = set() def visit(module): # if module in visited: # return # visited.add(module) path.append(module) for i in module: if i.imported in path: yield Cycle(path[path.index(i.imported):]) else: yield from visit(i.imported) path.pop() for v in self.values(): yield from visit(v) def find_cycles(self): return list(sorted(set(self._find_cycles()))) def get_all_imports(self): for module in self.values(): yield from module def __str__(self): return '''\n digraph "import" {{ charset="utf-8" rankdir=BT {} }} '''.format('\n'.join(str(i) for i in self.get_all_imports())) def main(): package = Package(sys.argv[1]) for cycle in package.find_cycles(): for i in range(len(cycle) - 1): edge = cycle[i][cycle[i+1]] edge.style['color'] = 'red' sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n') sys.stdout.write(str(package)) if __name__ == '__main__': main()