Adding cache and reconnect functionality

Adding ability to use CacheForDUT to cache files into DUT, as well as
the ability to force a reconnect to the DUT in case of loss.

BUG=None
TEST=unit

Change-Id: I9d915e886275b1f48dcedaf4b850b27d758cb19f
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/dev-util/+/3271912
Tested-by: Jaques Clapauch <jaquesc@google.com>
Auto-Submit: Jaques Clapauch <jaquesc@google.com>
Reviewed-by: Seewai Fu <seewaifu@google.com>
Commit-Queue: Jaques Clapauch <jaquesc@google.com>
diff --git a/src/chromiumos/test/dut/cmd/cros-dut/dutserver.go b/src/chromiumos/test/dut/cmd/cros-dut/dutserver.go
index 88bc35c..3dd6cc5 100644
--- a/src/chromiumos/test/dut/cmd/cros-dut/dutserver.go
+++ b/src/chromiumos/test/dut/cmd/cros-dut/dutserver.go
@@ -15,6 +15,8 @@
 	"io"
 	"log"
 	"net"
+	"net/url"
+	"path"
 	"strings"
 	"sync"
 
@@ -31,6 +33,10 @@
 	"chromiumos/test/dut/internal"
 )
 
+const CACHE_DOWNLOAD_URI = "/download/%s"
+const CACHE_UNTAR_AND_DOWNLOAD_URI = "/extract/%s?file=%s"
+const CACHE_EXTRACT_AND_DOWNLOAD_URI = "/decompress/%s"
+
 // DutServiceServer implementation of dut_service.proto
 type DutServiceServer struct {
 	manager        *lro.Manager
@@ -40,10 +46,11 @@
 	protoChunkSize int64
 	dutName        string
 	wiringAddress  string
+	cacheAddress   string
 }
 
 // newDutServiceServer creates a new dut service server to listen to rpc requests.
-func newDutServiceServer(l net.Listener, logger *log.Logger, conn dutssh.ClientInterface, serializerPath string, protoChunkSize int64, dutName, wiringAddress string) (*grpc.Server, func()) {
+func newDutServiceServer(l net.Listener, logger *log.Logger, conn dutssh.ClientInterface, serializerPath string, protoChunkSize int64, dutName, wiringAddress string, cacheAddress string) (*grpc.Server, func()) {
 	s := &DutServiceServer{
 		manager:        lro.New(),
 		logger:         logger,
@@ -52,6 +59,7 @@
 		protoChunkSize: protoChunkSize,
 		dutName:        dutName,
 		wiringAddress:  wiringAddress,
+		cacheAddress:   cacheAddress,
 	}
 
 	server := grpc.NewServer()
@@ -199,17 +207,95 @@
 	return stream.Send(resp)
 }
 
+// Cache downloads a specified file to the DUT via CacheForDut service
 func (s *DutServiceServer) Cache(ctx context.Context, req *api.CacheRequest) (*longrunning.Operation, error) {
 	s.logger.Println("Received api.CacheRequest: ", *req)
 	op := s.manager.NewOperation()
 
+	command := "curl -S -s -v -# -C - --retry 3 --retry-delay 60 "
+
+	url, err := s.getCacheURL(req)
+
+	if err != nil {
+		return nil, err
+	}
+
+	_, err = s.runCmdOutput(fmt.Sprintf("%s -o %s %s", command, req.DestinationPath, url))
+
+	if err != nil {
+		return nil, err
+	} else {
+		s.manager.SetResult(op.Name, &api.CacheResponse{
+			Result: &api.CacheResponse_Success_{},
+		})
+	}
+
 	return op, nil
 }
 
+// getCacheURL returns a constructed URL to the caching service given a specific
+// Source request type
+func (s *DutServiceServer) getCacheURL(req *api.CacheRequest) (string, error) {
+	switch op := req.Source.(type) {
+	case *api.CacheRequest_GsFile:
+		parsedPath, err := parseGSURL(op.GsFile.SourcePath)
+		if err != nil {
+			return "", err
+		}
+		return path.Join(s.cacheAddress, fmt.Sprintf(CACHE_DOWNLOAD_URI, parsedPath)), nil
+	case *api.CacheRequest_GsTarFile:
+		parsedPath, err := parseGSURL(op.GsTarFile.SourcePath)
+		if err != nil {
+			return "", err
+		}
+		return path.Join(s.cacheAddress, fmt.Sprintf(CACHE_UNTAR_AND_DOWNLOAD_URI, parsedPath, op.GsTarFile.SourceFile)), nil
+	case *api.CacheRequest_GsZipFile:
+		parsedPath, err := parseGSURL(op.GsZipFile.SourcePath)
+		if err != nil {
+			return "", err
+		}
+		return path.Join(s.cacheAddress, fmt.Sprintf(CACHE_EXTRACT_AND_DOWNLOAD_URI, parsedPath)), nil
+	default:
+		return "", fmt.Errorf("type can only be one of GsFile, GsTarFile or GSZipFile")
+	}
+}
+
+// parseGSURL retrieves the bucket and object from a GS URL.
+// URL expectation is of the form: "gs://bucket/object"
+func parseGSURL(gsUrl string) (string, error) {
+	if !strings.HasPrefix(gsUrl, "gs://") {
+		return "", fmt.Errorf("gs url must begin with 'gs://', instead have, %s", gsUrl)
+	}
+
+	u, err := url.Parse(gsUrl)
+	if err != nil {
+		return "", fmt.Errorf("unable to parse url, %w", err)
+	}
+
+	// Host corresponds to bucket
+	// Path corresponds to object
+	return path.Join(u.Host, u.Path), nil
+}
+
+// ForceReconnect attempts to reconnect to the DUT
 func (s *DutServiceServer) ForceReconnect(ctx context.Context, req *api.ForceReconnectRequest) (*longrunning.Operation, error) {
 	s.logger.Println("Received api.ForceReconnectRequest: ", *req)
+
 	op := s.manager.NewOperation()
 
+	conn, err := GetConnection(ctx, s.dutName, s.wiringAddress)
+
+	if err != nil {
+		return nil, err
+	} else {
+		s.manager.SetResult(op.Name, &api.CacheResponse{
+			Result: &api.CacheResponse_Success_{},
+		})
+
+	}
+
+	s.connection = &dutssh.SSHClient{Client: conn}
+
 	return op, nil
 }
 
diff --git a/src/chromiumos/test/dut/cmd/cros-dut/dutserver_test.go b/src/chromiumos/test/dut/cmd/cros-dut/dutserver_test.go
index e26e6b7..5fff297 100644
--- a/src/chromiumos/test/dut/cmd/cros-dut/dutserver_test.go
+++ b/src/chromiumos/test/dut/cmd/cros-dut/dutserver_test.go
@@ -9,6 +9,7 @@
 	"chromiumos/test/dut/cmd/cros-dut/dutssh/mock_dutssh"
 	"context"
 	"errors"
+	"fmt"
 	"io"
 	"log"
 	"net"
@@ -16,6 +17,7 @@
 	"testing"
 
 	"github.com/golang/mock/gomock"
+	"go.chromium.org/chromiumos/config/go/longrunning"
 	"go.chromium.org/chromiumos/config/go/test/api"
 	"golang.org/x/crypto/ssh"
 	"google.golang.org/grpc"
@@ -53,7 +55,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -139,7 +141,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -216,7 +218,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -297,7 +299,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -364,7 +366,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -438,7 +440,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -487,7 +489,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -533,7 +535,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -587,7 +589,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -641,7 +643,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -693,7 +695,7 @@
 	}
 
 	ctx := context.Background()
-	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress")
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
 	defer destructor()
 	if err != nil {
 		t.Fatalf("Failed to start DutServiceServer: %v", err)
@@ -718,3 +720,310 @@
 		t.Fatalf("Failed at api.Restart: %v", err)
 	}
 }
+
+// TestCache tests that the regular Cache command works
+func TestCache(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+	msi := mock_dutssh.NewMockSessionInterface(ctrl)
+
+	gomock.InOrder(
+		mci.EXPECT().NewSession().Return(msi, nil),
+		msi.EXPECT().Output(gomock.Eq("curl -S -s -v -# -C - --retry 3 --retry-delay 60  -o /dest/path cacheaddress/download/source/path")).Return([]byte("curl output"), nil),
+		msi.EXPECT().Close(),
+		mci.EXPECT().Close(),
+	)
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	op, err := cl.Cache(ctx, &api.CacheRequest{
+		DestinationPath: "/dest/path",
+		Source: &api.CacheRequest_GsFile{
+			GsFile: &api.CacheRequest_GSFile{
+				SourcePath: "gs://source/path"},
+		},
+	})
+
+	if err != nil {
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+
+	switch op.Result.(type) {
+	case *longrunning.Operation_Error:
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+}
+
+// TestUntarCache tests that the untar Cache command works
+func TestUntarCache(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+	msi := mock_dutssh.NewMockSessionInterface(ctrl)
+
+	gomock.InOrder(
+		mci.EXPECT().NewSession().Return(msi, nil),
+		msi.EXPECT().Output(gomock.Eq("curl -S -s -v -# -C - --retry 3 --retry-delay 60  -o /dest/path cacheaddress/extract/source/path?file=somefile")).Return([]byte("curl output"), nil),
+		msi.EXPECT().Close(),
+		mci.EXPECT().Close(),
+	)
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	op, err := cl.Cache(ctx, &api.CacheRequest{
+		DestinationPath: "/dest/path",
+		Source: &api.CacheRequest_GsTarFile{
+			GsTarFile: &api.CacheRequest_GSTARFile{
+				SourcePath: "gs://source/path",
+				SourceFile: "somefile",
+			},
+		},
+	})
+
+	if err != nil {
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+
+	switch op.Result.(type) {
+	case *longrunning.Operation_Error:
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+}
+
+// TestUnzipCache tests that the unzip Cache command works
+func TestUnzipCache(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+	msi := mock_dutssh.NewMockSessionInterface(ctrl)
+
+	gomock.InOrder(
+		mci.EXPECT().NewSession().Return(msi, nil),
+		msi.EXPECT().Output(gomock.Eq("curl -S -s -v -# -C - --retry 3 --retry-delay 60  -o /dest/path cacheaddress/decompress/source/path")).Return([]byte("curl output"), nil),
+		msi.EXPECT().Close(),
+		mci.EXPECT().Close(),
+	)
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	op, err := cl.Cache(ctx, &api.CacheRequest{
+		DestinationPath: "/dest/path",
+		Source: &api.CacheRequest_GsZipFile{
+			GsZipFile: &api.CacheRequest_GSZipFile{
+				SourcePath: "gs://source/path",
+			},
+		},
+	})
+
+	if err != nil {
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+
+	switch op.Result.(type) {
+	case *longrunning.Operation_Error:
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+}
+
+// TestCacheFailsWrongURL tests that the unzip Cache fails on a URL which doesn't comply
+func TestCacheFailsWrongURL(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+
+	mci.EXPECT().Close()
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	_, err = cl.Cache(ctx, &api.CacheRequest{
+		DestinationPath: "/dest/path",
+		Source: &api.CacheRequest_GsZipFile{
+			GsZipFile: &api.CacheRequest_GSZipFile{
+				SourcePath: "source/path",
+			},
+		},
+	})
+
+	if err == nil {
+		t.Fatalf("Expected failure due to improper formatting")
+	}
+
+}
+
+// TestCacheFailsCommandFails tests that the unzip Cache fails on a URL which doesn't comply
+func TestCacheFailsCommandFails(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+	msi := mock_dutssh.NewMockSessionInterface(ctrl)
+
+	gomock.InOrder(
+		mci.EXPECT().NewSession().Return(msi, nil),
+		msi.EXPECT().Output(gomock.Eq("curl -S -s -v -# -C - --retry 3 --retry-delay 60  -o /dest/path cacheaddress/download/source/path")).Return([]byte(""), fmt.Errorf("couldn't download")),
+		msi.EXPECT().Close(),
+		mci.EXPECT().Close(),
+	)
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	_, err = cl.Cache(ctx, &api.CacheRequest{
+		DestinationPath: "/dest/path",
+		Source: &api.CacheRequest_GsFile{
+			GsFile: &api.CacheRequest_GSFile{
+				SourcePath: "gs://source/path",
+			},
+		},
+	})
+
+	if !strings.Contains(err.Error(), "couldn't download") {
+		t.Fatalf("Failed at api.Cache: %v", err)
+	}
+
+}
+
+// TestForceReconnect tests that a Restart command works
+func TestForceReconnect(t *testing.T) {
+	ctrl := gomock.NewController(t)
+	defer ctrl.Finish()
+
+	mci := mock_dutssh.NewMockClientInterface(ctrl)
+
+	mci.EXPECT().Close()
+
+	var logBuf bytes.Buffer
+	l, err := net.Listen("tcp", ":0")
+	if err != nil {
+		t.Fatal("Failed to create a net listener: ", err)
+	}
+
+	ctx := context.Background()
+	srv, destructor := newDutServiceServer(l, log.New(&logBuf, "", log.LstdFlags|log.LUTC), mci, "serializer_path", 0, "dutname", "wiringaddress", "cacheaddress")
+	defer destructor()
+	if err != nil {
+		t.Fatalf("Failed to start DutServiceServer: %v", err)
+	}
+	go srv.Serve(l)
+	defer srv.Stop()
+
+	conn, err := grpc.Dial(l.Addr().String(), grpc.WithInsecure())
+	if err != nil {
+		t.Fatalf("Failed to dial: %v", err)
+	}
+	defer conn.Close()
+
+	cl := api.NewDutServiceClient(conn)
+	_, err = cl.ForceReconnect(ctx, &api.ForceReconnectRequest{})
+
+	// technically if we get to the reconnect step, we did everything right, so
+	// rather than mock the reconnect step, we assume that if we got there, we are
+	// successful
+	if !strings.Contains(err.Error(), "connection error") {
+		t.Fatalf("Failed at api.Restart: %v", err)
+	}
+}
diff --git a/src/chromiumos/test/dut/cmd/cros-dut/main.go b/src/chromiumos/test/dut/cmd/cros-dut/main.go
index 1fd89e5..989c3d4 100644
--- a/src/chromiumos/test/dut/cmd/cros-dut/main.go
+++ b/src/chromiumos/test/dut/cmd/cros-dut/main.go
@@ -54,6 +54,7 @@
 		wiringAddress := flag.String("wiring_address", "", "Address to TLW. Only required if using DUT name.")
 		protoChunkSize := flag.Int64("chunk_size", 1024*1024, "Largest size of blob or coredumps to include in an individual response.")
 		serializerPath := flag.String("serializer_path", "/usr/local/sbin/crash_serializer", "Location of the serializer binary on disk in the DUT.")
+		cacheAddress := flag.String("cache_address", "", "CacheForDUT service address.")
 		port := flag.Int("port", 0, "the port used to start service. default not specified")
 		flag.Parse()
 
@@ -78,6 +79,11 @@
 			fmt.Println("A Wiring address should not be specified if DUT address is used.")
 		}
 
+		if *cacheAddress == "" {
+			fmt.Println("Caching address must be specified.")
+			return 2
+		}
+
 		if *port == 0 {
 			fmt.Println("Please specify the port.")
 			return 2
@@ -109,7 +115,7 @@
 			return 2
 		}
 
-		server, destructor := newDutServiceServer(l, logger, &dutssh.SSHClient{Client: conn}, *serializerPath, *protoChunkSize, *dutName, *wiringAddress)
+		server, destructor := newDutServiceServer(l, logger, &dutssh.SSHClient{Client: conn}, *serializerPath, *protoChunkSize, *dutName, *wiringAddress, *cacheAddress)
 		defer destructor()
 
 		err = server.Serve(l)