diff --git a/contrib/import-graph b/contrib/import-graph index 1cc6f582..7a166a69 100755 --- a/contrib/import-graph +++ b/contrib/import-graph @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import itertools import os import re import sys @@ -8,15 +9,20 @@ re_import = re.compile(r'^import (.*?)$', re.M) re_import_from = re.compile(r'^from (.*?) import .*?$', re.M) class Import(object): - style = 'arrowhead="open", arrowtail="none"' - def __init__(self, importing, imported): + defstyle = {'arrowhead': 'open', 'arrowtail':'none'} + + def __init__(self, importing, imported, **kwargs): self.importing = importing self.imported = imported + self.style = self.defstyle.copy() + self.style.update(kwargs) def __str__(self): return '{}"{}" -> "{}" [{}];'.format( - ('#' if self.commented else ''), - self.importing, self.imported, self.style) + ('//' if self.commented else ''), + self.importing, + self.imported, + ', '.join('{}="{}"'.format(*i) for i in self.style.items())) def __eq__(self, other): return (self.importing.name, self.imported.name) \ @@ -27,13 +33,11 @@ class Import(object): @property def commented(self): - for i in (self.importing, self.imported): - if i.name.startswith('qubes.tests'): return True - if i.name.startswith('qubes.tools'): return True - - -class ImportFrom(Import): - style = 'arrowhead="open", arrowtail="none", color="red"' + if self.style.get('color', '') != 'red': + return True +# for i in (self.importing, self.imported): +# if i.name.startswith('qubes.tests'): return True +# if i.name.startswith('qubes.tools'): return True class Module(set): @@ -58,7 +62,13 @@ class Module(set): imported = self.package[imported] except KeyError: continue - self.add(ImportFrom(self, imported)) + self.add(Import(self, imported, style='dotted')) + + def __getitem__(self, key): + for i in self: + if i.imported == key: + return i + raise KeyError(key) @property def name(self): @@ -157,11 +167,14 @@ rankdir=BT def main(): package = Package(sys.argv[1]) - sys.stdout.write(str(package)) for cycle in package.find_cycles(): + for i in range(len(cycle) - 1): + edge = cycle[i][cycle[i+1]] + edge.style['color'] = 'red' sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n') + sys.stdout.write(str(package)) if __name__ == '__main__': main()