blob: 0e306caccdbd6d7897e825aa1c348e45dab451a3 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Copyright 2018 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""The backend service of Google Storage Cache Server.
Run `./bin/gs_archive_server` to start the server. After started, it listens on
a TCP port and/or an Unix domain socket. The latter performs better when work
with a local hosted reverse proxy server, e.g. Nginx.
The server accepts below requests:
- GET /download/<bucket>/path/to/file
Download the file from google storage.
- GET /extract/<bucket>/path/to/archive?file=path/to/file
Extract a file form a compressed/uncompressed TAR archive.
from __future__ import print_function
import argparse
import functools
import logging
import os
import subprocess
import sys
import tempfile
from six.moves import http_client
from six.moves import urllib
import requests
import cherrypy # pylint: disable=import-error
import constants
from chromite.lib import gs
_WRITE_BUFFER_SIZE_BYTES = 1024 * 1024 # 1 MB
# When extract files from TAR (either compressed or uncompressed), we suppose
# the TAR exists, so we can call `download` RPC to get it. It's straightforward
# for uncompressed TAR. But for compressed TAR, we cannot `download` it from
# GS because it doesn't exist there at all. In this case, we call `decompress`
# RPC internally to download and decompress. In order to tell if invoke of
# `download` RPC is a real download, or download+decompress, we use below HTTP
# header as a flag. It can also tell use what's the extension name of the
# compressed tar, e.g. '.tar.gz', etc. We use this information to get the file
# name on GS.
# The max size of temporary spool file in memory.
_SPOOL_FILE_SIZE_BYTES = 100 * 1024 * 1024 # 100 MB
_logger = logging.getLogger(__file__)
def _log(*args, **kwargs):
"""A wrapper function of logging.debug/info, etc."""
level = kwargs.pop('level', logging.DEBUG)
_logger.log(level, extra=cherrypy.request.headers, *args, **kwargs)
def _log_filtered_headers(all_headers, filtered_headers, level=logging.DEBUG):
"""Log the filtered headers only."""
_log('Filtered headers: %s', {k: all_headers.get(k) for k in
filtered_headers}, level=level)
def _check_file_extension(filename, ext_names=None):
"""Check the file name and, optionally, the ext name.
filename: The file name to be checked.
ext_names: The valid extension of |filename| should have.
The filename if the check is good.
ValueError: Raised if the checking failed.
if not filename:
raise ValueError('File name is required.')
if not ext_names:
return filename
for ext_name in ext_names:
if filename.endswith(ext_name):
return filename
raise ValueError("Extension name of '%s' isn't in %s" % (filename,
def _safe_get_param(all_params, param_name):
"""Check if |param_name| is in |all_params| and has non-empty value.
all_params: A dict of all parameters of the call.
param_name: The parameter name to be checked.
A set of all non-empty value.
Raise HTTP 400 error if no valid parameter in |all_params|.
value = all_params[param_name]
except KeyError:
raise cherrypy.HTTPError(http_client.BAD_REQUEST,
'Parameter "%s" is required!' % param_name)
return set(value) if isinstance(value, list) else {value}
def _to_cherrypy_error(func):
"""A decorator to convert Exceptions raised to proper cherrypy.HTTPError."""
def func_wrapper(*args, **kwargs):
return func(*args, **kwargs)
except requests.HTTPError as err:
# cherrypy.HTTPError wraps the error messages with HTML tags. But
# requests.HTTPError also do same work. So return the error message
# directly.
cherrypy.response.status = err.response.status_code
return err.response.content
except ValueError as err:
# The exception message is just a plain text, so wrap it with
# cherrypy.HTTPError to have necessary HTML tags
raise cherrypy.HTTPError(http_client.BAD_REQUEST, err.message) # pylint: disable=exception-message-attribute
return func_wrapper
class _CachingServer(object):
r"""The interface of caching server for GsArchiveServer.
This class provides an interface to work with the caching server (usually a
reversed http proxy server) which caches all intermediate results, e.g.
downloaded files, etc. and serves to GsArchiveServer.
The relationship of this class and other components is:
/-------------(python function call)-----------------------\
| |
v |
_CachingServer --(http/socket)--> NGINX --(http/socket)--> GsArchiveServer
^ |
| (https)
End user, DUTs ---(http)------------/ |
def __init__(self, url):
url: A tuple of URL scheme and netloc.
ValueError: Raised when input URL in wrong format.
self._url = url
def _call(self, action, path, args=None, headers=None):
"""Helper function to generate all RPC calls to the proxy server."""
url = urllib.parse.urlunsplit(
self._url + ('%s/%s' % (action, path),
urllib.parse.urlencode(args or {}),
_log('Sending request to caching server: %s', url)
# The header to control using or bypass cache.
_log_filtered_headers(headers, ('Range', 'X-No-Cache',
rsp = requests.get(url, headers=headers, stream=True)
_log('Caching server response %s: %s', rsp.status_code, url)
_log_filtered_headers(rsp.headers, ('Content-Type', 'Content-Length',
'Content-Range', 'X-Cache',
'Cache-Control', 'Date'))
return rsp
def _download_and_decompress_tar(self, path, ext_name, headers=None):
"""Helper function to download and decompress compressed TAR."""
# The |path| we have is like foo.tar. Combine with |ext_name| we can get
# the compressed file name on Google storage, e.g.
# 'foo.tar' + '.gz' => foo.tar.gz
# But it's special for '.tgz', i.e. 'foo.tar' + '.tgz' => 'foo.tgz'
if ext_name == '.tgz':
path, _ = os.path.splitext(path)
path = '%s%s' % (path, ext_name)
_log('Download and decompress %s', path)
return self._call('decompress', path, headers=headers)
def download(self, path, headers=None):
"""Download file |path| from the caching server."""
# When the request comes with header _HTTP_HEADER_COMPRESSED_TAR_EXT, we
# internally call `decompress` instead of `download` because Google storage
# only has the compressed version of the file to be "downloaded".
ext_name = headers.pop(_HTTP_HEADER_COMPRESSED_TAR_EXT, None)
# RPC `decompress` validates ext_name, so doesn't do that here.
if ext_name:
return self._download_and_decompress_tar(path, ext_name, headers=headers)
return self._call('download', path, headers=headers)
_SERVICE_ACCOUNT_BOTO_FILE = os.path.expanduser('~/.boto.service_account')
class _GSContext(object):
"""A wrapper class of gs.GSContext for service account migration purpose."""
def __init__(self):
self._ctx_default = gs.GSContext()
if os.path.isfile(_SERVICE_ACCOUNT_BOTO_FILE):
self._ctx_service_account = gs.GSContext(
self._ctx_service_account = None
def fetch_file(self, path, want_content=True):
"""Fetch the file stat and/or content from GS bucket.
path: The GS path of the file to fetch.
want_content: A boolean of whether fetch the content (as an iterator).
A tuple of (stat, content) which is the GS file stat and content
iterator (or None).
stat = self._ctx_default.Stat(path)
ctx = self._ctx_default
except gs.GSCommandError as err:
if not self._ctx_service_account:
if not err.stderr.startswith("You aren't authorized to read "):
_log('Not authorized by default. Trying service account.')
stat = self._ctx_service_account.Stat(path)
ctx = self._ctx_service_account
if want_content:
_log('Downloading %s', path, level=logging.INFO)
return stat, ctx.StreamingCat(path)
return stat, None
class GsArchiveServerError(Exception):
"""Standard exception class for GsArchiveServer."""
class GsArchiveServer(object):
"""The backend of Google Storage Cache server."""
def __init__(self, caching_server):
self._gsutil = _GSContext()
self._caching_server = caching_server
def list_dir(self, *args):
"""Lists contents of specified GS bucket/<board>/version."""
path = 'gs://%s' % _check_file_extension('/'.join(args))
gs_cmd = ['gsutil', 'ls', path]
proc = subprocess.Popen(gs_cmd, stdout=subprocess.PIPE,
content, _ = proc.communicate()
except subprocess.CalledProcessError as e:
raise cherrypy.HTTPError(http_client.NOT_FOUND, e.output)
return content
@cherrypy.config(**{'': True})
def download(self, *args):
"""Download a file from Google Storage.
For example: GET /download/bucket/path/to/file. This downloads the file
*args: All parts of the GS file path without gs:// prefix.
The stream of downloaded file.
path = 'gs://%s' % _check_file_extension('/'.join(args))
want_content = cherrypy.request.method != 'HEAD'
stat, content = self._gsutil.fetch_file(path, want_content)
except gs.GSNoSuchKey as err:
raise cherrypy.HTTPError(http_client.NOT_FOUND, err.message) # pylint: disable=exception-message-attribute
except gs.GSCommandError as err:
if "You aren't authorized to read" in err.result.error:
status = http_client.UNAUTHORIZED
status = http_client.SERVICE_UNAVAILABLE
raise cherrypy.HTTPError(status, '%s: %s' % (err.message, # pylint: disable=exception-message-attribute
'Content-Type': stat.content_type,
'Accept-Ranges': 'bytes',
'Content-Length': stat.content_length,
return content
@cherrypy.config(**{'': True})
def extract(self, *args, **kwargs):
"""Extract files from a compressed/uncompressed Tar archive.
The RPC accepts query 'file=' which is either a file name or a glob pattern.
It's optional to encode the file name or pattern in 'percent-encoding', i.e.
'/' -> '%2F', '*' -> '%2A', etc.
Extracting file 'path/to/file' from files.tgz:
GET /extract/<bucket>/files.tgz?file=path%2Fto%2Ffile
*args: All parts of the GS path of the archive, without gs:// prefix.
kwargs: file: The path or pattern of file to be extracted.
Extracted file content (Binary data).
archive = _check_file_extension(
ext_names=['.tar', '.tar.gz', '.tgz', '.tar.bz2', '.tar.xz'])
files = _safe_get_param(kwargs, 'file')
if len(files) != 1:
raise GsArchiveServerError('Cannot extract more than one file at a time.')
file_to_be_extracted = files.pop()
_log('Extracting "%s" from "%s".', file_to_be_extracted, archive)
archive_basename, archive_extname = os.path.splitext(archive)
headers = cherrypy.request.headers.copy()
if archive_extname == '.tar':
decompressed_archive_name = archive
# Compressed tar archives: we don't decompress them here. Instead, we
# suppose they have been decompressed, and continue the routine to extract
# from the supposed decompressed archive name.
# The magic is, we set a special HTTP header, and pass it to caching
# server. Eventually, caching server loops it back to `download` RPC.
# In `download`, we check this header. If it exists, then call
# `decompress` RPC other than a normal `download` RPC.
headers[_HTTP_HEADER_COMPRESSED_TAR_EXT] = archive_extname
# Get the name of decompressed archive, e.g. foo.tgz => foo.tar,
# bar.tar.xz => bar.tar, etc.
if archive_extname == '.tgz':
decompressed_archive_name = '%s.tar' % archive_basename
decompressed_archive_name = archive_basename
return self._extract_file_from_tar(file_to_be_extracted,
decompressed_archive_name, headers)
def _extract_file_from_tar(self, target_file, archive, headers=None):
"""Extracts the target file from the given archive.
target_file: The file to be extracted.
archive: The archive from which the file should be extracted.
headers: headers for the request that will get the archive.
Extracted file content (Binary data).
rsp =, headers=headers)
cmd = ['tar', '-O', '-x', target_file]
with tempfile.SpooledTemporaryFile(max_size=_SPOOL_FILE_SIZE_BYTES) as df:
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=df)
for chunk in rsp.iter_content(constants.READ_BUFFER_SIZE_BYTES):
# Update the response's content type to support yielding binary data.
cherrypy.response.headers['Content-Type'] = 'application/octet-stream'
# Go to the beginning of the file.
# Read the SpooledFile in chunks and yield the data.
while True:
data =
if not data:
yield data
@cherrypy.config(**{'': True})
def decompress(self, *args):
"""Decompress the compressed TAR archive.
*args: All parts of the GS path of the compressed archive, without gs://
The content generator of decompressed TAR archive.
zarchive = _check_file_extension(
'/'.join(args), ext_names=['.tar.gz', '.tar.bz2', '.tar.xz', '.tgz'])
_log('Decompressing "%s"', zarchive)
rsp =,
cherrypy.response.headers['Content-Type'] = 'application/x-tar'
cherrypy.response.headers['Accept-Ranges'] = 'bytes'
basename = os.path.basename(zarchive)
_, extname = os.path.splitext(basename)
# Command lines used to decompress file.
commands = {
'.gz': ['gzip', '-d', '-c'],
'.tgz': ['gzip', '-d', '-c'],
'.xz': ['xz', '-d', '-c'],
'.bz2': ['bzip2', '-d', '-c'],
decompressed_file = tempfile.SpooledTemporaryFile(
proc = subprocess.Popen(commands[extname], stdin=subprocess.PIPE,
_log('Decompress process id: %s.',
for chunk in rsp.iter_content(constants.READ_BUFFER_SIZE_BYTES):
_log('Decompression done.')
# The header of Content-Length is necessary for supporting range request.
# So we have to decompress the file locally to get the size. This may cause
# connection timeout issue if the decompression take too long time (e.g. 90
# seconds). As a reference, it takes about 10 seconds to decompress a 400MB
# tgz file., os.SEEK_END)
content_length = decompressed_file.tell()
_log('Decompressed content length is %d bytes.', content_length)
cherrypy.response.headers['Content-Length'] = str(content_length)
def decompressed_content():
_log('Streaming decompressed content of "%s" begin.', zarchive)
while True:
data =
if not data:
yield data
_log('Streaming of "%s" done.', zarchive)
return decompressed_content()
def _url_type(input_string):
"""Ensure |input_string| is a valid URL and convert to target type.
The target type is a tuple of (scheme, netloc).
split_result = urllib.parse.urlsplit(input_string)
if not split_result.scheme:
input_string = 'http://%s' % input_string
split_result = urllib.parse.urlsplit(input_string)
if not split_result.scheme or not split_result.netloc:
raise argparse.ArgumentTypeError('Wrong URL format: %s' % input_string)
return split_result.scheme, split_result.netloc
def parse_args(argv):
"""Parse arguments."""
parser = argparse.ArgumentParser(
# The service can either bind to a socket or listen to a port, but doesn't do
# both.
socket_or_port = parser.add_mutually_exclusive_group(required=True)
socket_or_port.add_argument('-s', '--socket',
help='Unix domain socket to bind')
socket_or_port.add_argument('-p', '--port', type=int,
help='Port number to listen.')
# TODO(guocb): support Unix domain socket
'-c', '--caching-server', required=True, type=_url_type,
help='URL of the proxy server. Valid format is '
'[http://]{<hostname>|<IP>}[:<port_number>]. When skipped, the default '
'scheme is http and port number is 80. Any other components in URL are '
'-b', '--bind', default='', type=str,
help='Option to specify alternate bind address. By default, '
'gs_archive_server starts on')
return parser.parse_args(argv)
def setup_logger():
"""Setup logger."""
formatter = logging.Formatter(
'%(module)s:%(asctime)-15s [%(Remote-Addr)s:%(thread)d] %(levelname)s:'
' %(message)s')
handler = logging.StreamHandler(sys.stdout)
def main(argv):
"""Main function."""
args = parse_args(argv)
if args.socket:
# in order to allow group user writing to domain socket, the directory
# should have GID bit set, i.e. g+s
cherrypy.server.socket_port = args.port
cherrypy.server.socket_file = args.socket
cherrypy.server.socket_host = args.bind
if __name__ == '__main__':