| /* |
| * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package e2e |
| |
| import ( |
| "bytes" |
| "fmt" |
| "os" |
| "os/exec" |
| "time" |
| |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| type localRunner struct{} |
| type remoteRunner struct { |
| sshKey string |
| sshUser string |
| host string |
| port string |
| } |
| |
| type runnerOption func(*remoteRunner) |
| |
| type Runner interface { |
| Run(script string) (string, string, error) |
| } |
| |
| func WithSshKey(key string) runnerOption { |
| return func(r *remoteRunner) { |
| r.sshKey = key |
| } |
| } |
| |
| func WithSshUser(user string) runnerOption { |
| return func(r *remoteRunner) { |
| r.sshUser = user |
| } |
| } |
| |
| func WithHost(host string) runnerOption { |
| return func(r *remoteRunner) { |
| r.host = host |
| } |
| } |
| |
| func WithPort(port string) runnerOption { |
| return func(r *remoteRunner) { |
| r.port = port |
| } |
| } |
| |
| func NewRunner(opts ...runnerOption) Runner { |
| r := &remoteRunner{} |
| for _, opt := range opts { |
| opt(r) |
| } |
| |
| // If the Host is empty, return a local runner |
| if r.host == "" { |
| return localRunner{} |
| } |
| |
| // Otherwise, return a remote runner |
| return r |
| } |
| |
| func (l localRunner) Run(script string) (string, string, error) { |
| // Create a command to run the script using bash |
| cmd := exec.Command("bash", "-c", script) |
| |
| // Buffer to capture standard output |
| var stdout bytes.Buffer |
| cmd.Stdout = &stdout |
| |
| // Buffer to capture standard error |
| var stderr bytes.Buffer |
| cmd.Stderr = &stderr |
| |
| // Run the command |
| err := cmd.Run() |
| if err != nil { |
| return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String()) |
| } |
| |
| // Return the captured stdout and nil error |
| return stdout.String(), "", nil |
| } |
| |
| func (r remoteRunner) Run(script string) (string, string, error) { |
| // Create a new SSH connection |
| client, err := connectOrDie(r.sshKey, r.sshUser, r.host, r.port) |
| if err != nil { |
| return "", "", fmt.Errorf("failed to connect to %s: %v", r.host, err) |
| } |
| defer client.Close() |
| |
| // Create a session |
| session, err := client.NewSession() |
| if err != nil { |
| return "", "", fmt.Errorf("failed to create session: %v", err) |
| } |
| defer session.Close() |
| |
| // Capture stdout and stderr |
| var stdout, stderr bytes.Buffer |
| session.Stdout = &stdout |
| session.Stderr = &stderr |
| |
| // Run the script |
| err = session.Run(script) |
| if err != nil { |
| return "", "", fmt.Errorf("script execution failed: %v\nSTDOUT: %s\nSTDERR: %s", err, stdout.String(), stderr.String()) |
| } |
| |
| // Return stdout as string if no errors |
| return stdout.String(), "", nil |
| } |
| |
| // createSshClient creates a ssh client, and retries if it fails to connect |
| func connectOrDie(sshKey, sshUser, host, port string) (*ssh.Client, error) { |
| var client *ssh.Client |
| var err error |
| key, err := os.ReadFile(sshKey) |
| if err != nil { |
| return nil, fmt.Errorf("failed to read key file: %v", err) |
| } |
| signer, err := ssh.ParsePrivateKey(key) |
| if err != nil { |
| return nil, fmt.Errorf("failed to parse private key: %v", err) |
| } |
| sshConfig := &ssh.ClientConfig{ |
| User: sshUser, |
| Auth: []ssh.AuthMethod{ |
| ssh.PublicKeys(signer), |
| }, |
| HostKeyCallback: ssh.InsecureIgnoreHostKey(), |
| } |
| |
| connectionFailed := false |
| for i := 0; i < 20; i++ { |
| client, err = ssh.Dial("tcp", host+":"+port, sshConfig) |
| if err == nil { |
| return client, nil // Connection succeeded, return the client. |
| } |
| connectionFailed = true |
| // Sleep for a brief moment before retrying. |
| // You can adjust the duration based on your requirements. |
| time.Sleep(1 * time.Second) |
| } |
| |
| if connectionFailed { |
| return nil, fmt.Errorf("failed to connect to %s after 10 retries, giving up", host) |
| } |
| |
| return client, nil |
| } |