Browse Source

contrib: import-graph creates smaller, more readable graph

Wojtek Porczyk 8 years ago
parent
commit
044aefe25a
1 changed files with 26 additions and 13 deletions
  1. 26 13
      contrib/import-graph

+ 26 - 13
contrib/import-graph

@@ -1,5 +1,6 @@
 #!/usr/bin/env python3
 
+import itertools
 import os
 import re
 import sys
@@ -8,15 +9,20 @@ re_import = re.compile(r'^import (.*?)$', re.M)
 re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
 
 class Import(object):
-    style = 'arrowhead="open", arrowtail="none"'
-    def __init__(self, importing, imported):
+    defstyle = {'arrowhead': 'open', 'arrowtail':'none'}
+
+    def __init__(self, importing, imported, **kwargs):
         self.importing = importing
         self.imported = imported
+        self.style = self.defstyle.copy()
+        self.style.update(kwargs)
 
     def __str__(self):
         return '{}"{}" -> "{}" [{}];'.format(
-            ('#' if self.commented else ''),
-            self.importing, self.imported, self.style)
+            ('//' if self.commented else ''),
+            self.importing,
+            self.imported,
+            ', '.join('{}="{}"'.format(*i) for i in self.style.items()))
 
     def __eq__(self, other):
         return (self.importing.name, self.imported.name) \
@@ -27,13 +33,11 @@ class Import(object):
 
     @property
     def commented(self):
-        for i in (self.importing, self.imported):
-            if i.name.startswith('qubes.tests'): return True
-            if i.name.startswith('qubes.tools'): return True
-
-
-class ImportFrom(Import):
-    style = 'arrowhead="open", arrowtail="none", color="red"'
+        if self.style.get('color', '') != 'red':
+            return True
+#       for i in (self.importing, self.imported):
+#           if i.name.startswith('qubes.tests'): return True
+#           if i.name.startswith('qubes.tools'): return True
 
 
 class Module(set):
@@ -58,7 +62,13 @@ class Module(set):
                 imported = self.package[imported]
             except KeyError:
                 continue
-            self.add(ImportFrom(self, imported))
+            self.add(Import(self, imported, style='dotted'))
+
+    def __getitem__(self, key):
+        for i in self:
+            if i.imported == key:
+                return i
+        raise KeyError(key)
 
     @property
     def name(self):
@@ -157,11 +167,14 @@ rankdir=BT
 
 def main():
     package = Package(sys.argv[1])
-    sys.stdout.write(str(package))
 
     for cycle in package.find_cycles():
+        for i in range(len(cycle) - 1):
+            edge = cycle[i][cycle[i+1]]
+            edge.style['color'] = 'red'
         sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
 
+    sys.stdout.write(str(package))
 
 if __name__ == '__main__':
     main()