big refactor and testing
This commit is contained in:
parent
a2ed11b2fb
commit
d6b1b90034
26 changed files with 908 additions and 70 deletions
|
|
@ -1,14 +0,0 @@
|
||||||
package disk
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestGetDisks(t *testing.T) {
|
|
||||||
disks := GetDisks()
|
|
||||||
if len(disks) == 0 {
|
|
||||||
t.Fatalf("Could not get disks on linux system.")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("Got %d disks on linux system", len(disks))
|
|
||||||
}
|
|
||||||
51
agent/disk/disk_test.go
Normal file
51
agent/disk/disk_test.go
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
package disk_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amidaware/rmmagent/agent/disk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetDisks(t *testing.T) {
|
||||||
|
exampleDisk := disk.Disk{
|
||||||
|
Device: "C:",
|
||||||
|
Fstype: "NTFS",
|
||||||
|
Total: "149.9 GB",
|
||||||
|
Used: "129.2 GB",
|
||||||
|
Free: "20.7 GB",
|
||||||
|
Percent: 86,
|
||||||
|
}
|
||||||
|
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expected []disk.Disk
|
||||||
|
atLeast int
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Disks",
|
||||||
|
expected: []disk.Disk{exampleDisk},
|
||||||
|
atLeast: 1,
|
||||||
|
expectedError: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := disk.GetDisks()
|
||||||
|
if fmt.Sprintf("%T", result) != "[]disk.Disk" {
|
||||||
|
t.Errorf("expected type %T, got type %T", tt.expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, tt.expectedError) {
|
||||||
|
t.Errorf("expected error (%v), got error(%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) < 1 {
|
||||||
|
t.Errorf("expected count at least %d, got %d", tt.atLeast, len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,7 +5,6 @@ import (
|
||||||
|
|
||||||
"github.com/amidaware/rmmagent/agent/utils"
|
"github.com/amidaware/rmmagent/agent/utils"
|
||||||
"github.com/shirou/gopsutil/disk"
|
"github.com/shirou/gopsutil/disk"
|
||||||
trmm "github.com/wh1te909/trmm-shared"
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -14,12 +13,11 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetDisks returns a list of fixed disks
|
// GetDisks returns a list of fixed disks
|
||||||
func GetDisks() []trmm.Disk {
|
func GetDisks() ([]Disk, error) {
|
||||||
ret := make([]trmm.Disk, 0)
|
ret := make([]Disk, 0)
|
||||||
partitions, err := disk.Partitions(false)
|
partitions, err := disk.Partitions(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
return ret, err
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range partitions {
|
for _, p := range partitions {
|
||||||
|
|
@ -36,7 +34,7 @@ func GetDisks() []trmm.Disk {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
d := trmm.Disk{
|
d := Disk{
|
||||||
Device: p.Device,
|
Device: p.Device,
|
||||||
Fstype: p.Fstype,
|
Fstype: p.Fstype,
|
||||||
Total: utils.ByteCountSI(usage.Total),
|
Total: utils.ByteCountSI(usage.Total),
|
||||||
|
|
@ -44,7 +42,9 @@ func GetDisks() []trmm.Disk {
|
||||||
Free: utils.ByteCountSI(usage.Free),
|
Free: utils.ByteCountSI(usage.Free),
|
||||||
Percent: int(usage.UsedPercent),
|
Percent: int(usage.UsedPercent),
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = append(ret, d)
|
ret = append(ret, d)
|
||||||
}
|
}
|
||||||
return ret
|
|
||||||
|
return ret, err
|
||||||
}
|
}
|
||||||
10
agent/disk/structs.go
Normal file
10
agent/disk/structs.go
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
package disk
|
||||||
|
|
||||||
|
type Disk struct {
|
||||||
|
Device string `json:"device"`
|
||||||
|
Fstype string `json:"fstype"`
|
||||||
|
Total string `json:"total"`
|
||||||
|
Used string `json:"used"`
|
||||||
|
Free string `json:"free"`
|
||||||
|
Percent int `json:"percent"`
|
||||||
|
}
|
||||||
5
agent/events/events_linux.go
Normal file
5
agent/events/events_linux.go
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
package events
|
||||||
|
|
||||||
|
func GetEventLog(logName string, searchLastDays int) ([]EventLogMsg, error) {
|
||||||
|
return []EventLogMsg{}, nil
|
||||||
|
}
|
||||||
|
|
@ -6,48 +6,40 @@ import (
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/amidaware/rmmagent/agent/syscall"
|
"github.com/amidaware/rmmagent/agent/syscall"
|
||||||
rmm "github.com/amidaware/rmmagent/shared"
|
|
||||||
"github.com/gonutz/w32/v2"
|
"github.com/gonutz/w32/v2"
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg {
|
func GetEventLog(logName string, searchLastDays int) ([]EventLogMsg, error) {
|
||||||
var (
|
var (
|
||||||
oldestLog uint32
|
oldestLog uint32
|
||||||
nextSize uint32
|
nextSize uint32
|
||||||
readBytes uint32
|
readBytes uint32
|
||||||
)
|
)
|
||||||
|
|
||||||
buf := []byte{0}
|
buf := []byte{0}
|
||||||
size := uint32(1)
|
size := uint32(1)
|
||||||
|
ret := make([]EventLogMsg, 0)
|
||||||
ret := make([]rmm.EventLogMsg, 0)
|
|
||||||
startTime := time.Now().Add(time.Duration(-(time.Duration(searchLastDays)) * (24 * time.Hour)))
|
startTime := time.Now().Add(time.Duration(-(time.Duration(searchLastDays)) * (24 * time.Hour)))
|
||||||
|
|
||||||
h := w32.OpenEventLog("", logName)
|
h := w32.OpenEventLog("", logName)
|
||||||
defer w32.CloseEventLog(h)
|
defer w32.CloseEventLog(h)
|
||||||
|
|
||||||
numRecords, _ := w32.GetNumberOfEventLogRecords(h)
|
numRecords, _ := w32.GetNumberOfEventLogRecords(h)
|
||||||
syscall.GetOldestEventLogRecord(h, &oldestLog)
|
err := syscall.GetOldestEventLogRecord(h, &oldestLog)
|
||||||
|
|
||||||
startNum := numRecords + oldestLog - 1
|
startNum := numRecords + oldestLog - 1
|
||||||
uid := 0
|
uid := 0
|
||||||
for i := startNum; i >= oldestLog; i-- {
|
for i := startNum; i >= oldestLog; i-- {
|
||||||
flags := syscall.EVENTLOG_BACKWARDS_READ | syscall.EVENTLOG_SEEK_READ
|
flags := syscall.EVENTLOG_BACKWARDS_READ | syscall.EVENTLOG_SEEK_READ
|
||||||
|
|
||||||
err := syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize)
|
err := syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err != windows.ERROR_INSUFFICIENT_BUFFER {
|
if err != windows.ERROR_INSUFFICIENT_BUFFER {
|
||||||
//a.Logger.Debugln(err)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
buf = make([]byte, nextSize)
|
buf = make([]byte, nextSize)
|
||||||
size = nextSize
|
size = nextSize
|
||||||
err = syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize)
|
err = syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
r := *(*syscall.EVENTLOGRECORD)(unsafe.Pointer(&buf[0]))
|
r := *(*syscall.EVENTLOGRECORD)(unsafe.Pointer(&buf[0]))
|
||||||
|
|
@ -75,10 +67,11 @@ func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg {
|
||||||
if r.NumStrings > 0 {
|
if r.NumStrings > 0 {
|
||||||
argsptr = uintptr(unsafe.Pointer(&args[0]))
|
argsptr = uintptr(unsafe.Pointer(&args[0]))
|
||||||
}
|
}
|
||||||
|
|
||||||
message, _ := syscall.GetResourceMessage(logName, sourceName, r.EventID, argsptr)
|
message, _ := syscall.GetResourceMessage(logName, sourceName, r.EventID, argsptr)
|
||||||
|
|
||||||
uid++
|
uid++
|
||||||
eventLogMsg := rmm.EventLogMsg{
|
eventLogMsg := EventLogMsg{
|
||||||
Source: sourceName,
|
Source: sourceName,
|
||||||
EventType: eventType,
|
EventType: eventType,
|
||||||
EventID: eventID,
|
EventID: eventID,
|
||||||
|
|
@ -86,9 +79,11 @@ func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg {
|
||||||
Time: timeWritten.String(),
|
Time: timeWritten.String(),
|
||||||
UID: uid,
|
UID: uid,
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = append(ret, eventLogMsg)
|
ret = append(ret, eventLogMsg)
|
||||||
}
|
}
|
||||||
return ret
|
|
||||||
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/mackerelio/go-check-plugins/blob/ad7910fdc45ccb892b5e5fda65ba0956c2b2885d/check-windows-eventlog/lib/check-windows-eventlog.go#L219
|
// https://github.com/mackerelio/go-check-plugins/blob/ad7910fdc45ccb892b5e5fda65ba0956c2b2885d/check-windows-eventlog/lib/check-windows-eventlog.go#L219
|
||||||
|
|
@ -102,6 +97,7 @@ func bytesToString(b []byte) (string, uint32) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(utf16.Decode(s)), uint32(i * 2)
|
return string(utf16.Decode(s)), uint32(i * 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
46
agent/events/events_windows_test.go
Normal file
46
agent/events/events_windows_test.go
Normal file
|
|
@ -0,0 +1,46 @@
|
||||||
|
package events_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amidaware/rmmagent/agent/events"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetEventLog(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expected []events.EventLogMsg
|
||||||
|
atLeast int
|
||||||
|
expectedError error
|
||||||
|
logname string
|
||||||
|
search int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get EventLog",
|
||||||
|
expected: []events.EventLogMsg{},
|
||||||
|
atLeast: 1,
|
||||||
|
expectedError: nil,
|
||||||
|
logname: "Application",
|
||||||
|
search: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := events.GetEventLog(tt.logname, tt.search)
|
||||||
|
if fmt.Sprintf("%T", result) != "[]events.EventLogMsg" {
|
||||||
|
t.Errorf("expected type %T, got type %T", []events.EventLogMsg{}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, tt.expectedError) {
|
||||||
|
t.Errorf("expected error (%v), got error (%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) < 1 {
|
||||||
|
t.Errorf("expected count at least %d, got %d", tt.atLeast, len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
10
agent/events/structs.go
Normal file
10
agent/events/structs.go
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
package events
|
||||||
|
|
||||||
|
type EventLogMsg struct {
|
||||||
|
Source string `json:"source"`
|
||||||
|
EventType string `json:"eventType"`
|
||||||
|
EventID uint32 `json:"eventID"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Time string `json:"time"`
|
||||||
|
UID int `json:"uid"` // for vue
|
||||||
|
}
|
||||||
|
|
@ -2,6 +2,6 @@ package patching
|
||||||
|
|
||||||
func PatchMgmnt(enable bool) error { return nil }
|
func PatchMgmnt(enable bool) error { return nil }
|
||||||
|
|
||||||
func GetWinUpdates() {}
|
func GetUpdates() {}
|
||||||
|
|
||||||
func InstallUpdates(guids []string) {}
|
func InstallUpdates(guids []string) {}
|
||||||
63
agent/patching/patching_test.go
Normal file
63
agent/patching/patching_test.go
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
package patching_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amidaware/rmmagent/agent/patching"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPatchMgmnt(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expectedError error
|
||||||
|
status bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Enable Patch Mgmnt",
|
||||||
|
expectedError: nil,
|
||||||
|
status: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disable Patch Mgmnt",
|
||||||
|
expectedError: nil,
|
||||||
|
status: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := patching.PatchMgmnt(tt.status)
|
||||||
|
if err != tt.expectedError {
|
||||||
|
t.Errorf("expected error (%v), got error (%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUpdates(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expectedError error
|
||||||
|
atLeast int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Updates",
|
||||||
|
expectedError: nil,
|
||||||
|
atLeast: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := patching.GetUpdates()
|
||||||
|
if !errors.Is(tt.expectedError, err) {
|
||||||
|
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) < tt.atLeast {
|
||||||
|
t.Errorf("expected at least %d, got %d", tt.atLeast, len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
package patching
|
package patching
|
||||||
|
|
||||||
import "golang.org/x/sys/windows/registry"
|
import (
|
||||||
|
"github.com/amidaware/rmmagent/agent/patching/wua"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
// PatchMgmnt enables/disables automatic update
|
// PatchMgmnt enables/disables automatic update
|
||||||
// 0 - Enable Automatic Updates (Default)
|
// 0 - Enable Automatic Updates (Default)
|
||||||
|
|
@ -26,3 +29,33 @@ func PatchMgmnt(enable bool) error {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PackageList []Package
|
||||||
|
|
||||||
|
func GetUpdates() (PackageList, error) {
|
||||||
|
wuaupdates, err := wua.WUAUpdates("IsInstalled=1 or IsInstalled=0 and Type='Software' and IsHidden=0")
|
||||||
|
packages := []Package{}
|
||||||
|
for _, p := range wuaupdates {
|
||||||
|
packages = append(packages, Package(p))
|
||||||
|
}
|
||||||
|
|
||||||
|
return packages, err
|
||||||
|
// if err != nil {
|
||||||
|
// a.Logger.Errorln(err)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for _, update := range updates {
|
||||||
|
// a.Logger.Debugln("GUID:", update.UpdateID)
|
||||||
|
// a.Logger.Debugln("Downloaded:", update.Downloaded)
|
||||||
|
// a.Logger.Debugln("Installed:", update.Installed)
|
||||||
|
// a.Logger.Debugln("KB:", update.KBArticleIDs)
|
||||||
|
// a.Logger.Debugln("--------------------------------")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// payload := rmm.WinUpdateResult{AgentID: a.AgentID, Updates: updates}
|
||||||
|
// _, err = a.rClient.R().SetBody(payload).Post("/api/v3/winupdates/")
|
||||||
|
// if err != nil {
|
||||||
|
// a.Logger.Debugln(err)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
|
||||||
16
agent/patching/structs.go
Normal file
16
agent/patching/structs.go
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
package patching
|
||||||
|
|
||||||
|
type Package struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Categories []string `json:"categories"`
|
||||||
|
CategoryIDs []string `json:"category_ids"`
|
||||||
|
KBArticleIDs []string `json:"kb_article_ids"`
|
||||||
|
MoreInfoURLs []string `json:"more_info_urls"`
|
||||||
|
SupportURL string `json:"support_url"`
|
||||||
|
UpdateID string `json:"guid"`
|
||||||
|
RevisionNumber int32 `json:"revision_number"`
|
||||||
|
Severity string `json:"severity"`
|
||||||
|
Installed bool `json:"installed"`
|
||||||
|
Downloaded bool `json:"downloaded"`
|
||||||
|
}
|
||||||
33
agent/patching/wua/structs.go
Normal file
33
agent/patching/wua/structs.go
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
package wua
|
||||||
|
|
||||||
|
import (
|
||||||
|
ole "github.com/go-ole/go-ole"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WUAPackage struct {
|
||||||
|
Title string `json:"title"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Categories []string `json:"categories"`
|
||||||
|
CategoryIDs []string `json:"category_ids"`
|
||||||
|
KBArticleIDs []string `json:"kb_article_ids"`
|
||||||
|
MoreInfoURLs []string `json:"more_info_urls"`
|
||||||
|
SupportURL string `json:"support_url"`
|
||||||
|
UpdateID string `json:"guid"`
|
||||||
|
RevisionNumber int32 `json:"revision_number"`
|
||||||
|
Severity string `json:"severity"`
|
||||||
|
Installed bool `json:"installed"`
|
||||||
|
Downloaded bool `json:"downloaded"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IUpdateSession is a an IUpdateSession.
|
||||||
|
type IUpdateSession struct {
|
||||||
|
*ole.IDispatch
|
||||||
|
}
|
||||||
|
|
||||||
|
type IUpdateCollection struct {
|
||||||
|
*ole.IDispatch
|
||||||
|
}
|
||||||
|
|
||||||
|
type IUpdate struct {
|
||||||
|
*ole.IDispatch
|
||||||
|
}
|
||||||
437
agent/patching/wua/wua_windows.go
Normal file
437
agent/patching/wua/wua_windows.go
Normal file
|
|
@ -0,0 +1,437 @@
|
||||||
|
package wua
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
ole "github.com/go-ole/go-ole"
|
||||||
|
"github.com/go-ole/go-ole/oleutil"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
S_OK = 0
|
||||||
|
S_FALSE = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var wuaSession sync.Mutex
|
||||||
|
|
||||||
|
func (s *IUpdateSession) Close() {
|
||||||
|
if s.IDispatch != nil {
|
||||||
|
s.IDispatch.Release()
|
||||||
|
}
|
||||||
|
|
||||||
|
ole.CoUninitialize()
|
||||||
|
wuaSession.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUpdateSession() (*IUpdateSession, error) {
|
||||||
|
wuaSession.Lock()
|
||||||
|
if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil {
|
||||||
|
e, ok := err.(*ole.OleError)
|
||||||
|
// S_OK and S_FALSE are both are Success codes.
|
||||||
|
// https://docs.microsoft.com/en-us/windows/win32/learnwin32/error-handling-in-com
|
||||||
|
if !ok || (e.Code() != S_OK && e.Code() != S_FALSE) {
|
||||||
|
wuaSession.Unlock()
|
||||||
|
return nil, fmt.Errorf(`ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED): %v`, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &IUpdateSession{}
|
||||||
|
|
||||||
|
unknown, err := oleutil.CreateObject("Microsoft.Update.Session")
|
||||||
|
if err != nil {
|
||||||
|
s.Close()
|
||||||
|
return nil, fmt.Errorf(`oleutil.CreateObject("Microsoft.Update.Session"): %v`, err)
|
||||||
|
}
|
||||||
|
disp, err := unknown.QueryInterface(ole.IID_IDispatch)
|
||||||
|
if err != nil {
|
||||||
|
unknown.Release()
|
||||||
|
s.Close()
|
||||||
|
return nil, fmt.Errorf(`error creating Dispatch object from Microsoft.Update.Session connection: %v`, err)
|
||||||
|
}
|
||||||
|
s.IDispatch = disp
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstallWUAUpdate install a WIndows update.
|
||||||
|
func (s *IUpdateSession) InstallWUAUpdate(updt *IUpdate) error {
|
||||||
|
_, err := updt.GetProperty("Title")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`updt.GetProperty("Title"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updts, err := NewUpdateCollection()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer updts.Release()
|
||||||
|
|
||||||
|
eula, err := updt.GetProperty("EulaAccepted")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(`updt.GetProperty("EulaAccepted"): %v`, err)
|
||||||
|
}
|
||||||
|
// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-oaut/7b39eb24-9d39-498a-bcd8-75c38e5823d0
|
||||||
|
if eula.Val == 0 {
|
||||||
|
if _, err := updt.CallMethod("AcceptEula"); err != nil {
|
||||||
|
return fmt.Errorf(`updt.CallMethod("AcceptEula"): %v`, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := updts.Add(updt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.DownloadWUAUpdateCollection(updts); err != nil {
|
||||||
|
return fmt.Errorf("DownloadWUAUpdateCollection error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.InstallWUAUpdateCollection(updts); err != nil {
|
||||||
|
return fmt.Errorf("InstallWUAUpdateCollection error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUpdateCollection() (*IUpdateCollection, error) {
|
||||||
|
updateCollObj, err := oleutil.CreateObject("Microsoft.Update.UpdateColl")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`oleutil.CreateObject("Microsoft.Update.UpdateColl"): %v`, err)
|
||||||
|
}
|
||||||
|
defer updateCollObj.Release()
|
||||||
|
|
||||||
|
updateColl, err := updateCollObj.IDispatch(ole.IID_IDispatch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &IUpdateCollection{IDispatch: updateColl}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IUpdateCollection) Add(updt *IUpdate) error {
|
||||||
|
if _, err := c.CallMethod("Add", updt.IDispatch); err != nil {
|
||||||
|
return fmt.Errorf(`IUpdateCollection.CallMethod("Add", updt): %v`, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IUpdateCollection) RemoveAt(i int) error {
|
||||||
|
if _, err := c.CallMethod("RemoveAt", i); err != nil {
|
||||||
|
return fmt.Errorf(`IUpdateCollection.CallMethod("RemoveAt", %d): %v`, i, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IUpdateCollection) Count() (int32, error) {
|
||||||
|
return GetCount(c.IDispatch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IUpdateCollection) Item(i int) (*IUpdate, error) {
|
||||||
|
updtRaw, err := c.GetProperty("Item", i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`IUpdateCollection.GetProperty("Item", %d): %v`, i, err)
|
||||||
|
}
|
||||||
|
return &IUpdate{IDispatch: updtRaw.ToIDispatch()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCount returns the Count property.
|
||||||
|
func GetCount(dis *ole.IDispatch) (int32, error) {
|
||||||
|
countRaw, err := dis.GetProperty("Count")
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf(`IDispatch.GetProperty("Count"): %v`, err)
|
||||||
|
}
|
||||||
|
count, _ := countRaw.Value().(int32)
|
||||||
|
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *IUpdate) kbaIDs() ([]string, error) {
|
||||||
|
kbArticleIDsRaw, err := u.GetProperty("KBArticleIDs")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`IUpdate.GetProperty("KBArticleIDs"): %v`, err)
|
||||||
|
}
|
||||||
|
kbArticleIDs := kbArticleIDsRaw.ToIDispatch()
|
||||||
|
defer kbArticleIDs.Release()
|
||||||
|
|
||||||
|
count, err := GetCount(kbArticleIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ss []string
|
||||||
|
for i := 0; i < int(count); i++ {
|
||||||
|
item, err := kbArticleIDs.GetProperty("Item", i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`kbArticleIDs.GetProperty("Item", %d): %v`, i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ss = append(ss, item.ToString())
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *IUpdate) categories() ([]string, []string, error) {
|
||||||
|
catRaw, err := u.GetProperty("Categories")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf(`IUpdate.GetProperty("Categories"): %v`, err)
|
||||||
|
}
|
||||||
|
cat := catRaw.ToIDispatch()
|
||||||
|
defer cat.Release()
|
||||||
|
|
||||||
|
count, err := GetCount(cat)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
return nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cns, cids []string
|
||||||
|
for i := 0; i < int(count); i++ {
|
||||||
|
itemRaw, err := cat.GetProperty("Item", i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf(`cat.GetProperty("Item", %d): %v`, i, err)
|
||||||
|
}
|
||||||
|
item := itemRaw.ToIDispatch()
|
||||||
|
defer item.Release()
|
||||||
|
|
||||||
|
name, err := item.GetProperty("Name")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf(`item.GetProperty("Name"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
categoryID, err := item.GetProperty("CategoryID")
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf(`item.GetProperty("CategoryID"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cns = append(cns, name.ToString())
|
||||||
|
cids = append(cids, categoryID.ToString())
|
||||||
|
}
|
||||||
|
return cns, cids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *IUpdate) moreInfoURLs() ([]string, error) {
|
||||||
|
moreInfoURLsRaw, err := u.GetProperty("MoreInfoURLs")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`IUpdate.GetProperty("MoreInfoURLs"): %v`, err)
|
||||||
|
}
|
||||||
|
moreInfoURLs := moreInfoURLsRaw.ToIDispatch()
|
||||||
|
defer moreInfoURLs.Release()
|
||||||
|
|
||||||
|
count, err := GetCount(moreInfoURLs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ss []string
|
||||||
|
for i := 0; i < int(count); i++ {
|
||||||
|
item, err := moreInfoURLs.GetProperty("Item", i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`moreInfoURLs.GetProperty("Item", %d): %v`, i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ss = append(ss, item.ToString())
|
||||||
|
}
|
||||||
|
return ss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *IUpdateCollection) extractPkg(item int) (*WUAPackage, error) {
|
||||||
|
updt, err := c.Item(item)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer updt.Release()
|
||||||
|
|
||||||
|
title, err := updt.GetProperty("Title")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("Title"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
description, err := updt.GetProperty("Description")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("Description"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kbArticleIDs, err := updt.kbaIDs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
categories, categoryIDs, err := updt.categories()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
moreInfoURLs, err := updt.moreInfoURLs()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
supportURL, err := updt.GetProperty("SupportURL")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("SupportURL"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
identityRaw, err := updt.GetProperty("Identity")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("Identity"): %v`, err)
|
||||||
|
}
|
||||||
|
identity := identityRaw.ToIDispatch()
|
||||||
|
defer identity.Release()
|
||||||
|
|
||||||
|
revisionNumber, err := identity.GetProperty("RevisionNumber")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`identity.GetProperty("RevisionNumber"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateID, err := identity.GetProperty("UpdateID")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`identity.GetProperty("UpdateID"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
severity, err := updt.GetProperty("MsrcSeverity")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("MsrcSeverity"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isInstalled, err := updt.GetProperty("IsInstalled")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("IsInstalled"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
isDownloaded, err := updt.GetProperty("IsDownloaded")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf(`updt.GetProperty("IsDownloaded"): %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WUAPackage{
|
||||||
|
Title: title.ToString(),
|
||||||
|
Description: description.ToString(),
|
||||||
|
SupportURL: supportURL.ToString(),
|
||||||
|
KBArticleIDs: kbArticleIDs,
|
||||||
|
UpdateID: updateID.ToString(),
|
||||||
|
Categories: categories,
|
||||||
|
CategoryIDs: categoryIDs,
|
||||||
|
MoreInfoURLs: moreInfoURLs,
|
||||||
|
Severity: severity.ToString(),
|
||||||
|
RevisionNumber: int32(revisionNumber.Val),
|
||||||
|
Downloaded: isDownloaded.Value().(bool),
|
||||||
|
Installed: isInstalled.Value().(bool),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WUAUpdates queries the Windows Update Agent API searcher with the provided query.
|
||||||
|
func WUAUpdates(query string) ([]WUAPackage, error) {
|
||||||
|
session, err := NewUpdateSession()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating NewUpdateSession: %v", err)
|
||||||
|
}
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
updts, err := session.GetWUAUpdateCollection(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling GetWUAUpdateCollection with query %q: %v", query, err)
|
||||||
|
}
|
||||||
|
defer updts.Release()
|
||||||
|
|
||||||
|
updtCnt, err := updts.Count()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if updtCnt == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var packages []WUAPackage
|
||||||
|
for i := 0; i < int(updtCnt); i++ {
|
||||||
|
pkg, err := updts.extractPkg(i)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
packages = append(packages, *pkg)
|
||||||
|
}
|
||||||
|
return packages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadWUAUpdateCollection downloads all updates in a IUpdateCollection
|
||||||
|
func (s *IUpdateSession) DownloadWUAUpdateCollection(updates *IUpdateCollection) error {
|
||||||
|
// returns IUpdateDownloader
|
||||||
|
// https://docs.microsoft.com/en-us/windows/desktop/api/wuapi/nn-wuapi-iupdatedownloader
|
||||||
|
downloaderRaw, err := s.CallMethod("CreateUpdateDownloader")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error calling method CreateUpdateDownloader on IUpdateSession: %v", err)
|
||||||
|
}
|
||||||
|
downloader := downloaderRaw.ToIDispatch()
|
||||||
|
defer downloader.Release()
|
||||||
|
|
||||||
|
if _, err := downloader.PutProperty("Updates", updates.IDispatch); err != nil {
|
||||||
|
return fmt.Errorf("error calling PutProperty Updates on IUpdateDownloader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := downloader.CallMethod("Download"); err != nil {
|
||||||
|
return fmt.Errorf("error calling method Download on IUpdateDownloader: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstallWUAUpdateCollection installs all updates in a IUpdateCollection
|
||||||
|
func (s *IUpdateSession) InstallWUAUpdateCollection(updates *IUpdateCollection) error {
|
||||||
|
// returns IUpdateInstallersession *ole.IDispatch,
|
||||||
|
// https://docs.microsoft.com/en-us/windows/desktop/api/wuapi/nf-wuapi-iupdatesession-createupdateinstaller
|
||||||
|
installerRaw, err := s.CallMethod("CreateUpdateInstaller")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error calling method CreateUpdateInstaller on IUpdateSession: %v", err)
|
||||||
|
}
|
||||||
|
installer := installerRaw.ToIDispatch()
|
||||||
|
defer installer.Release()
|
||||||
|
|
||||||
|
if _, err := installer.PutProperty("Updates", updates.IDispatch); err != nil {
|
||||||
|
return fmt.Errorf("error calling PutProperty Updates on IUpdateInstaller: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Look into using the async methods and attempt to track/log progress.
|
||||||
|
if _, err := installer.CallMethod("Install"); err != nil {
|
||||||
|
return fmt.Errorf("error calling method Install on IUpdateInstaller: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWUAUpdateCollection queries the Windows Update Agent API searcher with the provided query
|
||||||
|
// and returns a IUpdateCollection.
|
||||||
|
func (s *IUpdateSession) GetWUAUpdateCollection(query string) (*IUpdateCollection, error) {
|
||||||
|
// returns IUpdateSearcher
|
||||||
|
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa386515(v=vs.85).aspx
|
||||||
|
searcherRaw, err := s.CallMethod("CreateUpdateSearcher")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling CreateUpdateSearcher: %v", err)
|
||||||
|
}
|
||||||
|
searcher := searcherRaw.ToIDispatch()
|
||||||
|
defer searcher.Release()
|
||||||
|
|
||||||
|
// returns ISearchResult
|
||||||
|
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa386077(v=vs.85).aspx
|
||||||
|
resultRaw, err := searcher.CallMethod("Search", query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling method Search on IUpdateSearcher: %v", err)
|
||||||
|
}
|
||||||
|
result := resultRaw.ToIDispatch()
|
||||||
|
defer result.Release()
|
||||||
|
|
||||||
|
// returns IUpdateCollection
|
||||||
|
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa386107(v=vs.85).aspx
|
||||||
|
updtsRaw, err := result.GetProperty("Updates")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error calling GetProperty Updates on ISearchResult: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &IUpdateCollection{IDispatch: updtsRaw.ToIDispatch()}, nil
|
||||||
|
}
|
||||||
39
agent/patching/wua/wua_windows_test.go
Normal file
39
agent/patching/wua/wua_windows_test.go
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
package wua_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
wua "github.com/amidaware/rmmagent/agent/patching/wua"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWUAUpdates(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expected []wua.WUAPackage
|
||||||
|
atLeast int
|
||||||
|
expectedError error
|
||||||
|
query string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get WUA Updates",
|
||||||
|
expected: []wua.WUAPackage{},
|
||||||
|
atLeast: 1,
|
||||||
|
expectedError: nil,
|
||||||
|
query: "IsInstalled=1 or IsInstalled=0 and Type='Software' and IsHidden=0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := wua.WUAUpdates(tt.query)
|
||||||
|
if len(result) < tt.atLeast {
|
||||||
|
t.Errorf("expected at least %d, got %d", tt.atLeast, len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(tt.expectedError, err) {
|
||||||
|
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -5,7 +5,6 @@ import (
|
||||||
|
|
||||||
"github.com/amidaware/rmmagent/agent/utils"
|
"github.com/amidaware/rmmagent/agent/utils"
|
||||||
"github.com/gonutz/w32/v2"
|
"github.com/gonutz/w32/v2"
|
||||||
trmm "github.com/wh1te909/trmm-shared"
|
|
||||||
"golang.org/x/sys/windows/svc/mgr"
|
"golang.org/x/sys/windows/svc/mgr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -51,14 +50,14 @@ func GetServiceStatus(name string) (string, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "n/a", err
|
return "n/a", err
|
||||||
}
|
}
|
||||||
defer conn.Disconnect()
|
|
||||||
|
|
||||||
|
defer conn.Disconnect()
|
||||||
srv, err := conn.OpenService(name)
|
srv, err := conn.OpenService(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "n/a", err
|
return "n/a", err
|
||||||
}
|
}
|
||||||
defer srv.Close()
|
|
||||||
|
|
||||||
|
defer srv.Close()
|
||||||
q, err := srv.Query()
|
q, err := srv.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "n/a", err
|
return "n/a", err
|
||||||
|
|
@ -90,28 +89,28 @@ func serviceStatusText(num uint32) string {
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetServices returns a list of windows services
|
// GetServices returns a list of windows services
|
||||||
func GetServices() []trmm.WindowsService {
|
func GetServices() ([]Service, []error, error) {
|
||||||
ret := make([]trmm.WindowsService, 0)
|
ret := make([]Service, 0)
|
||||||
|
|
||||||
conn, err := mgr.Connect()
|
conn, err := mgr.Connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
return ret, nil, err
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer conn.Disconnect()
|
defer conn.Disconnect()
|
||||||
|
|
||||||
svcs, err := conn.ListServices()
|
svcs, err := conn.ListServices()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
return ret, nil, err
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
errors := []error{}
|
||||||
|
|
||||||
for _, s := range svcs {
|
for _, s := range svcs {
|
||||||
srv, err := conn.OpenService(s)
|
srv, err := conn.OpenService(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.Error() != "Access is denied." {
|
if err.Error() != "Access is denied." {
|
||||||
//a.Logger.Debugln("Open Service", s, err)
|
//a.Logger.Debugln("Open Service", s, err)
|
||||||
|
errors = append(errors, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
@ -120,17 +119,17 @@ func GetServices() []trmm.WindowsService {
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
q, err := srv.Query()
|
q, err := srv.Query()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
conf, err := srv.Config()
|
conf, err := srv.Config()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//a.Logger.Debugln(err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = append(ret, trmm.WindowsService{
|
ret = append(ret, Service{
|
||||||
Name: s,
|
Name: s,
|
||||||
Status: serviceStatusText(uint32(q.State)),
|
Status: serviceStatusText(uint32(q.State)),
|
||||||
DisplayName: utils.CleanString(conf.DisplayName),
|
DisplayName: utils.CleanString(conf.DisplayName),
|
||||||
|
|
@ -142,7 +141,8 @@ func GetServices() []trmm.WindowsService {
|
||||||
DelayedAutoStart: conf.DelayedAutoStart,
|
DelayedAutoStart: conf.DelayedAutoStart,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return ret
|
|
||||||
|
return ret, errors, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicestartmode?view=dotnet-plat-ext-3.1
|
// https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicestartmode?view=dotnet-plat-ext-3.1
|
||||||
|
|
|
||||||
79
agent/services/services_windows_test.go
Normal file
79
agent/services/services_windows_test.go
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
package services_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/amidaware/rmmagent/agent/services"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetServices(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expected []services.Service
|
||||||
|
atLeast int
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Get Services",
|
||||||
|
expected: []services.Service{},
|
||||||
|
atLeast: 1,
|
||||||
|
expectedError: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, errs, err := services.GetServices()
|
||||||
|
if fmt.Sprintf("%T", result) != "[]services.Service" {
|
||||||
|
t.Errorf("expected type %T, got type %T", tt.expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(errs) > 0 {
|
||||||
|
t.Logf("Continue errors occured %v", errs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("expected error (%v), got error(%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) < tt.atLeast {
|
||||||
|
t.Errorf("expect at least %d, got %d", tt.atLeast, len(result))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetServiceStatus(t *testing.T) {
|
||||||
|
testTable := []struct {
|
||||||
|
name string
|
||||||
|
expected string
|
||||||
|
expectedError error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "CryptSvc",
|
||||||
|
expected: "running",
|
||||||
|
expectedError: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NonExistentService",
|
||||||
|
expected: "n/a",
|
||||||
|
expectedError: windows.ERROR_SERVICE_DOES_NOT_EXIST,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range testTable {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := services.GetServiceStatus(tt.name)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("expected %s, got %s", tt.expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !errors.Is(err, tt.expectedError) {
|
||||||
|
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
13
agent/services/structs.go
Normal file
13
agent/services/structs.go
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
package services
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
DisplayName string `json:"display_name"`
|
||||||
|
BinPath string `json:"binpath"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
PID uint32 `json:"pid"`
|
||||||
|
StartType string `json:"start_type"`
|
||||||
|
DelayedAutoStart bool `json:"autodelay"`
|
||||||
|
}
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
trmm "github.com/wh1te909/trmm-shared"
|
trmm "github.com/wh1te909/trmm-shared"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetInstalledSoftware() []trmm.WinSoftwareList { return []trmm.WinSoftwareList{} }
|
func GetInstalledSoftware() []SoftwareList { return []WinSoftwareList{} }
|
||||||
|
|
||||||
func InstallChoco() {}
|
func InstallChoco() {}
|
||||||
|
|
||||||
|
|
|
||||||
7
agent/software/software_test.go
Normal file
7
agent/software/software_test.go
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
package software_test
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetInstalledSoftware(t *testing.T) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -5,20 +5,18 @@ import (
|
||||||
|
|
||||||
"github.com/amidaware/rmmagent/agent/utils"
|
"github.com/amidaware/rmmagent/agent/utils"
|
||||||
wapi "github.com/iamacarpet/go-win64api"
|
wapi "github.com/iamacarpet/go-win64api"
|
||||||
trmm "github.com/wh1te909/trmm-shared"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetInstalledSoftware() []trmm.WinSoftwareList {
|
func GetInstalledSoftware() ([]WinSoftwareList, error) {
|
||||||
ret := make([]trmm.WinSoftwareList, 0)
|
ret := make([]WinSoftwareList, 0)
|
||||||
|
|
||||||
sw, err := installedSoftwareList()
|
sw, err := installedSoftwareList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ret
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range sw {
|
for _, s := range sw {
|
||||||
t := s.InstallDate
|
t := s.InstallDate
|
||||||
ret = append(ret, trmm.WinSoftwareList{
|
ret = append(ret, WinSoftwareList{
|
||||||
Name: utils.CleanString(s.Name()),
|
Name: utils.CleanString(s.Name()),
|
||||||
Version: utils.CleanString(s.Version()),
|
Version: utils.CleanString(s.Version()),
|
||||||
Publisher: utils.CleanString(s.Publisher),
|
Publisher: utils.CleanString(s.Publisher),
|
||||||
|
|
@ -29,5 +27,6 @@ func GetInstalledSoftware() []trmm.WinSoftwareList {
|
||||||
Uninstall: utils.CleanString(s.UninstallString),
|
Uninstall: utils.CleanString(s.UninstallString),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return ret
|
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
@ -5,20 +5,18 @@ import (
|
||||||
|
|
||||||
"github.com/amidaware/rmmagent/agent/utils"
|
"github.com/amidaware/rmmagent/agent/utils"
|
||||||
wapi "github.com/iamacarpet/go-win64api"
|
wapi "github.com/iamacarpet/go-win64api"
|
||||||
trmm "github.com/wh1te909/trmm-shared"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetInstalledSoftware() []trmm.WinSoftwareList {
|
func GetInstalledSoftware() ([]SoftwareList, error) {
|
||||||
ret := make([]trmm.WinSoftwareList, 0)
|
ret := make([]SoftwareList, 0)
|
||||||
|
|
||||||
sw, err := wapi.InstalledSoftwareList()
|
sw, err := wapi.InstalledSoftwareList()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ret
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range sw {
|
for _, s := range sw {
|
||||||
t := s.InstallDate
|
t := s.InstallDate
|
||||||
ret = append(ret, trmm.WinSoftwareList{
|
ret = append(ret, SoftwareList{
|
||||||
Name: utils.CleanString(s.Name()),
|
Name: utils.CleanString(s.Name()),
|
||||||
Version: utils.CleanString(s.Version()),
|
Version: utils.CleanString(s.Version()),
|
||||||
Publisher: utils.CleanString(s.Publisher),
|
Publisher: utils.CleanString(s.Publisher),
|
||||||
|
|
@ -29,5 +27,6 @@ func GetInstalledSoftware() []trmm.WinSoftwareList {
|
||||||
Uninstall: utils.CleanString(s.UninstallString),
|
Uninstall: utils.CleanString(s.UninstallString),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return ret
|
|
||||||
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
12
agent/software/structs.go
Normal file
12
agent/software/structs.go
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
package software
|
||||||
|
|
||||||
|
type SoftwareList struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Publisher string `json:"publisher"`
|
||||||
|
InstallDate string `json:"install_date"`
|
||||||
|
Size string `json:"size"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
Uninstall string `json:"uninstall"`
|
||||||
|
}
|
||||||
1
agent/tactical/api/api.go
Normal file
1
agent/tactical/api/api.go
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
package api
|
||||||
1
go.mod
1
go.mod
|
|
@ -72,6 +72,7 @@ require (
|
||||||
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
|
golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect
|
||||||
golang.org/x/text v0.3.7 // indirect
|
golang.org/x/text v0.3.7 // indirect
|
||||||
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
|
golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect
|
||||||
|
golang.org/x/tools v0.1.11
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
|
||||||
gopkg.in/ini.v1 v1.66.6 // indirect
|
gopkg.in/ini.v1 v1.66.6 // indirect
|
||||||
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2 // indirect
|
gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2 // indirect
|
||||||
|
|
|
||||||
2
go.sum
2
go.sum
|
|
@ -530,6 +530,8 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f
|
||||||
golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
|
golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0=
|
||||||
|
golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY=
|
||||||
|
golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue