diff --git a/agent/disk/disk_linux_test.go b/agent/disk/disk_linux_test.go deleted file mode 100644 index a2d15cb..0000000 --- a/agent/disk/disk_linux_test.go +++ /dev/null @@ -1,14 +0,0 @@ -package disk - -import ( - "testing" -) - -func TestGetDisks(t *testing.T) { - disks := GetDisks() - if len(disks) == 0 { - t.Fatalf("Could not get disks on linux system.") - } - - t.Logf("Got %d disks on linux system", len(disks)) -} \ No newline at end of file diff --git a/agent/disk/disk_test.go b/agent/disk/disk_test.go new file mode 100644 index 0000000..feb37ce --- /dev/null +++ b/agent/disk/disk_test.go @@ -0,0 +1,51 @@ +package disk_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/amidaware/rmmagent/agent/disk" +) + +func TestGetDisks(t *testing.T) { + exampleDisk := disk.Disk{ + Device: "C:", + Fstype: "NTFS", + Total: "149.9 GB", + Used: "129.2 GB", + Free: "20.7 GB", + Percent: 86, + } + + testTable := []struct { + name string + expected []disk.Disk + atLeast int + expectedError error + }{ + { + name: "Get Disks", + expected: []disk.Disk{exampleDisk}, + atLeast: 1, + expectedError: nil, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := disk.GetDisks() + if fmt.Sprintf("%T", result) != "[]disk.Disk" { + t.Errorf("expected type %T, got type %T", tt.expected, result) + } + + if !errors.Is(err, tt.expectedError) { + t.Errorf("expected error (%v), got error(%v)", tt.expectedError, err) + } + + if len(result) < 1 { + t.Errorf("expected count at least %d, got %d", tt.atLeast, len(result)) + } + }) + } +} diff --git a/agent/disk/disk_windows.go b/agent/disk/disk_windows.go index 5939a96..9d743a4 100644 --- a/agent/disk/disk_windows.go +++ b/agent/disk/disk_windows.go @@ -5,7 +5,6 @@ import ( "github.com/amidaware/rmmagent/agent/utils" "github.com/shirou/gopsutil/disk" - trmm "github.com/wh1te909/trmm-shared" "golang.org/x/sys/windows" ) @@ -14,12 +13,11 @@ var ( ) // GetDisks returns a list of fixed disks -func GetDisks() []trmm.Disk { - ret := make([]trmm.Disk, 0) +func GetDisks() ([]Disk, error) { + ret := make([]Disk, 0) partitions, err := disk.Partitions(false) if err != nil { - //a.Logger.Debugln(err) - return ret + return ret, err } for _, p := range partitions { @@ -36,7 +34,7 @@ func GetDisks() []trmm.Disk { continue } - d := trmm.Disk{ + d := Disk{ Device: p.Device, Fstype: p.Fstype, Total: utils.ByteCountSI(usage.Total), @@ -44,7 +42,9 @@ func GetDisks() []trmm.Disk { Free: utils.ByteCountSI(usage.Free), Percent: int(usage.UsedPercent), } + ret = append(ret, d) } - return ret -} \ No newline at end of file + + return ret, err +} diff --git a/agent/disk/structs.go b/agent/disk/structs.go new file mode 100644 index 0000000..38ec792 --- /dev/null +++ b/agent/disk/structs.go @@ -0,0 +1,10 @@ +package disk + +type Disk struct { + Device string `json:"device"` + Fstype string `json:"fstype"` + Total string `json:"total"` + Used string `json:"used"` + Free string `json:"free"` + Percent int `json:"percent"` +} \ No newline at end of file diff --git a/agent/events/events_linux.go b/agent/events/events_linux.go new file mode 100644 index 0000000..9174b02 --- /dev/null +++ b/agent/events/events_linux.go @@ -0,0 +1,5 @@ +package events + +func GetEventLog(logName string, searchLastDays int) ([]EventLogMsg, error) { + return []EventLogMsg{}, nil +} \ No newline at end of file diff --git a/agent/events/events_windows.go b/agent/events/events_windows.go index 5205d13..1abe338 100644 --- a/agent/events/events_windows.go +++ b/agent/events/events_windows.go @@ -6,48 +6,40 @@ import ( "unsafe" "github.com/amidaware/rmmagent/agent/syscall" - rmm "github.com/amidaware/rmmagent/shared" "github.com/gonutz/w32/v2" "golang.org/x/sys/windows" ) -func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg { +func GetEventLog(logName string, searchLastDays int) ([]EventLogMsg, error) { var ( oldestLog uint32 nextSize uint32 readBytes uint32 ) + buf := []byte{0} size := uint32(1) - - ret := make([]rmm.EventLogMsg, 0) + ret := make([]EventLogMsg, 0) startTime := time.Now().Add(time.Duration(-(time.Duration(searchLastDays)) * (24 * time.Hour))) - h := w32.OpenEventLog("", logName) defer w32.CloseEventLog(h) - numRecords, _ := w32.GetNumberOfEventLogRecords(h) - syscall.GetOldestEventLogRecord(h, &oldestLog) - + err := syscall.GetOldestEventLogRecord(h, &oldestLog) startNum := numRecords + oldestLog - 1 uid := 0 for i := startNum; i >= oldestLog; i-- { flags := syscall.EVENTLOG_BACKWARDS_READ | syscall.EVENTLOG_SEEK_READ - err := syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize) if err != nil { if err != windows.ERROR_INSUFFICIENT_BUFFER { - //a.Logger.Debugln(err) break } buf = make([]byte, nextSize) size = nextSize err = syscall.ReadEventLog(h, flags, i, &buf[0], size, &readBytes, &nextSize) if err != nil { - //a.Logger.Debugln(err) break } - } r := *(*syscall.EVENTLOGRECORD)(unsafe.Pointer(&buf[0])) @@ -75,10 +67,11 @@ func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg { if r.NumStrings > 0 { argsptr = uintptr(unsafe.Pointer(&args[0])) } + message, _ := syscall.GetResourceMessage(logName, sourceName, r.EventID, argsptr) uid++ - eventLogMsg := rmm.EventLogMsg{ + eventLogMsg := EventLogMsg{ Source: sourceName, EventType: eventType, EventID: eventID, @@ -86,9 +79,11 @@ func GetEventLog(logName string, searchLastDays int) []rmm.EventLogMsg { Time: timeWritten.String(), UID: uid, } + ret = append(ret, eventLogMsg) } - return ret + + return ret, err } // https://github.com/mackerelio/go-check-plugins/blob/ad7910fdc45ccb892b5e5fda65ba0956c2b2885d/check-windows-eventlog/lib/check-windows-eventlog.go#L219 @@ -102,6 +97,7 @@ func bytesToString(b []byte) (string, uint32) { break } } + return string(utf16.Decode(s)), uint32(i * 2) } diff --git a/agent/events/events_windows_test.go b/agent/events/events_windows_test.go new file mode 100644 index 0000000..4f20e3d --- /dev/null +++ b/agent/events/events_windows_test.go @@ -0,0 +1,46 @@ +package events_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/amidaware/rmmagent/agent/events" +) + +func TestGetEventLog(t *testing.T) { + testTable := []struct { + name string + expected []events.EventLogMsg + atLeast int + expectedError error + logname string + search int + }{ + { + name: "Get EventLog", + expected: []events.EventLogMsg{}, + atLeast: 1, + expectedError: nil, + logname: "Application", + search: 1, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := events.GetEventLog(tt.logname, tt.search) + if fmt.Sprintf("%T", result) != "[]events.EventLogMsg" { + t.Errorf("expected type %T, got type %T", []events.EventLogMsg{}, result) + } + + if !errors.Is(err, tt.expectedError) { + t.Errorf("expected error (%v), got error (%v)", tt.expectedError, err) + } + + if len(result) < 1 { + t.Errorf("expected count at least %d, got %d", tt.atLeast, len(result)) + } + }) + } +} diff --git a/agent/events/structs.go b/agent/events/structs.go new file mode 100644 index 0000000..e8b8387 --- /dev/null +++ b/agent/events/structs.go @@ -0,0 +1,10 @@ +package events + +type EventLogMsg struct { + Source string `json:"source"` + EventType string `json:"eventType"` + EventID uint32 `json:"eventID"` + Message string `json:"message"` + Time string `json:"time"` + UID int `json:"uid"` // for vue +} diff --git a/agent/patching/patching_linux.go b/agent/patching/patching_linux.go index 74a9f6d..fe90636 100644 --- a/agent/patching/patching_linux.go +++ b/agent/patching/patching_linux.go @@ -2,6 +2,6 @@ package patching func PatchMgmnt(enable bool) error { return nil } -func GetWinUpdates() {} +func GetUpdates() {} -func InstallUpdates(guids []string) {} \ No newline at end of file +func InstallUpdates(guids []string) {} diff --git a/agent/patching/patching_test.go b/agent/patching/patching_test.go new file mode 100644 index 0000000..e69cf2d --- /dev/null +++ b/agent/patching/patching_test.go @@ -0,0 +1,63 @@ +package patching_test + +import ( + "errors" + "testing" + + "github.com/amidaware/rmmagent/agent/patching" +) + +func TestPatchMgmnt(t *testing.T) { + testTable := []struct { + name string + expectedError error + status bool + }{ + { + name: "Enable Patch Mgmnt", + expectedError: nil, + status: true, + }, + { + name: "Disable Patch Mgmnt", + expectedError: nil, + status: false, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + err := patching.PatchMgmnt(tt.status) + if err != tt.expectedError { + t.Errorf("expected error (%v), got error (%v)", tt.expectedError, err) + } + }) + } +} + +func TestGetUpdates(t *testing.T) { + testTable := []struct { + name string + expectedError error + atLeast int + }{ + { + name: "Get Updates", + expectedError: nil, + atLeast: 1, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := patching.GetUpdates() + if !errors.Is(tt.expectedError, err) { + t.Errorf("expected (%v), got (%v)", tt.expectedError, err) + } + + if len(result) < tt.atLeast { + t.Errorf("expected at least %d, got %d", tt.atLeast, len(result)) + } + }) + } +} diff --git a/agent/patching/patching_windows.go b/agent/patching/patching_windows.go index b934191..d9885e3 100644 --- a/agent/patching/patching_windows.go +++ b/agent/patching/patching_windows.go @@ -1,6 +1,9 @@ package patching -import "golang.org/x/sys/windows/registry" +import ( + "github.com/amidaware/rmmagent/agent/patching/wua" + "golang.org/x/sys/windows/registry" +) // PatchMgmnt enables/disables automatic update // 0 - Enable Automatic Updates (Default) @@ -25,4 +28,34 @@ func PatchMgmnt(enable bool) error { } return nil -} \ No newline at end of file +} + +type PackageList []Package + +func GetUpdates() (PackageList, error) { + wuaupdates, err := wua.WUAUpdates("IsInstalled=1 or IsInstalled=0 and Type='Software' and IsHidden=0") + packages := []Package{} + for _, p := range wuaupdates { + packages = append(packages, Package(p)) + } + + return packages, err + // if err != nil { + // a.Logger.Errorln(err) + // return + // } + + // for _, update := range updates { + // a.Logger.Debugln("GUID:", update.UpdateID) + // a.Logger.Debugln("Downloaded:", update.Downloaded) + // a.Logger.Debugln("Installed:", update.Installed) + // a.Logger.Debugln("KB:", update.KBArticleIDs) + // a.Logger.Debugln("--------------------------------") + // } + + // payload := rmm.WinUpdateResult{AgentID: a.AgentID, Updates: updates} + // _, err = a.rClient.R().SetBody(payload).Post("/api/v3/winupdates/") + // if err != nil { + // a.Logger.Debugln(err) + // } +} diff --git a/agent/patching/structs.go b/agent/patching/structs.go new file mode 100644 index 0000000..dbc1a93 --- /dev/null +++ b/agent/patching/structs.go @@ -0,0 +1,16 @@ +package patching + +type Package struct { + Title string `json:"title"` + Description string `json:"description"` + Categories []string `json:"categories"` + CategoryIDs []string `json:"category_ids"` + KBArticleIDs []string `json:"kb_article_ids"` + MoreInfoURLs []string `json:"more_info_urls"` + SupportURL string `json:"support_url"` + UpdateID string `json:"guid"` + RevisionNumber int32 `json:"revision_number"` + Severity string `json:"severity"` + Installed bool `json:"installed"` + Downloaded bool `json:"downloaded"` +} diff --git a/agent/patching/wua/structs.go b/agent/patching/wua/structs.go new file mode 100644 index 0000000..85e7a61 --- /dev/null +++ b/agent/patching/wua/structs.go @@ -0,0 +1,33 @@ +package wua + +import ( + ole "github.com/go-ole/go-ole" +) + +type WUAPackage struct { + Title string `json:"title"` + Description string `json:"description"` + Categories []string `json:"categories"` + CategoryIDs []string `json:"category_ids"` + KBArticleIDs []string `json:"kb_article_ids"` + MoreInfoURLs []string `json:"more_info_urls"` + SupportURL string `json:"support_url"` + UpdateID string `json:"guid"` + RevisionNumber int32 `json:"revision_number"` + Severity string `json:"severity"` + Installed bool `json:"installed"` + Downloaded bool `json:"downloaded"` +} + +// IUpdateSession is a an IUpdateSession. +type IUpdateSession struct { + *ole.IDispatch +} + +type IUpdateCollection struct { + *ole.IDispatch +} + +type IUpdate struct { + *ole.IDispatch +} diff --git a/agent/patching/wua/wua_windows.go b/agent/patching/wua/wua_windows.go new file mode 100644 index 0000000..52f1936 --- /dev/null +++ b/agent/patching/wua/wua_windows.go @@ -0,0 +1,437 @@ +package wua + +import ( + "fmt" + ole "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole/oleutil" + "sync" +) + +const ( + S_OK = 0 + S_FALSE = 1 +) + +var wuaSession sync.Mutex + +func (s *IUpdateSession) Close() { + if s.IDispatch != nil { + s.IDispatch.Release() + } + + ole.CoUninitialize() + wuaSession.Unlock() +} + +func NewUpdateSession() (*IUpdateSession, error) { + wuaSession.Lock() + if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { + e, ok := err.(*ole.OleError) + // S_OK and S_FALSE are both are Success codes. + // https://docs.microsoft.com/en-us/windows/win32/learnwin32/error-handling-in-com + if !ok || (e.Code() != S_OK && e.Code() != S_FALSE) { + wuaSession.Unlock() + return nil, fmt.Errorf(`ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED): %v`, err) + } + } + + s := &IUpdateSession{} + + unknown, err := oleutil.CreateObject("Microsoft.Update.Session") + if err != nil { + s.Close() + return nil, fmt.Errorf(`oleutil.CreateObject("Microsoft.Update.Session"): %v`, err) + } + disp, err := unknown.QueryInterface(ole.IID_IDispatch) + if err != nil { + unknown.Release() + s.Close() + return nil, fmt.Errorf(`error creating Dispatch object from Microsoft.Update.Session connection: %v`, err) + } + s.IDispatch = disp + + return s, nil +} + +// InstallWUAUpdate install a WIndows update. +func (s *IUpdateSession) InstallWUAUpdate(updt *IUpdate) error { + _, err := updt.GetProperty("Title") + if err != nil { + return fmt.Errorf(`updt.GetProperty("Title"): %v`, err) + } + + updts, err := NewUpdateCollection() + if err != nil { + return err + } + defer updts.Release() + + eula, err := updt.GetProperty("EulaAccepted") + if err != nil { + return fmt.Errorf(`updt.GetProperty("EulaAccepted"): %v`, err) + } + // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-oaut/7b39eb24-9d39-498a-bcd8-75c38e5823d0 + if eula.Val == 0 { + if _, err := updt.CallMethod("AcceptEula"); err != nil { + return fmt.Errorf(`updt.CallMethod("AcceptEula"): %v`, err) + } + } + + if err := updts.Add(updt); err != nil { + return err + } + + if err := s.DownloadWUAUpdateCollection(updts); err != nil { + return fmt.Errorf("DownloadWUAUpdateCollection error: %v", err) + } + + if err := s.InstallWUAUpdateCollection(updts); err != nil { + return fmt.Errorf("InstallWUAUpdateCollection error: %v", err) + } + + return nil +} + +func NewUpdateCollection() (*IUpdateCollection, error) { + updateCollObj, err := oleutil.CreateObject("Microsoft.Update.UpdateColl") + if err != nil { + return nil, fmt.Errorf(`oleutil.CreateObject("Microsoft.Update.UpdateColl"): %v`, err) + } + defer updateCollObj.Release() + + updateColl, err := updateCollObj.IDispatch(ole.IID_IDispatch) + if err != nil { + return nil, err + } + + return &IUpdateCollection{IDispatch: updateColl}, nil +} + +func (c *IUpdateCollection) Add(updt *IUpdate) error { + if _, err := c.CallMethod("Add", updt.IDispatch); err != nil { + return fmt.Errorf(`IUpdateCollection.CallMethod("Add", updt): %v`, err) + } + return nil +} + +func (c *IUpdateCollection) RemoveAt(i int) error { + if _, err := c.CallMethod("RemoveAt", i); err != nil { + return fmt.Errorf(`IUpdateCollection.CallMethod("RemoveAt", %d): %v`, i, err) + } + return nil +} + +func (c *IUpdateCollection) Count() (int32, error) { + return GetCount(c.IDispatch) +} + +func (c *IUpdateCollection) Item(i int) (*IUpdate, error) { + updtRaw, err := c.GetProperty("Item", i) + if err != nil { + return nil, fmt.Errorf(`IUpdateCollection.GetProperty("Item", %d): %v`, i, err) + } + return &IUpdate{IDispatch: updtRaw.ToIDispatch()}, nil +} + +// GetCount returns the Count property. +func GetCount(dis *ole.IDispatch) (int32, error) { + countRaw, err := dis.GetProperty("Count") + if err != nil { + return 0, fmt.Errorf(`IDispatch.GetProperty("Count"): %v`, err) + } + count, _ := countRaw.Value().(int32) + + return count, nil +} + +func (u *IUpdate) kbaIDs() ([]string, error) { + kbArticleIDsRaw, err := u.GetProperty("KBArticleIDs") + if err != nil { + return nil, fmt.Errorf(`IUpdate.GetProperty("KBArticleIDs"): %v`, err) + } + kbArticleIDs := kbArticleIDsRaw.ToIDispatch() + defer kbArticleIDs.Release() + + count, err := GetCount(kbArticleIDs) + if err != nil { + return nil, err + } + + if count == 0 { + return nil, nil + } + + var ss []string + for i := 0; i < int(count); i++ { + item, err := kbArticleIDs.GetProperty("Item", i) + if err != nil { + return nil, fmt.Errorf(`kbArticleIDs.GetProperty("Item", %d): %v`, i, err) + } + + ss = append(ss, item.ToString()) + } + return ss, nil +} + +func (u *IUpdate) categories() ([]string, []string, error) { + catRaw, err := u.GetProperty("Categories") + if err != nil { + return nil, nil, fmt.Errorf(`IUpdate.GetProperty("Categories"): %v`, err) + } + cat := catRaw.ToIDispatch() + defer cat.Release() + + count, err := GetCount(cat) + if err != nil { + return nil, nil, err + } + if count == 0 { + return nil, nil, nil + } + + var cns, cids []string + for i := 0; i < int(count); i++ { + itemRaw, err := cat.GetProperty("Item", i) + if err != nil { + return nil, nil, fmt.Errorf(`cat.GetProperty("Item", %d): %v`, i, err) + } + item := itemRaw.ToIDispatch() + defer item.Release() + + name, err := item.GetProperty("Name") + if err != nil { + return nil, nil, fmt.Errorf(`item.GetProperty("Name"): %v`, err) + } + + categoryID, err := item.GetProperty("CategoryID") + if err != nil { + return nil, nil, fmt.Errorf(`item.GetProperty("CategoryID"): %v`, err) + } + + cns = append(cns, name.ToString()) + cids = append(cids, categoryID.ToString()) + } + return cns, cids, nil +} + +func (u *IUpdate) moreInfoURLs() ([]string, error) { + moreInfoURLsRaw, err := u.GetProperty("MoreInfoURLs") + if err != nil { + return nil, fmt.Errorf(`IUpdate.GetProperty("MoreInfoURLs"): %v`, err) + } + moreInfoURLs := moreInfoURLsRaw.ToIDispatch() + defer moreInfoURLs.Release() + + count, err := GetCount(moreInfoURLs) + if err != nil { + return nil, err + } + + if count == 0 { + return nil, nil + } + + var ss []string + for i := 0; i < int(count); i++ { + item, err := moreInfoURLs.GetProperty("Item", i) + if err != nil { + return nil, fmt.Errorf(`moreInfoURLs.GetProperty("Item", %d): %v`, i, err) + } + + ss = append(ss, item.ToString()) + } + return ss, nil +} + +func (c *IUpdateCollection) extractPkg(item int) (*WUAPackage, error) { + updt, err := c.Item(item) + if err != nil { + return nil, err + } + defer updt.Release() + + title, err := updt.GetProperty("Title") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("Title"): %v`, err) + } + + description, err := updt.GetProperty("Description") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("Description"): %v`, err) + } + + kbArticleIDs, err := updt.kbaIDs() + if err != nil { + return nil, err + } + + categories, categoryIDs, err := updt.categories() + if err != nil { + return nil, err + } + + moreInfoURLs, err := updt.moreInfoURLs() + if err != nil { + return nil, err + } + + supportURL, err := updt.GetProperty("SupportURL") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("SupportURL"): %v`, err) + } + + identityRaw, err := updt.GetProperty("Identity") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("Identity"): %v`, err) + } + identity := identityRaw.ToIDispatch() + defer identity.Release() + + revisionNumber, err := identity.GetProperty("RevisionNumber") + if err != nil { + return nil, fmt.Errorf(`identity.GetProperty("RevisionNumber"): %v`, err) + } + + updateID, err := identity.GetProperty("UpdateID") + if err != nil { + return nil, fmt.Errorf(`identity.GetProperty("UpdateID"): %v`, err) + } + + severity, err := updt.GetProperty("MsrcSeverity") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("MsrcSeverity"): %v`, err) + } + + isInstalled, err := updt.GetProperty("IsInstalled") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("IsInstalled"): %v`, err) + } + + isDownloaded, err := updt.GetProperty("IsDownloaded") + if err != nil { + return nil, fmt.Errorf(`updt.GetProperty("IsDownloaded"): %v`, err) + } + + return &WUAPackage{ + Title: title.ToString(), + Description: description.ToString(), + SupportURL: supportURL.ToString(), + KBArticleIDs: kbArticleIDs, + UpdateID: updateID.ToString(), + Categories: categories, + CategoryIDs: categoryIDs, + MoreInfoURLs: moreInfoURLs, + Severity: severity.ToString(), + RevisionNumber: int32(revisionNumber.Val), + Downloaded: isDownloaded.Value().(bool), + Installed: isInstalled.Value().(bool), + }, nil +} + +// WUAUpdates queries the Windows Update Agent API searcher with the provided query. +func WUAUpdates(query string) ([]WUAPackage, error) { + session, err := NewUpdateSession() + if err != nil { + return nil, fmt.Errorf("error creating NewUpdateSession: %v", err) + } + defer session.Close() + + updts, err := session.GetWUAUpdateCollection(query) + if err != nil { + return nil, fmt.Errorf("error calling GetWUAUpdateCollection with query %q: %v", query, err) + } + defer updts.Release() + + updtCnt, err := updts.Count() + if err != nil { + return nil, err + } + + if updtCnt == 0 { + return nil, nil + } + + var packages []WUAPackage + for i := 0; i < int(updtCnt); i++ { + pkg, err := updts.extractPkg(i) + if err != nil { + return nil, err + } + packages = append(packages, *pkg) + } + return packages, nil +} + +// DownloadWUAUpdateCollection downloads all updates in a IUpdateCollection +func (s *IUpdateSession) DownloadWUAUpdateCollection(updates *IUpdateCollection) error { + // returns IUpdateDownloader + // https://docs.microsoft.com/en-us/windows/desktop/api/wuapi/nn-wuapi-iupdatedownloader + downloaderRaw, err := s.CallMethod("CreateUpdateDownloader") + if err != nil { + return fmt.Errorf("error calling method CreateUpdateDownloader on IUpdateSession: %v", err) + } + downloader := downloaderRaw.ToIDispatch() + defer downloader.Release() + + if _, err := downloader.PutProperty("Updates", updates.IDispatch); err != nil { + return fmt.Errorf("error calling PutProperty Updates on IUpdateDownloader: %v", err) + } + + if _, err := downloader.CallMethod("Download"); err != nil { + return fmt.Errorf("error calling method Download on IUpdateDownloader: %v", err) + } + return nil +} + +// InstallWUAUpdateCollection installs all updates in a IUpdateCollection +func (s *IUpdateSession) InstallWUAUpdateCollection(updates *IUpdateCollection) error { + // returns IUpdateInstallersession *ole.IDispatch, + // https://docs.microsoft.com/en-us/windows/desktop/api/wuapi/nf-wuapi-iupdatesession-createupdateinstaller + installerRaw, err := s.CallMethod("CreateUpdateInstaller") + if err != nil { + return fmt.Errorf("error calling method CreateUpdateInstaller on IUpdateSession: %v", err) + } + installer := installerRaw.ToIDispatch() + defer installer.Release() + + if _, err := installer.PutProperty("Updates", updates.IDispatch); err != nil { + return fmt.Errorf("error calling PutProperty Updates on IUpdateInstaller: %v", err) + } + + // TODO: Look into using the async methods and attempt to track/log progress. + if _, err := installer.CallMethod("Install"); err != nil { + return fmt.Errorf("error calling method Install on IUpdateInstaller: %v", err) + } + return nil +} + +// GetWUAUpdateCollection queries the Windows Update Agent API searcher with the provided query +// and returns a IUpdateCollection. +func (s *IUpdateSession) GetWUAUpdateCollection(query string) (*IUpdateCollection, error) { + // returns IUpdateSearcher + // https://msdn.microsoft.com/en-us/library/windows/desktop/aa386515(v=vs.85).aspx + searcherRaw, err := s.CallMethod("CreateUpdateSearcher") + if err != nil { + return nil, fmt.Errorf("error calling CreateUpdateSearcher: %v", err) + } + searcher := searcherRaw.ToIDispatch() + defer searcher.Release() + + // returns ISearchResult + // https://msdn.microsoft.com/en-us/library/windows/desktop/aa386077(v=vs.85).aspx + resultRaw, err := searcher.CallMethod("Search", query) + if err != nil { + return nil, fmt.Errorf("error calling method Search on IUpdateSearcher: %v", err) + } + result := resultRaw.ToIDispatch() + defer result.Release() + + // returns IUpdateCollection + // https://msdn.microsoft.com/en-us/library/windows/desktop/aa386107(v=vs.85).aspx + updtsRaw, err := result.GetProperty("Updates") + if err != nil { + return nil, fmt.Errorf("error calling GetProperty Updates on ISearchResult: %v", err) + } + + return &IUpdateCollection{IDispatch: updtsRaw.ToIDispatch()}, nil +} diff --git a/agent/patching/wua/wua_windows_test.go b/agent/patching/wua/wua_windows_test.go new file mode 100644 index 0000000..10aad54 --- /dev/null +++ b/agent/patching/wua/wua_windows_test.go @@ -0,0 +1,39 @@ +package wua_test + +import ( + "errors" + "testing" + + wua "github.com/amidaware/rmmagent/agent/patching/wua" +) + +func TestWUAUpdates(t *testing.T) { + testTable := []struct { + name string + expected []wua.WUAPackage + atLeast int + expectedError error + query string + }{ + { + name: "Get WUA Updates", + expected: []wua.WUAPackage{}, + atLeast: 1, + expectedError: nil, + query: "IsInstalled=1 or IsInstalled=0 and Type='Software' and IsHidden=0", + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := wua.WUAUpdates(tt.query) + if len(result) < tt.atLeast { + t.Errorf("expected at least %d, got %d", tt.atLeast, len(result)) + } + + if !errors.Is(tt.expectedError, err) { + t.Errorf("expected (%v), got (%v)", tt.expectedError, err) + } + }) + } +} diff --git a/agent/services/services_windows.go b/agent/services/services_windows.go index 0c7bbe8..0ae8a70 100644 --- a/agent/services/services_windows.go +++ b/agent/services/services_windows.go @@ -5,7 +5,6 @@ import ( "github.com/amidaware/rmmagent/agent/utils" "github.com/gonutz/w32/v2" - trmm "github.com/wh1te909/trmm-shared" "golang.org/x/sys/windows/svc/mgr" ) @@ -51,14 +50,14 @@ func GetServiceStatus(name string) (string, error) { if err != nil { return "n/a", err } - defer conn.Disconnect() + defer conn.Disconnect() srv, err := conn.OpenService(name) if err != nil { return "n/a", err } - defer srv.Close() + defer srv.Close() q, err := srv.Query() if err != nil { return "n/a", err @@ -90,28 +89,28 @@ func serviceStatusText(num uint32) string { } // GetServices returns a list of windows services -func GetServices() []trmm.WindowsService { - ret := make([]trmm.WindowsService, 0) +func GetServices() ([]Service, []error, error) { + ret := make([]Service, 0) conn, err := mgr.Connect() if err != nil { - //a.Logger.Debugln(err) - return ret + return ret, nil, err } + defer conn.Disconnect() - svcs, err := conn.ListServices() - if err != nil { - //a.Logger.Debugln(err) - return ret + return ret, nil, err } + errors := []error{} + for _, s := range svcs { srv, err := conn.OpenService(s) if err != nil { if err.Error() != "Access is denied." { //a.Logger.Debugln("Open Service", s, err) + errors = append(errors, err) } continue @@ -120,17 +119,17 @@ func GetServices() []trmm.WindowsService { defer srv.Close() q, err := srv.Query() if err != nil { - //a.Logger.Debugln(err) + errors = append(errors, err) continue } conf, err := srv.Config() if err != nil { - //a.Logger.Debugln(err) + errors = append(errors, err) continue } - ret = append(ret, trmm.WindowsService{ + ret = append(ret, Service{ Name: s, Status: serviceStatusText(uint32(q.State)), DisplayName: utils.CleanString(conf.DisplayName), @@ -142,7 +141,8 @@ func GetServices() []trmm.WindowsService { DelayedAutoStart: conf.DelayedAutoStart, }) } - return ret + + return ret, errors, nil } // https://docs.microsoft.com/en-us/dotnet/api/system.serviceprocess.servicestartmode?view=dotnet-plat-ext-3.1 diff --git a/agent/services/services_windows_test.go b/agent/services/services_windows_test.go new file mode 100644 index 0000000..06f852c --- /dev/null +++ b/agent/services/services_windows_test.go @@ -0,0 +1,79 @@ +package services_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/amidaware/rmmagent/agent/services" + "golang.org/x/sys/windows" +) + +func TestGetServices(t *testing.T) { + testTable := []struct { + name string + expected []services.Service + atLeast int + expectedError error + }{ + { + name: "Get Services", + expected: []services.Service{}, + atLeast: 1, + expectedError: nil, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, errs, err := services.GetServices() + if fmt.Sprintf("%T", result) != "[]services.Service" { + t.Errorf("expected type %T, got type %T", tt.expected, result) + } + + if len(errs) > 0 { + t.Logf("Continue errors occured %v", errs) + } + + if err != nil { + t.Errorf("expected error (%v), got error(%v)", tt.expectedError, err) + } + + if len(result) < tt.atLeast { + t.Errorf("expect at least %d, got %d", tt.atLeast, len(result)) + } + }) + } +} + +func TestGetServiceStatus(t *testing.T) { + testTable := []struct { + name string + expected string + expectedError error + }{ + { + name: "CryptSvc", + expected: "running", + expectedError: nil, + }, + { + name: "NonExistentService", + expected: "n/a", + expectedError: windows.ERROR_SERVICE_DOES_NOT_EXIST, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := services.GetServiceStatus(tt.name) + if result != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, result) + } + + if !errors.Is(err, tt.expectedError) { + t.Errorf("expected (%v), got (%v)", tt.expectedError, err) + } + }) + } +} diff --git a/agent/services/structs.go b/agent/services/structs.go new file mode 100644 index 0000000..db758c9 --- /dev/null +++ b/agent/services/structs.go @@ -0,0 +1,13 @@ +package services + +type Service struct { + Name string `json:"name"` + Status string `json:"status"` + DisplayName string `json:"display_name"` + BinPath string `json:"binpath"` + Description string `json:"description"` + Username string `json:"username"` + PID uint32 `json:"pid"` + StartType string `json:"start_type"` + DelayedAutoStart bool `json:"autodelay"` +} diff --git a/agent/software/software_linux.go b/agent/software/software_linux.go index 8ccec96..cf3af2e 100644 --- a/agent/software/software_linux.go +++ b/agent/software/software_linux.go @@ -4,7 +4,7 @@ import ( trmm "github.com/wh1te909/trmm-shared" ) -func GetInstalledSoftware() []trmm.WinSoftwareList { return []trmm.WinSoftwareList{} } +func GetInstalledSoftware() []SoftwareList { return []WinSoftwareList{} } func InstallChoco() {} diff --git a/agent/software/software_test.go b/agent/software/software_test.go new file mode 100644 index 0000000..9076ff7 --- /dev/null +++ b/agent/software/software_test.go @@ -0,0 +1,7 @@ +package software_test + +import "testing" + +func TestGetInstalledSoftware(t *testing.T) { + +} diff --git a/agent/software/software_windows_386.go b/agent/software/software_windows_386.go index b08c94d..5d7bd7a 100644 --- a/agent/software/software_windows_386.go +++ b/agent/software/software_windows_386.go @@ -5,20 +5,18 @@ import ( "github.com/amidaware/rmmagent/agent/utils" wapi "github.com/iamacarpet/go-win64api" - trmm "github.com/wh1te909/trmm-shared" ) -func GetInstalledSoftware() []trmm.WinSoftwareList { - ret := make([]trmm.WinSoftwareList, 0) - +func GetInstalledSoftware() ([]WinSoftwareList, error) { + ret := make([]WinSoftwareList, 0) sw, err := installedSoftwareList() if err != nil { - return ret + return ret, err } for _, s := range sw { t := s.InstallDate - ret = append(ret, trmm.WinSoftwareList{ + ret = append(ret, WinSoftwareList{ Name: utils.CleanString(s.Name()), Version: utils.CleanString(s.Version()), Publisher: utils.CleanString(s.Publisher), @@ -29,5 +27,6 @@ func GetInstalledSoftware() []trmm.WinSoftwareList { Uninstall: utils.CleanString(s.UninstallString), }) } - return ret + + return ret, nil } \ No newline at end of file diff --git a/agent/software/software_windows_amd64.go b/agent/software/software_windows_amd64.go index c1caf8f..f1f0d7f 100644 --- a/agent/software/software_windows_amd64.go +++ b/agent/software/software_windows_amd64.go @@ -5,20 +5,18 @@ import ( "github.com/amidaware/rmmagent/agent/utils" wapi "github.com/iamacarpet/go-win64api" - trmm "github.com/wh1te909/trmm-shared" ) -func GetInstalledSoftware() []trmm.WinSoftwareList { - ret := make([]trmm.WinSoftwareList, 0) - +func GetInstalledSoftware() ([]SoftwareList, error) { + ret := make([]SoftwareList, 0) sw, err := wapi.InstalledSoftwareList() if err != nil { - return ret + return ret, err } for _, s := range sw { t := s.InstallDate - ret = append(ret, trmm.WinSoftwareList{ + ret = append(ret, SoftwareList{ Name: utils.CleanString(s.Name()), Version: utils.CleanString(s.Version()), Publisher: utils.CleanString(s.Publisher), @@ -29,5 +27,6 @@ func GetInstalledSoftware() []trmm.WinSoftwareList { Uninstall: utils.CleanString(s.UninstallString), }) } - return ret + + return ret, nil } diff --git a/agent/software/structs.go b/agent/software/structs.go new file mode 100644 index 0000000..d81a534 --- /dev/null +++ b/agent/software/structs.go @@ -0,0 +1,12 @@ +package software + +type SoftwareList struct { + Name string `json:"name"` + Version string `json:"version"` + Publisher string `json:"publisher"` + InstallDate string `json:"install_date"` + Size string `json:"size"` + Source string `json:"source"` + Location string `json:"location"` + Uninstall string `json:"uninstall"` +} diff --git a/agent/tactical/api/api.go b/agent/tactical/api/api.go new file mode 100644 index 0000000..778f64e --- /dev/null +++ b/agent/tactical/api/api.go @@ -0,0 +1 @@ +package api diff --git a/go.mod b/go.mod index bb36f28..fd8034c 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f // indirect golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect + golang.org/x/tools v0.1.11 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.66.6 // indirect gopkg.in/toast.v1 v1.0.0-20180812000517-0a84660828b2 // indirect diff --git a/go.sum b/go.sum index 032ce40..f084d93 100644 --- a/go.sum +++ b/go.sum @@ -530,6 +530,8 @@ golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.0.0-20210105154028-b0ab187a4818/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20210108195828-e2f9c7f1fc8e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= +golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= +golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=