From 004b8c1a09e8f56104a6d64ddafb0eed391b2059 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 13 Jan 2023 21:51:40 +0000 Subject: [PATCH] Don't set TPU topology env vars in cloud_tpu_init.py This used to be necessary. However, now these are automatically set in libtpu. Beyond being redundant, the Python logic needs to be updated to avoid getting KeyErrors on new topologies and TPU versions, so better to remove it. This also moves `get_metadata` to cloud_tpu_cluster.py since it's only used in that file now. --- jax/_src/cloud_tpu_init.py | 64 -------------------------- jax/_src/clusters/cloud_tpu_cluster.py | 29 +++++++++++- 2 files changed, 28 insertions(+), 65 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index beb0c04b77d5..432be617b717 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -50,67 +50,3 @@ def cloud_tpu_init(): os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu') os.environ['TPU_ML_PLATFORM'] = 'JAX' - - # If the user has set any topology-related env vars, don't set any - # automatically. - if any([ - os.environ.get('CLOUD_TPU_TASK_ID', None), - os.environ.get('TPU_CHIPS_PER_HOST_BOUNDS', None), - os.environ.get('TPU_HOST_BOUNDS', None), - os.environ.get('TPU_MESH_CONTROLLER_ADDRESS', None), - os.environ.get('TPU_MESH_CONTROLLER_PORT', None), - os.environ.get('TPU_VISIBLE_DEVICES', None), - ]): - return - - worker_id = get_metadata('agent-worker-number') - accelerator_type = get_metadata('accelerator-type') - - accelerator_type_to_host_bounds = { - 'v2-8': '1,1,1', - 'v2-32': '2,2,1', - 'v2-128': '4,4,1', - 'v2-256': '4,8,1', - 'v2-512': '8,8,1', - 'v3-8': '1,1,1', - 'v3-32': '2,2,1', - 'v3-64': '2,4,1', - 'v3-128': '4,4,1', - 'v3-256': '4,8,1', - 'v3-512': '8,8,1', - 'v3-1024': '8,16,1', - 'v3-2048': '16,16,1', - } - - os.environ['CLOUD_TPU_TASK_ID'] = worker_id - - # If v4 TPU don't set any topology related flags, libtpu will set these values. - if not accelerator_type.startswith('v4-'): - os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1' - os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[ - accelerator_type] - - -def get_metadata(key): - import requests # pytype: disable=import-error - import time # pytype: disable=import-error - # Based on https://github.com/tensorflow/tensorflow/pull/40317 - gce_metadata_endpoint = 'http://' + os.environ.get( - 'GCE_METADATA_IP', 'metadata.google.internal') - - retry_count = 0 - retrySeconds = 0.500 - api_resp = None - - while retry_count < 6: - api_resp = requests.get( - f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', - headers={'Metadata-Flavor': 'Google'}) - if api_resp.status_code == 200: - break - retry_count += 1 - time.sleep(retrySeconds) - - if api_resp is None: - raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") - return api_resp.text diff --git a/jax/_src/clusters/cloud_tpu_cluster.py b/jax/_src/clusters/cloud_tpu_cluster.py index d241b3f8d144..ac6161bb1a1a 100644 --- a/jax/_src/clusters/cloud_tpu_cluster.py +++ b/jax/_src/clusters/cloud_tpu_cluster.py @@ -12,10 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Optional from jax._src.clusters import ClusterEnv from jax._src.lib import xla_bridge -from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm, get_metadata +from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm + + +def get_metadata(key): + import requests # pytype: disable=import-error + import time # pytype: disable=import-error + # Based on https://github.com/tensorflow/tensorflow/pull/40317 + gce_metadata_endpoint = 'http://' + os.environ.get( + 'GCE_METADATA_IP', 'metadata.google.internal') + + retry_count = 0 + retrySeconds = 0.500 + api_resp = None + + while retry_count < 6: + api_resp = requests.get( + f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}', + headers={'Metadata-Flavor': 'Google'}) + if api_resp.status_code == 200: + break + retry_count += 1 + time.sleep(retrySeconds) + + if api_resp is None: + raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries") + return api_resp.text + class TpuCluster(ClusterEnv): @classmethod