blob: adeca2cb165f0112524689b382a1dce110bdb2e7 [file] [log] [blame]
# Copyright 2018 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""The backend service of Google Storage Cache Server.
Run `./bin/gs_archive_server` to start the server. After started, it listens on
a TCP port and/or an Unix domain socket. The latter performs better when work
with a local hosted reverse proxy server, e.g. Nginx.
The server accepts below requests:
- GET /download/<bucket>/path/to/file: download the file from google storage
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import contextlib
import functools
import os
import StringIO
import subprocess
import sys
import tempfile
import urllib
import urlparse
import cherrypy
import requests
import tarfile_utils
from chromite.lib import cros_logging as logging
from chromite.lib import gs
# some http status codes
_HTTP_OK = 200
_HTTP_PARTIAL_CONTENT = 206
_HTTP_BAD_REQUEST = 400
_HTTP_UNAUTHORIZED = 401
_HTTP_NOT_FOUND = 404
_HTTP_INTERNAL_SERVER_ERROR = 500
_HTTP_SERVICE_UNAVAILABLE = 503
_READ_BUFFER_SIZE_BYTES = 1024 * 1024 # 1 MB
_WRITE_BUFFER_SIZE_BYTES = 1024 * 1024 # 1 MB
# The max size of temporary spool file in memory.
_SPOOL_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MB
_logger = logging.getLogger(__file__)
def _log(*args, **kwargs):
"""A wrapper function of logging.debug/info, etc."""
level = kwargs.pop('level', logging.DEBUG)
_logger.log(level, extra=cherrypy.request.headers, *args, **kwargs)
def _log_filtered_headers(all_headers, filtered_headers, level=logging.DEBUG):
"""Log the filtered headers only."""
_log('Filtered headers: %s', {k: all_headers.get(k) for k in
filtered_headers}, level=level)
def _check_file_extension(filename, ext_names=None):
"""Check the file name and, optionally, the ext name.
Args:
filename: The file name to be checked.
ext_names: The valid extension of |filename| should have.
Returns:
The filename if the check is good.
Raises:
ValueError: Raised if the checking failed.
"""
if not filename:
raise ValueError('File name is required.')
for ext_name in ext_names or []:
if filename.endswith(ext_name):
break
else:
raise ValueError("Extension name of '%s' isn't in %s" % (filename,
ext_names))
return filename
def _safe_get_param(all_params, param_name):
"""Check if |param_name| is in |all_params|.
Args:
all_params: A dict of all parameters of the call.
param_name: The parameter name to be checked.
Returns:
The value of |param_name|.
Raises:
Raise HTTP 400 error when the required parameter isn't in |all_params|.
"""
try:
return all_params[param_name]
except KeyError:
raise cherrypy.HTTPError(_HTTP_BAD_REQUEST,
'Parameter "%s" is required!' % param_name)
def _to_cherrypy_error(func):
"""A decorator to convert Exceptions raised to proper cherrypy.HTTPError."""
@functools.wraps(func)
def func_wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except requests.HTTPError as err:
# cherrypy.HTTPError wraps the error messages with HTML tags. But
# requests.HTTPError also do same work. So return the error message
# directly.
cherrypy.response.status = err.response.status_code
return err.response.content
except ValueError as err:
# The exception message is just a plain text, so wrap it with
# cherrypy.HTTPError to have necessary HTML tags
raise cherrypy.HTTPError(_HTTP_BAD_REQUEST, err.message)
return func_wrapper
class _CachingServer(object):
r"""The interface of caching server for GsArchiveServer.
This class provides an interface to work with the caching server (usually a
reversed http proxy server) which caches all intermediate results, e.g.
downloaded files, etc. and serves to GsArchiveServer.
The relationship of this class and other components is:
/-------------(python function call)-----------------------\
| |
v |
_CachingServer --(http/socket)--> NGINX --(http/socket)--> GsArchiveServer
^ |
| (https)
End user, DUTs ---(http)------------/ |
V
GoogleStorage
"""
def __init__(self, url):
"""Constructor
Args:
url: A tuple of URL scheme and netloc.
Raises:
ValueError: Raised when input URL in wrong format.
"""
self._url = url
def _call(self, action, path, args=None, headers=None):
"""Helper function to generate all RPC calls to the proxy server."""
url = urlparse.urlunsplit(self._url + ('%s/%s' % (action, path),
urllib.urlencode(args or {}), None))
_log('Sending request to caching server: %s', url)
# The header to control using or bypass cache.
_log_filtered_headers(headers, ('X-No-Cache',))
rsp = requests.get(url, headers=headers, stream=True)
_log('Caching server response %s', rsp.status_code)
_log_filtered_headers(rsp.headers, ('Content-Type', 'Content-Length',
'X-Cache', 'Cache-Control', 'Date'))
rsp.raise_for_status()
return rsp
def download(self, path, headers=None):
"""Call download RPC."""
return self._call('download', path, headers=headers)
def list_member(self, path, headers=None):
"""Call list_member RPC."""
return self._call('list_member', path, headers=headers)
class GsArchiveServer(object):
"""The backend of Google Storage Cache server."""
def __init__(self, caching_server):
self._gsutil = gs.GSContext()
self._caching_server = caching_server
@cherrypy.expose
@_to_cherrypy_error
def list_member(self, *args):
"""Get file list of an tar archive in CSV format.
An example, GET /list_member/bucket/path/to/file.tar
The output is in format of:
<file name>,<data1>,<data2>,...<data6>
<file name>,<data1>,<data2>,...<data6>
...
Details:
<file name>: The file name of the member, in URL percent encoding, e.g.
path/to/file,name -> path/to/file%2Cname.
<data1>: File record start offset, in bytes.
<data2>: File record size, in bytes.
<data3>: File record end offset, in bytes.
<data4>: File content start offset, in bytes.
<data5>: File content size, in bytes.
<data6>: File content end offset, in bytes.
This is an internal RPC and shouldn't be called by end user!
Args:
*args: All parts of tar file name (must end with '.tar').
Returns:
The generator of CSV stream.
"""
# TODO(guocb): new parameter to filter the list
archive = _check_file_extension('/'.join(args), ext_names=['.tar'])
rsp = self._caching_server.download(archive, cherrypy.request.headers)
cherrypy.response.headers['Content-Type'] = 'text/csv'
# We run tar command to get member list of a tar file (python tarfile module
# is too slow). Option '--block-number/-R' of tar prints out the starting
# block number for each file record.
_log('list member of the tar %s', archive)
tar_tv = tempfile.SpooledTemporaryFile(max_size=_SPOOL_FILE_SIZE_BYTES)
tar = subprocess.Popen(['tar', 'tv', '--block-number'],
stdin=subprocess.PIPE, stdout=tar_tv)
for chunk in rsp.iter_content(_READ_BUFFER_SIZE_BYTES):
tar.stdin.write(chunk)
tar.wait()
def _tar_member_list():
with tar_tv, contextlib.closing(StringIO.StringIO()) as stream:
tar_tv.seek(0)
for info in tarfile_utils.list_tar_members(tar_tv):
# Encode file name using URL percent encoding, so ',' in file name
# becomes to '%2C'.
stream.write('%s,%d,%d,%d,%d\n' % (
urllib.quote(info.filename), info.record_start, info.record_size,
info.content_start, info.size))
if stream.tell() > _WRITE_BUFFER_SIZE_BYTES:
yield stream.getvalue()
stream.seek(0)
if stream.tell():
yield stream.getvalue()
_log('list_member done')
return _tar_member_list()
@cherrypy.expose
@_to_cherrypy_error
def download(self, *args):
"""Download a file from Google Storage.
For example: GET /download/bucket/path/to/file. This downloads the file
gs://bucket/path/to/file.
Args:
*args: All parts of the GS file path without gs:// prefix.
Returns:
The stream of downloaded file.
"""
path = 'gs://%s' % _check_file_extension('/'.join(args))
_log('Downloading %s', path, level=logging.INFO)
try:
stat = self._gsutil.Stat(path)
content = self._gsutil.StreamingCat(path)
except gs.GSNoSuchKey as err:
raise cherrypy.HTTPError(_HTTP_NOT_FOUND, err.message)
except gs.GSCommandError as err:
if "You aren't authorized to read" in err.result.error:
status = _HTTP_UNAUTHORIZED
else:
status = _HTTP_SERVICE_UNAVAILABLE
raise cherrypy.HTTPError(status, '%s: %s' % (err.message,
err.result.error))
cherrypy.response.headers.update({
'Content-Type': stat.content_type,
'Accept-Ranges': 'bytes',
'Content-Length': stat.content_length,
})
_log('Download complete.')
return content
@cherrypy.expose
@_to_cherrypy_error
def extract(self, *args, **kwargs):
"""Extract a file from a Tar archive.
Examples:
GET /extract/chromeos-image-archive/release/files.tar?file=path/to/file
Args:
*args: All parts of the GS path of the archive, without gs:// prefix.
kwargs: file: The URL ENCODED path of file to be extracted.
Returns:
The stream of extracted file.
"""
# TODO(guocb): support compressed format of tar
archive = _check_file_extension('/'.join(args), ext_names=['.tar'])
filename = _safe_get_param(kwargs, 'file')
_log('Extracting "%s" from "%s".', filename, archive)
return self._extract_file_from_tar(filename, archive)
def _extract_file_from_tar(self, filename, archive):
"""Extract file of |filename| from |archive|."""
# Call `list_member` and search |filename| in it. If found, create another
# "Range Request" to download that range of bytes.
all_files = self._caching_server.list_member(
archive, headers=cherrypy.request.headers)
# The format of each line is '<filename>,<data1>,<data2>...'. And the
# filename is encoded by URL percent encoding, so no ',' in filename. Thus
# search '<filename>,' is good enough for looking up the file information.
target_str = '%s,' % filename
for l in all_files.iter_lines(chunk_size=_READ_BUFFER_SIZE_BYTES):
# TODO(guocb): Seems the searching performance is OK: for a 60K lines
# list, searching a file in very last takes about 0.1 second. But it
# would be better if there's a better and not very complex way.
# We don't split the line by ',' and compare the filename part is because
# of performance.
if l.startswith(target_str):
_log('The line for the file found: %s', l)
file_info = tarfile_utils.TarMemberInfo._make(l.split(','))
rsp = self._send_range_request(archive, file_info.content_start,
file_info.size)
rsp.raise_for_status()
return rsp.iter_content(_WRITE_BUFFER_SIZE_BYTES)
raise cherrypy.HTTPError(
_HTTP_BAD_REQUEST,
'File "%s" is not in archive "%s"!' % (filename, archive))
def _send_range_request(self, archive, start, size):
"""Create and send a "Range Request" to caching server.
Set HTTP Range header and just download the bytes in that "range".
https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests
"""
headers = cherrypy.request.headers.copy()
headers['Range'] = 'bytes=%s-%d' % (start, int(start) + int(size) - 1)
rsp = self._caching_server.download(archive, headers=headers)
if rsp.status_code == _HTTP_PARTIAL_CONTENT:
# Although this is a partial response, it has full content of
# extracted file. Thus reset the status code and content-length.
cherrypy.response.status = _HTTP_OK
cherrypy.response.headers['Content-Length'] = size
else:
_log('Expect HTTP_PARTIAL_CONTENT (206), but got %s', rsp.status_code,
level=logging.ERROR)
rsp.status_code = _HTTP_INTERNAL_SERVER_ERROR
rsp.reason = 'Range request failed for "%s"' % archive
return rsp
# pylint:disable=protected-access
download._cp_config = {'response.stream': True}
list_member._cp_config = {'response.stream': True}
extract._cp_config = {'response.stream': True}
def _url_type(input_string):
"""Ensure |input_string| is a valid URL and convert to target type.
The target type is a tuple of (scheme, netloc).
"""
split_result = urlparse.urlsplit(input_string)
if not split_result.scheme:
input_string = 'http://%s' % input_string
split_result = urlparse.urlsplit(input_string)
if not split_result.scheme or not split_result.netloc:
raise argparse.ArgumentTypeError('Wrong URL format: %s' % input_string)
return split_result.scheme, split_result.netloc
def parse_args(argv):
"""Parse arguments."""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description=__doc__)
# The service can either bind to a socket or listen to a port, but doesn't do
# both.
socket_or_port = parser.add_mutually_exclusive_group(required=True)
socket_or_port.add_argument('-s', '--socket',
help='Unix domain socket to bind')
socket_or_port.add_argument('-p', '--port', type=int,
help='Port number to listen.')
# TODO(guocb): support Unix domain socket
parser.add_argument(
'-c', '--caching-server', required=True, type=_url_type,
help='URL of the proxy server. Valid format is '
'[http://]{<hostname>|<IP>}[:<port_number>]. When skipped, the default '
'scheme is http and port number is 80. Any other components in URL are '
'ignored.')
return parser.parse_args(argv)
def setup_logger():
"""Setup logger."""
formatter = logging.Formatter(
'%(module)s:%(asctime)-15s [%(Remote-Addr)s:%(thread)d] %(levelname)s:'
' %(message)s')
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
_logger.setLevel(logging.DEBUG)
_logger.addHandler(handler)
def main(argv):
"""Main function."""
args = parse_args(argv)
setup_logger()
if args.socket:
# in order to allow group user writing to domain socket, the directory
# should have GID bit set, i.e. g+s
os.umask(0002)
cherrypy.server.socket_port = args.port
cherrypy.server.socket_file = args.socket
cherrypy.quickstart(GsArchiveServer(_CachingServer(args.caching_server)))
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))