Skip to content

Commit

Permalink
TAS: Fix handling of ResourceFlavor tolerations (#3723)
Browse files Browse the repository at this point in the history
* TAS: Fix handling of ResourceFlavor tolerations

* Review remarks
  • Loading branch information
mimowo authored Dec 3, 2024
1 parent a5dbb05 commit 4e6d3f3
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pkg/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3816,7 +3816,7 @@ func TestSnapshotError(t *testing.T) {
cache := New(client)
cache.AddOrUpdateResourceFlavor(&flavor)
if flavor.Spec.TopologyName != nil {
tasFlavorCache := cache.tasCache.NewTASFlavorCache(*flavor.Spec.TopologyName, []string{corev1.LabelHostname}, flavor.Spec.NodeLabels)
tasFlavorCache := cache.tasCache.NewTASFlavorCache(*flavor.Spec.TopologyName, []string{corev1.LabelHostname}, flavor.Spec.NodeLabels, flavor.Spec.Tolerations)
cache.tasCache.Set(kueue.ResourceFlavorReference(flavor.Name), tasFlavorCache)
}
if err := cache.AddClusterQueue(ctx, &clusterQueue); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/cache/tas_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ func TestFindTopologyAssignment(t *testing.T) {
client := clientBuilder.Build()

tasCache := NewTASCache(client)
tasFlavorCache := tasCache.NewTASFlavorCache("default", tc.levels, tc.nodeLabels)
tasFlavorCache := tasCache.NewTASFlavorCache("default", tc.levels, tc.nodeLabels, tc.tolerations)

snapshot, err := tasFlavorCache.snapshot(ctx)
if err != nil {
Expand Down
10 changes: 8 additions & 2 deletions pkg/cache/tas_flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,22 @@ type TASFlavorCache struct {
// by the flavor corresponding to the cache.
Levels []string

// tolerations represents the list of tolerations specified for the resource
// flavor
Tolerations []corev1.Toleration

// usage maintains the usage per topology domain
usage map[utiltas.TopologyDomainID]resources.Requests
}

func (t *TASCache) NewTASFlavorCache(topologyName kueue.TopologyReference, levels []string, nodeLabels map[string]string) *TASFlavorCache {
func (t *TASCache) NewTASFlavorCache(topologyName kueue.TopologyReference, levels []string, nodeLabels map[string]string,
tolerations []corev1.Toleration) *TASFlavorCache {
return &TASFlavorCache{
client: t.client,
TopologyName: topologyName,
Levels: slices.Clone(levels),
NodeLabels: maps.Clone(nodeLabels),
Tolerations: slices.Clone(tolerations),
usage: make(map[utiltas.TopologyDomainID]resources.Requests),
}
}
Expand Down Expand Up @@ -111,7 +117,7 @@ func (c *TASFlavorCache) snapshotForNodes(log logr.Logger, nodes []corev1.Node,

log.V(3).Info("Constructing TAS snapshot", "nodeLabels", c.NodeLabels,
"levels", c.Levels, "nodeCount", len(nodes), "podCount", len(pods))
snapshot := newTASFlavorSnapshot(log, c.TopologyName, c.Levels)
snapshot := newTASFlavorSnapshot(log, c.TopologyName, c.Levels, c.Tolerations)
nodeToDomain := make(map[string]utiltas.TopologyDomainID)
for _, node := range nodes {
nodeToDomain[node.Name] = snapshot.addNode(node)
Expand Down
11 changes: 8 additions & 3 deletions pkg/cache/tas_flavor_snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ type TASFlavorSnapshot struct {

// domainsPerLevel stores the static tree information
domainsPerLevel []domainByID

// tolerations represents the list of tolerations defined for the resource flavor
tolerations []corev1.Toleration
}

func newTASFlavorSnapshot(log logr.Logger, topologyName kueue.TopologyReference, levels []string) *TASFlavorSnapshot {
func newTASFlavorSnapshot(log logr.Logger, topologyName kueue.TopologyReference,
levels []string, tolerations []corev1.Toleration) *TASFlavorSnapshot {
domainsPerLevel := make([]domainByID, len(levels))
for level := range levels {
domainsPerLevel[level] = make(domainByID)
Expand All @@ -115,6 +119,7 @@ func newTASFlavorSnapshot(log logr.Logger, topologyName kueue.TopologyReference,
topologyName: topologyName,
levelKeys: slices.Clone(levels),
leaves: make(leafDomainByID),
tolerations: slices.Clone(tolerations),
domains: make(domainByID),
roots: make(domainByID),
domainsPerLevel: domainsPerLevel,
Expand Down Expand Up @@ -223,7 +228,7 @@ func (s *TASFlavorSnapshot) FindTopologyAssignment(
topologyRequest *kueue.PodSetTopologyRequest,
requests resources.Requests,
count int32,
tolerations []corev1.Toleration) (*kueue.TopologyAssignment, string) {
podSetTolerations []corev1.Toleration) (*kueue.TopologyAssignment, string) {
required := topologyRequest.Required != nil
key := levelKey(topologyRequest)
if key == nil {
Expand All @@ -234,7 +239,7 @@ func (s *TASFlavorSnapshot) FindTopologyAssignment(
return nil, fmt.Sprintf("no requested topology level: %s", *key)
}
// phase 1 - determine the number of pods which can fit in each topology domain
s.fillInCounts(requests, tolerations)
s.fillInCounts(requests, append(podSetTolerations, s.tolerations...))

// phase 2a: determine the level at which the assignment is done along with
// the domains which can accommodate all pods
Expand Down
2 changes: 1 addition & 1 deletion pkg/controller/tas/resource_flavor.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (r *rfReconciler) Reconcile(ctx context.Context, req reconcile.Request) (re
return reconcile.Result{}, client.IgnoreNotFound(err)
}
levels := utiltas.Levels(&topology)
tasInfo := r.tasCache.NewTASFlavorCache(kueue.TopologyReference(topology.Name), levels, flv.Spec.NodeLabels)
tasInfo := r.tasCache.NewTASFlavorCache(kueue.TopologyReference(topology.Name), levels, flv.Spec.NodeLabels, flv.Spec.Tolerations)
r.tasCache.Set(flavorReference, tasInfo)
}

Expand Down
84 changes: 83 additions & 1 deletion pkg/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4632,6 +4632,88 @@ func TestScheduleForTAS(t *testing.T) {
},
},
},
"scheduling workload on a tainted node when the toleration is on ResourceFlavor": {
nodes: []corev1.Node{
*testingnode.MakeNode("x1").
Label("tas-node", "true").
Label(corev1.LabelHostname, "x1").
StatusAllocatable(corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("1"),
}).
Taints(corev1.Taint{
Key: "example.com/gpu",
Value: "present",
Effect: corev1.TaintEffectNoSchedule,
}).
Ready().
Obj(),
},
topologies: []kueuealpha.Topology{defaultSingleLevelTopology},
resourceFlavors: []kueue.ResourceFlavor{{
ObjectMeta: metav1.ObjectMeta{
Name: "tas-default",
},
Spec: kueue.ResourceFlavorSpec{
NodeLabels: map[string]string{
"tas-node": "true",
},
Tolerations: []corev1.Toleration{
{
Key: "example.com/gpu",
Operator: corev1.TolerationOpExists,
},
},
TopologyName: ptr.To[kueue.TopologyReference]("tas-single-level"),
},
}},
clusterQueues: []kueue.ClusterQueue{
*utiltesting.MakeClusterQueue("tas-main").
ResourceGroup(
*utiltesting.MakeFlavorQuotas("tas-default").
Resource(corev1.ResourceCPU, "50").Obj()).
Obj(),
},
workloads: []kueue.Workload{
*utiltesting.MakeWorkload("foo", "default").
Queue("tas-main").
PodSets(
*utiltesting.MakePodSet("one", 1).
PreferredTopologyRequest(corev1.LabelHostname).
Request(corev1.ResourceCPU, "1").
Obj(),
).
Obj(),
},
wantNewAssignments: map[string]kueue.Admission{
"default/foo": *utiltesting.MakeAdmission("tas-main", "one").
Assignment(corev1.ResourceCPU, "tas-default", "1000m").
AssignmentPodCount(1).
TopologyAssignment(&kueue.TopologyAssignment{
Levels: utiltas.Levels(&defaultSingleLevelTopology),
Domains: []kueue.TopologyDomainAssignment{
{
Count: 1,
Values: []string{
"x1",
},
},
},
}).Obj(),
},
eventCmpOpts: []cmp.Option{eventIgnoreMessage},
wantEvents: []utiltesting.EventRecord{
{
Key: types.NamespacedName{Namespace: "default", Name: "foo"},
Reason: "QuotaReserved",
EventType: corev1.EventTypeNormal,
},
{
Key: types.NamespacedName{Namespace: "default", Name: "foo"},
Reason: "Admitted",
EventType: corev1.EventTypeNormal,
},
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down Expand Up @@ -4662,7 +4744,7 @@ func TestScheduleForTAS(t *testing.T) {
t := topologyByName[*flavor.Spec.TopologyName]
tasCache := cqCache.TASCache()
levels := utiltas.Levels(&t)
tasFlavorCache := tasCache.NewTASFlavorCache(*flavor.Spec.TopologyName, levels, flavor.Spec.NodeLabels)
tasFlavorCache := tasCache.NewTASFlavorCache(*flavor.Spec.TopologyName, levels, flavor.Spec.NodeLabels, flavor.Spec.Tolerations)
tasCache.Set(kueue.ResourceFlavorReference(flavor.Name), tasFlavorCache)
}
}
Expand Down

0 comments on commit 4e6d3f3

Please sign in to comment.