compiler_wrapper: refactor goma flag parsing

This is in preparation for adding more flag parsing of this type.

BUG=b:190741226
TEST=go test

Change-Id: Ief431a6e30b6ba22767cdd46247e362508addd4b
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/toolchain-utils/+/2956694
Reviewed-by: Ryan Beltran <ryanbeltran@chromium.org>
Tested-by: George Burgess <gbiv@chromium.org>
diff --git a/compiler_wrapper/gomacc_flag.go b/compiler_wrapper/gomacc_flag.go
index ac298b1..56522d4 100644
--- a/compiler_wrapper/gomacc_flag.go
+++ b/compiler_wrapper/gomacc_flag.go
@@ -5,32 +5,104 @@
 package main
 
 import (
+	"errors"
+	"fmt"
 	"os"
+	"strings"
 )
 
-func processGomaCccFlags(builder *commandBuilder) (gomaUsed bool, err error) {
-	gomaPath := ""
-	nextArgIsGomaPath := false
+var errNoSuchCmdlineArg = errors.New("no such commandline argument")
+
+// Removes one flag from `builder`, assuming that a value follows the flag. Two formats are
+// supported for this: `--foo=bar` and `--foo bar`. In either case, "bar" will be returned as the
+// `value`.
+//
+// If no flag is found on the commandline, this returns the `errNoSuchCmdlineArg` error. `builder`
+// is unmodified if this error is returned, but its contents are unspecified if any other error is
+// returned.
+//
+// In the case of multiple such flags, only the first encountered will be removed.
+func removeOneUserCmdlineFlagWithValue(builder *commandBuilder, flagName string) (flagValue string, err error) {
+	const (
+		searchingForFlag uint8 = iota
+		searchingForValue
+		searchComplete
+	)
+
+	flagRequiresAValue := func() error { return newUserErrorf("flag %q requires a value", flagName) }
+	searchState := searchingForFlag
 	builder.transformArgs(func(arg builderArg) string {
-		if arg.fromUser {
-			if arg.value == "--gomacc-path" {
-				nextArgIsGomaPath = true
-				return ""
-			}
-			if nextArgIsGomaPath {
-				gomaPath = arg.value
-				nextArgIsGomaPath = false
-				return ""
-			}
+		if err != nil {
+			return arg.value
 		}
-		return arg.value
+
+		switch searchState {
+		case searchingForFlag:
+			if !arg.fromUser {
+				return arg.value
+			}
+
+			if arg.value == flagName {
+				searchState = searchingForValue
+				return ""
+			}
+
+			isArgEq := strings.HasPrefix(arg.value, flagName) && arg.value[len(flagName)] == '='
+			if !isArgEq {
+				return arg.value
+			}
+
+			flagValue = arg.value[len(flagName)+1:]
+			searchState = searchComplete
+			return ""
+
+		case searchingForValue:
+			if !arg.fromUser {
+				err = flagRequiresAValue()
+				return arg.value
+			}
+
+			flagValue = arg.value
+			searchState = searchComplete
+			return ""
+
+		case searchComplete:
+			return arg.value
+
+		default:
+			panic(fmt.Sprintf("unknown search state: %v", searchState))
+		}
 	})
-	if nextArgIsGomaPath {
-		return false, newUserErrorf("--gomacc-path given without value")
+
+	if err != nil {
+		return "", err
 	}
-	if gomaPath == "" {
+
+	switch searchState {
+	case searchingForFlag:
+		return "", errNoSuchCmdlineArg
+
+	case searchingForValue:
+		return "", flagRequiresAValue()
+
+	case searchComplete:
+		return flagValue, nil
+
+	default:
+		panic(fmt.Sprintf("unknown search state: %v", searchState))
+	}
+}
+
+func processGomaCccFlags(builder *commandBuilder) (gomaUsed bool, err error) {
+	gomaPath, err := removeOneUserCmdlineFlagWithValue(builder, "--gomacc-path")
+	if err != nil && err != errNoSuchCmdlineArg {
+		return false, err
+	}
+
+	if err == errNoSuchCmdlineArg || gomaPath == "" {
 		gomaPath, _ = builder.env.getenv("GOMACC_PATH")
 	}
+
 	if gomaPath != "" {
 		if _, err := os.Lstat(gomaPath); err == nil {
 			builder.wrapPath(gomaPath)
diff --git a/compiler_wrapper/gomacc_flag_test.go b/compiler_wrapper/gomacc_flag_test.go
index e1dc33e..a436227 100644
--- a/compiler_wrapper/gomacc_flag_test.go
+++ b/compiler_wrapper/gomacc_flag_test.go
@@ -7,9 +7,101 @@
 import (
 	"os"
 	"path"
+	"reflect"
 	"testing"
 )
 
+func TestCommandlineFlagParsing(t *testing.T) {
+	withTestContext(t, func(ctx *testContext) {
+		type testCase struct {
+			extraFlags []string
+			// If this is nonempty, expectedValue is ignored. Otherwise, expectedValue
+			// has the expected value for the flag, and expectedCommand has the expected
+			// (extra) flags in the builder after filtering.
+			expectedError      string
+			expectedValue      string
+			expectedExtraFlags []string
+		}
+
+		const flagName = "--flag"
+		testCases := []testCase{
+			{
+				extraFlags:    nil,
+				expectedError: errNoSuchCmdlineArg.Error(),
+			},
+			{
+				extraFlags:    []string{flagName + "a"},
+				expectedError: errNoSuchCmdlineArg.Error(),
+			},
+			{
+				extraFlags:    []string{flagName},
+				expectedError: "flag \"" + flagName + "\" requires a value",
+			},
+			{
+				extraFlags:         []string{flagName, "foo"},
+				expectedValue:      "foo",
+				expectedExtraFlags: nil,
+			},
+			{
+				extraFlags:         []string{flagName + "=foo"},
+				expectedValue:      "foo",
+				expectedExtraFlags: nil,
+			},
+			{
+				extraFlags:         []string{flagName + "="},
+				expectedValue:      "",
+				expectedExtraFlags: nil,
+			},
+			{
+				extraFlags:         []string{flagName + "=foo", flagName + "=bar"},
+				expectedValue:      "foo",
+				expectedExtraFlags: []string{flagName + "=bar"},
+			},
+		}
+
+		for _, testCase := range testCases {
+			cmd := ctx.newCommand(gccX86_64, testCase.extraFlags...)
+			builder, err := newCommandBuilder(ctx, ctx.cfg, cmd)
+			if err != nil {
+				t.Fatalf("Failed creating a command builder: %v", err)
+			}
+
+			flagValue, err := removeOneUserCmdlineFlagWithValue(builder, flagName)
+			if err != nil {
+				if testCase.expectedError == "" {
+					t.Errorf("given extra flags %q, got unexpected error removing %q: %v", testCase.extraFlags, flagName, err)
+					continue
+				}
+
+				if e := err.Error(); e != testCase.expectedError {
+					t.Errorf("given extra flags %q, got error %q; wanted %q", testCase.extraFlags, e, testCase.expectedError)
+				}
+				continue
+			}
+
+			if testCase.expectedError != "" {
+				t.Errorf("given extra flags %q, got no error, but expected %q", testCase.extraFlags, testCase.expectedError)
+				continue
+			}
+
+			if flagValue != testCase.expectedValue {
+				t.Errorf("given extra flags %q, got value %q, but expected %q", testCase.extraFlags, flagValue, testCase.expectedValue)
+			}
+
+			currentFlags := []string{}
+			// Chop off the first arg, which should just be the compiler
+			for _, a := range builder.args {
+				currentFlags = append(currentFlags, a.value)
+			}
+
+			sameFlags := (len(currentFlags) == 0 && len(testCase.expectedExtraFlags) == 0) || reflect.DeepEqual(currentFlags, testCase.expectedExtraFlags)
+			if !sameFlags {
+				t.Errorf("given extra flags %q, got post-removal flags %q, but expected %q", testCase.extraFlags, currentFlags, testCase.expectedExtraFlags)
+			}
+		}
+	})
+}
+
 func TestCallGomaccIfEnvIsGivenAndValid(t *testing.T) {
 	withGomaccTestContext(t, func(ctx *testContext, gomaPath string) {
 		ctx.env = []string{"GOMACC_PATH=" + gomaPath}
@@ -76,7 +168,7 @@
 	withTestContext(t, func(ctx *testContext) {
 		stderr := ctx.mustFail(callCompiler(ctx, ctx.cfg,
 			ctx.newCommand(gccX86_64, mainCc, "--gomacc-path")))
-		if err := verifyNonInternalError(stderr, "--gomacc-path given without value"); err != nil {
+		if err := verifyNonInternalError(stderr, "flag \"--gomacc-path\" requires a value"); err != nil {
 			t.Error(err)
 		}
 	})