Browse Source

contrib: script for drawing import graph and listing cycles

Wojtek Porczyk 8 years ago
parent
commit
c0741972ba
1 changed files with 167 additions and 0 deletions
  1. 167 0
      contrib/import-graph

+ 167 - 0
contrib/import-graph

@@ -0,0 +1,167 @@
+#!/usr/bin/env python3
+
+import os
+import re
+import sys
+
+re_import = re.compile(r'^import (.*?)$', re.M)
+re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
+
+class Import(object):
+    style = 'arrowhead="open", arrowtail="none"'
+    def __init__(self, importing, imported):
+        self.importing = importing
+        self.imported = imported
+
+    def __str__(self):
+        return '{}"{}" -> "{}" [{}];'.format(
+            ('#' if self.commented else ''),
+            self.importing, self.imported, self.style)
+
+    def __eq__(self, other):
+        return (self.importing.name, self.imported.name) \
+            == (other.importing.name, other.imported.name)
+
+    def __hash__(self):
+        return hash((self.importing.name, self.imported.name))
+
+    @property
+    def commented(self):
+        for i in (self.importing, self.imported):
+            if i.name.startswith('qubes.tests'): return True
+            if i.name.startswith('qubes.tools'): return True
+
+
+class ImportFrom(Import):
+    style = 'arrowhead="open", arrowtail="none", color="red"'
+
+
+class Module(set):
+    def __init__(self, package, path):
+        self.package = package
+        self.path = path
+
+    def process(self):
+        with open(os.path.join(self.package.root, self.path)) as fh:
+            data = fh.read()
+        data.replace('\\\n', ' ')
+
+        for imported in re_import.findall(data):
+            try:
+                imported = self.package[imported]
+            except KeyError:
+                continue
+            self.add(Import(self, imported))
+
+        for imported in re_import_from.findall(data):
+            try:
+                imported = self.package[imported]
+            except KeyError:
+                continue
+            self.add(ImportFrom(self, imported))
+
+    @property
+    def name(self):
+        names = os.path.splitext(self.path)[0].split('/')
+        names.insert(0, self.package.name)
+        if names[-1] == '__init__':
+            del names[-1]
+        return '.'.join(names)
+
+    def __hash__(self):
+        return hash(self.name)
+
+    def __str__(self):
+        return self.name
+
+    def __repr__(self):
+        return '<{} {!r}>'.format(self.__class__.__name__, self.name)
+
+    def __lt__(self, other):
+        return self.name < other.name
+
+    def __eq__(self, other):
+        return self.name == other.name
+
+
+class Cycle(tuple):
+    def __new__(cls, modules):
+        i = modules.index(sorted(modules)[0])
+#       sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
+        return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
+
+#   def __lt__(self, other):
+#       if len(self) < len(other):
+#           return True
+#       elif len(self) > len(other):
+#           return False
+#
+#       return super(Cycle, self).__lt__(other)
+
+
+class Package(dict):
+    def __init__(self, root):
+        super(Package, self).__init__()
+        self.root = root
+
+        for dirpath, dirnames, filenames in os.walk(self.root):
+            for filename in filenames:
+                if not os.path.splitext(filename)[1] == '.py':
+                    continue
+                module = Module(self,
+                    os.path.relpath(os.path.join(dirpath, filename), self.root))
+                self[module.name] = module
+
+        for name, module in self.items():
+            module.process()
+
+    @property
+    def name(self):
+        return os.path.basename(self.root.rstrip(os.path.sep))
+
+    def _find_cycles(self):
+        # stolen from codereview.stackexchange.com/questions/86021 and hacked
+        path = []
+        visited = set()
+
+        def visit(module):
+#           if module in visited:
+#               return
+#           visited.add(module)
+            path.append(module)
+            for i in module:
+                if i.imported in path:
+                    yield Cycle(path[path.index(i.imported):])
+                else:
+                    yield from visit(i.imported)
+            path.pop()
+
+        for v in self.values():
+            yield from visit(v)
+
+    def find_cycles(self):
+        return list(sorted(set(self._find_cycles())))
+
+    def get_all_imports(self):
+        for module in self.values():
+            yield from module
+
+    def __str__(self):
+        return '''\n
+digraph "import" {{
+charset="utf-8"
+rankdir=BT
+{}
+}}
+'''.format('\n'.join(str(i) for i in self.get_all_imports()))
+
+def main():
+    package = Package(sys.argv[1])
+    sys.stdout.write(str(package))
+
+    for cycle in package.find_cycles():
+        sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
+
+
+if __name__ == '__main__':
+    main()