Add a route backwards from the DUT to the test host

This is useful during VPN tests where the default route gets set.

BUG=chromium-os:13757
TEST=Manual: Inspected "ip" commands from debug log.  Thutt: Please
verify

Change-Id: Ib3c86243a6b13c95706b58549049b00df633c348

R=thutt@chromium.org

Review URL: http://codereview.chromium.org/6689026
diff --git a/server/site_host_route.py b/server/site_host_route.py
new file mode 100644
index 0000000..78ea0b8
--- /dev/null
+++ b/server/site_host_route.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+
+import re, socket, subprocess
+from autotest_lib.client.common_lib import error
+
+class HostRoute(object):
+    """
+    Host Route: A utility for retrieving information about our route to a host
+
+    """
+
+    def __init__(self, host):
+        self.host = host    # Remote host
+        self.calculate()
+
+    def calculate(self):
+        output = self.run_command(["ip", "route", "get", self.host])
+        # This converts "172.22.18.53 via 10.0.0.1 dev eth0 src 10.0.0.200 \n.."
+        # into ("via", "10.0.0.1", "dev", "eth0", "src", "10.0.0.200")
+        route_info = re.split("\s*", output.split("\n")[0].rstrip(' '))[1:]
+
+        # Further, convert the list into a dict {"via": "10.0.0.1", ...}
+        self.route_info = dict(tuple(route_info[i:i+2])
+                               for i in range(0, len(route_info), 2))
+
+        if 'src' not in self.route_info:
+            raise error.TestFail('Cannot find route to host %s' % self.host)
+
+class LocalHostRoute(HostRoute):
+    """
+    Self Host Route: Retrieve host route for the test-host machine
+
+    """
+    def __init__(self, host):
+        # TODO(pstew): If we could depend on the host having the "ip" command
+        # we would just be able to do this:
+        #
+        #     HostRoute.__init__(self, host)
+        #
+        # but alas, we can't depend on this, so we fake it by creating a
+        # socket and figuring out what local address we bound to if we
+        # connected to the client
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.connect((host, 22)) # NB: Port doesn't matter
+        self.route_info = { 'src': sock.getsockname()[0] }
+
+    def run_command(self, args):
+        return subprocess.Popen(args, stdout=subprocess.PIPE).communicate()[0]
+
+class RemoteHostRoute(HostRoute):
+    """
+    Remote Host Route: Retrieve host route for a remote (DUT, server) machine
+
+    """
+    def __init__(self, remote, host):
+        self.remote = remote
+        HostRoute.__init__(self, host)
+
+    def run_command(self, args):
+        return self.remote.run(' '.join(args)).stdout
diff --git a/server/site_wifitest.py b/server/site_wifitest.py
index e05f303..6ddcf0a 100644
--- a/server/site_wifitest.py
+++ b/server/site_wifitest.py
@@ -9,6 +9,7 @@
 from autotest_lib.server import site_linux_router
 from autotest_lib.server import site_linux_server
 from autotest_lib.server import site_host_attributes
+from autotest_lib.server import site_host_route
 from autotest_lib.server import site_eap_certs
 from autotest_lib.server import test
 from autotest_lib.client.common_lib import error
@@ -154,6 +155,7 @@
         self.__client_discover_commands(client)
         self.profile_save({})
         self.firewall_rules = []
+        self.host_route_args = {}
 
         # interface name on client
         self.client_wlanif = client.get('wlandev',
@@ -177,6 +179,7 @@
         self.profile_cleanup({})
         self.client_netdump_stop({})
         self.firewall_cleanup({})
+        self.host_route_cleanup({})
 
 
     def __must_be_installed(self, host, cmd):
@@ -1230,6 +1233,11 @@
         # Must get 'ca_certificate', 'client-certificate' and 'client-key'.
         cert_pathnames = params.get('files', {})
 
+        # Starting up the VPN client may cause the DUT's routing table (esp.
+        # the default route) to change.  Set up a host route backwards so
+        # we don't lose our control connection in that event.
+        __add_host_route(self.client)
+
         if self.vpn_kind is None:
             raise error.TestFail('No VPN kind specified for this test.')
         elif self.vpn_kind == 'openvpn':
@@ -1268,6 +1276,32 @@
                                      'for VPN kind (%s)' % self.vpn_kind)
             self.vpn_kind = None
 
+        __del_host_route(self.client)
+
+    def __add_host_route(self, host):
+        # What is the local address we use to get to the test host?
+        local_ip = site_host_route.LocalHostRoute(host.ip).route_info["src"]
+
+        # How does the test host currently get to this local address?
+        host_route = site_host_route.RemoteHostRoute(host, local_ip).route_info
+
+        # Flatten the returned dict into a single string
+        route_args = " ".join(" ".join(x) for x in host_route.iteritems())
+
+        self.host_route_args[host.ip] = "%s %s" % (local_ip, route_args)
+        host.run("ip route add %s" % self.host_route_args[host.ip])
+
+    def __del_host_route(self, host):
+        if host.ip not in self.host_route_args:
+            return
+
+        host.run("ip route del %s" % self.host_route_args.pop(host.ip))
+
+    def host_route_cleanup(self, params):
+        for host in (self.client, self.server, self.router):
+            self.__del_host_route(host)
+
+
 class HelperThread(threading.Thread):
     # Class that wraps a ping command in a thread so it can run in the bg.
     def __init__(self, client, cmd):