168 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			168 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
import os
 | 
						|
import re
 | 
						|
import sys
 | 
						|
 | 
						|
re_import = re.compile(r'^import (.*?)$', re.M)
 | 
						|
re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
 | 
						|
 | 
						|
class Import(object):
 | 
						|
    style = 'arrowhead="open", arrowtail="none"'
 | 
						|
    def __init__(self, importing, imported):
 | 
						|
        self.importing = importing
 | 
						|
        self.imported = imported
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return '{}"{}" -> "{}" [{}];'.format(
 | 
						|
            ('#' if self.commented else ''),
 | 
						|
            self.importing, self.imported, self.style)
 | 
						|
 | 
						|
    def __eq__(self, other):
 | 
						|
        return (self.importing.name, self.imported.name) \
 | 
						|
            == (other.importing.name, other.imported.name)
 | 
						|
 | 
						|
    def __hash__(self):
 | 
						|
        return hash((self.importing.name, self.imported.name))
 | 
						|
 | 
						|
    @property
 | 
						|
    def commented(self):
 | 
						|
        for i in (self.importing, self.imported):
 | 
						|
            if i.name.startswith('qubes.tests'): return True
 | 
						|
            if i.name.startswith('qubes.tools'): return True
 | 
						|
 | 
						|
 | 
						|
class ImportFrom(Import):
 | 
						|
    style = 'arrowhead="open", arrowtail="none", color="red"'
 | 
						|
 | 
						|
 | 
						|
class Module(set):
 | 
						|
    def __init__(self, package, path):
 | 
						|
        self.package = package
 | 
						|
        self.path = path
 | 
						|
 | 
						|
    def process(self):
 | 
						|
        with open(os.path.join(self.package.root, self.path)) as fh:
 | 
						|
            data = fh.read()
 | 
						|
        data.replace('\\\n', ' ')
 | 
						|
 | 
						|
        for imported in re_import.findall(data):
 | 
						|
            try:
 | 
						|
                imported = self.package[imported]
 | 
						|
            except KeyError:
 | 
						|
                continue
 | 
						|
            self.add(Import(self, imported))
 | 
						|
 | 
						|
        for imported in re_import_from.findall(data):
 | 
						|
            try:
 | 
						|
                imported = self.package[imported]
 | 
						|
            except KeyError:
 | 
						|
                continue
 | 
						|
            self.add(ImportFrom(self, imported))
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self):
 | 
						|
        names = os.path.splitext(self.path)[0].split('/')
 | 
						|
        names.insert(0, self.package.name)
 | 
						|
        if names[-1] == '__init__':
 | 
						|
            del names[-1]
 | 
						|
        return '.'.join(names)
 | 
						|
 | 
						|
    def __hash__(self):
 | 
						|
        return hash(self.name)
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return self.name
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return '<{} {!r}>'.format(self.__class__.__name__, self.name)
 | 
						|
 | 
						|
    def __lt__(self, other):
 | 
						|
        return self.name < other.name
 | 
						|
 | 
						|
    def __eq__(self, other):
 | 
						|
        return self.name == other.name
 | 
						|
 | 
						|
 | 
						|
class Cycle(tuple):
 | 
						|
    def __new__(cls, modules):
 | 
						|
        i = modules.index(sorted(modules)[0])
 | 
						|
#       sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
 | 
						|
        return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
 | 
						|
 | 
						|
#   def __lt__(self, other):
 | 
						|
#       if len(self) < len(other):
 | 
						|
#           return True
 | 
						|
#       elif len(self) > len(other):
 | 
						|
#           return False
 | 
						|
#
 | 
						|
#       return super(Cycle, self).__lt__(other)
 | 
						|
 | 
						|
 | 
						|
class Package(dict):
 | 
						|
    def __init__(self, root):
 | 
						|
        super(Package, self).__init__()
 | 
						|
        self.root = root
 | 
						|
 | 
						|
        for dirpath, dirnames, filenames in os.walk(self.root):
 | 
						|
            for filename in filenames:
 | 
						|
                if not os.path.splitext(filename)[1] == '.py':
 | 
						|
                    continue
 | 
						|
                module = Module(self,
 | 
						|
                    os.path.relpath(os.path.join(dirpath, filename), self.root))
 | 
						|
                self[module.name] = module
 | 
						|
 | 
						|
        for name, module in self.items():
 | 
						|
            module.process()
 | 
						|
 | 
						|
    @property
 | 
						|
    def name(self):
 | 
						|
        return os.path.basename(self.root.rstrip(os.path.sep))
 | 
						|
 | 
						|
    def _find_cycles(self):
 | 
						|
        # stolen from codereview.stackexchange.com/questions/86021 and hacked
 | 
						|
        path = []
 | 
						|
        visited = set()
 | 
						|
 | 
						|
        def visit(module):
 | 
						|
#           if module in visited:
 | 
						|
#               return
 | 
						|
#           visited.add(module)
 | 
						|
            path.append(module)
 | 
						|
            for i in module:
 | 
						|
                if i.imported in path:
 | 
						|
                    yield Cycle(path[path.index(i.imported):])
 | 
						|
                else:
 | 
						|
                    yield from visit(i.imported)
 | 
						|
            path.pop()
 | 
						|
 | 
						|
        for v in self.values():
 | 
						|
            yield from visit(v)
 | 
						|
 | 
						|
    def find_cycles(self):
 | 
						|
        return list(sorted(set(self._find_cycles())))
 | 
						|
 | 
						|
    def get_all_imports(self):
 | 
						|
        for module in self.values():
 | 
						|
            yield from module
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return '''\n
 | 
						|
digraph "import" {{
 | 
						|
charset="utf-8"
 | 
						|
rankdir=BT
 | 
						|
{}
 | 
						|
}}
 | 
						|
'''.format('\n'.join(str(i) for i in self.get_all_imports()))
 | 
						|
 | 
						|
def main():
 | 
						|
    package = Package(sys.argv[1])
 | 
						|
    sys.stdout.write(str(package))
 | 
						|
 | 
						|
    for cycle in package.find_cycles():
 | 
						|
        sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
 | 
						|
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    main()
 |