import threading
import socket
import logging
from random import getrandbits
from time import sleep, time
from multiprocessing.pool import ThreadPool
from multiprocessing import Lock
from threading import Thread

class MassTraceroute:
	def __init__(self, socket_timeout, listener_timeout, threads, targets, maxhops):
		logging.info("Setting up class")
		self.wait = 0.2
		self.maxhops = maxhops
		self.listener_timeout = listener_timeout
		self.socket_timeout = socket_timeout
		self.targets = targets
		self.threads = threads
		self.pool = ThreadPool(processes=self.threads)
		self.routes = dict()
		logging.info("creating listener socket")
		self.listener = self.init_listener()
		self.sources = dict()
		self.last_send = 0
		self.sender = dict()
		for thread in range(threads):
			logging.info("Creating #{} sender socket".format(thread))
			self.sender[thread] = dict()
			self.sender[thread]['socket'] = self.init_sender()
			self.sender[thread]['lock'] = Lock()

	def main(self):
		logging.info("Starting listener thread")
		Thread(target=self.process_listener).start()
		logging.info("Starting sender thread")
		Thread(target=self.process_sender).start()

	def init_listener(self):
		listener = socket.socket(
			family=socket.AF_INET,
			type=socket.SOCK_RAW,
			proto=socket.IPPROTO_ICMP
		)
		listener.settimeout(self.listener_timeout)
		listener.bind(('', 0))
		return listener

	def process_listener(self):
		while((self.last_send + self.listener_timeout) > time()):
			try:
				data, addr = self.listener.recvfrom(1024)
			except socket.error as e:
				pass
			srcport = int.from_bytes(data[48:49], byteorder="big")
			dstport = int.from_bytes(data[50:51], byteorder="big")
			if srcport in self.sources and dstport in self.sources[srcport]:
				hop = self.sources[srcport][dstport]
				ip = self.sources[srcport]["ip"]
				self.routes[ip][hop] = addr
			else:
				logging.error("Received package with srcport or dstport non existsant?")

	def process_sender(self):
		self.pool.map(self.send, self.targets)


	def init_sender(self):
		sender = socket.socket(
			family=socket.AF_INET,
			type=socket.SOCK_DGRAM,
			proto=socket.IPPROTO_UDP
		)
		return sender

	def send(self, dstip):
		lock = False
		while(not lock):
			for thread in range(self.threads):
				if self.sender[thread]["lock"].acquire():
					break
			sleep(self.wait)

		srcport = getrandbits(16)
		while(srcport not in self.sources):
			srcport = getrandbits(16)
	   	
		self.sources[srcport]["ip"] = dstip

		self.sender[thread]["socket"].bind(('', srcport))

		for ttl in range(self.maxhops):
			self.sender[thread]["socket"].setsockopt(socket.SOL_IP, socket.IP_TTL, ttl)
			dstport = getrandbits(16)
			while(dstport not in self.sources["srcport"]):
				dstport = getrandbits(16)
			self.sources["srcport"][dstport] = ttl
			logging.debug("Sending packet srcport = {}, dstport = {}, ttl = {}, dstip = {}".format(srcport, dstport, ttl, dstip))
			sender.sendto(b'', (dstip, dstport))
			self.last_send = time()

		self.thread[sender]["lock"].release()