diff --git a/contrib/import-graph b/contrib/import-graph new file mode 100755 index 00000000..1cc6f582 --- /dev/null +++ b/contrib/import-graph @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 + +import os +import re +import sys + +re_import = re.compile(r'^import (.*?)$', re.M) +re_import_from = re.compile(r'^from (.*?) import .*?$', re.M) + +class Import(object): + style = 'arrowhead="open", arrowtail="none"' + def __init__(self, importing, imported): + self.importing = importing + self.imported = imported + + def __str__(self): + return '{}"{}" -> "{}" [{}];'.format( + ('#' if self.commented else ''), + self.importing, self.imported, self.style) + + def __eq__(self, other): + return (self.importing.name, self.imported.name) \ + == (other.importing.name, other.imported.name) + + def __hash__(self): + return hash((self.importing.name, self.imported.name)) + + @property + def commented(self): + for i in (self.importing, self.imported): + if i.name.startswith('qubes.tests'): return True + if i.name.startswith('qubes.tools'): return True + + +class ImportFrom(Import): + style = 'arrowhead="open", arrowtail="none", color="red"' + + +class Module(set): + def __init__(self, package, path): + self.package = package + self.path = path + + def process(self): + with open(os.path.join(self.package.root, self.path)) as fh: + data = fh.read() + data.replace('\\\n', ' ') + + for imported in re_import.findall(data): + try: + imported = self.package[imported] + except KeyError: + continue + self.add(Import(self, imported)) + + for imported in re_import_from.findall(data): + try: + imported = self.package[imported] + except KeyError: + continue + self.add(ImportFrom(self, imported)) + + @property + def name(self): + names = os.path.splitext(self.path)[0].split('/') + names.insert(0, self.package.name) + if names[-1] == '__init__': + del names[-1] + return '.'.join(names) + + def __hash__(self): + return hash(self.name) + + def __str__(self): + return self.name + + def __repr__(self): + return '<{} {!r}>'.format(self.__class__.__name__, self.name) + + def __lt__(self, other): + return self.name < other.name + + def __eq__(self, other): + return self.name == other.name + + +class Cycle(tuple): + def __new__(cls, modules): + i = modules.index(sorted(modules)[0]) +# sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i)) + return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1]) + +# def __lt__(self, other): +# if len(self) < len(other): +# return True +# elif len(self) > len(other): +# return False +# +# return super(Cycle, self).__lt__(other) + + +class Package(dict): + def __init__(self, root): + super(Package, self).__init__() + self.root = root + + for dirpath, dirnames, filenames in os.walk(self.root): + for filename in filenames: + if not os.path.splitext(filename)[1] == '.py': + continue + module = Module(self, + os.path.relpath(os.path.join(dirpath, filename), self.root)) + self[module.name] = module + + for name, module in self.items(): + module.process() + + @property + def name(self): + return os.path.basename(self.root.rstrip(os.path.sep)) + + def _find_cycles(self): + # stolen from codereview.stackexchange.com/questions/86021 and hacked + path = [] + visited = set() + + def visit(module): +# if module in visited: +# return +# visited.add(module) + path.append(module) + for i in module: + if i.imported in path: + yield Cycle(path[path.index(i.imported):]) + else: + yield from visit(i.imported) + path.pop() + + for v in self.values(): + yield from visit(v) + + def find_cycles(self): + return list(sorted(set(self._find_cycles()))) + + def get_all_imports(self): + for module in self.values(): + yield from module + + def __str__(self): + return '''\n +digraph "import" {{ +charset="utf-8" +rankdir=BT +{} +}} +'''.format('\n'.join(str(i) for i in self.get_all_imports())) + +def main(): + package = Package(sys.argv[1]) + sys.stdout.write(str(package)) + + for cycle in package.find_cycles(): + sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n') + + +if __name__ == '__main__': + main()