Skip to content

Commit

Permalink
[Train] Allow downloading checkpoints from S3 from GCP for llama-2 wo…
Browse files Browse the repository at this point in the history
…rkspace (ray-project#38154)

* Use aws cli instead of awsv2
* Fix the time logging metrics
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha authored Aug 7, 2023
1 parent 1df5ac5 commit 769afd1
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
head_node_type:
name: head_node_type
instance_type: n1-highmem-64-nvidia-k80-12gb-1
instance_type: g2-standard-32-nvidia-l4-1
resources:
custom_resources:
large_cpu_mem: 1

worker_node_types:
- name: gpu_worker
instance_type: n1-standard-16-nvidia-k80-12gb-1
instance_type: g2-standard-16-nvidia-l4-1
min_workers: 15
max_workers: 15
use_spot: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
OPTIM_EPS = 1e-8
OPTIM_WEIGHT_DECAY = 0.0

MIRROR_LINK = "s3://llama-2-weights/"


def get_number_of_params(model: nn.Module):
state_dict = model.state_dict()
Expand All @@ -60,7 +58,9 @@ def collate_fn(batch, tokenizer, block_size, device):

def get_pretrained_path(model_id: str):
mirror_uri = get_mirror_link(model_id)
ckpt_path, _ = get_checkpoint_and_refs_dir(model_id=model_id, bucket_uri=mirror_uri)
ckpt_path, _ = get_checkpoint_and_refs_dir(
model_id=model_id, bucket_uri=mirror_uri, s3_sync_args=["--no-sign-request"]
)
return ckpt_path


Expand Down Expand Up @@ -277,8 +277,9 @@ def training_function(kwargs: dict):
print("Starting training ...")
print("Number of batches on main process", train_ds_len // batch_size)

fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0
for epoch in range(num_epochs):

fwd_time_sum, bwd_time_sum, optim_step_time_sum = 0, 0, 0
s_epoch = time.time()
model.train()
loss_sum = torch.tensor(0.0).to(accelerator.device)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
#!/bin/bash

# Function to setup AWS
setup_aws() {
echo "Setting up AWS..."
chmod +x ./setup_aws.sh
if ! ./setup_aws.sh; then
echo "Failed to setup AWS. Exiting..."
exit 1
fi
}

# Function to prepare nodes
prepare_nodes() {
Expand Down Expand Up @@ -105,7 +96,6 @@ esac
MODEL_ID="meta-llama/Llama-2-${SIZE}-hf"
CONFIG_DIR="./deepspeed_configs/zero_3_llama_2_${SIZE}.json"

setup_aws
prepare_nodes "${MODEL_ID}"
check_and_create_dataset "${DATA_DIR}"
fine_tune "$BS" "$ND" "$MODEL_ID" "$BASE_DIR" "$CONFIG_DIR" "$TRAIN_PATH" "$TEST_PATH" "$TOKEN_PATH" "${params[@]}"
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def get_hash_from_bucket(

s3_sync_args = s3_sync_args or []
subprocess.run(
["awsv2", "s3", "cp", "--quiet"]
["aws", "s3", "cp", "--quiet"]
+ s3_sync_args
+ [os.path.join(bucket_uri, "refs", "main"), "."]
)
Expand Down Expand Up @@ -67,15 +67,12 @@ def download_model(
path = os.path.join(TRANSFORMERS_CACHE, f"models--{model_id.replace('/', '--')}")

cmd = (
[
"awsv2",
"s3",
"sync",
]
["aws", "s3", "sync"]
+ s3_sync_args
+ (["--exclude", "*", "--include", "*token*"] if tokenizer_only else [])
+ [bucket_uri, path]
)
print(f"RUN({cmd})")
subprocess.run(cmd)
print("done")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ allowed_azs:

head_node_type:
name: head_node_type
instance_type: n1-highmem-64-nvidia-k80-12gb-1
instance_type: g2-standard-32-nvidia-l4-1
resources:
custom_resources:
large_cpu_mem: 1

worker_node_types:
- name: gpu_worker
instance_type: n1-standard-16-nvidia-k80-12gb-1
instance_type: g2-standard-16-nvidia-l4-1
min_workers: 15
max_workers: 15
use_spot: false
Expand Down

0 comments on commit 769afd1

Please sign in to comment.