blob: a347f5510850e3ffbb7575020d5d64be33c2e945 [file] [log] [blame]
# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""HTTP Server classes for Recall server.
This module should not be imported directly, instead the public classes
are imported directly into the top-level recall package.
"""
__all__ = ["HTTPServer", "HTTPSServer", "HTTPRequestHandler"]
import BaseHTTPServer
import fnmatch
import httplib
import logging
import SocketServer
import ssl
import tempfile
import threading
import socket_util
from certificate_authority import CertificateAuthority
from dns_client import DNSRequest, DNSClient
from http_client import HTTPRequest, HTTPClient
def _GetHostnameForAddress(dns_client, address):
"""Get the hostname for an address.
Utility function that uses a DNS Client to perform a reverse lookup
and returns the hostname without trailing periods.
"""
for hostname in dns_client(DNSRequest.ReverseLookup(address)):
return hostname.text.rstrip('.')
else:
return None
class HTTPServer(SocketServer.ThreadingMixIn,
BaseHTTPServer.HTTPServer,
threading.Thread):
"""Simple multithreaded HTTP Server.
This class implements a multithreaded HTTP Server that uses the HTTP Client
passed to the constructor to resolve requests. For consistency with the
HTTPSServer class, the constructor also accepts DNS Client and
Certificate Authority arguments, though these will not be used.
The shutdown() method must be called to clean up.
"""
logger = logging.getLogger("HTTPServer")
ssl = False
def __init__(self, server_address,
http_client=HTTPClient(),
dns_client=DNSClient(),
certificate_authority=None):
BaseHTTPServer.HTTPServer.__init__(self, server_address, HTTPRequestHandler)
self.request_queue_size = 128
threading.Thread.__init__(self, target=self.serve_forever)
self.http_client = http_client
self.dns_client = dns_client
self.certificate_authority = certificate_authority
self.logger.info("Starting on %s", self.server_address)
self.daemon = True
self.start()
def shutdown(self):
"""Shutdown the server."""
self.logger.info("Shutting down")
super(HTTPServer, self).shutdown()
class HTTPSServer(SocketServer.ThreadingMixIn,
BaseHTTPServer.HTTPServer,
threading.Thread):
"""Multithreaded HTTPS Server.
This class implements a multithreaded HTTPS Server that uses the HTTP Client
passed to the constructor to resolve requests. The original destination
address of incoming connections is resolved to a hostname using the passed
DNS Client, and a certificate generated using the passed Certificate
Authority.
For best results, the DNS Client should be the SymmetricDNSClient class.
The shutdown() method must be called to clean up.
"""
logger = logging.getLogger("HTTPSServer")
ssl = True
def __init__(self, server_address,
http_client=HTTPClient(),
dns_client=DNSClient(),
certificate_authority=None):
BaseHTTPServer.HTTPServer.__init__(self, server_address, HTTPRequestHandler)
self.request_queue_size = 128
threading.Thread.__init__(self, target=self.serve_forever)
self.http_client = http_client
self.dns_client = dns_client
self.certificate_authority = certificate_authority
self.logger.info("Starting on %s", self.server_address)
self.daemon = True
self.start()
def shutdown(self):
"""Shutdown the server."""
self.logger.info("Shutting down")
super(HTTPSServer, self).shutdown()
def get_request(self):
"""Accept incoming request.
Looks up the original destination of the address and resolves that to
a hostname using the DNS Client passed to the constructor. Certificates
and Private Keys are obtained from the class Certificate Authority,
and each connection is individually wrapped through SSL.
"""
(conn, address) = self.socket.accept()
self.logger.debug("Accepted request from %s:%d", address[0], address[1])
try:
original_address, original_port \
= socket_util.GetOriginalDestinationAddress(conn)
certificate_hostname = _GetHostnameForAddress(self.dns_client,
original_address)
if certificate_hostname is None:
certificate_hostname = original_address
self.logger.debug("Original destination %s:%d; using certificate for %s",
original_address, original_port, certificate_hostname)
except (TypeError, KeyError):
certificate_hostname = self.server_name
self.logger.warn("Using our own certificate for this request")
(certificate_file, private_key_file) = \
self.certificate_authority.GetCertificateAndPrivateKey(
certificate_hostname)
return (ssl.wrap_socket(conn, server_side=True,
certfile=certificate_file,
keyfile=private_key_file),
address)
class HTTPRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""Request handler for HTTP and HTTPS Servers.
Handles incoming HTTP requests on behalf of HTTPServer and HTTPSServer
(distinguished by their ssl members). The request is converted to an
HTTPRequest object and a response obtained from the server's http_client
member before being sent back to the client.
Additionally if the incoming request is directed at the server itself,
the first element of the path may be a function local to this class in
which case it is run to generate the response.
"""
protocol_version = 'HTTP/1.1'
# Turn on buffering, we explicitly flush where we need to
wbufsize = -1
def __init__(self, request, client_address, server):
self.logger = logging.getLogger("HTTPRequestHandler:%s:%d"
% (client_address[0], client_address[1]))
BaseHTTPServer.BaseHTTPRequestHandler.__init__(
self, request, client_address, server)
def log_request(self, code='-', size='-'):
# we do our own request logging
pass
# reformat other log messages to our own logger
def log_error(self, format, *args):
self.logger.error(format, *args)
def log_message(self, format, *args):
self.logger.info(format, *args)
# handle all methods the same way
def do_HEAD(self):
self._HandleRequest()
def do_GET(self):
self._HandleRequest()
def do_POST(self):
self._HandleRequest()
def _RequestIsForSelf(self):
"""Check whether the request is for our server name or IP Address.
Returns:
True if request is for us, False otherwise.
"""
server_aliases = [ self.server.server_name, self.request.getsockname()[0] ]
try:
sep = self.server.server_name.index('.')
server_aliases.append(self.server.server_name[:sep])
except ValueError:
pass
return self.host.split(':')[0] in server_aliases
def _HandleRequest(self):
"""Handle the request."""
# Lookup the hostname
self.host = self.headers.get('host', None)
if not self.host:
try:
original_address, original_port \
= socket_util.GetOriginalDestinationAddress(self.request)
hostname = _GetHostnameForAddress(self.server.dns_client,
original_address)
if hostname:
self.host = '%s:%d' % (hostname, original_port)
else:
self.host = '%s:%d' % (original_address, original_port)
self.logger.debug("Missing Host header in request, used %s", self.host)
except TypeError:
return self._Error("Missing Host header in request, "
"and can't obtain original destination")
# Handle requests for our own host
if self._RequestIsForSelf():
command = self.path[1:].split('/')
try:
return getattr(self, command[0])(*command[1:])
except AttributeError:
return self._Error("Unknown command %s" % command[0])
except TypeError, e:
return self._Error(str(e))
content_length = int(self.headers.get('Content-Length', 0))
if content_length:
body = self.rfile.read(content_length)
else:
body = None
request = HTTPRequest(self.host, self.command, self.path,
self.headers.items(), body,
self.server.ssl)
try:
response = self.server.http_client(request)
except KeyError:
return self._Error("Not found in archive", 404)
if response.version == 10:
self.protocol_version = 'HTTP/1.0'
self.send_response(response.status, response.reason)
sent_content_length = False
for header, value in response.headers:
self.send_header(header, value)
if header.title() == 'Content-Length':
sent_content_length = True
# Sometimes we need to send the content-length header ourselves; in those
# cases delay ending the headers until we receive the data from the server
if response.chunked or sent_content_length:
self.end_headers()
self.wfile.flush()
else:
self.logger.debug("Will send Content-Length later")
for chunk in response.chunks:
if response.chunked:
self.wfile.write('%x\r\n%s\r\n' % (len(chunk), chunk))
else:
if not sent_content_length:
self.send_header('Content-Length', str(len(chunk)))
self.end_headers()
sent_content_length = True
self.wfile.write(chunk)
self.wfile.flush()
# Should never happen, but let's be careful
if not response.chunked and not sent_content_length:
self.logger.debug("Handled empty request")
self.send_header('Content-Length', '0')
self.end_headers()
self.wfile.flush()
if response.version == 10:
self.close_connection = 1
def GetRootCertificate(self):
"""Generate a response with the CA's certificate.
Command intended for use by clients, writes back the attached CA's
root certificate.
"""
with open(self.server.certificate_authority.certificate_file) \
as cert:
certificate = cert.read()
self.send_response(httplib.OK)
self.send_header('Content-Type', 'text/plain')
self.send_header('Content-Length', str(len(certificate)))
self.end_headers()
self.wfile.write(certificate)
self.wfile.flush()
def _Error(self, message, code=httplib.INTERNAL_SERVER_ERROR):
"""Reply with an error.
Generates an error reply and returns it to the client.
"""
self.logger.warn(message)
self.send_response(code)
self.send_header('Content-Type', 'text/plain')
self.send_header('Content-Length', str(len(message)))
self.end_headers()
self.wfile.write(message)
self.wfile.flush()