package commands

import (
	"context"
	"flag"
	"fmt"
	"os"
	"path/filepath"
	"strings"

	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/deviceinfo"
	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/features"
	"cos.googlesource.com/cos/tools.git/src/cmd/cos_gpu_installer/internal/installer"
	"cos.googlesource.com/cos/tools.git/src/pkg/cos"
	"cos.googlesource.com/cos/tools.git/src/pkg/gpuconfig"
	log "github.com/golang/glog"
	"github.com/google/subcommands"
	"github.com/pkg/errors"
)

// ListCommand is the subcommand to list supported GPU drivers.
type ListCommand struct {
	gcsDownloadBucket string
	gcsDownloadPrefix string
	gpuProtoCacheDir  string
	targetGPU         string
	debug             bool
}

// Name implements subcommands.Command.Name.
func (*ListCommand) Name() string { return "list" }

// Synopsis implements subcommands.Command.Synopsis.
func (*ListCommand) Synopsis() string { return "List supported GPU drivers for this version." }

// Usage implements subcommands.Command.Usage.
func (*ListCommand) Usage() string { return "list\n" }

// SetFlags implements subcommands.Command.SetFlags.
func (c *ListCommand) SetFlags(f *flag.FlagSet) {
	f.StringVar(&c.gcsDownloadBucket, "gcs-download-bucket", "",
		"The GCS bucket to download COS artifacts from. "+
			"The default bucket is one of 'cos-tools', 'cos-tools-asia' and 'cos-tools-eu' based on where the VM is running. "+
			"Those are the public COS artifacts buckets.")
	f.StringVar(&c.gcsDownloadPrefix, "gcs-download-prefix", "",
		"The GCS path prefix when downloading COS artifacts."+
			"If not set then the COS build number and board (e.g. 13310.1041.38/lakitu) will be used.")
	f.StringVar(&c.targetGPU, "target-gpu", "", fmt.Sprintf("This flag specifies the GPU device to display its compatible drivers. "+
		"If specified, it must be one of %s. If not specified, the GPU device will be auto-detected by the installer.", strings.Join(deviceinfo.AllGPUTypeStrings, ", ")))
	f.StringVar(&c.gpuProtoCacheDir, "gpu-proto-cache-dir", "",
		"The GPU proto cache directory that GPU driver versions proto file is stored into. If unspecified, the GPU driver versions proto file will not be cached.")
	f.BoolVar(&c.debug, "debug", false,
		"Enable debug mode.")
}

// Execute implements subcommands.Command.Execute.
func (c *ListCommand) Execute(ctx context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
	envReader, err := cos.NewEnvReader(hostRootPath)
	if err != nil {
		c.logError(errors.Wrap(err, "failed to create envReader"))
		return subcommands.ExitFailure
	}
	log.Infof("Running on COS build id %s", envReader.BuildNumber())
	fullCOSGPUConfigJsonPath := filepath.Join(hostRootPath, cosGPUConfigJsonPath)
	log.Infof("Feature flags initialization")
	featureConfig, err := features.InitFlags(fullCOSGPUConfigJsonPath)
	if err != nil {
		log.Errorf("Unable to init feature flags: %v", err)
		return subcommands.ExitFailure
	}
	downloader := cos.NewGCSDownloader(nil, envReader, c.gcsDownloadBucket, c.gcsDownloadPrefix, "", "")
	if featureConfig.PerGPULabelFeatureFlag {
		var gpuProtoCacheDir = filepath.Join(hostRootPath, c.gpuProtoCacheDir)
		if c.gpuProtoCacheDir == "" {
			gpuProtoCacheDir, err = os.MkdirTemp(scratchDir, "gpuProtoCacheDir")
			if err != nil {
				log.Errorf("Failed to create tmp gpu proto cache dir in: %s", scratchDir)
				return subcommands.ExitFailure
			}
			log.Infof("Created the tmp gpu proto cache directory: %s", gpuProtoCacheDir)
		}
		var gpuType deviceinfo.GPUType = deviceinfo.NO_GPU
		if featureConfig.PerGPULabelFeatureFlag && c.targetGPU != "" {
			gpuType, err = deviceinfo.ParseGPUType(c.targetGPU)
			if err != nil {
				c.logError(fmt.Errorf("failed to parse target GPU type: %v", err))
				return subcommands.ExitFailure
			}
		} else {
			gpuType, err = deviceinfo.GetGPUTypeInfo()
			if err != nil {
				log.Infof("No GPU device detected!")
			}
		}
		log.Infof("List GPU driver for device type: %s", gpuType)

		gpuVersionLabelMappings, err := downloadGPUDriverLabelMapping(ctx, downloader, gpuProtoCacheDir, gpuType)
		if err != nil {
			c.logError(fmt.Errorf("failed to fetch the available GPU driver versions and GPU driver labels for GPU type: %s with error: %v", gpuType.String(), err))
			return subcommands.ExitFailure
		}
		for gpuType, gpuDriverMap := range gpuVersionLabelMappings {
			fmt.Printf("\nGPU Device: %s\n", gpuType.String())
			for driverVersion, gpuDriverLabel := range gpuDriverMap {
				promptMsg := driverVersion
				if gpuDriverLabel != "" {
					promptMsg += " " + gpuDriverLabel
				}
				fmt.Printf("%s\n", promptMsg)
			}
		}
	} else {
		artifacts, err := downloader.ListGPUExtensionArtifacts()
		if err != nil {
			c.logError(errors.Wrap(err, "failed to list gpu extension artifacts"))
			return subcommands.ExitFailure
		}
		gpuDriverToLabelMap, err := getGPUDriverToMajorLabelMap(ctx, downloader)
		if err != nil {
			c.logError(fmt.Errorf("error appears when fetching the GPU driver versions to GPU driver major labels map: %w", err))
		}
		for _, artifact := range artifacts {
			driverVersion := ""
			if strings.HasSuffix(artifact, ".signature.tar.gz") {
				driverVersion = strings.TrimSuffix(artifact, ".signature.tar.gz")
			} else if strings.HasPrefix(artifact, "nvidia-drivers-") && strings.HasSuffix(artifact, "-signature.tar.gz") {
				driverVersion = strings.TrimPrefix(artifact, "nvidia-drivers-")
				driverVersion = strings.TrimSuffix(driverVersion, "-signature.tar.gz")
			}
			if driverVersion != "" {
				promptMsg := driverVersion
				if gpuDriverLabel, ok := gpuDriverToLabelMap[driverVersion]; ok {
					promptMsg += " " + gpuDriverLabel
				}
				fmt.Printf("%s\n", promptMsg)
			}
		}
	}
	return subcommands.ExitSuccess
}

// downloadGPUDriverLabelMapping retrieves mapping from GPUType to corresponding driver versions and their labels.
// This function will download the GPU driver version data to gpuInstallDir and parse it and
// constructs a hierarchical map where the first level keys are GPUTypes and the second level keys are driver versions
// with their associated labels as values.
// Example of the expected mapping:
//
//	V100 -> {
//			535.129.03 -> [R535][latest][default],
//			470.223.02 -> [R470]
//			525.147.05 -> [R525]
//			}
//
// Note:
// For NO_GPU case, this function will return mappings for all known GPU and their corresponding driver versions and label.
// For OTHERS case, this function will log the information that the GPU device is not recognized, and the GPU drivers in the OTHERS category from proto will be listed.
func downloadGPUDriverLabelMapping(ctx context.Context, downloader cos.ArtifactsDownloader, gpuInstallDir string, gpuType deviceinfo.GPUType) (map[deviceinfo.GPUType]map[string]string, error) {
	gpuProtoContent, err := installer.DownloadGPUDriverVersionsProto(ctx, downloader, gpuInstallDir)
	if err != nil {
		return nil, fmt.Errorf("failed to download and read GPU driver versions proto with error: %v", err)
	}
	var gpuTypeToDriverLabelMap = map[deviceinfo.GPUType]map[string]string{}
	if gpuType == deviceinfo.NO_GPU {
		log.Infof("GPU driver versions for all available GPU types supported in COS will be listed.")
		for _, availableGpuType := range deviceinfo.AvailableGPUTypesList {
			driverLabelMap, err := buildGPUDriverLabelMap(gpuProtoContent, availableGpuType)
			if err != nil {
				return nil, err
			}
			gpuTypeToDriverLabelMap[availableGpuType] = driverLabelMap
		}
		return gpuTypeToDriverLabelMap, nil
	}
	if gpuType == deviceinfo.Others {
		fmt.Printf("\nNote: This GPU device is not recognized by the gpu installer, the listed GPU drivers below may not compatible with current GPU device.\n" +
			"Please go to https://cloud.google.com/compute/docs/gpus/install-drivers-gpu for more details\n")
	}
	driverLabelMap, err := buildGPUDriverLabelMap(gpuProtoContent, gpuType)
	if err != nil {
		return nil, err
	}
	gpuTypeToDriverLabelMap[gpuType] = driverLabelMap
	return gpuTypeToDriverLabelMap, nil
}
func buildGPUDriverLabelMap(gpuProtoContent []byte, gpuType deviceinfo.GPUType) (map[string]string, error) {
	driverVersions, err := gpuconfig.GetGPUDriverVersions(gpuProtoContent, gpuType.String())
	var gpuDriverMap = map[string]string{}
	if err != nil {
		return nil, fmt.Errorf("failed to get the supported GPU driver versions for GPU type: %s with error: %v", gpuType.String(), err)
	}
	for _, driverVersion := range driverVersions {
		gpuDriverLabel := driverVersion.Label
		gpuDriverVersion := driverVersion.Version
		if _, ok := gpuDriverMap[gpuDriverVersion]; !ok {
			gpuDriverMap[gpuDriverVersion] = ""
		}
		if gpuDriverLabel != "" {
			gpuDriverMap[gpuDriverVersion] += "[" + gpuDriverLabel + "]"
		}
	}
	return gpuDriverMap, nil
}

// getGPUDriverToMajorLabelMap gets the GPU Driver version and associated labels
// E.g.
// 535.129.03 -> [R535][latest][default],
// 470.223.02 -> [R470]
// 525.147.05 -> [R525]
func getGPUDriverToMajorLabelMap(ctx context.Context, downloader cos.ArtifactsDownloader) (map[string]string, error) {
	GPUDriverMajorVersionArtifactsMap, err := installer.DownloadGPUDriverVersionArtifacts(ctx, downloader)
	if err != nil {
		return nil, fmt.Errorf("error while fetching available gpu driver version files - %w", err)
	}
	var gpuDriverMap = map[string]string{}
	for gpuArtifact, GPUDriverVersion := range GPUDriverMajorVersionArtifactsMap {
		majorGPUDriverLabel := strings.TrimPrefix(gpuArtifact, installer.MajorGPUDriverArtifactPrefix)
		majorGPUDriverLabel = strings.TrimSuffix(majorGPUDriverLabel, installer.MajorGPUDriverArtifactSuffix)
		gpuDriverMap[GPUDriverVersion] += "[" + majorGPUDriverLabel + "]"
	}
	return gpuDriverMap, nil
}
func (c *ListCommand) logError(err error) {
	if c.debug {
		log.Errorf("%+v", err)
	} else {
		log.Errorf("%v", err)
	}
}
func (c *ListCommand) logWarning(err error) {
	if c.debug {
		log.Warningf("%+v", err)
	} else {
		log.Warningf("%v", err)
	}
}
