Skip to content

Commit

Permalink
Added support for --gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
runabol committed Oct 30, 2023
1 parent 29f49f0 commit a3e27ca
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 7 deletions.
8 changes: 6 additions & 2 deletions datastore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type taskRecord struct {
Each []byte `db:"each_"`
SubJob []byte `db:"subjob"`
SubJobID string `db:"subjob_id"`
GPUs string `db:"gpus"`
}

type jobRecord struct {
Expand Down Expand Up @@ -202,6 +203,7 @@ func (r taskRecord) toTask() (*tork.Task, error) {
Each: each,
Description: r.Description,
SubJob: subjob,
GPUs: r.GPUs,
}, nil
}

Expand Down Expand Up @@ -409,12 +411,13 @@ func (ds *PostgresDatastore) CreateTask(ctx context.Context, t *tork.Task) error
subjob, -- $31
networks, -- $32
files_, -- $33
registry -- $34
registry, -- $34
gpus -- $35
)
values (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,
$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,
$27,$28,$29,$30,$31,$32,$33,$34)`
$27,$28,$29,$30,$31,$32,$33,$34,$35)`
_, err = ds.exec(q,
t.ID, // $1
t.JobID, // $2
Expand Down Expand Up @@ -450,6 +453,7 @@ func (ds *PostgresDatastore) CreateTask(ctx context.Context, t *tork.Task) error
pq.StringArray(t.Networks), // $32
files, // $33
registry, // $34
t.GPUs, // $35
)
if err != nil {
return errors.Wrapf(err, "error inserting task to the db")
Expand Down
2 changes: 2 additions & 0 deletions datastore/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestPostgresCreateAndGetTask(t *testing.T) {
Networks: []string{"some-network"},
Files: map[string]string{"myfile": "hello world"},
Registry: &tork.Registry{Username: "me", Password: "secret"},
GPUs: "all",
}
err = ds.CreateTask(ctx, &t1)
assert.NoError(t, err)
Expand All @@ -46,6 +47,7 @@ func TestPostgresCreateAndGetTask(t *testing.T) {
assert.Equal(t, map[string]string{"myfile": "hello world"}, t2.Files)
assert.Equal(t, "me", t2.Registry.Username)
assert.Equal(t, "secret", t2.Registry.Password)
assert.Equal(t, "all", t2.GPUs)
assert.Nil(t, t2.Parallel)
}

Expand Down
3 changes: 2 additions & 1 deletion db/postgres/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ CREATE TABLE tasks (
each_ jsonb,
description text,
subjob jsonb,
networks text[]
networks text[],
gpus text
);
CREATE INDEX idx_tasks_state ON tasks (state);
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ retract v0.1.0

require (
github.com/antonmedv/expr v1.15.3
github.com/docker/cli v24.0.7+incompatible
github.com/docker/docker v24.0.6+incompatible
github.com/docker/go-units v0.5.0
github.com/fatih/color v1.15.0
Expand Down Expand Up @@ -63,6 +64,7 @@ require (
github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
Expand Down
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docker/cli v24.0.7+incompatible h1:wa/nIwYFW7BVTGa7SWPVyyXU9lgORqUb1xfI36MSkFg=
github.com/docker/cli v24.0.7+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8=
github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE=
Expand Down Expand Up @@ -120,6 +122,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0=
github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
Expand Down Expand Up @@ -178,6 +182,7 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down
2 changes: 2 additions & 0 deletions input/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Task struct {
Parallel *Parallel `json:"parallel,omitempty" yaml:"parallel,omitempty"`
Each *Each `json:"each,omitempty" yaml:"each,omitempty"`
SubJob *SubJob `json:"subjob,omitempty" yaml:"subjob,omitempty"`
GPUs string `json:"gpus,omitempty" yaml:"gpus,omitempty"`
}
type Mount struct {
Type string `json:"type,omitempty" yaml:"type,omitempty"`
Expand Down Expand Up @@ -140,6 +141,7 @@ func (i Task) toTask() *tork.Task {
Parallel: parallel,
Each: each,
SubJob: subjob,
GPUs: i.GPUs,
}
}

Expand Down
19 changes: 15 additions & 4 deletions runtime/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"time"
"unicode"

cliopts "github.com/docker/cli/opts"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
Expand Down Expand Up @@ -211,13 +212,23 @@ func (d *DockerRuntime) doRun(ctx context.Context, t *tork.Task) error {
return errors.Wrapf(err, "invalid memory value")
}

resources := container.Resources{
NanoCPUs: cpus,
Memory: mem,
}

if t.GPUs != "" {
gpuOpts := cliopts.GpuOpts{}
if err := gpuOpts.Set(t.GPUs); err != nil {
return errors.Wrapf(err, "error setting GPUs")
}
resources.DeviceRequests = gpuOpts.Value()
}

hc := container.HostConfig{
PublishAllPorts: true,
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: cpus,
Memory: mem,
},
Resources: resources,
}

cmd := t.CMD
Expand Down
2 changes: 2 additions & 0 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Task struct {
Parallel *ParallelTask `json:"parallel,omitempty"`
Each *EachTask `json:"each,omitempty"`
SubJob *SubJobTask `json:"subjob,omitempty"`
GPUs string `json:"gpus,omitempty"`
}

type SubJobTask struct {
Expand Down Expand Up @@ -163,6 +164,7 @@ func (t *Task) Clone() *Task {
Each: each,
Description: t.Description,
SubJob: subjob,
GPUs: t.GPUs,
}
}

Expand Down

0 comments on commit a3e27ca

Please sign in to comment.