task tests

This commit is contained in:
redanthrax 2022-06-20 22:13:37 -07:00
parent 7998bbe0eb
commit 4e656c6556
3 changed files with 85 additions and 15 deletions

View file

@ -1,2 +0,0 @@
package tasks

View file

@ -11,18 +11,15 @@ import (
) )
func CreateSchedTask(st SchedTask) (bool, error) { func CreateSchedTask(st SchedTask) (bool, error) {
//a.Logger.Debugf("%+v\n", st)
conn, err := taskmaster.Connect() conn, err := taskmaster.Connect()
if err != nil { if err != nil {
//a.Logger.Errorln(err)
return false, err return false, err
} }
defer conn.Disconnect()
defer conn.Disconnect()
var trigger taskmaster.Trigger var trigger taskmaster.Trigger
var action taskmaster.ExecAction var action taskmaster.ExecAction
var tasktrigger taskmaster.TaskTrigger var tasktrigger taskmaster.TaskTrigger
var now = time.Now() var now = time.Now()
if st.Trigger == "manual" { if st.Trigger == "manual" {
tasktrigger = taskmaster.TaskTrigger{ tasktrigger = taskmaster.TaskTrigger{
@ -47,7 +44,6 @@ func CreateSchedTask(st SchedTask) (bool, error) {
var path, workdir, args string var path, workdir, args string
def := conn.NewTaskDefinition() def := conn.NewTaskDefinition()
switch st.Trigger { switch st.Trigger {
case "runonce": case "runonce":
trigger = taskmaster.TimeTrigger{ trigger = taskmaster.TimeTrigger{
@ -96,10 +92,11 @@ func CreateSchedTask(st SchedTask) (bool, error) {
def.AddTrigger(trigger) def.AddTrigger(trigger)
currentDir, _ := os.Getwd()
switch st.Type { switch st.Type {
case "rmm": case "rmm":
path = winExeName path = os.Args[0]
workdir = a.ProgramDir workdir = currentDir
args = fmt.Sprintf("-m taskrunner -p %d", st.PK) args = fmt.Sprintf("-m taskrunner -p %d", st.PK)
case "schedreboot": case "schedreboot":
path = "shutdown.exe" path = "shutdown.exe"
@ -116,8 +113,8 @@ func CreateSchedTask(st SchedTask) (bool, error) {
WorkingDir: workdir, WorkingDir: workdir,
Args: args, Args: args,
} }
def.AddAction(action)
def.AddAction(action)
def.Principal.RunLevel = taskmaster.TASK_RUNLEVEL_HIGHEST def.Principal.RunLevel = taskmaster.TASK_RUNLEVEL_HIGHEST
def.Principal.LogonType = taskmaster.TASK_LOGON_SERVICE_ACCOUNT def.Principal.LogonType = taskmaster.TASK_LOGON_SERVICE_ACCOUNT
def.Principal.UserID = "SYSTEM" def.Principal.UserID = "SYSTEM"
@ -139,7 +136,6 @@ func CreateSchedTask(st SchedTask) (bool, error) {
_, success, err := conn.CreateTask(fmt.Sprintf("\\%s", st.Name), def, st.Overwrite) _, success, err := conn.CreateTask(fmt.Sprintf("\\%s", st.Name), def, st.Overwrite)
if err != nil { if err != nil {
//a.Logger.Errorln(err)
return false, err return false, err
} }
@ -161,17 +157,17 @@ func DeleteSchedTask(name string) error {
return nil return nil
} }
func ListSchedTasks() []string { func ListSchedTasks() ([]string, error) {
ret := make([]string, 0) ret := make([]string, 0)
conn, err := taskmaster.Connect() conn, err := taskmaster.Connect()
if err != nil { if err != nil {
return ret return ret, err
} }
defer conn.Disconnect() defer conn.Disconnect()
tasks, err := conn.GetRegisteredTasks() tasks, err := conn.GetRegisteredTasks()
if err != nil { if err != nil {
return ret return ret, err
} }
for _, task := range tasks { for _, task := range tasks {
@ -179,7 +175,7 @@ func ListSchedTasks() []string {
} }
tasks.Release() tasks.Release()
return ret return ret, nil
} }
// CleanupSchedTasks removes all tacticalrmm sched tasks during uninstall // CleanupSchedTasks removes all tacticalrmm sched tasks during uninstall

View file

@ -0,0 +1,76 @@
package tasks_test
import (
"errors"
"testing"
"github.com/amidaware/rmmagent/agent/tasks"
"github.com/amidaware/taskmaster"
)
func TestCreateSchedTask(t *testing.T) {
testTask := tasks.SchedTask{
PK: 0,
Name: "Test Task",
Trigger: "manual",
Enabled: false,
Type: "rmm",
TaskPolicy: taskmaster.TASK_INSTANCES_IGNORE_NEW,
DeleteAfter: true,
Overwrite: true,
}
testTable := []struct {
name string
expected bool
expectedError error
}{
{
name: "Create Sched Task",
expected: true,
expectedError: nil,
},
}
for _, tt := range testTable {
t.Run(tt.name, func(t *testing.T) {
result, err := tasks.CreateSchedTask(testTask)
if !result {
t.Errorf("Expected %t, got %t", tt.expected, result)
}
if !errors.Is(tt.expectedError, err) {
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
}
})
}
}
func TestListSchedTasks(t *testing.T) {
testTable := []struct {
name string
expected []string
atLeast int
expectedError error
}{
{
name: "List Sched Task",
expected: []string{},
atLeast: 1,
expectedError: nil,
},
}
for _, tt := range testTable {
t.Run(tt.name, func(t *testing.T) {
result, err := tasks.ListSchedTasks()
if len(result) < tt.atLeast {
t.Errorf("expect at least %d, got %d", tt.atLeast, len(result))
}
if !errors.Is(tt.expectedError, err) {
t.Errorf("expected (%v), got (%v)", tt.expectedError, err)
}
})
}
}