blob: b5da283df918e01f49246d24a3b203c519a3a4dd [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package policyenforcer
import (
"errors"
"fmt"
"os"
"testing"
"policy-manager/pkg/devicepolicy"
"policy-manager/pkg/sysapi"
"policy-manager/pkg/systemd"
"policy-manager/protos"
"github.com/golang/protobuf/proto"
)
const (
// tmpCosDevicePolicyFile is the tmp file storing the COS policy.
tmpCosDevicePolicyFile = "/tmp/cos_device_policy"
cosDevicePolicyFilePerm = os.FileMode(0644)
systemctlCmd = "testdata/systemctl.sh"
)
// generateFakePolicyBytes is a helper function to generate device policy bytes
// to simulate how device policy will be read from disk.
func generateFakePolicyBytes(t *testing.T, config *protos.InstanceConfig) (cosDevicePolicyBytes []byte) {
// Create fake policy.
fakePolicy := new(devicepolicy.DevicePolicy)
fakePolicy.SetFromInstanceConfig(config)
cosDevicePolicyBytes, err := proto.Marshal(config)
if err != nil {
t.Fatal(err)
}
return cosDevicePolicyBytes
}
// TestUpdateServiceState cheecks the update made to the services. It
// updates the status of the desired services.
func TestUpdateServiceState(t *testing.T) {
tests := []struct {
name string
onDiskConfig *protos.InstanceConfig
getConfigErr error
getStatusErr error
isMetricsEnabled bool
isUpdateEnabled bool
isLogging bool
isMonitoring bool
expectChangeMetricsErr error
expectChangeUpdateStrategyErr error
expectChangeLoggingErr error
expectChangeMonitoringErr error
expectErr bool
}{
{
name: "NoEnforcement",
onDiskConfig: &protos.InstanceConfig{},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: true,
isUpdateEnabled: true,
isLogging: true,
isMonitoring: true,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: false,
},
{
name: "TurnOnLogging",
onDiskConfig: &protos.InstanceConfig{
HealthMonitorConfig: &protos.HealthMonitorConfig{
Enforced: proto.Bool(true),
LoggingEnabled: proto.Bool(true),
},
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: false,
isUpdateEnabled: false,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: false,
},
{
name: "TurnOffUpdate",
onDiskConfig: &protos.InstanceConfig{
UpdateStrategy: proto.String("update_disabled"),
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: false,
isUpdateEnabled: true,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: false,
},
{
name: "TurnOnMonitoringTurnOffLoggingTurnOffMetrics",
onDiskConfig: &protos.InstanceConfig{
HealthMonitorConfig: &protos.HealthMonitorConfig{
Enforced: proto.Bool(true),
MonitoringEnabled: proto.Bool(true),
},
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: true,
isUpdateEnabled: false,
isLogging: true,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: false,
},
{
name: "TurnOnMetricsAndUpdate",
onDiskConfig: &protos.InstanceConfig{
MetricsEnabled: proto.Bool(true),
UpdateStrategy: proto.String(""),
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: false,
isUpdateEnabled: true,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: false,
},
{
name: "NoDiskConfig",
onDiskConfig: nil,
getConfigErr: errors.New("error"),
getStatusErr: nil,
isMetricsEnabled: false,
isUpdateEnabled: false,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: true,
},
{
name: "ErrorWhenTurnOnMetricsFails",
onDiskConfig: &protos.InstanceConfig{
MetricsEnabled: proto.Bool(true),
HealthMonitorConfig: &protos.HealthMonitorConfig{
Enforced: proto.Bool(true),
LoggingEnabled: proto.Bool(true),
MonitoringEnabled: proto.Bool(true),
},
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: false,
isUpdateEnabled: false,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: errors.New("error"),
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: nil,
expectChangeMonitoringErr: nil,
expectErr: true,
},
{
name: "ErrorWhenTurnOnLoggingWontAffectMonitoring",
onDiskConfig: &protos.InstanceConfig{
HealthMonitorConfig: &protos.HealthMonitorConfig{
Enforced: proto.Bool(true),
LoggingEnabled: proto.Bool(true),
MonitoringEnabled: proto.Bool(true),
},
},
getConfigErr: nil,
getStatusErr: nil,
isMetricsEnabled: true,
isUpdateEnabled: false,
isLogging: false,
isMonitoring: false,
expectChangeMetricsErr: nil,
expectChangeUpdateStrategyErr: nil,
expectChangeLoggingErr: errors.New("error"),
expectChangeMonitoringErr: nil,
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var err error
fakeSystemdClient := systemd.NewSystemdClient(systemctlCmd)
expectedInstanceConfig := devicepolicy.GetDefaultInstanceConfigFromBase(test.onDiskConfig)
expectedCosDevicePolicyBytes := generateFakePolicyBytes(t, expectedInstanceConfig)
// Expect the policy files to be written.
sysapi.AtomicWriteFile(tmpCosDevicePolicyFile, expectedCosDevicePolicyBytes, cosDevicePolicyFilePerm)
defer os.RemoveAll(tmpCosDevicePolicyFile)
serviceMonitor := map[string]string{
"loggingService": fmt.Sprintf("%s,%t,%v", "logging.service", test.isLogging, test.expectChangeLoggingErr),
"monitoringService": fmt.Sprintf("%s,%t,%v", "monitoring.service", test.isMonitoring, test.expectChangeMonitoringErr),
"metricsService": fmt.Sprintf("%s,%t,%v", "metrics.service", test.isMetricsEnabled, test.expectChangeMetricsErr),
"updateService": fmt.Sprintf("%s,%t,%v", "update.service", test.isUpdateEnabled, test.expectChangeUpdateStrategyErr),
}
client := NewPolicyEnforcer(*fakeSystemdClient)
err = client.UpdateServiceState(tmpCosDevicePolicyFile, serviceMonitor)
if test.onDiskConfig == nil {
err = errors.New("config not present")
}
if err == nil && test.expectErr {
t.Errorf("Test %s passed, want error", test.name)
} else if err != nil && !test.expectErr {
t.Errorf("Test %s got unexpected error %v", test.name, err)
}
})
}
}
// TestGetServiceStatus tests the retrieval of the status of services
// related to logging, monitoring, metrics and update engine.
func TestGetServiceStatus(t *testing.T) {
tests := []struct {
name string
isMetricsEnabled bool
checkMetricsErr error
isUpdateDisabled bool
checkUpdateDisabledError error
isLogging bool
checkLoggingErr error
isMonitoring bool
checkMonitoringErr error
expectedStatus *protos.ServiceStatus
expectErr bool
}{
{
name: "NoServiceRunning",
isMetricsEnabled: false,
checkMetricsErr: nil,
isUpdateDisabled: false,
checkUpdateDisabledError: nil,
isLogging: false,
checkLoggingErr: nil,
isMonitoring: false,
checkMonitoringErr: nil,
expectedStatus: &protos.ServiceStatus{
UpdateEngine: proto.String("update_disabled"),
Metrics: proto.Bool(false),
Logging: proto.Bool(false),
Monitoring: proto.Bool(false),
},
expectErr: false,
},
{
name: "AllServicesRunning",
isMetricsEnabled: true,
checkMetricsErr: nil,
isUpdateDisabled: true,
checkUpdateDisabledError: nil,
isLogging: true,
checkLoggingErr: nil,
isMonitoring: true,
checkMonitoringErr: nil,
expectedStatus: &protos.ServiceStatus{
UpdateEngine: proto.String(""),
Metrics: proto.Bool(true),
Logging: proto.Bool(true),
Monitoring: proto.Bool(true),
},
expectErr: false,
},
{
name: "CheckMetricsServiceFailed",
isMetricsEnabled: false,
checkMetricsErr: errors.New("error"),
isUpdateDisabled: true,
checkUpdateDisabledError: nil,
isLogging: true,
checkLoggingErr: nil,
isMonitoring: false,
checkMonitoringErr: nil,
expectedStatus: &protos.ServiceStatus{
UpdateEngine: proto.String(""),
Logging: proto.Bool(true),
Monitoring: proto.Bool(false),
},
expectErr: true,
},
{
name: "CheckUpdateServiceFailed",
isMetricsEnabled: true,
checkMetricsErr: nil,
isUpdateDisabled: false,
checkUpdateDisabledError: errors.New("error"),
isLogging: true,
checkLoggingErr: nil,
isMonitoring: false,
checkMonitoringErr: nil,
expectedStatus: &protos.ServiceStatus{
Metrics: proto.Bool(true),
Logging: proto.Bool(true),
Monitoring: proto.Bool(false),
},
expectErr: true,
},
{
name: "CheckLoggingServiceFailed",
isMetricsEnabled: true,
checkMetricsErr: nil,
isUpdateDisabled: true,
checkUpdateDisabledError: nil,
isLogging: false,
checkLoggingErr: errors.New("error"),
isMonitoring: true,
checkMonitoringErr: nil,
expectedStatus: &protos.ServiceStatus{
UpdateEngine: proto.String(""),
Metrics: proto.Bool(true),
Monitoring: proto.Bool(true),
},
expectErr: true,
},
{
name: "CheckMonitoringServiceFailed",
isMetricsEnabled: true,
checkMetricsErr: nil,
isUpdateDisabled: true,
checkUpdateDisabledError: nil,
isLogging: true,
checkLoggingErr: nil,
isMonitoring: false,
checkMonitoringErr: errors.New("error"),
expectedStatus: &protos.ServiceStatus{
UpdateEngine: proto.String(""),
Metrics: proto.Bool(true),
Logging: proto.Bool(true),
},
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var err error
fakeSystemdClient := systemd.NewSystemdClient(systemctlCmd)
serviceMonitor := map[string]string{
"loggingService": fmt.Sprintf("%s,%t,%v", "logging.service", test.isLogging, test.checkLoggingErr),
"monitoringService": fmt.Sprintf("%s,%t,%v", "monitoring.service", test.isMonitoring, test.checkMonitoringErr),
"metricsService": fmt.Sprintf("%s,%t,%v", "metrics.service", test.isMetricsEnabled, test.checkMetricsErr),
"updateService": fmt.Sprintf("%s,%t,%v", "update.service", test.isUpdateDisabled, test.checkUpdateDisabledError),
}
client := NewPolicyEnforcer(*fakeSystemdClient)
status, err := client.GetServiceStatus(serviceMonitor)
if err == nil && test.expectErr {
t.Errorf("Test %s passed, want error", test.name)
} else if err != nil && !test.expectErr {
t.Errorf("Test %s got unexpected error %v", test.name, err)
} else if !proto.Equal(status, test.expectedStatus) {
t.Errorf("Test %s got %s, want %s",
test.name,
status.String(),
test.expectedStatus.String())
}
})
}
}