big refactor and testing
This commit is contained in:
parent
a2ed11b2fb
commit
d6b1b90034
26 changed files with 908 additions and 70 deletions
|
|
@ -5,7 +5,6 @@ import (
|
|||
|
||||
"github.com/amidaware/rmmagent/agent/utils"
|
||||
"github.com/gonutz/w32/v2"
|
||||
trmm "github.com/wh1te909/trmm-shared"
|
||||
"golang.org/x/sys/windows/svc/mgr"
|
||||
)
|
||||
|
||||
|
|
@ -51,14 +50,14 @@ func GetServiceStatus(name string) (string, error) {
|
|||
if err != nil {
|
||||
return "n/a", err
|
||||
}
|
||||
defer conn.Disconnect()
|
||||
|
||||
defer conn.Disconnect()
|
||||
srv, err := conn.OpenService(name)
|
||||
if err != nil {
|
||||
return "n/a", err
|
||||
}
|
||||
defer srv.Close()
|
||||
|
||||
defer srv.Close()
|
||||
q, err := srv.Query()
|
||||
if err != nil {
|
||||
return "n/a", err
|
||||
|
|
@ -90,28 +89,28 @@ func serviceStatusText(num uint32) string {
|
|||
}
|
||||
|
||||
// GetServices returns a list of windows services
|
||||
func GetServices() []trmm.WindowsService {
|
||||
ret := make([]trmm.WindowsService, 0)
|
||||
func GetServices() ([]Service, []error, error) {
|
||||
ret := make([]Service, 0)
|
||||
|
||||
conn, err := mgr.Connect()
|
||||
if err != nil {
|
||||
//a.Logger.Debugln(err)
|
||||
return ret
|
||||
return ret, nil, err
|
||||
}
|
||||
|
||||
defer conn.Disconnect()
|
||||
|
||||
svcs, err := conn.ListServices()
|
||||
|
||||
if err != nil {
|
||||
//a.Logger.Debugln(err)
|
||||
return ret
|
||||
return ret, nil, err
|
||||
}
|
||||
|
||||
errors := []error{}
|
||||
|
||||
for _, s := range svcs {
|
||||
srv, err := conn.OpenService(s)
|
||||
if err != nil {
|
||||
if err.Error() != "Access is denied." {
|
||||
//a.Logger.Debugln("Open Service", s, err)
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
continue
|
||||
|
|
@ -120,17 +119,17 @@ func GetServices() []trmm.WindowsService {
|
|||
defer srv.Close()
|
||||
q, err := srv.Query()
|
||||
if err != nil {
|
||||
//a.Logger.Debugln(err)
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
conf, err := srv.Config()
|
||||
if err != nil {
|
||||
//a.Logger.Debugln(err)
|
||||
errors = append(errors, err)
|
||||
continue
|
||||
}
|
||||
|
||||
ret = append(ret, trmm.WindowsService{
|
||||
ret = append(ret, Service{
|
||||
Name: s,
|
||||
Status: serviceStatusText(uint32(q.State)),
|
||||
DisplayName: utils.CleanString(conf.DisplayName),
|
||||
|
|
@ -142,7 +141,8 @@ func GetServices() []trmm.WindowsService {
|
|||
DelayedAutoStart: conf.DelayedAutoStart,
|
||||
})
|
||||
}
|
||||
return ret
|
||||
|
||||
return ret, errors, nil
|
||||
}
|
||||
|
||||
// https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicestartmode?view=dotnet-plat-ext-3.1
|
||||
|
|
|
|||
79
agent/services/services_windows_test.go
Normal file
79
agent/services/services_windows_test.go
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
package services_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/amidaware/rmmagent/agent/services"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestGetServices(t *testing.T) {
|
||||
testTable := []struct {
|
||||
name string
|
||||
expected []services.Service
|
||||
atLeast int
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Get Services",
|
||||
expected: []services.Service{},
|
||||
atLeast: 1,
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testTable {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, errs, err := services.GetServices()
|
||||
if fmt.Sprintf("%T", result) != "[]services.Service" {
|
||||
t.Errorf("expected type %T, got type %T", tt.expected, result)
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
t.Logf("Continue errors occured %v", errs)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("expected error (%v), got error(%v)", tt.expectedError, err)
|
||||
}
|
||||
|
||||
if len(result) < tt.atLeast {
|
||||
t.Errorf("expect at least %d, got %d", tt.atLeast, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServiceStatus(t *testing.T) {
|
||||
testTable := []struct {
|
||||
name string
|
||||
expected string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "CryptSvc",
|
||||
expected: "running",
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "NonExistentService",
|
||||
expected: "n/a",
|
||||
expectedError: windows.ERROR_SERVICE_DOES_NOT_EXIST,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testTable {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := services.GetServiceStatus(tt.name)
|
||||
if result != tt.expected {
|
||||
t.Errorf("expected %s, got %s", tt.expected, result)
|
||||
}
|
||||
|
||||
if !errors.Is(err, tt.expectedError) {
|
||||
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
13
agent/services/structs.go
Normal file
13
agent/services/structs.go
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
package services
|
||||
|
||||
type Service struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
DisplayName string `json:"display_name"`
|
||||
BinPath string `json:"binpath"`
|
||||
Description string `json:"description"`
|
||||
Username string `json:"username"`
|
||||
PID uint32 `json:"pid"`
|
||||
StartType string `json:"start_type"`
|
||||
DelayedAutoStart bool `json:"autodelay"`
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue