refactor and testing

This commit is contained in:
redanthrax 2022-06-21 11:17:38 -07:00
parent 91c9de6e34
commit c038774f2c
16 changed files with 445 additions and 195 deletions

View file

@ -80,6 +80,7 @@ func New(logger *logrus.Logger, version string) *Agent {
if len(ac.Proxy) > 0 {
restyC.SetProxy(ac.Proxy)
}
if len(ac.Cert) > 0 {
restyC.SetRootCertificate(ac.Cert)
}
@ -137,22 +138,6 @@ func New(logger *logrus.Logger, version string) *Agent {
}
}
type CmdStatus struct {
Status gocmd.Status
Stdout string
Stderr string
}
type CmdOptions struct {
Shell string
Command string
Args []string
Timeout time.Duration
IsScript bool
IsExecutable bool
Detached bool
}
func (a *Agent) NewCMDOpts() *CmdOptions {
return &CmdOptions{
Shell: "/bin/bash",

92
agent/agent_linux.go Normal file
View file

@ -0,0 +1,92 @@
package agent
func New(logger *logrus.Logger, version string) *Agent {
host, _ := ps.Host()
info := host.Info()
pd := filepath.Join(os.Getenv("ProgramFiles"), progFilesName)
exe := filepath.Join(pd, winExeName)
sd := os.Getenv("SystemDrive")
var pybin string
switch runtime.GOARCH {
case "amd64":
pybin = filepath.Join(pd, "py38-x64", "python.exe")
case "386":
pybin = filepath.Join(pd, "py38-x32", "python.exe")
}
ac := NewAgentConfig()
headers := make(map[string]string)
if len(ac.Token) > 0 {
headers["Content-Type"] = "application/json"
headers["Authorization"] = fmt.Sprintf("Token %s", ac.Token)
}
restyC := resty.New()
restyC.SetBaseURL(ac.BaseURL)
restyC.SetCloseConnection(true)
restyC.SetHeaders(headers)
restyC.SetTimeout(15 * time.Second)
restyC.SetDebug(logger.IsLevelEnabled(logrus.DebugLevel))
if len(ac.Proxy) > 0 {
restyC.SetProxy(ac.Proxy)
}
if len(ac.Cert) > 0 {
restyC.SetRootCertificate(ac.Cert)
}
var MeshSysBin string
if len(ac.CustomMeshDir) > 0 {
MeshSysBin = filepath.Join(ac.CustomMeshDir, "MeshAgent.exe")
} else {
MeshSysBin = filepath.Join(os.Getenv("ProgramFiles"), "Mesh Agent", "MeshAgent.exe")
}
if runtime.GOOS == "linux" {
MeshSysBin = "/opt/tacticalmesh/meshagent"
}
svcConf := &service.Config{
Executable: exe,
Name: winSvcName,
DisplayName: "TacticalRMM Agent Service",
Arguments: []string{"-m", "svc"},
Description: "TacticalRMM Agent Service",
Option: service.KeyValue{
"StartType": "automatic",
"OnFailure": "restart",
"OnFailureDelayDuration": "5s",
"OnFailureResetPeriod": 10,
},
}
return &Agent{
Hostname: info.Hostname,
Arch: info.Architecture,
BaseURL: ac.BaseURL,
AgentID: ac.AgentID,
ApiURL: ac.APIURL,
Token: ac.Token,
AgentPK: ac.PK,
Cert: ac.Cert,
ProgramDir: pd,
EXE: exe,
SystemDrive: sd,
MeshInstaller: "meshagent.exe",
MeshSystemBin: MeshSysBin,
MeshSVC: meshSvcName,
PyBin: pybin,
Headers: headers,
Logger: logger,
Version: version,
Debug: logger.IsLevelEnabled(logrus.DebugLevel),
rClient: restyC,
Proxy: ac.Proxy,
Platform: runtime.GOOS,
GoArch: runtime.GOARCH,
ServiceConfig: svcConf,
}
}

View file

@ -2,7 +2,9 @@ package agent
import (
"os"
"time"
gocmd "github.com/go-cmd/cmd"
"github.com/go-resty/resty/v2"
"github.com/kardianos/service"
"github.com/sirupsen/logrus"
@ -37,3 +39,31 @@ type Agent struct {
GoArch string
ServiceConfig *service.Config
}
type AgentConfig struct {
BaseURL string
AgentID string
APIURL string
Token string
AgentPK string
PK int
Cert string
Proxy string
CustomMeshDir string
}
type CmdStatus struct {
Status gocmd.Status
Stdout string
Stderr string
}
type CmdOptions struct {
Shell string
Command string
Args []string
Timeout time.Duration
IsScript bool
IsExecutable bool
Detached bool
}

View file

@ -97,7 +97,7 @@ func CmdV2(c *CmdOptions) CmdStatus {
return
case <-ctx.Done():
pid := envCmd.Status().PID
utils.KillProc(int32(pid))
KillProc(int32(pid))
}
}()

View file

@ -18,6 +18,7 @@ import (
ps "github.com/elastic/go-sysinfo"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/shirou/gopsutil/process"
wapf "github.com/wh1te909/go-win64api"
trmm "github.com/wh1te909/trmm-shared"
"golang.org/x/sys/windows"
@ -102,7 +103,7 @@ func RunScript(code string, shell string, args []string, timeout int) (stdout, s
<-ctx.Done()
_ = utils.KillProc(p)
_ = KillProc(p)
timedOut = true
}(pid)
@ -223,7 +224,7 @@ func CMDShell(shell string, cmdArgs []string, command string, timeout int, detac
<-ctx.Done()
_ = utils.KillProc(p)
_ = KillProc(p)
timedOut = true
}(pid)
@ -441,7 +442,7 @@ func KillHungUpdates() {
}
if strings.Contains(p.Exe, "winagent-v") {
//a.Logger.Debugln("killing process", p.Exe)
utils.KillProc(int32(p.PID))
KillProc(int32(p.PID))
}
}
}
@ -484,3 +485,26 @@ Add-MpPreference -ExclusionPath 'C:\Program Files\Mesh Agent\*'
return nil
}
// KillProc kills a process and its children
func KillProc(pid int32) error {
p, err := process.NewProcess(pid)
if err != nil {
return err
}
children, err := p.Children()
if err == nil {
for _, child := range children {
if err := child.Kill(); err != nil {
continue
}
}
}
if err := p.Kill(); err != nil {
return err
}
return nil
}

View file

@ -1 +1,54 @@
package api
import (
"fmt"
"time"
"github.com/amidaware/rmmagent/agent/tactical/config"
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
)
var restyC resty.Client
func init() {
ac := config.NewAgentConfig()
headers := make(map[string]string)
if len(ac.Token) > 0 {
headers["Content-Type"] = "application/json"
headers["Authorization"] = fmt.Sprintf("Token %s", ac.Token)
}
restyC := resty.New()
restyC.SetBaseURL(ac.BaseURL)
restyC.SetCloseConnection(true)
restyC.SetHeaders(headers)
restyC.SetTimeout(15 * time.Second)
restyC.SetDebug(shared.DEBUG)
if len(ac.Proxy) > 0 {
restyC.SetProxy(ac.Proxy)
}
if len(ac.Cert) > 0 {
restyC.SetRootCertificate(ac.Cert)
}
}
func PostPayload(payload interface{}, url string) error {
_, err := restyC.R().SetBody(payload).Post("/api/v3/syncmesh/")
if err != nil {
return err
}
return nil
}
func GetResult(result interface{}, url string) (*resty.Response, error) {
r, err := restyC.R().SetResult(result).Get(url)
if err != nil {
return nil, err
}
return r, nil
}

View file

@ -3,50 +3,45 @@ package checks
import (
"encoding/json"
"fmt"
"runtime"
"sync"
"time"
"github.com/amidaware/rmmagent/agent/system"
"github.com/amidaware/rmmagent/agent/tactical/api"
"github.com/amidaware/rmmagent/agent/utils"
rmm "github.com/amidaware/rmmagent/shared"
ps "github.com/elastic/go-sysinfo"
"github.com/go-resty/resty/v2"
)
func CheckRunner(agentID string) {
func CheckRunner(agentID string) error {
sleepDelay := utils.RandRange(14, 22)
//a.Logger.Debugf("CheckRunner() init sleeping for %v seconds", sleepDelay)
time.Sleep(time.Duration(sleepDelay) * time.Second)
for {
interval, err := GetCheckInterval(agentID)
if err == nil && !ChecksRunning() {
if runtime.GOOS == "windows" {
_, err = system.CMD(system.GetProgramEXE(), []string{"-m", "checkrunner"}, 600, false)
if err != nil {
//a.Logger.Errorln("Checkrunner RunChecks", err)
}
} else {
RunChecks(agentID, false)
_, err = system.CMD(system.GetProgramEXE(), []string{"-m", "checkrunner"}, 600, false)
if err != nil {
return err
}
}
//a.Logger.Debugln("Checkrunner sleeping for", interval)
time.Sleep(time.Duration(interval) * time.Second)
}
return nil
}
func GetCheckInterval(agentID string) (int, error) {
r, err := a.rClient.R().SetResult(&rmm.CheckInfo{}).Get(fmt.Sprintf("/api/v3/%s/checkinterval/", a.AgentID))
r, err := api.GetResult(CheckInfo{}, fmt.Sprintf("/api/v3/%s/checkinterval/", agentID))
if err != nil {
a.Logger.Debugln(err)
return 120, err
}
if r.IsError() {
a.Logger.Debugln("Checkinterval response code:", r.StatusCode())
return 120, fmt.Errorf("checkinterval response code: %v", r.StatusCode())
}
interval := r.Result().(*rmm.CheckInfo).Interval
interval := r.Result().(*CheckInfo).Interval
return interval, nil
}
@ -90,7 +85,7 @@ func RunChecks(agentID string, force bool) error {
} else {
url = fmt.Sprintf("/api/v3/%s/checkrunner/", agentID)
}
r, err := a.rClient.R().Get(url)
if err != nil {
//a.Logger.Debugln(err)
@ -177,4 +172,4 @@ func RunChecks(agentID string, force bool) error {
wg.Wait()
return nil
}
}

View file

@ -0,0 +1,6 @@
package checks
type CheckInfo struct {
AgentPK int `json:"agent"`
Interval int `json:"check_interval"`
}

View file

@ -0,0 +1,37 @@
package config
import (
"strconv"
"golang.org/x/sys/windows/registry"
)
func NewAgentConfig() *AgentConfig {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\TacticalRMM`, registry.ALL_ACCESS)
if err != nil {
return &AgentConfig{}
}
baseurl, _, _ := k.GetStringValue("BaseURL")
agentid, _, _ := k.GetStringValue("AgentID")
apiurl, _, _ := k.GetStringValue("ApiURL")
token, _, _ := k.GetStringValue("Token")
agentpk, _, _ := k.GetStringValue("AgentPK")
pk, _ := strconv.Atoi(agentpk)
cert, _, _ := k.GetStringValue("Cert")
proxy, _, _ := k.GetStringValue("Proxy")
customMeshDir, _, _ := k.GetStringValue("MeshDir")
return &AgentConfig{
BaseURL: baseurl,
AgentID: agentid,
APIURL: apiurl,
Token: token,
AgentPK: agentpk,
PK: pk,
Cert: cert,
Proxy: proxy,
CustomMeshDir: customMeshDir,
}
}

View file

@ -0,0 +1,13 @@
package config
type AgentConfig struct {
BaseURL string
AgentID string
APIURL string
Token string
AgentPK string
PK int
Cert string
Proxy string
CustomMeshDir string
}

View file

@ -0,0 +1,26 @@
package mesh
import (
"github.com/amidaware/rmmagent/agent/tactical/api"
"github.com/amidaware/rmmagent/agent/utils"
)
func SyncMeshNodeID(agentID string) error {
id, err := GetMeshNodeID()
if err != nil {
return err
}
payload := MeshNodeID{
Func: "syncmesh",
Agentid: agentID,
NodeID: utils.StripAll(id),
}
err = api.PostPayload(payload, "/api/v3/syncmesh/")
if err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,123 @@
package mesh
import (
"errors"
"os"
"path/filepath"
"strings"
"github.com/amidaware/rmmagent/agent/system"
"github.com/amidaware/rmmagent/agent/tactical/config"
"github.com/amidaware/rmmagent/agent/utils"
ps "github.com/elastic/go-sysinfo"
)
// ForceKillMesh kills all mesh agent related processes
func ForceKillMesh() error {
pids := make([]int, 0)
procs, err := ps.Processes()
if err != nil {
return err
}
for _, process := range procs {
p, err := process.Info()
if err != nil {
continue
}
if strings.Contains(strings.ToLower(p.Name), "meshagent") {
pids = append(pids, p.PID)
}
}
for _, pid := range pids {
if err := system.KillProc(int32(pid)); err != nil {
return err
}
}
return nil
}
func GetMeshNodeID() (string, error) {
out, err := system.CMD(getMeshBinLocation(), []string{"-nodeid"}, 10, false)
if err != nil {
return "", err
}
stdout := out[0]
stderr := out[1]
if stderr != "" {
return "", err
}
if stdout == "" || strings.Contains(strings.ToLower(utils.StripAll(stdout)), "not defined") {
return "", errors.New("failed to get mesh node id")
}
return stdout, nil
}
func getMeshBinLocation() string {
ac := config.NewAgentConfig()
var MeshSysBin string
if len(ac.CustomMeshDir) > 0 {
MeshSysBin = filepath.Join(ac.CustomMeshDir, "MeshAgent.exe")
} else {
MeshSysBin = filepath.Join(os.Getenv("ProgramFiles"), "Mesh Agent", "MeshAgent.exe")
}
return MeshSysBin
}
func installMesh(meshbin, exe, proxy string) (string, error) {
var meshNodeID string
meshInstallArgs := []string{"-fullinstall"}
if len(proxy) > 0 {
meshProxy := fmt.Sprintf("--WebProxy=%s", proxy)
meshInstallArgs = append(meshInstallArgs, meshProxy)
}
//a.Logger.Debugln("Mesh install args:", meshInstallArgs)
meshOut, meshErr := system.CMD(meshbin, meshInstallArgs, int(90), false)
if meshErr != nil {
fmt.Println(meshOut[0])
fmt.Println(meshOut[1])
fmt.Println(meshErr)
}
fmt.Println(meshOut)
//a.Logger.Debugln("Sleeping for 5")
time.Sleep(5 * time.Second)
meshSuccess := false
for !meshSuccess {
//a.Logger.Debugln("Getting mesh node id")
pMesh, pErr := system.CMD(exe, []string{"-nodeid"}, int(30), false)
if pErr != nil {
//a.Logger.Errorln(pErr)
time.Sleep(5 * time.Second)
continue
}
if pMesh[1] != "" {
//a.Logger.Errorln(pMesh[1])
time.Sleep(5 * time.Second)
continue
}
meshNodeID = utils.StripAll(pMesh[0])
//a.Logger.Debugln("Node id:", meshNodeID)
if strings.Contains(strings.ToLower(meshNodeID), "not defined") {
//a.Logger.Errorln(meshNodeID)
time.Sleep(5 * time.Second)
continue
}
meshSuccess = true
}
return meshNodeID, nil
}

View file

@ -0,0 +1,7 @@
package mesh
type MeshNodeID struct {
Func string `json:"func"`
Agentid string `json:"agent_id"`
NodeID string `json:"nodeid"`
}

View file

@ -1,46 +1 @@
package tactical
import (
"time"
"github.com/amidaware/rmmagent/agent/utils"
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
)
func PostRequest(url string, body interface{}, timeout time.Duration) (resty.Response, error) {
agentConfig := NewAgentConfig()
client := resty.New()
client.SetBaseURL(agentConfig.BaseURL)
client.SetTimeout(timeout * time.Second)
client.SetCloseConnection(true)
if len(agentConfig.Proxy) > 0 {
client.SetProxy(agentConfig.Proxy)
}
if shared.DEBUG {
client.SetDebug(true)
}
response, err := client.R().SetBody(body).Post(url)
return *response, err
}
func SyncMeshNodeID() bool {
id, err := GetMeshNodeID()
if err != nil {
//a.Logger.Errorln("SyncMeshNodeID() getMeshNodeID()", err)
return false
}
agentConfig := NewAgentConfig()
payload := shared.MeshNodeID{
Func: "syncmesh",
Agentid: agentConfig.AgentID,
NodeID: utils.StripAll(id),
}
_, err = PostRequest("/api/v3/syncmesh/", payload, 15)
return err == nil
}

View file

@ -1,21 +1,19 @@
package tactical
import (
"errors"
"fmt"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/amidaware/rmmagent/agent/patching"
"github.com/amidaware/rmmagent/agent/services"
"github.com/amidaware/rmmagent/agent/software"
"github.com/amidaware/rmmagent/agent/system"
"github.com/amidaware/rmmagent/agent/tactical/config"
"github.com/amidaware/rmmagent/agent/tactical/rpc"
"github.com/amidaware/rmmagent/agent/tasks"
"github.com/amidaware/rmmagent/agent/utils"
@ -27,87 +25,6 @@ import (
"golang.org/x/sys/windows/registry"
)
func GetMeshBinary() string {
var MeshSysBin string
ac := NewAgentConfig()
if len(ac.CustomMeshDir) > 0 {
MeshSysBin = filepath.Join(ac.CustomMeshDir, "MeshAgent.exe")
} else {
MeshSysBin = filepath.Join(os.Getenv("ProgramFiles"), "Mesh Agent", "MeshAgent.exe")
}
return MeshSysBin
}
func NewAgentConfig() *rmm.AgentConfig {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, `SOFTWARE\TacticalRMM`, registry.ALL_ACCESS)
if err != nil {
return &rmm.AgentConfig{}
}
baseurl, _, _ := k.GetStringValue("BaseURL")
agentid, _, _ := k.GetStringValue("AgentID")
apiurl, _, _ := k.GetStringValue("ApiURL")
token, _, _ := k.GetStringValue("Token")
agentpk, _, _ := k.GetStringValue("AgentPK")
pk, _ := strconv.Atoi(agentpk)
cert, _, _ := k.GetStringValue("Cert")
proxy, _, _ := k.GetStringValue("Proxy")
customMeshDir, _, _ := k.GetStringValue("MeshDir")
return &rmm.AgentConfig{
BaseURL: baseurl,
AgentID: agentid,
APIURL: apiurl,
Token: token,
AgentPK: agentpk,
PK: pk,
Cert: cert,
Proxy: proxy,
CustomMeshDir: customMeshDir,
}
}
func GetMeshNodeID() (string, error) {
out, err := system.CMD(GetMeshBinary(), []string{"-nodeid"}, 10, false)
if err != nil {
//a.Logger.Debugln(err)
return "", err
}
stdout := out[0]
stderr := out[1]
if stderr != "" {
//a.Logger.Debugln(stderr)
return "", err
}
if stdout == "" || strings.Contains(strings.ToLower(utils.StripAll(stdout)), "not defined") {
//a.Logger.Debugln("Failed getting mesh node id", stdout)
return "", errors.New("failed to get mesh node id")
}
return stdout, nil
}
func SendSoftware() {
sw := software.GetInstalledSoftware()
//a.Logger.Debugln(sw)
config := NewAgentConfig()
payload := map[string]interface{}{
"agent_id": config.AgentID,
"software": sw,
}
_, err := PostRequest("/api/v3/software/", payload, 15)
if err != nil {
//a.Logger.Debugln(err)
}
}
func UninstallCleanup() {
registry.DeleteKey(registry.LOCAL_MACHINE, `SOFTWARE\TacticalRMM`)
patching.PatchMgmnt(false)
@ -123,7 +40,7 @@ func AgentUpdate(url, inno, version string) {
//a.Logger.Infof("Agent updating from %s to %s", a.Version, version)
//a.Logger.Infoln("Downloading agent update from", url)
config := NewAgentConfig()
config := config.NewAgentConfig()
rClient := resty.New()
rClient.SetCloseConnection(true)
rClient.SetTimeout(15 * time.Minute)

View file

@ -14,8 +14,6 @@ import (
"github.com/amidaware/rmmagent/shared"
"github.com/go-resty/resty/v2"
"github.com/shirou/gopsutil/v3/process"
trmm "github.com/wh1te909/trmm-shared"
)
func CaptureOutput(f func()) string {
@ -94,37 +92,17 @@ func StripAll(s string) string {
return s
}
// KillProc kills a process and its children
func KillProc(pid int32) error {
p, err := process.NewProcess(pid)
if err != nil {
return err
}
children, err := p.Children()
if err == nil {
for _, child := range children {
if err := child.Kill(); err != nil {
continue
}
}
}
if err := p.Kill(); err != nil {
return err
}
return nil
}
func CreateTRMMTempDir() {
func CreateTRMMTempDir() error {
// create the temp dir for running scripts
dir := filepath.Join(os.TempDir(), "trmm")
if !trmm.FileExists(dir) {
if !FileExists(dir) {
err := os.Mkdir(dir, 0775)
if err != nil {
//a.Logger.Errorln(err)
return err
}
}
return nil
}
func RandRange(min, max int) int {
@ -199,3 +177,12 @@ func BytesToString(b []byte) (string, uint32) {
}
return string(utf16.Decode(s)), uint32(i * 2)
}
func FileExists(path string) bool {
if _, err := os.Stat(path); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}