import threading import socket import logging from random import getrandbits from time import sleep, time from multiprocessing.pool import ThreadPool from multiprocessing import Lock from threading import Thread class MassTraceroute: def __init__(self, socket_timeout, listener_timeout, threads, targets, maxhops): logging.info("Setting up class") self.wait = 0.2 self.maxhops = maxhops self.listener_timeout = listener_timeout self.socket_timeout = socket_timeout self.targets = targets self.threads = threads self.pool = ThreadPool(processes=self.threads) self.routes = dict() logging.info("creating listener socket") self.listener = self.init_listener() self.sources = dict() self.last_send = 0 self.sender = dict() for thread in range(threads): logging.info("Creating #{} sender socket".format(thread)) self.sender[thread] = dict() self.sender[thread]['socket'] = self.init_sender() self.sender[thread]['lock'] = Lock() def main(self): logging.info("Starting listener thread") Thread(target=self.process_listener).start() logging.info("Starting sender thread") Thread(target=self.process_sender).start() def init_listener(self): listener = socket.socket( family=socket.AF_INET, type=socket.SOCK_RAW, proto=socket.IPPROTO_ICMP ) listener.settimeout(self.listener_timeout) listener.bind(('', 0)) return listener def process_listener(self): while((self.last_send + self.listener_timeout) > time()): try: data, addr = self.listener.recvfrom(1024) except socket.error as e: pass srcport = int.from_bytes(data[48:49], byteorder="big") dstport = int.from_bytes(data[50:51], byteorder="big") if srcport in self.sources and dstport in self.sources[srcport]: hop = self.sources[srcport][dstport] ip = self.sources[srcport]["ip"] self.routes[ip][hop] = addr else: logging.error("Received package with srcport or dstport non existsant?") def process_sender(self): self.pool.map(self.send, self.targets) def init_sender(self): sender = socket.socket( family=socket.AF_INET, type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP ) return sender def send(self, dstip): lock = False while(not lock): for thread in range(self.threads): if self.sender[thread]["lock"].acquire(): break sleep(self.wait) srcport = getrandbits(16) while(srcport not in self.sources): srcport = getrandbits(16) self.sources[srcport]["ip"] = dstip self.sender[thread]["socket"].bind(('', srcport)) for ttl in range(self.maxhops): self.sender[thread]["socket"].setsockopt(socket.SOL_IP, socket.IP_TTL, ttl) dstport = getrandbits(16) while(dstport not in self.sources["srcport"]): dstport = getrandbits(16) self.sources["srcport"][dstport] = ttl logging.debug("Sending packet srcport = {}, dstport = {}, ttl = {}, dstip = {}".format(srcport, dstport, ttl, dstip)) sender.sendto(b'', (dstip, dstport)) self.last_send = time() self.thread[sender]["lock"].release()