Made DNS lookups threaded/async and switched to a real message queue for the pipes
diff --git a/tsproxy.py b/tsproxy.py index c8042f3..f905912 100644 --- a/tsproxy.py +++ b/tsproxy.py
@@ -18,10 +18,13 @@ import signal import socket import sys +import Queue +import threading server = None in_pipe = None out_pipe = None +must_exit = False connections = {} ######################################################################################################################## @@ -33,28 +36,56 @@ def __init__(self, direction): self.direction = direction + self.queue = Queue.Queue() def SendMessage(self, message): - global connections - connection_id = message['connection'] - if connection_id in connections: - peer = 'server' - if self.direction == self.PIPE_IN: - peer = 'client' - if peer in connections[connection_id]: - try: - connections[connection_id][peer].handle_message(message) - except: - try: - connections[connection_id]['server'].close() - except: - pass - try: - connections[connection_id]['client'].close() - except: - pass - del connections[connection_id] + self.queue.put(message); + def tick(self): + global connections + while not self.queue.empty(): + message = self.queue.get() + connection_id = message['connection'] + if connection_id in connections: + peer = 'server' + if self.direction == self.PIPE_IN: + peer = 'client' + if peer in connections[connection_id]: + try: + connections[connection_id][peer].handle_message(message) + except: + # Clean up any disconnected connections + try: + connections[connection_id]['server'].close() + except: + pass + try: + connections[connection_id]['client'].close() + except: + pass + del connections[connection_id] + + +######################################################################################################################## +# Threaded DNS resolver +######################################################################################################################## +class AsyncDNS(threading.Thread): + def __init__(self, client_id, hostname, port, result_pipe): + threading.Thread.__init__(self) + self.hostname = hostname + self.port = port + self.client_id = client_id + self.result_pipe = result_pipe + + def run(self): + try: + addresses = socket.getaddrinfo(self.hostname, self.port) + print '[{0:d}] Resolving {1}:{2:d} Completed'.format(self.client_id, self.hostname, self.port) + except: + addresses = () + print '[{0:d}] Resolving {1}:{2:d} Failed'.format(self.client_id, self.hostname, self.port) + message = {'message': 'resolved', 'connection': self.client_id, 'addresses': addresses} + self.result_pipe.SendMessage(message) ######################################################################################################################## # TCP Client @@ -63,9 +94,8 @@ STATE_ERROR = -1 STATE_IDLE = 0 STATE_RESOLVING = 1 - STATE_RESOLVED = 2 - STATE_CONNECTING = 3 - STATE_CONNECTED = 4 + STATE_CONNECTING = 2 + STATE_CONNECTED = 3 def __init__(self, client_id): asyncore.dispatcher.__init__(self) @@ -73,6 +103,9 @@ self.state = self.STATE_IDLE self.buffer = ''; self.addr = None + self.dns_thread = None + self.hostname = None + self.port = None def SendMessage(self, type, message): message['message'] = type @@ -130,21 +163,16 @@ self.SendMessage('data', {'data': data}) def HandleResolve(self, message): - #TODO (pmeenan): Run the actual lookup in a thread asynchronously + global in_pipe if 'hostname' in message: - hostname = message['hostname'] - port = 0 + self.hostname = message['hostname'] + self.port = 0 if 'port' in message: - port = message['port'] - print '[{0:d}] Resolving {1}:{2:d}'.format(self.client_id, hostname, port) + self.port = message['port'] + print '[{0:d}] Resolving {1}:{2:d}'.format(self.client_id, self.hostname, self.port) self.state = self.STATE_RESOLVING - try: - addresses = socket.getaddrinfo(hostname, port) - except: - addresses = () - print '[{0:d}] Resolving {1}:{2:d} FAILED'.format(self.client_id, hostname, port) - self.state = self.STATE_RESOLVED - self.SendMessage('resolved', {'addresses': addresses}) + self.dns_thread = AsyncDNS(self.client_id, self.hostname, self.port, in_pipe); + self.dns_thread.start() def HandleConnect(self, message): if 'addresses' in message and len(message['addresses']): @@ -328,13 +356,29 @@ print 'Starting Socks5 proxy server on {0}:{1:d}'.format(options.interface, options.port) signal.signal(signal.SIGINT, signal_handler) server = Socks5Server(options.interface, options.port) - asyncore.loop(timeout = 0.001, use_poll = True) + run_loop(timeout = 0.001) def signal_handler(signal, frame): global server print('Exiting...') + must_exit = True del server sys.exit(0) +# Replacement for the asyncore loop that allows us to schedule our own timers +def run_loop(timeout=0.001, use_poll=False): + global must_exit + global in_pipe + global out_pipe + if use_poll and hasattr(asyncore.select, 'poll'): + asyncore_poll = asyncore.poll2 + else: + asyncore_poll = asyncore.poll + map = asyncore.socket_map + while not must_exit: + asyncore_poll(timeout, map) + in_pipe.tick() + out_pipe.tick() + if '__main__' == __name__: main()