168 lines
4.5 KiB
Python
Executable File
168 lines
4.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
import os
|
|
import re
|
|
import sys
|
|
|
|
re_import = re.compile(r'^import (.*?)$', re.M)
|
|
re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
|
|
|
|
class Import(object):
|
|
style = 'arrowhead="open", arrowtail="none"'
|
|
def __init__(self, importing, imported):
|
|
self.importing = importing
|
|
self.imported = imported
|
|
|
|
def __str__(self):
|
|
return '{}"{}" -> "{}" [{}];'.format(
|
|
('#' if self.commented else ''),
|
|
self.importing, self.imported, self.style)
|
|
|
|
def __eq__(self, other):
|
|
return (self.importing.name, self.imported.name) \
|
|
== (other.importing.name, other.imported.name)
|
|
|
|
def __hash__(self):
|
|
return hash((self.importing.name, self.imported.name))
|
|
|
|
@property
|
|
def commented(self):
|
|
for i in (self.importing, self.imported):
|
|
if i.name.startswith('qubes.tests'): return True
|
|
if i.name.startswith('qubes.tools'): return True
|
|
|
|
|
|
class ImportFrom(Import):
|
|
style = 'arrowhead="open", arrowtail="none", color="red"'
|
|
|
|
|
|
class Module(set):
|
|
def __init__(self, package, path):
|
|
self.package = package
|
|
self.path = path
|
|
|
|
def process(self):
|
|
with open(os.path.join(self.package.root, self.path)) as fh:
|
|
data = fh.read()
|
|
data.replace('\\\n', ' ')
|
|
|
|
for imported in re_import.findall(data):
|
|
try:
|
|
imported = self.package[imported]
|
|
except KeyError:
|
|
continue
|
|
self.add(Import(self, imported))
|
|
|
|
for imported in re_import_from.findall(data):
|
|
try:
|
|
imported = self.package[imported]
|
|
except KeyError:
|
|
continue
|
|
self.add(ImportFrom(self, imported))
|
|
|
|
@property
|
|
def name(self):
|
|
names = os.path.splitext(self.path)[0].split('/')
|
|
names.insert(0, self.package.name)
|
|
if names[-1] == '__init__':
|
|
del names[-1]
|
|
return '.'.join(names)
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def __repr__(self):
|
|
return '<{} {!r}>'.format(self.__class__.__name__, self.name)
|
|
|
|
def __lt__(self, other):
|
|
return self.name < other.name
|
|
|
|
def __eq__(self, other):
|
|
return self.name == other.name
|
|
|
|
|
|
class Cycle(tuple):
|
|
def __new__(cls, modules):
|
|
i = modules.index(sorted(modules)[0])
|
|
# sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
|
|
return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
|
|
|
|
# def __lt__(self, other):
|
|
# if len(self) < len(other):
|
|
# return True
|
|
# elif len(self) > len(other):
|
|
# return False
|
|
#
|
|
# return super(Cycle, self).__lt__(other)
|
|
|
|
|
|
class Package(dict):
|
|
def __init__(self, root):
|
|
super(Package, self).__init__()
|
|
self.root = root
|
|
|
|
for dirpath, dirnames, filenames in os.walk(self.root):
|
|
for filename in filenames:
|
|
if not os.path.splitext(filename)[1] == '.py':
|
|
continue
|
|
module = Module(self,
|
|
os.path.relpath(os.path.join(dirpath, filename), self.root))
|
|
self[module.name] = module
|
|
|
|
for name, module in self.items():
|
|
module.process()
|
|
|
|
@property
|
|
def name(self):
|
|
return os.path.basename(self.root.rstrip(os.path.sep))
|
|
|
|
def _find_cycles(self):
|
|
# stolen from codereview.stackexchange.com/questions/86021 and hacked
|
|
path = []
|
|
visited = set()
|
|
|
|
def visit(module):
|
|
# if module in visited:
|
|
# return
|
|
# visited.add(module)
|
|
path.append(module)
|
|
for i in module:
|
|
if i.imported in path:
|
|
yield Cycle(path[path.index(i.imported):])
|
|
else:
|
|
yield from visit(i.imported)
|
|
path.pop()
|
|
|
|
for v in self.values():
|
|
yield from visit(v)
|
|
|
|
def find_cycles(self):
|
|
return list(sorted(set(self._find_cycles())))
|
|
|
|
def get_all_imports(self):
|
|
for module in self.values():
|
|
yield from module
|
|
|
|
def __str__(self):
|
|
return '''\n
|
|
digraph "import" {{
|
|
charset="utf-8"
|
|
rankdir=BT
|
|
{}
|
|
}}
|
|
'''.format('\n'.join(str(i) for i in self.get_all_imports()))
|
|
|
|
def main():
|
|
package = Package(sys.argv[1])
|
|
sys.stdout.write(str(package))
|
|
|
|
for cycle in package.find_cycles():
|
|
sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|