GS Cache: extract a file from a compressed TAR.

This change adds support to extract a file from a compressed TAR
archive, i.e. tar.gz, tgz, tar.xz, and tar.bz2.

BUG=chromium:824580
TEST=Ran all unit tests.

Change-Id: I9693198b876cb4f7b634265b007855390e97bda3
Reviewed-on: https://chromium-review.googlesource.com/1081150
Commit-Ready: ChromeOS CL Exonerator Bot <chromiumos-cl-exonerator@appspot.gserviceaccount.com>
Tested-by: Congbin Guo <guocb@chromium.org>
Reviewed-by: Congbin Guo <guocb@chromium.org>
diff --git a/gs_cache/gs_archive_server.py b/gs_cache/gs_archive_server.py
index a36d8f7..83adedf 100644
--- a/gs_cache/gs_archive_server.py
+++ b/gs_cache/gs_archive_server.py
@@ -9,7 +9,10 @@
 with a local hosted reverse proxy server, e.g. Nginx.
 
 The server accepts below requests:
-  - GET /download/<bucket>/path/to/file: download the file from google storage
+  - GET /download/<bucket>/path/to/file
+      Download the file from google storage.
+  - GET /extract/<bucket>/path/to/archive?file=path/to/file
+      Extract a file form a compressed/uncompressed TAR archive.
 """
 
 from __future__ import absolute_import
@@ -46,6 +49,17 @@
 _READ_BUFFER_SIZE_BYTES = 1024 * 1024  # 1 MB
 _WRITE_BUFFER_SIZE_BYTES = 1024 * 1024  # 1 MB
 
+# When extract files from TAR (either compressed or uncompressed), we suppose
+# the TAR exists, so we can call `download` RPC to get it. It's straightforward
+# for uncompressed TAR. But for compressed TAR, we cannot `download` it from
+# GS because it doesn't exist there at all. In this case, we call `decompress`
+# RPC internally to download and decompress. In order to tell if invoke of
+# `download` RPC is a real download, or download+decompress, we use below HTTP
+# header as a flag. It can also tell use what's the extension name of the
+# compressed tar, e.g. '.tar.gz', etc. We use this information to get the file
+# name on GS.
+_HTTP_HEADER_COMPRESSED_TAR_EXT = 'X-Compressed-Tar-Ext'
+
 # The max size of temporary spool file in memory.
 _SPOOL_FILE_SIZE_BYTES = 100 * 1024 * 1024  # 100 MB
 
@@ -166,17 +180,41 @@
                                            urllib.urlencode(args or {}), None))
     _log('Sending request to caching server: %s', url)
     # The header to control using or bypass cache.
-    _log_filtered_headers(headers, ('X-No-Cache',))
+    _log_filtered_headers(headers, ('Range', 'X-No-Cache',
+                                    _HTTP_HEADER_COMPRESSED_TAR_EXT))
     rsp = requests.get(url, headers=headers, stream=True)
-    _log('Caching server response %s', rsp.status_code)
+    _log('Caching server response %s: %s', rsp.status_code, url)
     _log_filtered_headers(rsp.headers, ('Content-Type', 'Content-Length',
-                                        'X-Cache', 'Cache-Control', 'Date'))
+                                        'Content-Range', 'X-Cache',
+                                        'Cache-Control', 'Date'))
     rsp.raise_for_status()
     return rsp
 
+  def _download_and_decompress_tar(self, path, ext_name, headers=None):
+    """Helper function to download and decompress compressed TAR."""
+    # The |path| we have is like foo.tar. Combine with |ext_name| we can get
+    # the compressed file name on Google storage, e.g.
+    # 'foo.tar' + '.gz' => foo.tar.gz
+    # But it's special for '.tgz', i.e. 'foo.tar' + '.tgz' => 'foo.tgz'
+    if ext_name == '.tgz':
+      path, _ = os.path.splitext(path)
+
+    path = '%s%s' % (path, ext_name)
+    _log('Download and decompress %s', path)
+    return self._call('decompress', path, headers=headers)
+
   def download(self, path, headers=None):
-    """Call download RPC."""
-    return self._call('download', path, headers=headers)
+    """Download file |path| from the caching server."""
+    # When the request comes with header _HTTP_HEADER_COMPRESSED_TAR_EXT, we
+    # internally call `decompress` instead of `download` because Google storage
+    # only has the compressed version of the file to be "downloaded".
+    ext_name = headers.pop(_HTTP_HEADER_COMPRESSED_TAR_EXT, None)
+
+    # RPC `decompress` validates ext_name, so doesn't do that here.
+    if ext_name:
+      return self._download_and_decompress_tar(path, ext_name, headers=headers)
+    else:
+      return self._call('download', path, headers=headers)
 
   def list_member(self, path, headers=None):
     """Call list_member RPC."""
@@ -198,8 +236,8 @@
 
     An example, GET /list_member/bucket/path/to/file.tar
     The output is in format of:
-      <file name>,<data1>,<data2>,...<data6>
-      <file name>,<data1>,<data2>,...<data6>
+      <file name>,<data1>,<data2>,...<data4>
+      <file name>,<data1>,<data2>,...<data4>
       ...
 
     Details:
@@ -207,10 +245,8 @@
         path/to/file,name  -> path/to/file%2Cname.
       <data1>: File record start offset, in bytes.
       <data2>: File record size, in bytes.
-      <data3>: File record end offset, in bytes.
-      <data4>: File content start offset, in bytes.
-      <data5>: File content size, in bytes.
-      <data6>: File content end offset, in bytes.
+      <data3>: File content start offset, in bytes.
+      <data4>: File content size, in bytes.
 
     This is an internal RPC and shouldn't be called by end user!
 
@@ -303,10 +339,10 @@
   @cherrypy.config(**{'response.stream': True})
   @_to_cherrypy_error
   def extract(self, *args, **kwargs):
-    """Extract a file from a Tar archive.
+    """Extract a file from a compressed/uncompressed Tar archive.
 
     Examples:
-      GET /extract/chromeos-image-archive/release/files.tar?file=path/to/file
+      GET /extract/chromeos-image-archive/release/files.tgz?file=path/to/file
 
     Args:
       *args: All parts of the GS path of the archive, without gs:// prefix.
@@ -315,18 +351,40 @@
     Returns:
       The stream of extracted file.
     """
-    # TODO(guocb): support compressed format of tar
-    archive = _check_file_extension('/'.join(args), ext_names=['.tar'])
+    archive = _check_file_extension(
+        '/'.join(args),
+        ext_names=['.tar', '.tar.gz', '.tgz', '.tar.bz2', '.tar.xz'])
     filename = _safe_get_param(kwargs, 'file')
     _log('Extracting "%s" from "%s".', filename, archive)
-    return self._extract_file_from_tar(filename, archive)
+    archive_basename, archive_extname = os.path.splitext(archive)
 
-  def _extract_file_from_tar(self, filename, archive):
-    """Extract file of |filename| from |archive|."""
+    headers = cherrypy.request.headers.copy()
+    if archive_extname == '.tar':
+      decompressed_archive_name = archive
+    else:
+      # Compressed tar archives: we don't decompress them here. Instead, we
+      # suppose they have been decompressed, and continue the routine to extract
+      # from the supposed decompressed archive name.
+      # The magic is, we set a special HTTP header, and pass it to caching
+      # server. Eventually, caching server loops it back to `download` RPC.
+      # In `download`, we check this header. If it exists, then call
+      # `decompress` RPC other than a normal `download` RPC.
+      headers[_HTTP_HEADER_COMPRESSED_TAR_EXT] = archive_extname
+      # Get the name of decompressed archive, e.g. foo.tgz => foo.tar,
+      # bar.tar.xz => bar.tar, etc.
+      if archive_extname == '.tgz':
+        decompressed_archive_name = '%s.tar' % archive_basename
+      else:
+        decompressed_archive_name = archive_basename
+
+    return self._extract_file_from_tar(filename, decompressed_archive_name,
+                                       headers)
+
+  def _extract_file_from_tar(self, filename, archive, headers=None):
+    """Extract file of |filename| from |archive| with http headers |headers|."""
     # Call `list_member` and search |filename| in it. If found, create another
     # "Range Request" to download that range of bytes.
-    all_files = self._caching_server.list_member(
-        archive, headers=cherrypy.request.headers)
+    all_files = self._caching_server.list_member(archive, headers=headers)
     # The format of each line is '<filename>,<data1>,<data2>...'. And the
     # filename is encoded by URL percent encoding, so no ',' in filename. Thus
     # search '<filename>,' is good enough for looking up the file information.
@@ -343,7 +401,7 @@
         _log('The line for the file found: %s', l)
         file_info = tarfile_utils.TarMemberInfo._make(l.split(','))
         rsp = self._send_range_request(archive, file_info.content_start,
-                                       file_info.size)
+                                       file_info.size, headers)
         rsp.raise_for_status()
         return rsp.iter_content(_WRITE_BUFFER_SIZE_BYTES)
 
@@ -351,13 +409,12 @@
         _HTTP_BAD_REQUEST,
         'File "%s" is not in archive "%s"!' % (filename, archive))
 
-  def _send_range_request(self, archive, start, size):
+  def _send_range_request(self, archive, start, size, headers):
     """Create and send a "Range Request" to caching server.
 
     Set HTTP Range header and just download the bytes in that "range".
     https://developer.mozilla.org/en-US/docs/Web/HTTP/Range_requests
     """
-    headers = cherrypy.request.headers.copy()
     headers['Range'] = 'bytes=%s-%d' % (start, int(start) + int(size) - 1)
     rsp = self._caching_server.download(archive, headers=headers)
 
@@ -394,6 +451,7 @@
     rsp = self._caching_server.download(zarchive,
                                         headers=cherrypy.request.headers)
     cherrypy.response.headers['Content-Type'] = 'application/x-tar'
+    cherrypy.response.headers['Accept-Ranges'] = 'bytes'
 
     basename = os.path.basename(zarchive)
     _, extname = os.path.splitext(basename)
diff --git a/gs_cache/tests/gs_archive_server_test.py b/gs_cache/tests/gs_archive_server_test.py
index 040157b..00dfc16 100644
--- a/gs_cache/tests/gs_archive_server_test.py
+++ b/gs_cache/tests/gs_archive_server_test.py
@@ -68,6 +68,22 @@
         'z_size': 51200,
         'z_md5': 'baa91444d9a1d8e173c42dfa776b1b98',
     },
+    'a_file_from_tgz': {
+        'path': 'dev_image_new/autotest/tools/common.py',
+        'from': '%s/stateful.tgz' % _DIR,
+        'md5': '634ac656b484758491674530ebe9fbc3'
+    },
+    'a_file_from_bz2': {
+        'path':
+            'autotest/au_control_files/control.paygen_au_canary_full_10500.0.0',
+        'from': '%s/paygen_au_canary_control.tar.bz2' % _DIR,
+        'md5': '5491d80aa4788084d974bd92df67815d'
+    },
+    'a_file_from_xz': {
+        'path': 'mount_image.sh',
+        'from': '%s/image_scripts.tar.xz' % _DIR,
+        'md5': 'e89dd3eb2fa386c3b0eef538a5ab57c3',
+    },
 }
 
 # a tgz file with only one file "bar" which content is "foo\n"
@@ -189,6 +205,19 @@
       rsp = self.server.decompress('baz.tar.xz')
       self.assertEquals(''.join(rsp), _A_TAR_FILE)
 
+  def test_extract_ztar(self):
+    """Test extract a file from a compressed tar archive."""
+    with mock.patch.object(self.server, '_caching_server') as cache_server:
+      cache_server.list_member.return_value.iter_lines.return_value = [
+          'foobar,_,_,0,123']
+      self.server.extract('baz.tar.gz', file='foobar')
+      self.server.extract('baz.tar.bz2', file='foobar')
+      self.server.extract('baz.tar.xz', file='foobar')
+      self.server.extract('baz.tgz', file='foobar')
+
+      self.assertTrue(cache_server.list_member.called)
+      self.assertTrue(cache_server.download.called)
+
 
 def testing_server_setup():
   """Check if testing server is setup."""
@@ -266,6 +295,13 @@
       self.assertEquals(rsp.headers['Content-Type'], 'application/x-tar')
       self._verify_md5(rsp.content, tested_file['z_md5'])
 
+  def test_extract_from_compressed_tar(self):
+    """Test extracting a file from a compressed tar file."""
+    for k in ('a_file_from_tgz', 'a_file_from_xz', 'a_file_from_bz2'):
+      tested_file = _TEST_DATA[k]
+      rsp = self._get_page('/extract/%(from)s?file=%(path)s' % tested_file)
+      self._verify_md5(rsp.content, tested_file['md5'])
+
 
 if __name__ == "__main__":
   unittest.main()