cos_gpu_config_builder: fix milestone extraction from kernel version

BUG=b/238798451

Change-Id: I58f70317ab64791ed62eecedcc3026a4e039b3d0
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/37269
Reviewed-by: He Gao <hegao@google.com>
Tested-by: Arnav Kansal <rnv@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
diff --git a/src/pkg/gpuconfig/generate_configs.go b/src/pkg/gpuconfig/generate_configs.go
index d7fd5e0..1902eb2 100644
--- a/src/pkg/gpuconfig/generate_configs.go
+++ b/src/pkg/gpuconfig/generate_configs.go
@@ -36,6 +36,17 @@
 	VersionType   string                 `json:"version_type"`
 }
 
+func kernelVersionToMilestone(kernelVersion string) string {
+	milestone := ""
+	for _, sep := range []string{"m", "r"} { // release branch or main branch check
+		if split := strings.Split(kernelVersion, sep); len(split) == 2 {
+			milestone = split[1]
+			break
+		}
+	}
+	return milestone
+}
+
 // Generates and GPU precompilation build configs(and metadata) for a given
 // tuple of kernelVersion and driver versions
 func GenerateKernelCIConfigs(ctx context.Context, client *storage.Client, kernelVersion string, driverVersions []string) ([]GPUPrecompilationConfig, error) {
@@ -45,7 +56,7 @@
 		if err != nil {
 			return nil, err
 		}
-		milestone := strings.Split(kernelVersion, "m")[1]
+		milestone := kernelVersionToMilestone(kernelVersion)
 		configs = append(configs, GPUPrecompilationConfig{config, driverVersion, milestone, kernelVersion, "Kernel"})
 	}
 	return configs, nil
diff --git a/src/pkg/gpuconfig/generate_configs_test.go b/src/pkg/gpuconfig/generate_configs_test.go
index c9c8243..0ac1ffa 100644
--- a/src/pkg/gpuconfig/generate_configs_test.go
+++ b/src/pkg/gpuconfig/generate_configs_test.go
@@ -58,6 +58,21 @@
 	}
 }
 
+func TestKernelVersionToMilestone(t *testing.T) {
+	for _, tc := range []struct {
+		kernelVersion     string
+		milestoneExpected string
+	}{
+		{"5.10.100-14.m97", "97"},
+		{"5.10.107-10.r97", "97"},
+		{"5.10.100-14", ""},
+	} {
+		if got := kernelVersionToMilestone(tc.kernelVersion); got != tc.milestoneExpected {
+			t.Errorf("kernelVersionToMilestone() = %+v, want %+v", got, tc.milestoneExpected)
+		}
+	}
+}
+
 func TestFetchToolchainTarballPath(t *testing.T) {
 	gcs := fakes.GCSForTest(t)
 	defer gcs.Close()