This commit is contained in:
redanthrax 2022-06-21 09:58:08 -07:00
parent 4e656c6556
commit 91c9de6e34
5 changed files with 155 additions and 42 deletions

View file

@ -56,18 +56,15 @@ func RunScript(code string, shell string, args []string, timeout int) (stdout, s
tmpfn, err := ioutil.TempFile(dir, ext) tmpfn, err := ioutil.TempFile(dir, ext)
if err != nil { if err != nil {
//a.Logger.Errorln(err)
return "", err.Error(), 85, err return "", err.Error(), 85, err
} }
defer os.Remove(tmpfn.Name()) defer os.Remove(tmpfn.Name())
if _, err := tmpfn.Write(content); err != nil { if _, err := tmpfn.Write(content); err != nil {
//a.Logger.Errorln(err)
return "", err.Error(), 85, err return "", err.Error(), 85, err
} }
if err := tmpfn.Close(); err != nil { if err := tmpfn.Close(); err != nil {
//a.Logger.Errorln(err)
return "", err.Error(), 85, err return "", err.Error(), 85, err
} }
@ -93,7 +90,6 @@ func RunScript(code string, shell string, args []string, timeout int) (stdout, s
cmd.Stdout = &outb cmd.Stdout = &outb
cmd.Stderr = &errb cmd.Stderr = &errb
if cmdErr := cmd.Start(); cmdErr != nil { if cmdErr := cmd.Start(); cmdErr != nil {
//a.Logger.Debugln(cmdErr)
return "", cmdErr.Error(), 65, cmdErr return "", cmdErr.Error(), 65, cmdErr
} }
@ -474,7 +470,7 @@ func OsString() string {
return osFullName return osFullName
} }
func AddDefenderExlusions() { func AddDefenderExlusions() error {
code := ` code := `
Add-MpPreference -ExclusionPath 'C:\Program Files\TacticalAgent\*' Add-MpPreference -ExclusionPath 'C:\Program Files\TacticalAgent\*'
Add-MpPreference -ExclusionPath 'C:\Windows\Temp\winagent-v*.exe' Add-MpPreference -ExclusionPath 'C:\Windows\Temp\winagent-v*.exe'
@ -483,6 +479,8 @@ Add-MpPreference -ExclusionPath 'C:\Program Files\Mesh Agent\*'
` `
_, _, _, err := RunScript(code, "powershell", []string{}, 20) _, _, _, err := RunScript(code, "powershell", []string{}, 20)
if err != nil { if err != nil {
//a.Logger.Debugln(err) return err
} }
return nil
} }

View file

@ -0,0 +1,67 @@
package system_test
import (
"errors"
"strings"
"testing"
"github.com/amidaware/rmmagent/agent/system"
)
func TestRunScript(t *testing.T) {
testTable := []struct {
name string
code string
shell string
args []string
timeout int
expectedStdout string
expectedStderr string
expectedExitCode int
expectedError error
}{
{
name: "Run Script",
code: "Test-Path -Path C:\\Windows",
shell: "powershell",
args: []string{},
timeout: 30,
expectedStdout: "True\r\n",
expectedStderr: "",
expectedExitCode: 0,
expectedError: nil,
},
{
name: "Run Error Script",
code: "Get-ThisError",
shell: "powershell",
args: []string{},
timeout: 30,
expectedStdout: "",
expectedStderr: "The term 'Get-ThisError' is not recognized as the name of a cmdlet",
expectedExitCode: 0,
expectedError: nil,
},
}
for _, tt := range testTable {
t.Run(tt.name, func(t *testing.T) {
stdout, stderr, exitcode, err := system.RunScript(tt.code, tt.shell, tt.args, tt.timeout)
if stdout != tt.expectedStdout {
t.Errorf("expected stdout %s, got %s", tt.expectedStdout, stdout)
}
if !strings.Contains(stderr, tt.expectedStderr) {
t.Errorf("expected stderr to contain %s, got %s", tt.expectedStderr, stderr)
}
if exitcode != tt.expectedExitCode {
t.Errorf("expected exitcode %d, got %d", tt.expectedExitCode, exitcode)
}
if !errors.Is(err, tt.expectedError) {
t.Errorf("expected error (%v), got (%v)", tt.expectedError, err)
}
})
}
}

View file

@ -1,31 +1,40 @@
package utils package utils_test
import ( import (
"testing" "testing"
"github.com/amidaware/rmmagent/agent/utils"
) )
func TestByteCountSI(t *testing.T) { func TestByteCountSI(t *testing.T) {
var bytes uint64 = 1048576 testTable := []struct {
mb := ByteCountSI(bytes) name string
if mb != "1.0 MB" { expected string
t.Errorf("Expected 1.0 MB, got %s", mb) bytes uint64
} }{
} {
name: "Bytes to Kilobytes",
func TestRemoveWinNewLines(t *testing.T) { expected: "1.0 kB",
result := RemoveWinNewLines("test\r\n") bytes: 1024,
if result != "test\n" { },
t.Fatalf("Expected testing\\n, got %s", result) {
name: "Bytes to Megabytes",
expected: "1.0 MB",
bytes: 1048576,
},
{
name: "Bytes to Gigabytes",
expected: "1.0 GB",
bytes: 1073741824,
},
} }
t.Logf("Result: %s", result) for _, tt := range testTable {
} t.Run(tt.name, func(t *testing.T) {
result := utils.ByteCountSI(tt.bytes)
func TestStripAll(t *testing.T) { if result != tt.expected {
result := StripAll(" test\r\n ") t.Errorf("expected %s, got %s", tt.expected, result)
if result != "test" { }
t.Fatalf("Expecte test, got %s", result) })
} }
t.Log("Test result expected")
} }

View file

@ -7,72 +7,73 @@ import (
"github.com/yusufpapurcu/wmi" "github.com/yusufpapurcu/wmi"
) )
func GetWMIInfo() map[string]interface{} { func GetWMIInfo() (map[string]interface{}, []error) {
wmiInfo := make(map[string]interface{}) wmiInfo := make(map[string]interface{})
errs := []error{}
compSysProd, err := GetWin32_ComputerSystemProduct() compSysProd, err := GetWin32_ComputerSystemProduct()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
compSys, err := GetWin32_ComputerSystem() compSys, err := GetWin32_ComputerSystem()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
netAdaptConfig, err := GetWin32_NetworkAdapterConfiguration() netAdaptConfig, err := GetWin32_NetworkAdapterConfiguration()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
physMem, err := GetWin32_PhysicalMemory() physMem, err := GetWin32_PhysicalMemory()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
winOS, err := GetWin32_OperatingSystem() winOS, err := GetWin32_OperatingSystem()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
baseBoard, err := GetWin32_BaseBoard() baseBoard, err := GetWin32_BaseBoard()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
bios, err := GetWin32_BIOS() bios, err := GetWin32_BIOS()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
disk, err := GetWin32_DiskDrive() disk, err := GetWin32_DiskDrive()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
netAdapt, err := GetWin32_NetworkAdapter() netAdapt, err := GetWin32_NetworkAdapter()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
desktopMon, err := GetWin32_DesktopMonitor() desktopMon, err := GetWin32_DesktopMonitor()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
cpu, err := GetWin32_Processor() cpu, err := GetWin32_Processor()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
usb, err := GetWin32_USBController() usb, err := GetWin32_USBController()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
graphics, err := GetWin32_VideoController() graphics, err := GetWin32_VideoController()
if err != nil { if err != nil {
//a.Logger.Debugln(err) errs = append(errs, err)
} }
wmiInfo["comp_sys_prod"] = compSysProd wmiInfo["comp_sys_prod"] = compSysProd
@ -89,7 +90,7 @@ func GetWMIInfo() map[string]interface{} {
wmiInfo["usb"] = usb wmiInfo["usb"] = usb
wmiInfo["graphics"] = graphics wmiInfo["graphics"] = graphics
return wmiInfo return wmiInfo, errs
} }
func GetWin32_ComputerSystemProduct() ([]interface{}, error) { func GetWin32_ComputerSystemProduct() ([]interface{}, error) {
@ -113,6 +114,7 @@ func GetWin32_ComputerSystemProduct() ([]interface{}, error) {
if err := json.Unmarshal(b, &un); err != nil { if err := json.Unmarshal(b, &un); err != nil {
return ret, err return ret, err
} }
tmp = append(tmp, un) tmp = append(tmp, un)
ret = append(ret, tmp) ret = append(ret, tmp)
} }

View file

@ -0,0 +1,37 @@
package wmi_test
import (
"reflect"
"testing"
"github.com/amidaware/rmmagent/agent/wmi"
)
func TestGetWMIInfo(t *testing.T) {
testTable := []struct {
name string
expected map[string]interface{}
atLeast int
expectedErrors []error
}{
{
name: "Get WMI Data",
expected: make(map[string]interface{}),
atLeast: 1,
expectedErrors: []error{},
},
}
for _, tt := range testTable {
t.Run(tt.name, func(t *testing.T) {
result, errs := wmi.GetWMIInfo()
if len(result) < tt.atLeast {
t.Errorf("expected at least %d, got %d", tt.atLeast, len(result))
}
if !reflect.DeepEqual(errs, tt.expectedErrors) {
t.Errorf("expected errors (%v), got (%v)", tt.expectedErrors, errs)
}
})
}
}