core-admin/contrib/import-graph

181 lines
4.9 KiB
Plaintext
Raw Permalink Normal View History

#!/usr/bin/env python3
import itertools
import os
import re
import sys
re_import = re.compile(r'^import (.*?)$', re.M)
re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
class Import(object):
defstyle = {'arrowhead': 'open', 'arrowtail':'none'}
def __init__(self, importing, imported, **kwargs):
self.importing = importing
self.imported = imported
self.style = self.defstyle.copy()
self.style.update(kwargs)
def __str__(self):
return '{}"{}" -> "{}" [{}];'.format(
('//' if self.commented else ''),
self.importing,
self.imported,
', '.join('{}="{}"'.format(*i) for i in self.style.items()))
def __eq__(self, other):
return (self.importing.name, self.imported.name) \
== (other.importing.name, other.imported.name)
def __hash__(self):
return hash((self.importing.name, self.imported.name))
@property
def commented(self):
if self.style.get('color', '') != 'red':
return True
# for i in (self.importing, self.imported):
# if i.name.startswith('qubes.tests'): return True
# if i.name.startswith('qubes.tools'): return True
class Module(set):
def __init__(self, package, path):
self.package = package
self.path = path
def process(self):
with open(os.path.join(self.package.root, self.path)) as fh:
data = fh.read()
data.replace('\\\n', ' ')
for imported in re_import.findall(data):
try:
imported = self.package[imported]
except KeyError:
continue
self.add(Import(self, imported))
for imported in re_import_from.findall(data):
try:
imported = self.package[imported]
except KeyError:
continue
self.add(Import(self, imported, style='dotted'))
def __getitem__(self, key):
for i in self:
if i.imported == key:
return i
raise KeyError(key)
@property
def name(self):
names = os.path.splitext(self.path)[0].split('/')
names.insert(0, self.package.name)
if names[-1] == '__init__':
del names[-1]
return '.'.join(names)
def __hash__(self):
return hash(self.name)
def __str__(self):
return self.name
def __repr__(self):
return '<{} {!r}>'.format(self.__class__.__name__, self.name)
def __lt__(self, other):
return self.name < other.name
def __eq__(self, other):
return self.name == other.name
class Cycle(tuple):
def __new__(cls, modules):
i = modules.index(sorted(modules)[0])
# sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
# def __lt__(self, other):
# if len(self) < len(other):
# return True
# elif len(self) > len(other):
# return False
#
# return super(Cycle, self).__lt__(other)
class Package(dict):
def __init__(self, root):
super(Package, self).__init__()
self.root = root
for dirpath, dirnames, filenames in os.walk(self.root):
for filename in filenames:
if not os.path.splitext(filename)[1] == '.py':
continue
module = Module(self,
os.path.relpath(os.path.join(dirpath, filename), self.root))
self[module.name] = module
for name, module in self.items():
module.process()
@property
def name(self):
return os.path.basename(self.root.rstrip(os.path.sep))
def _find_cycles(self):
# stolen from codereview.stackexchange.com/questions/86021 and hacked
path = []
visited = set()
def visit(module):
# if module in visited:
# return
# visited.add(module)
path.append(module)
for i in module:
if i.imported in path:
yield Cycle(path[path.index(i.imported):])
else:
yield from visit(i.imported)
path.pop()
for v in self.values():
yield from visit(v)
def find_cycles(self):
return list(sorted(set(self._find_cycles())))
def get_all_imports(self):
for module in self.values():
yield from module
def __str__(self):
return '''\n
digraph "import" {{
charset="utf-8"
rankdir=BT
{}
}}
'''.format('\n'.join(str(i) for i in self.get_all_imports()))
def main():
package = Package(sys.argv[1])
for cycle in package.find_cycles():
for i in range(len(cycle) - 1):
edge = cycle[i][cycle[i+1]]
edge.style['color'] = 'red'
sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
sys.stdout.write(str(package))
if __name__ == '__main__':
main()