Skip to content

Commit

Permalink
Cancel task API
Browse files Browse the repository at this point in the history
  • Loading branch information
runabol committed Aug 10, 2023
1 parent 4a04b9a commit 707d26c
Show file tree
Hide file tree
Showing 12 changed files with 115 additions and 44 deletions.
26 changes: 26 additions & 0 deletions coordinator/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func newAPI(cfg Config) *api {
}
r.GET("/status", s.status)
r.POST("/task", s.createTask)
r.PUT("/task/:id/cancel", s.cancelTask)
r.GET("/task/:id", s.getTask)
r.GET("/queue", s.listQueues)
r.GET("/node", s.listActiveNodes)
Expand Down Expand Up @@ -126,6 +127,31 @@ func (s *api) getTask(c *gin.Context) {
c.JSON(http.StatusOK, t)
}

func (s *api) cancelTask(c *gin.Context) {
id := c.Param("id")
err := s.ds.UpdateTask(c, id, func(u *task.Task) error {
if u.State != task.Running {
return errors.New("task in not running")
}
u.State = task.Cancelled
if u.Node != "" {
node, err := s.ds.GetNodeByID(c, u.Node)
if err != nil {
return err
}
if err := s.broker.PublishTask(c, node.Queue, u); err != nil {
return err
}
}
return nil
})
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
c.JSON(http.StatusOK, gin.H{"status": "OK"})
}

func (s *api) start() error {
go func() {
// service connections
Expand Down
32 changes: 22 additions & 10 deletions coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,15 @@ func (c *Coordinator) taskPendingHandler(thread string) func(ctx context.Context
if err := c.broker.PublishTask(ctx, qname, t); err != nil {
return err
}
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) {
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) error {
// we don't want to mark the task as SCHEDULED
// if an out-of-order task completion/failure
// arrived earlier
if u.State == task.Pending {
u.State = t.State
u.ScheduledAt = t.ScheduledAt
}
return nil
})
}
}
Expand All @@ -96,14 +97,16 @@ func (c *Coordinator) taskStartedHandler(thread string) func(ctx context.Context
Str("task-id", t.ID).
Str("thread", thread).
Msg("received task start")
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) {
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) error {
// we don't want to mark the task as RUNNING
// if an out-of-order task completion/failure
// arrived earlier
if u.State == task.Scheduled {
u.State = t.State
u.StartedAt = t.StartedAt
u.Node = t.Node
}
return nil
})
}
}
Expand All @@ -114,25 +117,33 @@ func (c *Coordinator) taskCompletedHandler(thread string) func(ctx context.Conte
Str("task-id", t.ID).
Str("thread", thread).
Msg("received task completion")
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) {
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) error {
u.State = task.Completed
u.CompletedAt = t.CompletedAt
u.Result = t.Result
return nil
})
}
}

func (c *Coordinator) taskFailedHandler(thread string) func(ctx context.Context, t *task.Task) error {
return func(ctx context.Context, t *task.Task) error {
log.Error().
Str("task-id", t.ID).
Str("task-error", t.Error).
Str("thread", thread).
Msg("received task failure")
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) {
return c.ds.UpdateTask(ctx, t.ID, func(u *task.Task) error {
if u.State == task.Cancelled {
log.Debug().
Str("thread", thread).
Msgf("task %s was previously cancelled. ignoring error.", t.ID)
return nil
}
log.Error().
Str("task-id", t.ID).
Str("task-error", t.Error).
Str("thread", thread).
Msg("received task failure")
u.State = task.Failed
u.FailedAt = t.FailedAt
u.Error = t.Error
return nil
})
}
}
Expand All @@ -146,13 +157,14 @@ func (c *Coordinator) handleHeartbeats(ctx context.Context, n *node.Node) error
Msg("received first heartbeat")
return c.ds.CreateNode(ctx, n)
}
return c.ds.UpdateNode(ctx, n.ID, func(u *node.Node) {
return c.ds.UpdateNode(ctx, n.ID, func(u *node.Node) error {
log.Info().
Str("node-id", n.ID).
Float64("cpu-percent", n.CPUPercent).
Msg("received heartbeat")
u.LastHeartbeatAt = n.LastHeartbeatAt
u.CPUPercent = n.CPUPercent
return nil
})

}
Expand Down
4 changes: 2 additions & 2 deletions datastore/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ const (

type Datastore interface {
CreateTask(ctx context.Context, t *task.Task) error
UpdateTask(ctx context.Context, id string, modify func(u *task.Task)) error
UpdateTask(ctx context.Context, id string, modify func(u *task.Task) error) error
GetTaskByID(ctx context.Context, id string) (*task.Task, error)
CreateNode(ctx context.Context, n *node.Node) error
UpdateNode(ctx context.Context, id string, modify func(u *node.Node)) error
UpdateNode(ctx context.Context, id string, modify func(u *node.Node) error) error
GetNodeByID(ctx context.Context, id string) (*node.Node, error)
GetActiveNodes(ctx context.Context, lastHeartbeatAfter time.Time) ([]*node.Node, error)
}
12 changes: 8 additions & 4 deletions datastore/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ func (ds *InMemoryDatastore) GetTaskByID(ctx context.Context, id string) (*task.
return &t, nil
}

func (ds *InMemoryDatastore) UpdateTask(ctx context.Context, id string, modify func(t *task.Task)) error {
func (ds *InMemoryDatastore) UpdateTask(ctx context.Context, id string, modify func(u *task.Task) error) error {
ds.mu.Lock()
defer ds.mu.Unlock()
t, ok := ds.tasks[id]
if !ok {
return ErrTaskNotFound
}
modify(&t)
if err := modify(&t); err != nil {
return err
}
ds.tasks[t.ID] = t
return nil
}
Expand All @@ -63,14 +65,16 @@ func (ds *InMemoryDatastore) CreateNode(ctx context.Context, n *node.Node) error
return nil
}

func (ds *InMemoryDatastore) UpdateNode(ctx context.Context, id string, modify func(u *node.Node)) error {
func (ds *InMemoryDatastore) UpdateNode(ctx context.Context, id string, modify func(u *node.Node) error) error {
ds.mu.Lock()
defer ds.mu.Unlock()
n, ok := ds.nodes[id]
if !ok {
return ErrNodeNotFound
}
modify(&n)
if err := modify(&n); err != nil {
return err
}
ds.nodes[n.ID] = n
return nil
}
Expand Down
20 changes: 13 additions & 7 deletions datastore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type nodeRecord struct {
StartedAt time.Time `db:"started_at"`
LastHeartbeatAt time.Time `db:"last_heartbeat_at"`
CPUPercent float64 `db:"cpu_percent"`
Queue string `db:"queue"`
}

func (r nodeRecord) toNode() *node.Node {
Expand All @@ -40,6 +41,7 @@ func (r nodeRecord) toNode() *node.Node {
StartedAt: r.StartedAt,
CPUPercent: r.CPUPercent,
LastHeartbeatAt: r.LastHeartbeatAt,
Queue: r.Queue,
}
}

Expand Down Expand Up @@ -91,7 +93,7 @@ func (ds *PostgresDatastore) GetTaskByID(ctx context.Context, id string) (*task.
return t, nil
}

func (ds *PostgresDatastore) UpdateTask(ctx context.Context, id string, modify func(t *task.Task)) error {
func (ds *PostgresDatastore) UpdateTask(ctx context.Context, id string, modify func(t *task.Task) error) error {
tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return errors.Wrapf(err, "unable to begin tx")
Expand All @@ -104,7 +106,9 @@ func (ds *PostgresDatastore) UpdateTask(ctx context.Context, id string, modify f
if err := json.Unmarshal(tr.Serialized, t); err != nil {
return errors.Wrapf(err, "error desiralizing task")
}
modify(t)
if err := modify(t); err != nil {
return err
}
bytez, err := json.Marshal(t)
if err != nil {
return errors.Wrapf(err, "failed to serialize task")
Expand All @@ -130,17 +134,17 @@ func (ds *PostgresDatastore) UpdateTask(ctx context.Context, id string, modify f

func (ds *PostgresDatastore) CreateNode(ctx context.Context, n *node.Node) error {
q := `insert into nodes
(id,started_at,last_heartbeat_at,cpu_percent)
(id,started_at,last_heartbeat_at,cpu_percent,queue)
values
($1,$2,$3,$4)`
_, err := ds.db.Exec(q, n.ID, n.StartedAt, n.LastHeartbeatAt, n.CPUPercent)
($1,$2,$3,$4,$5)`
_, err := ds.db.Exec(q, n.ID, n.StartedAt, n.LastHeartbeatAt, n.CPUPercent, n.Queue)
if err != nil {
return errors.Wrapf(err, "error inserting node to the db")
}
return nil
}

func (ds *PostgresDatastore) UpdateNode(ctx context.Context, id string, modify func(u *node.Node)) error {
func (ds *PostgresDatastore) UpdateNode(ctx context.Context, id string, modify func(u *node.Node) error) error {
tx, err := ds.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return errors.Wrapf(err, "unable to begin tx")
Expand All @@ -150,7 +154,9 @@ func (ds *PostgresDatastore) UpdateNode(ctx context.Context, id string, modify f
return errors.Wrapf(err, "error fetching node from db")
}
n := nr.toNode()
modify(n)
if err := modify(n); err != nil {
return err
}
q := `update nodes set
last_heartbeat_at = $1,
cpu_percent = $2
Expand Down
1 change: 1 addition & 0 deletions db/postgres/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CREATE INDEX idx_tasks_serialized ON tasks USING GIN (serialized);

CREATE TABLE nodes (
id varchar(64) not null primary key,
queue varchar(64) not null,
started_at timestamp not null,
last_heartbeat_at timestamp not null,
cpu_percent float not null
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ require (
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/jmoiron/sqlx v1.3.5
github.com/jxskiss/base62 v1.1.0
github.com/lib/pq v1.10.9
github.com/pkg/errors v0.9.1
github.com/rabbitmq/amqp091-go v1.8.1
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw=
github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
Expand Down
1 change: 1 addition & 0 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ type Node struct {
StartedAt time.Time `json:"startedAt,omitempty"`
CPUPercent float64 `json:"cpuPercent,omitempty"`
LastHeartbeatAt time.Time `json:"lastHeartbeatAt,omitempty"`
Queue string `json:"queue,omitempty"`
}
7 changes: 0 additions & 7 deletions uuid/uuid.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package uuid

import (
"encoding/base64"
"strings"

guuid "github.com/google/uuid"
Expand All @@ -11,9 +10,3 @@ import (
func NewUUID() string {
return strings.ReplaceAll(guuid.NewString(), "-", "")
}

// NewUUID creates a new random UUID and encodes it in base64 or panics.
func NewUUIDBase64() string {
u := guuid.New()
return base64.RawURLEncoding.EncodeToString(u[:])
}
4 changes: 0 additions & 4 deletions uuid/uuid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,3 @@ import (
func TestNewUUID(t *testing.T) {
assert.Equal(t, 32, len(uuid.NewUUID()))
}

func TestNewUUIDAsBase64(t *testing.T) {
assert.Less(t, len(uuid.NewUUIDBase64()), 32)
}
Loading

0 comments on commit 707d26c

Please sign in to comment.