123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- #!/usr/bin/env python3
- import itertools
- import os
- import re
- import sys
- re_import = re.compile(r'^import (.*?)$', re.M)
- re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
- class Import(object):
- defstyle = {'arrowhead': 'open', 'arrowtail':'none'}
- def __init__(self, importing, imported, **kwargs):
- self.importing = importing
- self.imported = imported
- self.style = self.defstyle.copy()
- self.style.update(kwargs)
- def __str__(self):
- return '{}"{}" -> "{}" [{}];'.format(
- ('//' if self.commented else ''),
- self.importing,
- self.imported,
- ', '.join('{}="{}"'.format(*i) for i in self.style.items()))
- def __eq__(self, other):
- return (self.importing.name, self.imported.name) \
- == (other.importing.name, other.imported.name)
- def __hash__(self):
- return hash((self.importing.name, self.imported.name))
- @property
- def commented(self):
- if self.style.get('color', '') != 'red':
- return True
- # for i in (self.importing, self.imported):
- # if i.name.startswith('qubes.tests'): return True
- # if i.name.startswith('qubes.tools'): return True
- class Module(set):
- def __init__(self, package, path):
- self.package = package
- self.path = path
- def process(self):
- with open(os.path.join(self.package.root, self.path)) as fh:
- data = fh.read()
- data.replace('\\\n', ' ')
- for imported in re_import.findall(data):
- try:
- imported = self.package[imported]
- except KeyError:
- continue
- self.add(Import(self, imported))
- for imported in re_import_from.findall(data):
- try:
- imported = self.package[imported]
- except KeyError:
- continue
- self.add(Import(self, imported, style='dotted'))
- def __getitem__(self, key):
- for i in self:
- if i.imported == key:
- return i
- raise KeyError(key)
- @property
- def name(self):
- names = os.path.splitext(self.path)[0].split('/')
- names.insert(0, self.package.name)
- if names[-1] == '__init__':
- del names[-1]
- return '.'.join(names)
- def __hash__(self):
- return hash(self.name)
- def __str__(self):
- return self.name
- def __repr__(self):
- return '<{} {!r}>'.format(self.__class__.__name__, self.name)
- def __lt__(self, other):
- return self.name < other.name
- def __eq__(self, other):
- return self.name == other.name
- class Cycle(tuple):
- def __new__(cls, modules):
- i = modules.index(sorted(modules)[0])
- # sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
- return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
- # def __lt__(self, other):
- # if len(self) < len(other):
- # return True
- # elif len(self) > len(other):
- # return False
- #
- # return super(Cycle, self).__lt__(other)
- class Package(dict):
- def __init__(self, root):
- super(Package, self).__init__()
- self.root = root
- for dirpath, dirnames, filenames in os.walk(self.root):
- for filename in filenames:
- if not os.path.splitext(filename)[1] == '.py':
- continue
- module = Module(self,
- os.path.relpath(os.path.join(dirpath, filename), self.root))
- self[module.name] = module
- for name, module in self.items():
- module.process()
- @property
- def name(self):
- return os.path.basename(self.root.rstrip(os.path.sep))
- def _find_cycles(self):
- # stolen from codereview.stackexchange.com/questions/86021 and hacked
- path = []
- visited = set()
- def visit(module):
- # if module in visited:
- # return
- # visited.add(module)
- path.append(module)
- for i in module:
- if i.imported in path:
- yield Cycle(path[path.index(i.imported):])
- else:
- yield from visit(i.imported)
- path.pop()
- for v in self.values():
- yield from visit(v)
- def find_cycles(self):
- return list(sorted(set(self._find_cycles())))
- def get_all_imports(self):
- for module in self.values():
- yield from module
- def __str__(self):
- return '''\n
- digraph "import" {{
- charset="utf-8"
- rankdir=BT
- {}
- }}
- '''.format('\n'.join(str(i) for i in self.get_all_imports()))
- def main():
- package = Package(sys.argv[1])
- for cycle in package.find_cycles():
- for i in range(len(cycle) - 1):
- edge = cycle[i][cycle[i+1]]
- edge.style['color'] = 'red'
- sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
- sys.stdout.write(str(package))
- if __name__ == '__main__':
- main()
|