From 4e656c6556297817cbace1d3027808ffc30edff3 Mon Sep 17 00:00:00 2001 From: redanthrax Date: Mon, 20 Jun 2022 22:13:37 -0700 Subject: [PATCH] task tests --- agent/tasks/structs.go | 2 - agent/tasks/tasks_windows.go | 22 ++++----- agent/tasks/tasks_windows_test.go | 76 +++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 15 deletions(-) delete mode 100644 agent/tasks/structs.go create mode 100644 agent/tasks/tasks_windows_test.go diff --git a/agent/tasks/structs.go b/agent/tasks/structs.go deleted file mode 100644 index 5ba87cf..0000000 --- a/agent/tasks/structs.go +++ /dev/null @@ -1,2 +0,0 @@ -package tasks - diff --git a/agent/tasks/tasks_windows.go b/agent/tasks/tasks_windows.go index d5960cb..4798f8f 100644 --- a/agent/tasks/tasks_windows.go +++ b/agent/tasks/tasks_windows.go @@ -11,18 +11,15 @@ import ( ) func CreateSchedTask(st SchedTask) (bool, error) { - //a.Logger.Debugf("%+v\n", st) conn, err := taskmaster.Connect() if err != nil { - //a.Logger.Errorln(err) return false, err } - defer conn.Disconnect() + defer conn.Disconnect() var trigger taskmaster.Trigger var action taskmaster.ExecAction var tasktrigger taskmaster.TaskTrigger - var now = time.Now() if st.Trigger == "manual" { tasktrigger = taskmaster.TaskTrigger{ @@ -47,7 +44,6 @@ func CreateSchedTask(st SchedTask) (bool, error) { var path, workdir, args string def := conn.NewTaskDefinition() - switch st.Trigger { case "runonce": trigger = taskmaster.TimeTrigger{ @@ -96,10 +92,11 @@ func CreateSchedTask(st SchedTask) (bool, error) { def.AddTrigger(trigger) + currentDir, _ := os.Getwd() switch st.Type { case "rmm": - path = winExeName - workdir = a.ProgramDir + path = os.Args[0] + workdir = currentDir args = fmt.Sprintf("-m taskrunner -p %d", st.PK) case "schedreboot": path = "shutdown.exe" @@ -116,8 +113,8 @@ func CreateSchedTask(st SchedTask) (bool, error) { WorkingDir: workdir, Args: args, } - def.AddAction(action) + def.AddAction(action) def.Principal.RunLevel = taskmaster.TASK_RUNLEVEL_HIGHEST def.Principal.LogonType = taskmaster.TASK_LOGON_SERVICE_ACCOUNT def.Principal.UserID = "SYSTEM" @@ -139,7 +136,6 @@ func CreateSchedTask(st SchedTask) (bool, error) { _, success, err := conn.CreateTask(fmt.Sprintf("\\%s", st.Name), def, st.Overwrite) if err != nil { - //a.Logger.Errorln(err) return false, err } @@ -161,17 +157,17 @@ func DeleteSchedTask(name string) error { return nil } -func ListSchedTasks() []string { +func ListSchedTasks() ([]string, error) { ret := make([]string, 0) conn, err := taskmaster.Connect() if err != nil { - return ret + return ret, err } defer conn.Disconnect() tasks, err := conn.GetRegisteredTasks() if err != nil { - return ret + return ret, err } for _, task := range tasks { @@ -179,7 +175,7 @@ func ListSchedTasks() []string { } tasks.Release() - return ret + return ret, nil } // CleanupSchedTasks removes all tacticalrmm sched tasks during uninstall diff --git a/agent/tasks/tasks_windows_test.go b/agent/tasks/tasks_windows_test.go new file mode 100644 index 0000000..3a79a62 --- /dev/null +++ b/agent/tasks/tasks_windows_test.go @@ -0,0 +1,76 @@ +package tasks_test + +import ( + "errors" + "testing" + + "github.com/amidaware/rmmagent/agent/tasks" + "github.com/amidaware/taskmaster" +) + +func TestCreateSchedTask(t *testing.T) { + testTask := tasks.SchedTask{ + PK: 0, + Name: "Test Task", + Trigger: "manual", + Enabled: false, + Type: "rmm", + TaskPolicy: taskmaster.TASK_INSTANCES_IGNORE_NEW, + DeleteAfter: true, + Overwrite: true, + } + + testTable := []struct { + name string + expected bool + expectedError error + }{ + { + name: "Create Sched Task", + expected: true, + expectedError: nil, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := tasks.CreateSchedTask(testTask) + if !result { + t.Errorf("Expected %t, got %t", tt.expected, result) + } + + if !errors.Is(tt.expectedError, err) { + t.Errorf("expected (%v), got (%v)", tt.expectedError, err) + } + }) + } +} + +func TestListSchedTasks(t *testing.T) { + testTable := []struct { + name string + expected []string + atLeast int + expectedError error + }{ + { + name: "List Sched Task", + expected: []string{}, + atLeast: 1, + expectedError: nil, + }, + } + + for _, tt := range testTable { + t.Run(tt.name, func(t *testing.T) { + result, err := tasks.ListSchedTasks() + if len(result) < tt.atLeast { + t.Errorf("expect at least %d, got %d", tt.atLeast, len(result)) + } + + if !errors.Is(tt.expectedError, err) { + t.Errorf("expected (%v), got (%v)", tt.expectedError, err) + } + }) + } +}