From a3e27caa42c8653f0a4f44146632b575e0441769 Mon Sep 17 00:00:00 2001 From: Arik Cohen Date: Mon, 30 Oct 2023 00:18:31 -0400 Subject: [PATCH] Added support for --gpus --- datastore/postgres.go | 8 ++++++-- datastore/postgres_test.go | 2 ++ db/postgres/schema.go | 3 ++- go.mod | 2 ++ go.sum | 5 +++++ input/task.go | 2 ++ runtime/docker/docker.go | 19 +++++++++++++++---- task.go | 2 ++ 8 files changed, 36 insertions(+), 7 deletions(-) diff --git a/datastore/postgres.go b/datastore/postgres.go index 507a0ef1..e484a34b 100644 --- a/datastore/postgres.go +++ b/datastore/postgres.go @@ -56,6 +56,7 @@ type taskRecord struct { Each []byte `db:"each_"` SubJob []byte `db:"subjob"` SubJobID string `db:"subjob_id"` + GPUs string `db:"gpus"` } type jobRecord struct { @@ -202,6 +203,7 @@ func (r taskRecord) toTask() (*tork.Task, error) { Each: each, Description: r.Description, SubJob: subjob, + GPUs: r.GPUs, }, nil } @@ -409,12 +411,13 @@ func (ds *PostgresDatastore) CreateTask(ctx context.Context, t *tork.Task) error subjob, -- $31 networks, -- $32 files_, -- $33 - registry -- $34 + registry, -- $34 + gpus -- $35 ) values ( $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14, $15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26, - $27,$28,$29,$30,$31,$32,$33,$34)` + $27,$28,$29,$30,$31,$32,$33,$34,$35)` _, err = ds.exec(q, t.ID, // $1 t.JobID, // $2 @@ -450,6 +453,7 @@ func (ds *PostgresDatastore) CreateTask(ctx context.Context, t *tork.Task) error pq.StringArray(t.Networks), // $32 files, // $33 registry, // $34 + t.GPUs, // $35 ) if err != nil { return errors.Wrapf(err, "error inserting task to the db") diff --git a/datastore/postgres_test.go b/datastore/postgres_test.go index ca7b0ddc..5046feb0 100644 --- a/datastore/postgres_test.go +++ b/datastore/postgres_test.go @@ -35,6 +35,7 @@ func TestPostgresCreateAndGetTask(t *testing.T) { Networks: []string{"some-network"}, Files: map[string]string{"myfile": "hello world"}, Registry: &tork.Registry{Username: "me", Password: "secret"}, + GPUs: "all", } err = ds.CreateTask(ctx, &t1) assert.NoError(t, err) @@ -46,6 +47,7 @@ func TestPostgresCreateAndGetTask(t *testing.T) { assert.Equal(t, map[string]string{"myfile": "hello world"}, t2.Files) assert.Equal(t, "me", t2.Registry.Username) assert.Equal(t, "secret", t2.Registry.Password) + assert.Equal(t, "all", t2.GPUs) assert.Nil(t, t2.Parallel) } diff --git a/db/postgres/schema.go b/db/postgres/schema.go index 7ad5b850..0cdcee67 100644 --- a/db/postgres/schema.go +++ b/db/postgres/schema.go @@ -85,7 +85,8 @@ CREATE TABLE tasks ( each_ jsonb, description text, subjob jsonb, - networks text[] + networks text[], + gpus text ); CREATE INDEX idx_tasks_state ON tasks (state); diff --git a/go.mod b/go.mod index f5bec77c..f6748bf7 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ retract v0.1.0 require ( github.com/antonmedv/expr v1.15.3 + github.com/docker/cli v24.0.7+incompatible github.com/docker/docker v24.0.6+incompatible github.com/docker/go-units v0.5.0 github.com/fatih/color v1.15.0 @@ -63,6 +64,7 @@ require ( github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect + github.com/sirupsen/logrus v1.9.0 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect diff --git a/go.sum b/go.sum index a2c23263..f7a60578 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46t github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/cli v24.0.7+incompatible h1:wa/nIwYFW7BVTGa7SWPVyyXU9lgORqUb1xfI36MSkFg= +github.com/docker/cli v24.0.7+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/docker v24.0.6+incompatible h1:hceabKCtUgDqPu+qm0NgsaXf28Ljf4/pWFL7xjWWDgE= @@ -120,6 +122,8 @@ github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFt github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -178,6 +182,7 @@ golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211103235746-7861aae1554b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/input/task.go b/input/task.go index b0122fa8..c720f54b 100644 --- a/input/task.go +++ b/input/task.go @@ -28,6 +28,7 @@ type Task struct { Parallel *Parallel `json:"parallel,omitempty" yaml:"parallel,omitempty"` Each *Each `json:"each,omitempty" yaml:"each,omitempty"` SubJob *SubJob `json:"subjob,omitempty" yaml:"subjob,omitempty"` + GPUs string `json:"gpus,omitempty" yaml:"gpus,omitempty"` } type Mount struct { Type string `json:"type,omitempty" yaml:"type,omitempty"` @@ -140,6 +141,7 @@ func (i Task) toTask() *tork.Task { Parallel: parallel, Each: each, SubJob: subjob, + GPUs: i.GPUs, } } diff --git a/runtime/docker/docker.go b/runtime/docker/docker.go index 94c14226..99354911 100644 --- a/runtime/docker/docker.go +++ b/runtime/docker/docker.go @@ -14,6 +14,7 @@ import ( "time" "unicode" + cliopts "github.com/docker/cli/opts" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" @@ -211,13 +212,23 @@ func (d *DockerRuntime) doRun(ctx context.Context, t *tork.Task) error { return errors.Wrapf(err, "invalid memory value") } + resources := container.Resources{ + NanoCPUs: cpus, + Memory: mem, + } + + if t.GPUs != "" { + gpuOpts := cliopts.GpuOpts{} + if err := gpuOpts.Set(t.GPUs); err != nil { + return errors.Wrapf(err, "error setting GPUs") + } + resources.DeviceRequests = gpuOpts.Value() + } + hc := container.HostConfig{ PublishAllPorts: true, Mounts: mounts, - Resources: container.Resources{ - NanoCPUs: cpus, - Memory: mem, - }, + Resources: resources, } cmd := t.CMD diff --git a/task.go b/task.go index f522cbd0..320f56fe 100644 --- a/task.go +++ b/task.go @@ -58,6 +58,7 @@ type Task struct { Parallel *ParallelTask `json:"parallel,omitempty"` Each *EachTask `json:"each,omitempty"` SubJob *SubJobTask `json:"subjob,omitempty"` + GPUs string `json:"gpus,omitempty"` } type SubJobTask struct { @@ -163,6 +164,7 @@ func (t *Task) Clone() *Task { Each: each, Description: t.Description, SubJob: subjob, + GPUs: t.GPUs, } }