organizing and refactoring

This commit is contained in:
redanthrax 2022-06-16 17:04:01 -07:00
parent 13b5474cd8
commit 6f159d4728
20 changed files with 832 additions and 488 deletions

View file

@ -14,8 +14,13 @@ go test -vet=off
Add to settings.json
```
"gopls": {
"build.buildFlags": [
"-tags=DEBUG"
]
},
"go.testFlags": [
"-vet=off"
],
"go.testTags": "TEST"
"go.testTags": "TEST",
```

View file

@ -29,7 +29,6 @@ import (
rmm "github.com/amidaware/rmmagent/shared"
ps "github.com/elastic/go-sysinfo"
gocmd "github.com/go-cmd/cmd"
"github.com/go-resty/resty/v2"
"github.com/kardianos/service"
nats "github.com/nats-io/nats.go"
@ -38,36 +37,6 @@ import (
trmm "github.com/wh1te909/trmm-shared"
)
// Agent struct
type Agent struct {
Hostname string
Arch string
AgentID string
BaseURL string
ApiURL string
Token string
AgentPK int
Cert string
ProgramDir string
EXE string
SystemDrive string
MeshInstaller string
MeshSystemBin string
MeshSVC string
PyBin string
Headers map[string]string
Logger *logrus.Logger
Version string
Debug bool
rClient *resty.Client
Proxy string
LogTo string
LogFile *os.File
Platform string
GoArch string
ServiceConfig *service.Config
}
const (
progFilesName = "TacticalAgent"
winExeName = "tacticalrmm.exe"
@ -167,114 +136,6 @@ func New(logger *logrus.Logger, version string) *Agent {
}
}
type CmdStatus struct {
Status gocmd.Status
Stdout string
Stderr string
}
type CmdOptions struct {
Shell string
Command string
Args []string
Timeout time.Duration
IsScript bool
IsExecutable bool
Detached bool
}
func (a *Agent) NewCMDOpts() *CmdOptions {
return &CmdOptions{
Shell: "/bin/bash",
Timeout: 30,
}
}
func (a *Agent) CmdV2(c *CmdOptions) CmdStatus {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout*time.Second)
defer cancel()
// Disable output buffering, enable streaming
cmdOptions := gocmd.Options{
Buffered: false,
Streaming: true,
}
// have a child process that is in a different process group so that
// parent terminating doesn't kill child
if c.Detached {
cmdOptions.BeforeExec = []func(cmd *exec.Cmd){
func(cmd *exec.Cmd) {
cmd.SysProcAttr = SetDetached()
},
}
}
var envCmd *gocmd.Cmd
if c.IsScript {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, c.Args...) // call script directly
} else if c.IsExecutable {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, c.Command) // c.Shell: bin + c.Command: args as one string
} else {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, "-c", c.Command) // /bin/bash -c 'ls -l /var/log/...'
}
var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
// Print STDOUT and STDERR lines streaming from Cmd
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
// Done when both channels have been closed
// https://dave.cheney.net/2013/04/30/curious-channels
for envCmd.Stdout != nil || envCmd.Stderr != nil {
select {
case line, open := <-envCmd.Stdout:
if !open {
envCmd.Stdout = nil
continue
}
fmt.Fprintln(&stdoutBuf, line)
a.Logger.Debugln(line)
case line, open := <-envCmd.Stderr:
if !open {
envCmd.Stderr = nil
continue
}
fmt.Fprintln(&stderrBuf, line)
a.Logger.Debugln(line)
}
}
}()
// Run and wait for Cmd to return, discard Status
envCmd.Start()
go func() {
select {
case <-doneChan:
return
case <-ctx.Done():
a.Logger.Debugf("Command timed out after %d seconds\n", c.Timeout)
pid := envCmd.Status().PID
a.Logger.Debugln("Killing process with PID", pid)
KillProc(int32(pid))
}
}()
// Wait for goroutine to print everything
<-doneChan
ret := CmdStatus{
Status: envCmd.Status(),
Stdout: CleanString(stdoutBuf.String()),
Stderr: CleanString(stderrBuf.String()),
}
a.Logger.Debugf("%+v\n", ret)
return ret
}
func (a *Agent) GetCPULoadAvg() int {
fallback := false
pyCode := `
@ -326,7 +187,7 @@ func (a *Agent) ForceKillMesh() {
for _, pid := range pids {
a.Logger.Debugln("Killing mesh process with pid %d", pid)
if err := KillProc(int32(pid)); err != nil {
if err := utils.KillProc(int32(pid)); err != nil {
a.Logger.Debugln(err)
}
}
@ -468,3 +329,7 @@ func (a *Agent) CreateTRMMTempDir() {
}
}
}
func (a *Agent) GetDisks() []trmm.Disk {
return disk.GetDisks()
}

View file

@ -18,7 +18,6 @@ import (
"runtime"
"strconv"
"strings"
"syscall"
"time"
rmm "github.com/amidaware/rmmagent/shared"
@ -27,297 +26,9 @@ import (
"github.com/jaypipes/ghw"
"github.com/kardianos/service"
"github.com/shirou/gopsutil/v3/cpu"
"github.com/shirou/gopsutil/v3/disk"
psHost "github.com/shirou/gopsutil/v3/host"
"github.com/spf13/viper"
trmm "github.com/wh1te909/trmm-shared"
)
func ShowStatus(version string) {
fmt.Println(version)
}
func (a *Agent) GetDisks() []trmm.Disk {
ret := make([]trmm.Disk, 0)
partitions, err := disk.Partitions(false)
if err != nil {
a.Logger.Debugln(err)
return ret
}
for _, p := range partitions {
if strings.Contains(p.Device, "dev/loop") {
continue
}
usage, err := disk.Usage(p.Mountpoint)
if err != nil {
a.Logger.Debugln(err)
continue
}
d := trmm.Disk{
Device: p.Device,
Fstype: p.Fstype,
Total: ByteCountSI(usage.Total),
Used: ByteCountSI(usage.Used),
Free: ByteCountSI(usage.Free),
Percent: int(usage.UsedPercent),
}
ret = append(ret, d)
}
return ret
}
func (a *Agent) SystemRebootRequired() (bool, error) {
// deb
paths := [2]string{"/var/run/reboot-required", "/run/reboot-required"}
for _, p := range paths {
if trmm.FileExists(p) {
return true, nil
}
}
// rhel
bins := [2]string{"/usr/bin/needs-restarting", "/bin/needs-restarting"}
for _, bin := range bins {
if trmm.FileExists(bin) {
opts := a.NewCMDOpts()
// https://man7.org/linux/man-pages/man1/needs-restarting.1.html
// -r Only report whether a full reboot is required (exit code 1) or not (exit code 0).
opts.Command = fmt.Sprintf("%s -r", bin)
out := a.CmdV2(opts)
if out.Status.Error != nil {
a.Logger.Debugln("SystemRebootRequired(): ", out.Status.Error.Error())
continue
}
if out.Status.Exit == 1 {
return true, nil
}
return false, nil
}
}
return false, nil
}
func (a *Agent) LoggedOnUser() string {
var ret string
users, err := psHost.Users()
if err != nil {
return ret
}
// return the first logged in user
for _, user := range users {
if user.User != "" {
ret = user.User
break
}
}
return ret
}
func (a *Agent) osString() string {
h, err := psHost.Info()
if err != nil {
return "error getting host info"
}
return fmt.Sprintf("%s %s %s %s", strings.Title(h.Platform), h.PlatformVersion, h.KernelArch, h.KernelVersion)
}
func NewAgentConfig() *rmm.AgentConfig {
viper.SetConfigName("tacticalagent")
viper.SetConfigType("json")
viper.AddConfigPath("/etc/")
viper.AddConfigPath(".")
err := viper.ReadInConfig()
if err != nil {
return &rmm.AgentConfig{}
}
agentpk := viper.GetString("agentpk")
pk, _ := strconv.Atoi(agentpk)
ret := &rmm.AgentConfig{
BaseURL: viper.GetString("baseurl"),
AgentID: viper.GetString("agentid"),
APIURL: viper.GetString("apiurl"),
Token: viper.GetString("token"),
AgentPK: agentpk,
PK: pk,
Cert: viper.GetString("cert"),
Proxy: viper.GetString("proxy"),
CustomMeshDir: viper.GetString("meshdir"),
}
return ret
}
func (a *Agent) RunScript(code string, shell string, args []string, timeout int) (stdout, stderr string, exitcode int, e error) {
code = removeWinNewLines(code)
content := []byte(code)
f, err := createTmpFile()
if err != nil {
a.Logger.Errorln("RunScript createTmpFile()", err)
return "", err.Error(), 85, err
}
defer os.Remove(f.Name())
if _, err := f.Write(content); err != nil {
a.Logger.Errorln(err)
return "", err.Error(), 85, err
}
if err := f.Close(); err != nil {
a.Logger.Errorln(err)
return "", err.Error(), 85, err
}
if err := os.Chmod(f.Name(), 0770); err != nil {
a.Logger.Errorln(err)
return "", err.Error(), 85, err
}
opts := a.NewCMDOpts()
opts.IsScript = true
opts.Shell = f.Name()
opts.Args = args
opts.Timeout = time.Duration(timeout)
out := a.CmdV2(opts)
retError := ""
if out.Status.Error != nil {
retError += CleanString(out.Status.Error.Error())
retError += "\n"
}
if len(out.Stderr) > 0 {
retError += out.Stderr
}
return out.Stdout, retError, out.Status.Exit, nil
}
func SetDetached() *syscall.SysProcAttr {
return &syscall.SysProcAttr{Setpgid: true}
}
func (a *Agent) AgentUpdate(url, inno, version string) {
self, err := os.Executable()
if err != nil {
a.Logger.Errorln("AgentUpdate() os.Executable():", err)
return
}
f, err := createTmpFile()
if err != nil {
a.Logger.Errorln("AgentUpdate createTmpFile()", err)
return
}
defer os.Remove(f.Name())
a.Logger.Infof("Agent updating from %s to %s", a.Version, version)
a.Logger.Infoln("Downloading agent update from", url)
rClient := resty.New()
rClient.SetCloseConnection(true)
rClient.SetTimeout(15 * time.Minute)
rClient.SetDebug(a.Debug)
if len(a.Proxy) > 0 {
rClient.SetProxy(a.Proxy)
}
r, err := rClient.R().SetOutput(f.Name()).Get(url)
if err != nil {
a.Logger.Errorln("AgentUpdate() download:", err)
f.Close()
return
}
if r.IsError() {
a.Logger.Errorln("AgentUpdate() status code:", r.StatusCode())
f.Close()
return
}
f.Close()
os.Chmod(f.Name(), 0755)
err = os.Rename(f.Name(), self)
if err != nil {
a.Logger.Errorln("AgentUpdate() os.Rename():", err)
return
}
opts := a.NewCMDOpts()
opts.Detached = true
opts.Command = "systemctl restart tacticalagent.service"
a.CmdV2(opts)
}
func (a *Agent) AgentUninstall(code string) {
f, err := createTmpFile()
if err != nil {
a.Logger.Errorln("AgentUninstall createTmpFile():", err)
return
}
f.Write([]byte(code))
f.Close()
os.Chmod(f.Name(), 0770)
opts := a.NewCMDOpts()
opts.IsScript = true
opts.Shell = f.Name()
opts.Args = []string{"uninstall"}
opts.Detached = true
a.CmdV2(opts)
}
func (a *Agent) NixMeshNodeID() string {
var meshNodeID string
meshSuccess := false
a.Logger.Debugln("Getting mesh node id")
if !trmm.FileExists(a.MeshSystemBin) {
a.Logger.Debugln(a.MeshSystemBin, "does not exist. Skipping.")
return ""
}
opts := a.NewCMDOpts()
opts.IsExecutable = true
opts.Shell = a.MeshSystemBin
opts.Command = "-nodeid"
for !meshSuccess {
out := a.CmdV2(opts)
meshNodeID = out.Stdout
a.Logger.Debugln("Stdout:", out.Stdout)
a.Logger.Debugln("Stderr:", out.Stderr)
if meshNodeID == "" {
time.Sleep(1 * time.Second)
continue
} else if strings.Contains(strings.ToLower(meshNodeID), "graphical version") || strings.Contains(strings.ToLower(meshNodeID), "zenity") {
time.Sleep(1 * time.Second)
continue
}
meshSuccess = true
}
return meshNodeID
}
func (a *Agent) getMeshNodeID() (string, error) {
return a.NixMeshNodeID(), nil
}
func (a *Agent) RecoverMesh() {
a.Logger.Infoln("Attempting mesh recovery")
opts := a.NewCMDOpts()
opts.Command = "systemctl restart meshagent.service"
a.CmdV2(opts)
a.SyncMeshNodeID()
}
func (a *Agent) GetWMIInfo() map[string]interface{} {
wmiInfo := make(map[string]interface{})
ips := make([]string, 0)

40
agent/disk/disk_linux.go Normal file
View file

@ -0,0 +1,40 @@
package disk
import (
"strings"
d "github.com/shirou/gopsutil/v3/disk"
trmm "github.com/wh1te909/trmm-shared"
"github.com/amidaware/rmmagent/agent/utils"
)
func GetDisks() []trmm.Disk {
ret := make([]trmm.Disk, 0)
partitions, err := d.Partitions(false)
if err != nil {
return nil
}
for _, p := range partitions {
if strings.Contains(p.Device, "dev/loop") {
continue
}
usage, err := d.Usage(p.Mountpoint)
if err != nil {
continue
}
d := trmm.Disk{
Device: p.Device,
Fstype: p.Fstype,
Total: utils.ByteCountSI(usage.Total),
Used: utils.ByteCountSI(usage.Used),
Free: utils.ByteCountSI(usage.Free),
Percent: int(usage.UsedPercent),
}
ret = append(ret, d)
}
return ret
}

View file

@ -0,0 +1,14 @@
package disk
import (
"testing"
)
func TestGetDisks(t *testing.T) {
disks := GetDisks()
if len(disks) == 0 {
t.Fatalf("Could not get disks on linux system.")
}
t.Logf("Got %d disks on linux system", len(disks))
}

39
agent/structs.go Normal file
View file

@ -0,0 +1,39 @@
package agent
import (
"os"
"github.com/go-resty/resty/v2"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
)
// Agent struct
type Agent struct {
Hostname string
Arch string
AgentID string
BaseURL string
ApiURL string
Token string
AgentPK int
Cert string
ProgramDir string
EXE string
SystemDrive string
MeshInstaller string
MeshSystemBin string
MeshSVC string
PyBin string
Headers map[string]string
Logger *logrus.Logger
Version string
Debug bool
rClient *resty.Client
Proxy string
LogTo string
LogFile *os.File
Platform string
GoArch string
ServiceConfig *service.Config
}

13
agent/system/structs.go Normal file
View file

@ -0,0 +1,13 @@
package system
import "time"
type CmdOptions struct {
Shell string
Command string
Args []string
Timeout time.Duration
IsScript bool
IsExecutable bool
Detached bool
}

101
agent/system/system.go Normal file
View file

@ -0,0 +1,101 @@
package system
import (
"bytes"
"context"
"fmt"
"os/exec"
"time"
"github.com/amidaware/rmmagent/agent/utils"
gocmd "github.com/go-cmd/cmd"
)
type CmdStatus struct {
Status gocmd.Status
Stdout string
Stderr string
}
func CmdV2(c *CmdOptions) CmdStatus {
ctx, cancel := context.WithTimeout(context.Background(), c.Timeout * time.Second)
defer cancel()
// Disable output buffering, enable streaming
cmdOptions := gocmd.Options{
Buffered: false,
Streaming: true,
}
// have a child process that is in a different process group so that
// parent terminating doesn't kill child
if c.Detached {
cmdOptions.BeforeExec = []func(cmd *exec.Cmd){
func(cmd *exec.Cmd) {
cmd.SysProcAttr = SetDetached()
},
}
}
var envCmd *gocmd.Cmd
if c.IsScript {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, c.Args...) // call script directly
} else if c.IsExecutable {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, c.Command) // c.Shell: bin + c.Command: args as one string
} else {
envCmd = gocmd.NewCmdOptions(cmdOptions, c.Shell, "-c", c.Command) // /bin/bash -c 'ls -l /var/log/...'
}
var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
// Print STDOUT and STDERR lines streaming from Cmd
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
// Done when both channels have been closed
// https://dave.cheney.net/2013/04/30/curious-channels
for envCmd.Stdout != nil || envCmd.Stderr != nil {
select {
case line, open := <-envCmd.Stdout:
if !open {
envCmd.Stdout = nil
continue
}
fmt.Fprintln(&stdoutBuf, line)
case line, open := <-envCmd.Stderr:
if !open {
envCmd.Stderr = nil
continue
}
fmt.Fprintln(&stderrBuf, line)
}
}
}()
// Run and wait for Cmd to return, discard Status
envCmd.Start()
go func() {
select {
case <-doneChan:
return
case <-ctx.Done():
pid := envCmd.Status().PID
KillProc(int32(pid))
}
}()
// Wait for goroutine to print everything
<-doneChan
ret := CmdStatus{
Status: envCmd.Status(),
Stdout: utils.CleanString(stdoutBuf.String()),
Stderr: utils.CleanString(stderrBuf.String()),
}
return ret
}

View file

@ -0,0 +1,154 @@
package system
import (
"fmt"
"os"
"strings"
"syscall"
"time"
"github.com/amidaware/rmmagent/agent/utils"
"github.com/shirou/gopsutil/process"
psHost "github.com/shirou/gopsutil/v3/host"
"github.com/wh1te909/trmm-shared"
)
func NewCMDOpts() *CmdOptions {
return &CmdOptions{
Shell: "/bin/bash",
Timeout: 30,
}
}
func SetDetached() *syscall.SysProcAttr {
return &syscall.SysProcAttr{Setpgid: true}
}
func ShowStatus(version string) {
fmt.Println(version)
}
func SystemRebootRequired() (bool, error) {
// deb
paths := [2]string{"/var/run/reboot-required", "/run/reboot-required"}
for _, p := range paths {
if trmm.FileExists(p) {
return true, nil
}
}
// rhel
bins := [2]string{"/usr/bin/needs-restarting", "/bin/needs-restarting"}
for _, bin := range bins {
if trmm.FileExists(bin) {
opts := NewCMDOpts()
// https://man7.org/linux/man-pages/man1/needs-restarting.1.html
// -r Only report whether a full reboot is required (exit code 1) or not (exit code 0).
opts.Command = fmt.Sprintf("%s -r", bin)
out := CmdV2(opts)
if out.Status.Error != nil {
continue
}
if out.Status.Exit == 1 {
return true, nil
}
return false, nil
}
}
return false, nil
}
func LoggedOnUser() string {
var ret string
users, err := psHost.Users()
if err != nil {
return ret
}
// return the first logged in user
for _, user := range users {
if user.User != "" {
ret = user.User
break
}
}
return ret
}
func OsString() string {
h, err := psHost.Info()
if err != nil {
return "error getting host info"
}
return fmt.Sprintf("%s %s %s %s", strings.Title(h.Platform), h.PlatformVersion, h.KernelArch, h.KernelVersion)
}
// KillProc kills a process and its children
func KillProc(pid int32) error {
p, err := process.NewProcess(pid)
if err != nil {
return err
}
children, err := p.Children()
if err == nil {
for _, child := range children {
if err := child.Kill(); err != nil {
continue
}
}
}
if err := p.Kill(); err != nil {
return err
}
return nil
}
func RunScript(code string, shell string, args []string, timeout int) (stdout, stderr string, exitcode int, e error) {
code = utils.RemoveWinNewLines(code)
content := []byte(code)
f, err := utils.CreateTmpFile()
if err != nil {
return "", err.Error(), 85, err
}
defer os.Remove(f.Name())
if _, err := f.Write(content); err != nil {
return "", err.Error(), 85, err
}
if err := f.Close(); err != nil {
return "", err.Error(), 85, err
}
if err := os.Chmod(f.Name(), 0770); err != nil {
return "", err.Error(), 85, err
}
opts := NewCMDOpts()
opts.IsScript = true
opts.Shell = f.Name()
opts.Args = args
opts.Timeout = time.Duration(timeout)
out := CmdV2(opts)
retError := ""
if out.Status.Error != nil {
retError += utils.CleanString(out.Status.Error.Error())
retError += "\n"
}
if len(out.Stderr) > 0 {
retError += out.Stderr
}
return out.Stdout, retError, out.Status.Exit, nil
}

View file

@ -0,0 +1,67 @@
package system
import (
"testing"
"github.com/amidaware/rmmagent/agent/utils"
)
func TestNewCMDOpts(t *testing.T) {
opts := NewCMDOpts()
if opts.Shell != "/bin/bash" {
t.Fatalf("Expected /bin/bash, got %s", opts.Shell)
}
}
func TestSystemRebootRequired(t *testing.T) {
required, err := SystemRebootRequired()
if err != nil {
t.Fatal(err)
}
t.Logf("System Reboot Required %t", required)
}
func TestShowStatus(t *testing.T) {
output := utils.CaptureOutput(func() {
ShowStatus("1.0.0")
});
if output != "1.0.0\n" {
t.Fatalf("Expected 1.0.0, got %s", output)
}
}
func TestLoggedOnUser(t *testing.T) {
user := LoggedOnUser()
if user == "" {
t.Fatalf("Expected a user, got empty")
}
t.Logf("Logged on user: %s", user)
}
func TestOsString(t *testing.T) {
osString := OsString()
if osString == "error getting host info" {
t.Fatalf("Unable to get OS string")
}
t.Logf("OS String: %s", osString)
}
func TestRunScript(t *testing.T) {
stdout, stderr, exitcode, err := RunScript("#!/bin/sh\ncat /etc/os-release", "/bin/sh", nil, 30)
if err != nil {
t.Fatal(err)
}
if stderr != "" {
t.Fatal(stderr)
}
if exitcode != 0 {
t.Fatalf("Error: Exit Code %d", exitcode)
}
t.Logf("Result: %s", stdout)
}

View file

@ -0,0 +1,36 @@
package tactical
import (
"time"
"github.com/amidaware/rmmagent/agent/utils"
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
)
func SyncMeshNodeID() bool {
id, err := GetMeshNodeID()
if err != nil {
//a.Logger.Errorln("SyncMeshNodeID() getMeshNodeID()", err)
return false
}
agentConfig := NewAgentConfig()
payload := shared.MeshNodeID{
Func: "syncmesh",
Agentid: agentConfig.AgentID,
NodeID: utils.StripAll(id),
}
client := resty.New()
client.SetBaseURL(agentConfig.BaseURL)
client.SetTimeout(15 * time.Second)
client.SetCloseConnection(true)
if shared.DEBUG {
client.SetDebug(true)
}
_, err = client.R().SetBody(payload).Post("/api/v3/syncmesh/")
return err == nil
}

View file

@ -0,0 +1,170 @@
package tactical
import (
"os"
"strconv"
"strings"
"time"
"github.com/amidaware/rmmagent/agent/system"
"github.com/amidaware/rmmagent/agent/utils"
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
"github.com/spf13/viper"
"github.com/wh1te909/trmm-shared"
)
func GetMeshBinary() string {
return "/opt/tacticalmesh/meshagent"
}
func NewAgentConfig() *shared.AgentConfig {
viper.SetConfigName("tacticalagent")
viper.SetConfigType("json")
viper.AddConfigPath("/etc/")
viper.AddConfigPath(".")
err := viper.ReadInConfig()
if err != nil {
return &shared.AgentConfig{}
}
agentpk := viper.GetString("agentpk")
pk, _ := strconv.Atoi(agentpk)
ret := &shared.AgentConfig{
BaseURL: viper.GetString("baseurl"),
AgentID: viper.GetString("agentid"),
APIURL: viper.GetString("apiurl"),
Token: viper.GetString("token"),
AgentPK: agentpk,
PK: pk,
Cert: viper.GetString("cert"),
Proxy: viper.GetString("proxy"),
CustomMeshDir: viper.GetString("meshdir"),
}
return ret
}
func AgentUpdate(url, inno, version string) bool {
self, err := os.Executable()
if err != nil {
return false
}
f, err := utils.CreateTmpFile()
if err != nil {
return false
}
defer os.Remove(f.Name())
//logger.Infof("Agent updating from %s to %s", a.Version, version)
//logger.Infoln("Downloading agent update from", url)
rClient := resty.New()
rClient.SetCloseConnection(true)
rClient.SetTimeout(15 * time.Minute)
if shared.DEBUG {
rClient.SetDebug(true)
}
config := NewAgentConfig()
if len(config.Proxy) > 0 {
rClient.SetProxy(config.Proxy)
}
r, err := rClient.R().SetOutput(f.Name()).Get(url)
if err != nil {
//a.Logger.Errorln("AgentUpdate() download:", err)
f.Close()
return false
}
if r.IsError() {
//a.Logger.Errorln("AgentUpdate() status code:", r.StatusCode())
f.Close()
return false
}
f.Close()
os.Chmod(f.Name(), 0755)
err = os.Rename(f.Name(), self)
if err != nil {
//a.Logger.Errorln("AgentUpdate() os.Rename():", err)
return false
}
opts := system.NewCMDOpts()
opts.Detached = true
opts.Command = "systemctl restart tacticalagent.service"
system.CmdV2(opts)
return true
}
func AgentUninstall(code string) bool {
f, err := utils.CreateTmpFile()
if err != nil {
//a.Logger.Errorln("AgentUninstall createTmpFile():", err)
return false
}
f.Write([]byte(code))
f.Close()
os.Chmod(f.Name(), 0770)
opts := system.NewCMDOpts()
opts.IsScript = true
opts.Shell = f.Name()
opts.Args = []string{"uninstall"}
opts.Detached = true
system.CmdV2(opts)
return true
}
func NixMeshNodeID() string {
var meshNodeID string
meshSuccess := false
//a.Logger.Debugln("Getting mesh node id")
if !trmm.FileExists(GetMeshBinary()) {
//a.Logger.Debugln(a.MeshSystemBin, "does not exist. Skipping.")
return ""
}
opts := system.NewCMDOpts()
opts.IsExecutable = true
opts.Shell = GetMeshBinary()
opts.Command = "-nodeid"
for !meshSuccess {
out := system.CmdV2(opts)
meshNodeID = out.Stdout
//a.Logger.Debugln("Stdout:", out.Stdout)
//a.Logger.Debugln("Stderr:", out.Stderr)
if meshNodeID == "" {
time.Sleep(1 * time.Second)
continue
} else if strings.Contains(strings.ToLower(meshNodeID), "graphical version") || strings.Contains(strings.ToLower(meshNodeID), "zenity") {
time.Sleep(1 * time.Second)
continue
}
meshSuccess = true
}
return meshNodeID
}
func GetMeshNodeID() (string, error) {
return NixMeshNodeID(), nil
}
func RecoverMesh(agentID string) {
//a.Logger.Infoln("Attempting mesh recovery")
opts := system.NewCMDOpts()
opts.Command = "systemctl restart meshagent.service"
system.CmdV2(opts)
SyncMeshNodeID()
}

View file

@ -0,0 +1,43 @@
package tactical
import (
"testing"
)
func TestNewAgentConfig(t *testing.T) {
config := NewAgentConfig()
if config.BaseURL == "" {
t.Fatal("Could not get config")
}
t.Logf("Config BaseURL: %s", config.BaseURL)
}
func TestAgentUpdate(t *testing.T) {
url := "https://github.com/redanthrax/rmmagent/releases/download/v2.0.4/linuxagent"
result := AgentUpdate(url, "", "v2.0.4")
if(!result) {
t.Fatal("Agent update resulted in false")
}
t.Log("Agent update resulted in true")
}
func TestAgentUninstall(t *testing.T) {
result := AgentUninstall("foo")
if !result {
t.Fatal("Agent uninstall resulted in error")
}
t.Log("Agent uninstall was true")
}
func TestNixMeshNodeID(t *testing.T) {
nodeid := NixMeshNodeID()
if nodeid == "" {
t.Fatal("Unable to get mesh node id")
}
t.Logf("MeshNodeID: %s", nodeid)
}

View file

@ -0,0 +1,17 @@
package tactical
import "testing"
func TestSyncMeshNodeID(t *testing.T) {
agentConfig := NewAgentConfig()
if agentConfig.AgentID == "" {
t.Fatal("Could not get AgentID")
}
result := SyncMeshNodeID()
if !result {
t.Fatal("SyncMeshNodeID resulted in error")
}
t.Log("Synced mesh node id")
}

View file

@ -172,36 +172,6 @@ func IsValidIP(ip string) bool {
return net.ParseIP(ip) != nil
}
// StripAll strips all whitespace and newline chars
func StripAll(s string) string {
s = strings.TrimSpace(s)
s = strings.Trim(s, "\n")
s = strings.Trim(s, "\r")
return s
}
// KillProc kills a process and its children
func KillProc(pid int32) error {
p, err := process.NewProcess(pid)
if err != nil {
return err
}
children, err := p.Children()
if err == nil {
for _, child := range children {
if err := child.Kill(); err != nil {
continue
}
}
}
if err := p.Kill(); err != nil {
return err
}
return nil
}
// DjangoStringResp removes double quotes from django rest api resp
func DjangoStringResp(resp string) string {
return strings.Trim(resp, `"`)
@ -216,13 +186,6 @@ func TestTCP(addr string) error {
return nil
}
// CleanString removes invalid utf-8 byte sequences
func CleanString(s string) string {
r := strings.NewReplacer("\x00", "")
s = r.Replace(s)
return strings.ToValidUTF8(s, "")
}
// https://golangcode.com/unzip-files-in-go/
func Unzip(src, dest string) error {
r, err := zip.OpenReader(src)
@ -299,22 +262,6 @@ func randomCheckDelay() {
time.Sleep(time.Duration(randRange(300, 950)) * time.Millisecond)
}
func removeWinNewLines(s string) string {
return strings.ReplaceAll(s, "\r\n", "\n")
}
func createTmpFile() (*os.File, error) {
var f *os.File
f, err := os.CreateTemp("", "trmm")
if err != nil {
cwd, err := os.Getwd()
if err != nil {
return f, err
}
f, err = os.CreateTemp(cwd, "trmm")
if err != nil {
return f, err
}
}
return f, nil
}

89
agent/utils/utils.go Normal file
View file

@ -0,0 +1,89 @@
package utils
import (
"bytes"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
)
func CaptureOutput(f func()) string {
old := os.Stdout
r, w, _ := os.Pipe()
os.Stdout = w
f()
w.Close()
os.Stdout = old
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func ByteCountSI(b uint64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB",
float64(b)/float64(div), "kMGTPE"[exp])
}
// CleanString removes invalid utf-8 byte sequences
func CleanString(s string) string {
r := strings.NewReplacer("\x00", "")
s = r.Replace(s)
return strings.ToValidUTF8(s, "")
}
func RemoveWinNewLines(s string) string {
return strings.ReplaceAll(s, "\r\n", "\n")
}
func CreateTmpFile() (*os.File, error) {
var f *os.File
f, err := os.CreateTemp("", "trmm")
if err != nil {
cwd, err := os.Getwd()
if err != nil {
return f, err
}
f, err = os.CreateTemp(cwd, "trmm")
if err != nil {
return f, err
}
}
return f, nil
}
func WebRequest(requestType string, timeout time.Duration, payload map[string]string, url string, proxy string) (response resty.Response, err error) {
client := resty.New()
client.SetTimeout(timeout * time.Second)
client.SetCloseConnection(true)
if shared.DEBUG {
client.SetDebug(true)
}
result, err := client.R().Get(url)
return *result, err
}
// StripAll strips all whitespace and newline chars
func StripAll(s string) string {
s = strings.TrimSpace(s)
s = strings.Trim(s, "\n")
s = strings.Trim(s, "\r")
return s
}

31
agent/utils/utils_test.go Normal file
View file

@ -0,0 +1,31 @@
package utils
import (
"testing"
)
func TestByteCountSI(t *testing.T) {
var bytes uint64 = 1048576
mb := ByteCountSI(bytes)
if mb != "1.0 MB" {
t.Errorf("Expected 1.0 MB, got %s", mb)
}
}
func TestRemoveWinNewLines(t *testing.T) {
result := RemoveWinNewLines("test\r\n")
if result != "test\n" {
t.Fatalf("Expected testing\\n, got %s", result)
}
t.Logf("Result: %s", result)
}
func TestStripAll(t *testing.T) {
result := StripAll(" test\r\n ")
if result != "test" {
t.Fatalf("Expecte test, got %s", result)
}
t.Log("Test result expected")
}

1
go.mod
View file

@ -59,6 +59,7 @@ require (
github.com/rickb777/plural v1.4.1 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
github.com/scjalliance/comshim v0.0.0-20190308082608-cf06d2532c4e // indirect
github.com/shirou/gopsutil v3.21.11+incompatible // indirect
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/cast v1.5.0 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect

2
go.sum
View file

@ -255,6 +255,8 @@ github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XF
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/scjalliance/comshim v0.0.0-20190308082608-cf06d2532c4e h1:+/AzLkOdIXEPrAQtwAeWOBnPQ0BnYlBW0aCZmSb47u4=
github.com/scjalliance/comshim v0.0.0-20190308082608-cf06d2532c4e/go.mod h1:9Tc1SKnfACJb9N7cw2eyuI6xzy845G7uZONBsi5uPEA=
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
github.com/shirou/gopsutil/v3 v3.22.5 h1:atX36I/IXgFiB81687vSiBI5zrMsxcIBkP9cQMJQoJA=
github.com/shirou/gopsutil/v3 v3.22.5/go.mod h1:so9G9VzeHt/hsd0YwqprnjHnfARAUktauykSbr+y2gA=
github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE=

View file

@ -18,7 +18,6 @@ import (
"os/user"
"path/filepath"
"runtime"
"github.com/amidaware/rmmagent/agent"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"