contrib: script for drawing import graph and listing cycles
This commit is contained in:
parent
ebb79e9c4f
commit
c0741972ba
167
contrib/import-graph
Executable file
167
contrib/import-graph
Executable file
@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
re_import = re.compile(r'^import (.*?)$', re.M)
|
||||
re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
|
||||
|
||||
class Import(object):
|
||||
style = 'arrowhead="open", arrowtail="none"'
|
||||
def __init__(self, importing, imported):
|
||||
self.importing = importing
|
||||
self.imported = imported
|
||||
|
||||
def __str__(self):
|
||||
return '{}"{}" -> "{}" [{}];'.format(
|
||||
('#' if self.commented else ''),
|
||||
self.importing, self.imported, self.style)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.importing.name, self.imported.name) \
|
||||
== (other.importing.name, other.imported.name)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.importing.name, self.imported.name))
|
||||
|
||||
@property
|
||||
def commented(self):
|
||||
for i in (self.importing, self.imported):
|
||||
if i.name.startswith('qubes.tests'): return True
|
||||
if i.name.startswith('qubes.tools'): return True
|
||||
|
||||
|
||||
class ImportFrom(Import):
|
||||
style = 'arrowhead="open", arrowtail="none", color="red"'
|
||||
|
||||
|
||||
class Module(set):
|
||||
def __init__(self, package, path):
|
||||
self.package = package
|
||||
self.path = path
|
||||
|
||||
def process(self):
|
||||
with open(os.path.join(self.package.root, self.path)) as fh:
|
||||
data = fh.read()
|
||||
data.replace('\\\n', ' ')
|
||||
|
||||
for imported in re_import.findall(data):
|
||||
try:
|
||||
imported = self.package[imported]
|
||||
except KeyError:
|
||||
continue
|
||||
self.add(Import(self, imported))
|
||||
|
||||
for imported in re_import_from.findall(data):
|
||||
try:
|
||||
imported = self.package[imported]
|
||||
except KeyError:
|
||||
continue
|
||||
self.add(ImportFrom(self, imported))
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
names = os.path.splitext(self.path)[0].split('/')
|
||||
names.insert(0, self.package.name)
|
||||
if names[-1] == '__init__':
|
||||
del names[-1]
|
||||
return '.'.join(names)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def __repr__(self):
|
||||
return '<{} {!r}>'.format(self.__class__.__name__, self.name)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.name < other.name
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.name == other.name
|
||||
|
||||
|
||||
class Cycle(tuple):
|
||||
def __new__(cls, modules):
|
||||
i = modules.index(sorted(modules)[0])
|
||||
# sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
|
||||
return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
|
||||
|
||||
# def __lt__(self, other):
|
||||
# if len(self) < len(other):
|
||||
# return True
|
||||
# elif len(self) > len(other):
|
||||
# return False
|
||||
#
|
||||
# return super(Cycle, self).__lt__(other)
|
||||
|
||||
|
||||
class Package(dict):
|
||||
def __init__(self, root):
|
||||
super(Package, self).__init__()
|
||||
self.root = root
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(self.root):
|
||||
for filename in filenames:
|
||||
if not os.path.splitext(filename)[1] == '.py':
|
||||
continue
|
||||
module = Module(self,
|
||||
os.path.relpath(os.path.join(dirpath, filename), self.root))
|
||||
self[module.name] = module
|
||||
|
||||
for name, module in self.items():
|
||||
module.process()
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return os.path.basename(self.root.rstrip(os.path.sep))
|
||||
|
||||
def _find_cycles(self):
|
||||
# stolen from codereview.stackexchange.com/questions/86021 and hacked
|
||||
path = []
|
||||
visited = set()
|
||||
|
||||
def visit(module):
|
||||
# if module in visited:
|
||||
# return
|
||||
# visited.add(module)
|
||||
path.append(module)
|
||||
for i in module:
|
||||
if i.imported in path:
|
||||
yield Cycle(path[path.index(i.imported):])
|
||||
else:
|
||||
yield from visit(i.imported)
|
||||
path.pop()
|
||||
|
||||
for v in self.values():
|
||||
yield from visit(v)
|
||||
|
||||
def find_cycles(self):
|
||||
return list(sorted(set(self._find_cycles())))
|
||||
|
||||
def get_all_imports(self):
|
||||
for module in self.values():
|
||||
yield from module
|
||||
|
||||
def __str__(self):
|
||||
return '''\n
|
||||
digraph "import" {{
|
||||
charset="utf-8"
|
||||
rankdir=BT
|
||||
{}
|
||||
}}
|
||||
'''.format('\n'.join(str(i) for i in self.get_all_imports()))
|
||||
|
||||
def main():
|
||||
package = Package(sys.argv[1])
|
||||
sys.stdout.write(str(package))
|
||||
|
||||
for cycle in package.find_cycles():
|
||||
sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue
Block a user