Skip to content

Commit

Permalink
Remove FLOPs calculation from finetuning scripts (Lightning-AI#688)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 30, 2023
1 parent 4826ab3 commit f684138
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 71 deletions.
18 changes: 0 additions & 18 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,6 @@ def train(

validate(fabric, model, val_data, tokenizer) # sanity check

with torch.device("meta"):
meta_model = GPT(model.config)
mark_only_adapter_as_trainable(meta_model)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `flops_per_batch=estimated_flops` instead
estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
forward_fn = lambda: meta_model(x)
loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
Expand Down Expand Up @@ -190,7 +173,6 @@ def train(
time=t1 - total_t0,
samples=(iter_num + 1) * micro_batch_size,
lengths=total_lengths,
flops_per_batch=measured_flops,
)
throughput.compute_and_log(step=iter_num)
fabric.print(
Expand Down
18 changes: 0 additions & 18 deletions finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,23 +136,6 @@ def train(

validate(fabric, model, val_data, tokenizer) # sanity check

with torch.device("meta"):
meta_model = GPT(model.config)
mark_only_adapter_v2_as_trainable(meta_model)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `flops_per_batch=estimated_flops` instead
estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
forward_fn = lambda: meta_model(x)
loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
Expand Down Expand Up @@ -190,7 +173,6 @@ def train(
time=t1 - total_t0,
samples=(iter_num + 1) * micro_batch_size,
lengths=total_lengths,
flops_per_batch=measured_flops,
)
throughput.compute_and_log(step=iter_num)
fabric.print(
Expand Down
17 changes: 0 additions & 17 deletions finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,6 @@ def train(

validate(fabric, model, val_data, tokenizer) # sanity check

with torch.device("meta"):
meta_model = GPT(model.config)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `flops_per_batch=estimated_flops` instead
estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
forward_fn = lambda: meta_model(x)
loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
Expand Down Expand Up @@ -184,7 +168,6 @@ def train(
time=t1 - total_t0,
samples=(iter_num + 1) * micro_batch_size,
lengths=total_lengths,
flops_per_batch=measured_flops,
)
throughput.compute_and_log(step=iter_num)
fabric.print(
Expand Down
18 changes: 0 additions & 18 deletions finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,6 @@ def train(

validate(fabric, model, val_data, tokenizer) # sanity check

with torch.device("meta"):
meta_model = GPT(model.config)
mark_only_lora_as_trainable(meta_model)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `flops_per_batch=estimated_flops` instead
estimated_flops = estimate_flops(meta_model, training=True) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
forward_fn = lambda: meta_model(x)
loss_fn = lambda y: chunked_cross_entropy(y, x, chunk_size=0)
measured_flops = measure_flops(meta_model, forward_fn, loss_fn)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x

throughput = ThroughputMonitor(fabric, window_size=50)
step_count = 0
total_lengths = 0
Expand Down Expand Up @@ -234,7 +217,6 @@ def train(
time=t1 - total_t0,
samples=(iter_num + 1) * micro_batch_size,
lengths=total_lengths,
flops_per_batch=measured_flops,
)
throughput.compute_and_log(step=iter_num)
fabric.print(
Expand Down

0 comments on commit f684138

Please sign in to comment.