Skip to content

Commit

Permalink
test: cover db.GroupCheckpointUUIDsByExperimentID (determined-ai#8508)
Browse files Browse the repository at this point in the history
  • Loading branch information
stoksc authored Jan 24, 2024
1 parent d661404 commit eb48302
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 1 deletion.
101 changes: 101 additions & 0 deletions master/internal/db/postgres_checkpoints_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"

"github.com/determined-ai/determined/master/pkg/etc"
"github.com/determined-ai/determined/master/pkg/model"
Expand Down Expand Up @@ -306,3 +307,103 @@ func BenchmarkUpdateCheckpointSize(b *testing.B) {

require.NoError(t, MarkCheckpointsDeleted(ctx, checkpoints))
}

func TestPgDB_GroupCheckpointUUIDsByExperimentID(t *testing.T) {
require.NoError(t, etc.SetRootPath(RootFromDB))
db := MustResolveTestPostgres(t)
MustMigrateTestPostgres(t, db, MigrationsFromDB)

// Setup some fake data for us to work with.
expToCkptUUIDs := make(map[int][]uuid.UUID)
user := RequireMockUser(t, db)
for i := 0; i < 3; i++ {
exp := RequireMockExperiment(t, db, user)
_, tk := RequireMockTrial(t, db, exp)

var ids []uuid.UUID
for j := 0; j < 3; j++ {
id := uuid.New()
err := AddCheckpointMetadata(context.TODO(), &model.CheckpointV2{
UUID: id,
TaskID: tk.TaskID,
})
require.NoError(t, err)
ids = append(ids, id)
}

expToCkptUUIDs[exp.ID] = ids
}

type testCase struct {
name string
input []uuid.UUID
want map[int][]uuid.UUID
wantErr bool
}

tests := []testCase{
{
name: "empty is ok",
input: []uuid.UUID{},
want: make(map[int][]uuid.UUID),
},
{
// TODO: A missing checkpoint probably shouldn't be silently removed from the grouping.
name: "missing checkpoint returns an error (but it doesn't, yet)",
input: []uuid.UUID{uuid.New()},
want: make(map[int][]uuid.UUID),
},
}

expID := maps.Keys(expToCkptUUIDs)[0]
ckptUUIDs := expToCkptUUIDs[expID]
tests = append(tests, testCase{
name: "grouping checkpoints but they all belong to one experiment",
input: expToCkptUUIDs[expID],
want: map[int][]uuid.UUID{expID: ckptUUIDs},
})

tests = append(tests, testCase{
name: "grouping checkpoints across many experiments",
input: flatten(maps.Values(expToCkptUUIDs)),
want: expToCkptUUIDs,
})

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
groupings, err := db.GroupCheckpointUUIDsByExperimentID(tt.input)
if tt.wantErr {
require.Error(t, err)
}
require.NoError(t, err)

// Unpack the response into a sane format---this API just isn't very usable.
got := make(map[int][]uuid.UUID)
for _, g := range groupings {
ckptStrs := strings.Split(g.CheckpointUUIDSStr, ",")
var ckpts []uuid.UUID
for _, ckptStr := range ckptStrs {
ckpt, err := uuid.Parse(ckptStr)
if err != nil {
require.NoError(t, err)
}
ckpts = append(ckpts, ckpt)
}
got[g.ExperimentID] = append(got[g.ExperimentID], ckpts...)
}

require.ElementsMatch(t, maps.Keys(tt.want), maps.Keys(got))
for wantID, wantCkpts := range tt.want {
require.ElementsMatch(t, wantCkpts, got[wantID])
}
})
}
}

func flatten[T any](in [][]T) []T {
var out []T
for _, i := range in {
out = append(out, i...)
}
return out
}
7 changes: 7 additions & 0 deletions master/internal/db/postgres_trial.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,13 @@ func calculateNewSummaryMetrics(

// AddCheckpointMetadata persists metadata for a completed checkpoint to the database.
func AddCheckpointMetadata(ctx context.Context, m *model.CheckpointV2) error {
if m.ReportTime.IsZero() {
m.ReportTime = time.Now().UTC()
}
if m.State == "" {
m.State = model.CompletedState
}

var size int64
for _, v := range m.Resources {
size += v
Expand Down
6 changes: 5 additions & 1 deletion master/internal/rm/agentrm/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ func newAgent(
unregister func(),
) *agent {
a := &agent{
syslog: logrus.WithField("component", "agent").WithField("id", id),
syslog: logrus.WithFields(logrus.Fields{
"component": "agent",
"id": id,
"resource-pool": resourcePoolName,
}),
id: id,
registeredTime: time.Now(),
agentUpdates: agentUpdates,
Expand Down

0 comments on commit eb48302

Please sign in to comment.