import-graph 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. #!/usr/bin/env python3
  2. import os
  3. import re
  4. import sys
  5. re_import = re.compile(r'^import (.*?)$', re.M)
  6. re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
  7. class Import(object):
  8. style = 'arrowhead="open", arrowtail="none"'
  9. def __init__(self, importing, imported):
  10. self.importing = importing
  11. self.imported = imported
  12. def __str__(self):
  13. return '{}"{}" -> "{}" [{}];'.format(
  14. ('#' if self.commented else ''),
  15. self.importing, self.imported, self.style)
  16. def __eq__(self, other):
  17. return (self.importing.name, self.imported.name) \
  18. == (other.importing.name, other.imported.name)
  19. def __hash__(self):
  20. return hash((self.importing.name, self.imported.name))
  21. @property
  22. def commented(self):
  23. for i in (self.importing, self.imported):
  24. if i.name.startswith('qubes.tests'): return True
  25. if i.name.startswith('qubes.tools'): return True
  26. class ImportFrom(Import):
  27. style = 'arrowhead="open", arrowtail="none", color="red"'
  28. class Module(set):
  29. def __init__(self, package, path):
  30. self.package = package
  31. self.path = path
  32. def process(self):
  33. with open(os.path.join(self.package.root, self.path)) as fh:
  34. data = fh.read()
  35. data.replace('\\\n', ' ')
  36. for imported in re_import.findall(data):
  37. try:
  38. imported = self.package[imported]
  39. except KeyError:
  40. continue
  41. self.add(Import(self, imported))
  42. for imported in re_import_from.findall(data):
  43. try:
  44. imported = self.package[imported]
  45. except KeyError:
  46. continue
  47. self.add(ImportFrom(self, imported))
  48. @property
  49. def name(self):
  50. names = os.path.splitext(self.path)[0].split('/')
  51. names.insert(0, self.package.name)
  52. if names[-1] == '__init__':
  53. del names[-1]
  54. return '.'.join(names)
  55. def __hash__(self):
  56. return hash(self.name)
  57. def __str__(self):
  58. return self.name
  59. def __repr__(self):
  60. return '<{} {!r}>'.format(self.__class__.__name__, self.name)
  61. def __lt__(self, other):
  62. return self.name < other.name
  63. def __eq__(self, other):
  64. return self.name == other.name
  65. class Cycle(tuple):
  66. def __new__(cls, modules):
  67. i = modules.index(sorted(modules)[0])
  68. # sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
  69. return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
  70. # def __lt__(self, other):
  71. # if len(self) < len(other):
  72. # return True
  73. # elif len(self) > len(other):
  74. # return False
  75. #
  76. # return super(Cycle, self).__lt__(other)
  77. class Package(dict):
  78. def __init__(self, root):
  79. super(Package, self).__init__()
  80. self.root = root
  81. for dirpath, dirnames, filenames in os.walk(self.root):
  82. for filename in filenames:
  83. if not os.path.splitext(filename)[1] == '.py':
  84. continue
  85. module = Module(self,
  86. os.path.relpath(os.path.join(dirpath, filename), self.root))
  87. self[module.name] = module
  88. for name, module in self.items():
  89. module.process()
  90. @property
  91. def name(self):
  92. return os.path.basename(self.root.rstrip(os.path.sep))
  93. def _find_cycles(self):
  94. # stolen from codereview.stackexchange.com/questions/86021 and hacked
  95. path = []
  96. visited = set()
  97. def visit(module):
  98. # if module in visited:
  99. # return
  100. # visited.add(module)
  101. path.append(module)
  102. for i in module:
  103. if i.imported in path:
  104. yield Cycle(path[path.index(i.imported):])
  105. else:
  106. yield from visit(i.imported)
  107. path.pop()
  108. for v in self.values():
  109. yield from visit(v)
  110. def find_cycles(self):
  111. return list(sorted(set(self._find_cycles())))
  112. def get_all_imports(self):
  113. for module in self.values():
  114. yield from module
  115. def __str__(self):
  116. return '''\n
  117. digraph "import" {{
  118. charset="utf-8"
  119. rankdir=BT
  120. {}
  121. }}
  122. '''.format('\n'.join(str(i) for i in self.get_all_imports()))
  123. def main():
  124. package = Package(sys.argv[1])
  125. sys.stdout.write(str(package))
  126. for cycle in package.find_cycles():
  127. sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
  128. if __name__ == '__main__':
  129. main()