Skip to content

Commit

Permalink
Don't set TPU topology env vars in cloud_tpu_init.py
Browse files Browse the repository at this point in the history
This used to be necessary. However, now these are automatically set in
libtpu. Beyond being redundant, the Python logic needs to be updated
to avoid getting KeyErrors on new topologies and TPU versions, so
better to remove it.

This also moves `get_metadata` to cloud_tpu_cluster.py since it's only
used in that file now.
  • Loading branch information
skye committed Jan 13, 2023
1 parent 86ba62d commit 004b8c1
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 65 deletions.
64 changes: 0 additions & 64 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,67 +50,3 @@ def cloud_tpu_init():
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'

# If the user has set any topology-related env vars, don't set any
# automatically.
if any([
os.environ.get('CLOUD_TPU_TASK_ID', None),
os.environ.get('TPU_CHIPS_PER_HOST_BOUNDS', None),
os.environ.get('TPU_HOST_BOUNDS', None),
os.environ.get('TPU_MESH_CONTROLLER_ADDRESS', None),
os.environ.get('TPU_MESH_CONTROLLER_PORT', None),
os.environ.get('TPU_VISIBLE_DEVICES', None),
]):
return

worker_id = get_metadata('agent-worker-number')
accelerator_type = get_metadata('accelerator-type')

accelerator_type_to_host_bounds = {
'v2-8': '1,1,1',
'v2-32': '2,2,1',
'v2-128': '4,4,1',
'v2-256': '4,8,1',
'v2-512': '8,8,1',
'v3-8': '1,1,1',
'v3-32': '2,2,1',
'v3-64': '2,4,1',
'v3-128': '4,4,1',
'v3-256': '4,8,1',
'v3-512': '8,8,1',
'v3-1024': '8,16,1',
'v3-2048': '16,16,1',
}

os.environ['CLOUD_TPU_TASK_ID'] = worker_id

# If v4 TPU don't set any topology related flags, libtpu will set these values.
if not accelerator_type.startswith('v4-'):
os.environ['TPU_CHIPS_PER_HOST_BOUNDS'] = '2,2,1'
os.environ['TPU_HOST_BOUNDS'] = accelerator_type_to_host_bounds[
accelerator_type]


def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')

retry_count = 0
retrySeconds = 0.500
api_resp = None

while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'})
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)

if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text
29 changes: 28 additions & 1 deletion jax/_src/clusters/cloud_tpu_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional
from jax._src.clusters import ClusterEnv
from jax._src.lib import xla_bridge
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm, get_metadata
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm


def get_metadata(key):
import requests # pytype: disable=import-error
import time # pytype: disable=import-error
# Based on https://github.com/tensorflow/tensorflow/pull/40317
gce_metadata_endpoint = 'http://' + os.environ.get(
'GCE_METADATA_IP', 'metadata.google.internal')

retry_count = 0
retrySeconds = 0.500
api_resp = None

while retry_count < 6:
api_resp = requests.get(
f'{gce_metadata_endpoint}/computeMetadata/v1/instance/attributes/{key}',
headers={'Metadata-Flavor': 'Google'})
if api_resp.status_code == 200:
break
retry_count += 1
time.sleep(retrySeconds)

if api_resp is None:
raise RuntimeError(f"Getting metadata['{key}'] failed for 6 tries")
return api_resp.text


class TpuCluster(ClusterEnv):
@classmethod
Expand Down

0 comments on commit 004b8c1

Please sign in to comment.