cos_gpu_config_builder: fix milestone extraction from kernel version
BUG=b/238798451
Change-Id: I58f70317ab64791ed62eecedcc3026a4e039b3d0
Reviewed-on: https://cos-review.googlesource.com/c/cos/tools/+/37269
Reviewed-by: He Gao <hegao@google.com>
Tested-by: Arnav Kansal <rnv@google.com>
Cloud-Build: GCB Service account <228075978874@cloudbuild.gserviceaccount.com>
diff --git a/src/pkg/gpuconfig/generate_configs.go b/src/pkg/gpuconfig/generate_configs.go
index d7fd5e0..1902eb2 100644
--- a/src/pkg/gpuconfig/generate_configs.go
+++ b/src/pkg/gpuconfig/generate_configs.go
@@ -36,6 +36,17 @@
VersionType string `json:"version_type"`
}
+func kernelVersionToMilestone(kernelVersion string) string {
+ milestone := ""
+ for _, sep := range []string{"m", "r"} { // release branch or main branch check
+ if split := strings.Split(kernelVersion, sep); len(split) == 2 {
+ milestone = split[1]
+ break
+ }
+ }
+ return milestone
+}
+
// Generates and GPU precompilation build configs(and metadata) for a given
// tuple of kernelVersion and driver versions
func GenerateKernelCIConfigs(ctx context.Context, client *storage.Client, kernelVersion string, driverVersions []string) ([]GPUPrecompilationConfig, error) {
@@ -45,7 +56,7 @@
if err != nil {
return nil, err
}
- milestone := strings.Split(kernelVersion, "m")[1]
+ milestone := kernelVersionToMilestone(kernelVersion)
configs = append(configs, GPUPrecompilationConfig{config, driverVersion, milestone, kernelVersion, "Kernel"})
}
return configs, nil
diff --git a/src/pkg/gpuconfig/generate_configs_test.go b/src/pkg/gpuconfig/generate_configs_test.go
index c9c8243..0ac1ffa 100644
--- a/src/pkg/gpuconfig/generate_configs_test.go
+++ b/src/pkg/gpuconfig/generate_configs_test.go
@@ -58,6 +58,21 @@
}
}
+func TestKernelVersionToMilestone(t *testing.T) {
+ for _, tc := range []struct {
+ kernelVersion string
+ milestoneExpected string
+ }{
+ {"5.10.100-14.m97", "97"},
+ {"5.10.107-10.r97", "97"},
+ {"5.10.100-14", ""},
+ } {
+ if got := kernelVersionToMilestone(tc.kernelVersion); got != tc.milestoneExpected {
+ t.Errorf("kernelVersionToMilestone() = %+v, want %+v", got, tc.milestoneExpected)
+ }
+ }
+}
+
func TestFetchToolchainTarballPath(t *testing.T) {
gcs := fakes.GCSForTest(t)
defer gcs.Close()