import-graph 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #!/usr/bin/env python3
  2. import itertools
  3. import os
  4. import re
  5. import sys
  6. re_import = re.compile(r'^import (.*?)$', re.M)
  7. re_import_from = re.compile(r'^from (.*?) import .*?$', re.M)
  8. class Import(object):
  9. defstyle = {'arrowhead': 'open', 'arrowtail':'none'}
  10. def __init__(self, importing, imported, **kwargs):
  11. self.importing = importing
  12. self.imported = imported
  13. self.style = self.defstyle.copy()
  14. self.style.update(kwargs)
  15. def __str__(self):
  16. return '{}"{}" -> "{}" [{}];'.format(
  17. ('//' if self.commented else ''),
  18. self.importing,
  19. self.imported,
  20. ', '.join('{}="{}"'.format(*i) for i in self.style.items()))
  21. def __eq__(self, other):
  22. return (self.importing.name, self.imported.name) \
  23. == (other.importing.name, other.imported.name)
  24. def __hash__(self):
  25. return hash((self.importing.name, self.imported.name))
  26. @property
  27. def commented(self):
  28. if self.style.get('color', '') != 'red':
  29. return True
  30. # for i in (self.importing, self.imported):
  31. # if i.name.startswith('qubes.tests'): return True
  32. # if i.name.startswith('qubes.tools'): return True
  33. class Module(set):
  34. def __init__(self, package, path):
  35. self.package = package
  36. self.path = path
  37. def process(self):
  38. with open(os.path.join(self.package.root, self.path)) as fh:
  39. data = fh.read()
  40. data.replace('\\\n', ' ')
  41. for imported in re_import.findall(data):
  42. try:
  43. imported = self.package[imported]
  44. except KeyError:
  45. continue
  46. self.add(Import(self, imported))
  47. for imported in re_import_from.findall(data):
  48. try:
  49. imported = self.package[imported]
  50. except KeyError:
  51. continue
  52. self.add(Import(self, imported, style='dotted'))
  53. def __getitem__(self, key):
  54. for i in self:
  55. if i.imported == key:
  56. return i
  57. raise KeyError(key)
  58. @property
  59. def name(self):
  60. names = os.path.splitext(self.path)[0].split('/')
  61. names.insert(0, self.package.name)
  62. if names[-1] == '__init__':
  63. del names[-1]
  64. return '.'.join(names)
  65. def __hash__(self):
  66. return hash(self.name)
  67. def __str__(self):
  68. return self.name
  69. def __repr__(self):
  70. return '<{} {!r}>'.format(self.__class__.__name__, self.name)
  71. def __lt__(self, other):
  72. return self.name < other.name
  73. def __eq__(self, other):
  74. return self.name == other.name
  75. class Cycle(tuple):
  76. def __new__(cls, modules):
  77. i = modules.index(sorted(modules)[0])
  78. # sys.stderr.write('modules={!r} i={!r}\n'.format(modules, i))
  79. return super(Cycle, cls).__new__(cls, modules[i:] + modules[:i+1])
  80. # def __lt__(self, other):
  81. # if len(self) < len(other):
  82. # return True
  83. # elif len(self) > len(other):
  84. # return False
  85. #
  86. # return super(Cycle, self).__lt__(other)
  87. class Package(dict):
  88. def __init__(self, root):
  89. super(Package, self).__init__()
  90. self.root = root
  91. for dirpath, dirnames, filenames in os.walk(self.root):
  92. for filename in filenames:
  93. if not os.path.splitext(filename)[1] == '.py':
  94. continue
  95. module = Module(self,
  96. os.path.relpath(os.path.join(dirpath, filename), self.root))
  97. self[module.name] = module
  98. for name, module in self.items():
  99. module.process()
  100. @property
  101. def name(self):
  102. return os.path.basename(self.root.rstrip(os.path.sep))
  103. def _find_cycles(self):
  104. # stolen from codereview.stackexchange.com/questions/86021 and hacked
  105. path = []
  106. visited = set()
  107. def visit(module):
  108. # if module in visited:
  109. # return
  110. # visited.add(module)
  111. path.append(module)
  112. for i in module:
  113. if i.imported in path:
  114. yield Cycle(path[path.index(i.imported):])
  115. else:
  116. yield from visit(i.imported)
  117. path.pop()
  118. for v in self.values():
  119. yield from visit(v)
  120. def find_cycles(self):
  121. return list(sorted(set(self._find_cycles())))
  122. def get_all_imports(self):
  123. for module in self.values():
  124. yield from module
  125. def __str__(self):
  126. return '''\n
  127. digraph "import" {{
  128. charset="utf-8"
  129. rankdir=BT
  130. {}
  131. }}
  132. '''.format('\n'.join(str(i) for i in self.get_all_imports()))
  133. def main():
  134. package = Package(sys.argv[1])
  135. for cycle in package.find_cycles():
  136. for i in range(len(cycle) - 1):
  137. edge = cycle[i][cycle[i+1]]
  138. edge.style['color'] = 'red'
  139. sys.stderr.write(' -> '.join(str(module) for module in cycle) + '\n')
  140. sys.stdout.write(str(package))
  141. if __name__ == '__main__':
  142. main()