#!/usr/bin/env python from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from SocketServer import ThreadingMixIn from httplib import HTTPResponse from urlparse import urlparse from ssl import wrap_socket from os import path, mkdir from socket import socket from re import compile from OpenSSL.crypto import (X509Extension, X509, dump_privatekey, dump_certificate, load_certificate, load_privatekey, PKey, TYPE_RSA, X509Req) from OpenSSL.SSL import FILETYPE_PEM __author__ = 'Nadeem Douba' __copyright__ = 'Copyright 2012, Cygnos Corporation' __credits__ = ['Nadeem Douba'] __license__ = 'GPL' __version__ = '0.1' __maintainer__ = 'Nadeem Douba' __email__ = 'ndouba@cygnos.com' __status__ = 'Development' __all__ = [ 'CertificateAuthority', 'ProxyHandler', 'RequestInterceptorPlugin', 'ResponseInterceptorPlugin', 'MitmProxy', 'AsyncMitmProxy', 'InvalidInterceptorPluginException' ] class CertificateAuthority(object): def __init__(self, ca_file='ca.pem'): self.ca_file = 'ca.pem' self._serial = 1 if not path.exists('.ssl'): mkdir('.ssl') if not path.exists(ca_file): self._generate_ca() else: self._read_ca(ca_file) def _generate_ca(self): # Generate key self.key = PKey() self.key.generate_key(TYPE_RSA, 2048) # Generate certificate self.cert = X509() self.cert.set_version(3) self.cert.set_serial_number(1) self.cert.get_subject().CN = 'ca.mitm.com' self.cert.gmtime_adj_notBefore(0) self.cert.gmtime_adj_notAfter(315360000) self.cert.set_issuer(self.cert.get_subject()) self.cert.set_pubkey(self.key) self.cert.add_extensions([ X509Extension("basicConstraints", True, "CA:TRUE, pathlen:0"), X509Extension("keyUsage", True, "keyCertSign, cRLSign"), X509Extension("subjectKeyIdentifier", False, "hash", subject=self.cert), ]) self.cert.sign(self.key, "sha1") with open('ca.pem', 'wb+') as f: f.write(dump_privatekey(FILETYPE_PEM, self.key)) f.write(dump_certificate(FILETYPE_PEM, self.cert)) def _read_ca(self, file): self.cert = load_certificate(FILETYPE_PEM, open(file).read()) self.key = load_privatekey(FILETYPE_PEM, open(file).read()) def __getitem__(self, cn): cnp = path.sep.join(['.ssl', '.%s.pem' % cn]) if not path.exists(cnp): # create certificate key = PKey() key.generate_key(TYPE_RSA, 2048) # Generate CSR req = X509Req() req.get_subject().CN = cn req.set_pubkey(key) req.sign(key, 'sha1') # Sign CSR cert = X509() cert.set_subject(req.get_subject()) cert.set_serial_number(self.serial) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(31536000) cert.set_issuer(self.cert.get_subject()) cert.set_pubkey(req.get_pubkey()) cert.sign(self.key, 'sha1') with open(cnp, 'wb+') as f: f.write(dump_privatekey(FILETYPE_PEM, key)) f.write(dump_certificate(FILETYPE_PEM, cert)) return cnp @property def serial(self): self._serial += 1 return self._serial class UnsupportedSchemeException(Exception): pass class ProxyHandler(BaseHTTPRequestHandler): r = compile(r'http://[^/]+(/?.*)') def _connect_to_host(self): # Get hostname and port to connect to if self.is_connect: host, port = self.path.split(':') else: u = urlparse(self.path) if u.scheme != 'http': raise UnsupportedSchemeException('Unknown scheme %s' % repr(u.scheme)) host = u.hostname port = u.port or 80 # Connect to destination self._proxy_sock = socket() self._proxy_sock.settimeout(10) self._proxy_sock.connect((host, int(port))) # Wrap socket if SSL is required if self.is_connect: self._proxy_sock = wrap_socket(self._proxy_sock) def _transition_to_ssl(self): self.request = wrap_socket(self.request, server_side=True, certfile=self.server.ca[self.path.split(':')[0]]) def do_CONNECT(self): self.is_connect = True try: # Connect to destination first self._connect_to_host() # If successful, let's do this! self.send_response(200, 'Connection established') self.end_headers() #self.request.sendall('%s 200 Connection established\r\n\r\n' % self.request_version) self._transition_to_ssl() except Exception, e: self.send_error(500, str(e)) return # Reload! self.setup() self.ssl_host = 'https://%s' % self.path self.handle_one_request() def do_COMMAND(self): # Is this an SSL tunnel? path = self.path or '/' if not self.is_connect: try: # Connect to destination self._connect_to_host() except Exception, e: self.send_error(500, str(e)) return # Extract path path = self.r.search(path).groups()[0] or '/' # Build request req = '%s %s %s\r\n' % (self.command, path, self.request_version) # Add headers to the request req += '%s\r\n' % self.headers # Append message body if present to the request if 'Content-Length' in self.headers: req += self.rfile.read(int(self.headers['Content-Length'])) # Send it down the pipe! self._proxy_sock.sendall(self.mitm_request(req)) # Parse response h = HTTPResponse(self._proxy_sock) h.begin() # Get rid of the pesky header del h.msg['Transfer-Encoding'] # Time to relay the message across res = '%s %s %s\r\n' % (self.request_version, h.status, h.reason) res += '%s\r\n' % h.msg res += h.read() # Let's close off the remote end h.close() self._proxy_sock.close() # Relay the message self.request.sendall(self.mitm_response(res)) def mitm_request(self, data): for p in self.server._req_plugins: data = p(self.server, self).do_request(data) return data def mitm_response(self, data): for p in self.server._res_plugins: data = p(self.server, self).do_response(data) return data def __getattr__(self, item): if item.startswith('do_'): return self.do_COMMAND class InterceptorPlugin(object): def __init__(self, server, msg): self.server = server self.message = msg class RequestInterceptorPlugin(InterceptorPlugin): def do_request(self, data): return data class ResponseInterceptorPlugin(InterceptorPlugin): def do_response(self, data): return data class InvalidInterceptorPluginException(Exception): pass class MitmProxy(HTTPServer): def __init__(self, server_address=('', 8080), RequestHandlerClass=ProxyHandler, bind_and_activate=True, ca_file='ca.pem'): HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self.ca = CertificateAuthority(ca_file) self._res_plugins = [] self._req_plugins = [] def register_interceptor(self, interceptor_class): if not issubclass(interceptor_class, InterceptorPlugin): raise InvalidInterceptorException('Expected type InterceptorPlugin got %s instead' % type(interceptor_class)) if issubclass(interceptor_class, RequestInterceptorPlugin): self._req_plugins.append(interceptor_class) if issubclass(interceptor_class, ResponseInterceptorPlugin): self._res_plugins.append(interceptor_class) class AsyncMitmProxy(ThreadingMixIn, MitmProxy): pass class MitmProxyHandler(ProxyHandler): def mitm_request(self, data): print '>> %s' % repr(data[:100]) return data def mitm_response(self, data): print '<< %s' % repr(data[:100]) return data class DebugInterceptor(RequestInterceptorPlugin, ResponseInterceptorPlugin): def do_request(self, data): print '>> %s' % repr(data[:100]) return data def do_response(self, data): print '<< %s' % repr(data[:100]) return data if __name__ == '__main__': proxy = AsyncMitmProxy() proxy.register_interceptor(DebugInterceptor) try: proxy.serve_forever() except KeyboardInterrupt: proxy.server_close()