diff --git a/ray-on-gke/kuberay-tpu-webhook/main.go b/ray-on-gke/kuberay-tpu-webhook/main.go index 6ece06310..0c9abc596 100755 --- a/ray-on-gke/kuberay-tpu-webhook/main.go +++ b/ray-on-gke/kuberay-tpu-webhook/main.go @@ -154,21 +154,33 @@ func mutatePod( container := pod.Spec.Containers[i] if(containerRequestingTPUs(container)) { path := fmt.Sprintf("/spec/containers/%d/env", i) - value := corev1.EnvVar{ + value1 := corev1.EnvVar{ Name: "TPU_WORKER_ID", Value: fmt.Sprint(tpu_worker_id), } - patch := map[string]interface{}{ + value2 := corev1.EnvVar{ + Name: "TPU_NAME", + Value: fmt.Sprint(groupName), + } + patch1 := map[string]interface{}{ + "op": "add", + } + patch2 := map[string]interface{}{ "op": "add", } if(len(container.Env) == 0) { - patch["path"] = path - patch["value"] = []corev1.EnvVar{value} + patch1["path"] = path + patch1["value"] = []corev1.EnvVar{value1} + patch2["path"] = path + patch2["value"] = []corev1.EnvVar{value2} } else { - patch["path"] = fmt.Sprintf("%s/-", path) - patch["value"] = value + patch1["path"] = fmt.Sprintf("%s/-", path) + patch1["value"] = value1 + patch2["path"] = fmt.Sprintf("%s/-", path) + patch2["value"] = value2 } - patches = append(patches, patch) + patches = append(patches, patch1) + patches = append(patches, patch2) } }