Skip to content

Commit

Permalink
fix create jobState, order tasks by average duration (desc) (twitter#520
Browse files Browse the repository at this point in the history
)

* fix create jobState, order tasks by average duration (desc)

* minor style changes

* added logging
  • Loading branch information
JeanetteBruno authored Jul 14, 2021
1 parent 4003d62 commit d46edf0
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 19 deletions.
7 changes: 4 additions & 3 deletions scheduler/server/job_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (s taskStatesByDuration) Less(i, j int) bool {
// The jobState will reflect any previous progress made on this job and logged to the Sagalog
// Note: taskDurations is optional and only used to enable sorts using taskStatesByDuration above.
func newJobState(job *domain.Job, jobClass string, saga *saga.Saga, taskDurations *lru.Cache,
tasksByJobClassAndStartTimeSec map[taskClassAndStartKey]taskStateByJobIDTaskID) *jobState {
tasksByJobClassAndStartTimeSec map[taskClassAndStartKey]taskStateByJobIDTaskID, durationKeyExtractor func(string) string) *jobState {
j := &jobState{
Job: job,
Saga: saga,
Expand All @@ -88,10 +88,11 @@ func newJobState(job *domain.Job, jobClass string, saga *saga.Saga, taskDuration

for _, taskDef := range job.Def.Tasks {
var duration time.Duration
durationKey := durationKeyExtractor(taskDef.TaskID)
if taskDurations != nil {
if iface, ok := taskDurations.Get(taskDef.TaskID); !ok {
if iface, ok := taskDurations.Get(durationKey); !ok {
duration = math.MaxInt64
addOrUpdateTaskDuration(taskDurations, taskDef.TaskID, duration)
addOrUpdateTaskDuration(taskDurations, durationKey, duration)
} else {
if ad, ok := iface.(*averageDuration); ok {
duration = ad.duration
Expand Down
6 changes: 3 additions & 3 deletions scheduler/server/job_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func Test_GetUnscheduledTasks_ReturnsAllUnscheduledTasks(t *testing.T) {
jobAsBytes, _ := job.Serialize()

saga, _ := sagalogs.MakeInMemorySagaCoordinatorNoGC().MakeSaga(job.Id, jobAsBytes)
jobState := newJobState(&job, "", saga, nil, nil)
jobState := newJobState(&job, "", saga, nil, nil, nopDurationKeyExtractor)

tasks := jobState.getUnScheduledTasks()

Expand All @@ -31,7 +31,7 @@ func Test_NewJobState_PreviousProgress_StartedTasks(t *testing.T) {
for _, task := range job.Def.Tasks {
saga.StartTask(task.TaskID, nil)
}
jobState := newJobState(&job, "", saga, nil, nil)
jobState := newJobState(&job, "", saga, nil, nil, nopDurationKeyExtractor)

tasks := jobState.getUnScheduledTasks()
if len(tasks) != len(job.Def.Tasks) {
Expand All @@ -49,7 +49,7 @@ func Test_NewJobState_PreviousProgress_CompletedTasks(t *testing.T) {
saga.StartTask(task.TaskID, nil)
saga.EndTask(task.TaskID, nil)
}
jobState := newJobState(&job, "", saga, nil, nil)
jobState := newJobState(&job, "", saga, nil, nil, nopDurationKeyExtractor)

tasks := jobState.getUnScheduledTasks()
if len(tasks) != 0 {
Expand Down
27 changes: 18 additions & 9 deletions scheduler/server/stateful_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os"
"reflect"
"sort"
"strings"
"sync"
Expand Down Expand Up @@ -89,6 +90,11 @@ func stringInSlice(a string, list []string) bool {
return false
}

// nopDurationKeyExtractor returns an unchanged key.
func nopDurationKeyExtractor(key string) string {
return key
}

// Scheduler Config variables read at initialization
// MaxRetriesPerTask - the number of times to retry a failing task before
// marking it as completed.
Expand Down Expand Up @@ -306,6 +312,11 @@ func NewStatefulScheduler(
return nil
}

dkef := durationKeyExtractorFn
if durationKeyExtractorFn == nil {
dkef = nopDurationKeyExtractor
}

sched := &statefulScheduler{
config: &config,
sagaCoord: sc,
Expand All @@ -325,7 +336,7 @@ func NewStatefulScheduler(

tasksByJobClassAndStartTimeSec: tasksByClassAndStartMap,
persistor: persistor,
durationKeyExtractorFn: durationKeyExtractorFn,
durationKeyExtractorFn: dkef,
}

sched.setThrottle(-1)
Expand Down Expand Up @@ -778,7 +789,7 @@ func (s *statefulScheduler) addJobsLoop() {

reqToClassMap, _ := s.GetRequestorToClassMap()
jc := GetRequestorClass(newJobMsg.job.Def.Requestor, reqToClassMap)
js := newJobState(newJobMsg.job, jc, newJobMsg.saga, s.taskDurations, s.tasksByJobClassAndStartTimeSec)
js := newJobState(newJobMsg.job, jc, newJobMsg.saga, s.taskDurations, s.tasksByJobClassAndStartTimeSec, s.durationKeyExtractorFn)
s.inProgressJobs = append(s.inProgressJobs, js)

sort.Sort(sort.Reverse(taskStatesByDuration(js.Tasks)))
Expand Down Expand Up @@ -910,10 +921,7 @@ func (s *statefulScheduler) scheduleTasks() {
jobState := s.getJob(jobID)
sa := jobState.Saga
rs := s.runnerFactory(nodeSt.node)
durationID := taskID
if s.durationKeyExtractorFn != nil {
durationID = s.durationKeyExtractorFn(durationID)
}
durationID := s.durationKeyExtractorFn(taskID)

preventRetries := bool(task.NumTimesTried >= s.config.MaxRetriesPerTask)

Expand Down Expand Up @@ -1276,19 +1284,20 @@ func addOrUpdateRequestorHistory(requestorHistory *lru.Cache, requestor, newHist
requestorHistory.Add(requestor, history)
}

func addOrUpdateTaskDuration(taskDurations *lru.Cache, taskId string, d time.Duration) {
func addOrUpdateTaskDuration(taskDurations *lru.Cache, durationKey string, d time.Duration) {
var ad *averageDuration
iface, ok := taskDurations.Get(taskId)
iface, ok := taskDurations.Get(durationKey)
if !ok {
ad = &averageDuration{count: 1, duration: d}
taskDurations.Add(durationKey, ad)
} else {
ad, ok = iface.(*averageDuration)
if !ok {
log.Errorf("task duration object was not *averageDuration type! (it is %s)", reflect.TypeOf(ad))
return
}
ad.update(d)
}
taskDurations.Add(taskId, ad)
}

// set the max schedulable tasks. -1 = unlimited, 0 = don't accept any more requests, >0 = only accept job
Expand Down
43 changes: 43 additions & 0 deletions scheduler/server/stateful_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"errors"
"fmt"
"math"
"strings"
"testing"
"time"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/twitter/scoot/scheduler/setup/worker"
"github.com/twitter/scoot/snapshot"
"github.com/twitter/scoot/snapshot/snapshots"
"github.com/twitter/scoot/tests/testhelpers"
)

//Mocks sometimes hang without useful output, this allows early exit with err msg.
Expand Down Expand Up @@ -881,3 +883,44 @@ func Test_StatefulScheduler_RequestorCountsStats(t *testing.T) {
assert.True(t, strings.Contains(tmp, "\"schedNumWaitingTasksGauge_fake R1\": 3"))
assert.True(t, strings.Contains(tmp, "\"schedNumWaitingTasksGauge_fake R2\": 7"))
}

// Test creating job definitions with tasks in descending duration order
func Test_TaskAssignments_TasksScheduledByDuration(t *testing.T) {
// create a test cluster with 3 nodes
testCluster := makeTestCluster("node1", "node2", "node3")
s := getDebugStatefulScheduler(testCluster)
taskKeyFn := func(key string) string {
keyParts := strings.Split(key, " ")
return keyParts[len(keyParts)-1]
}
s.durationKeyExtractorFn = taskKeyFn

// create a jobdef with 10 tasks
job := domain.GenJob(testhelpers.GenJobId(testhelpers.NewRand()), 10)

// set the scheduler's current (fake) duration data
for i := range job.Def.Tasks {
// update TaskID to match what we are really seeing from our clients (GenJob() generates different values needed by other tests)
job.Def.Tasks[i].TaskID = strings.Join(job.Def.Tasks[i].Argv, " ")
addOrUpdateTaskDuration(s.taskDurations, s.durationKeyExtractorFn(job.Def.Tasks[i].TaskID), time.Duration(i)*time.Second)
}
go func() {
// simulate checking the job and returning no error, so ScheduleJob() will put the job definition
// immediately on the addJobCh
checkJobMsg := <-s.checkJobCh
checkJobMsg.resultCh <- nil
}()

s.ScheduleJob(job.Def)
s.addJobs()

js1 := s.inProgressJobs[0]
// verify tasks are in descending duration order
for i, task := range js1.Tasks {
assert.True(t, task.AvgDuration != time.Duration(math.MaxInt64), "average duration not found for task %d, %v", i, task)
if i == 0 {
continue
}
assert.True(t, js1.Tasks[i-1].AvgDuration >= task.AvgDuration, fmt.Sprintf("tasks not in descending duration order at task %d, %v", i, task))
}
}
8 changes: 4 additions & 4 deletions scheduler/server/task_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func Test_TaskAssignment_NoNodesAvailable(t *testing.T) {
jobAsBytes, _ := job.Serialize()

saga, _ := sagalogs.MakeInMemorySagaCoordinatorNoGC().MakeSaga(job.Id, jobAsBytes)
js := newJobState(&job, "", saga, nil, nil)
js := newJobState(&job, "", saga, nil, nil, nopDurationKeyExtractor)

// create a test cluster with no nodes
testCluster := makeTestCluster()
Expand All @@ -32,7 +32,7 @@ func Test_TaskAssignment_NoNodesAvailable(t *testing.T) {
}

func Test_TaskAssignment_NoTasks(t *testing.T) {
// create a test cluster with no nodes
// create a test cluster with 5 nodes
testCluster := makeTestCluster("node1", "node2", "node3", "node4", "node5")
s := getDebugStatefulScheduler(testCluster)
assignments := getTaskAssignments([]*jobState{}, s)
Expand All @@ -49,9 +49,9 @@ func Test_TaskAssignments_TasksScheduled(t *testing.T) {
jobAsBytes, _ := job.Serialize()

saga, _ := sagalogs.MakeInMemorySagaCoordinatorNoGC().MakeSaga(job.Id, jobAsBytes)
js := newJobState(&job, "", saga, nil, nil)
js := newJobState(&job, "", saga, nil, nil, nopDurationKeyExtractor)

// create a test cluster with no nodes
// create a test cluster with 5 nodes
testCluster := makeTestCluster("node1", "node2", "node3", "node4", "node5")
s := getDebugStatefulScheduler(testCluster)
unScheduledTasks := js.getUnScheduledTasks()
Expand Down

0 comments on commit d46edf0

Please sign in to comment.