diff --git a/benchmarks/benchmark_aq.py b/benchmarks/benchmark_aq.py index bedd6b1425..174038d206 100644 --- a/benchmarks/benchmark_aq.py +++ b/benchmarks/benchmark_aq.py @@ -7,24 +7,60 @@ ) from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, + TORCH_VERSION_AT_LEAST_2_5, ) from torchao.quantization.quant_api import ( + int4_weight_only, + int8_weight_only, + int8_dynamic_activation_int8_weight, + quantize_, _replace_with_custom_fn_if_matches_filter, ) import copy +def _int8wo_api(mod, **kwargs): + if TORCH_VERSION_AT_LEAST_2_4: + quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int8_woqtensors(mod, **kwargs) + +def _int8da_int8w_api(mod, **kwargs): + if TORCH_VERSION_AT_LEAST_2_4: + quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int8_dqtensors(mod, **kwargs) + +def _int4wo_api(mod, **kwargs): + if TORCH_VERSION_AT_LEAST_2_4: + kwargs_copy = kwargs.copy() + if "groupsize" in kwargs_copy: + kwargs_copy["group_size"] = kwargs_copy["groupsize"] + del kwargs_copy["groupsize"] + quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False) + if not TORCH_VERSION_AT_LEAST_2_5: + unwrap_tensor_subclass(mod) + else: + change_linear_weights_to_int4_woqtensors(mod, **kwargs) + class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): + """Single linear for m * k * n problem size + """ + def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"): super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) - self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) + self.m = m + self.dtype = dtype + self.device = device + self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device) - def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + def example_inputs(self): + return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),) def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) + x = self.linear(x) return x def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): @@ -69,14 +105,17 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs): _ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight) -def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None): +torch._dynamo.config.cache_size_limit = 50000 + +@torch.no_grad +def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None): if kwargs is None: kwargs = {} - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval() + m_bf16 = copy.deepcopy(m) m_ref = copy.deepcopy(m) - # setting batch_size to 20 to be compatible with the kernel - example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") + example_inputs = m.example_inputs() api(m, **kwargs) @@ -91,27 +130,41 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None): # perf comparison from torchao.utils import benchmark_model # warmup - WARMUP = 5 + WARMUP = 20 RUNS = 100 - m = torch.compile(m, mode='max-autotune', fullgraph=True) - - benchmark_model(m, WARMUP, example_inputs) - elapsed_time = benchmark_model(m, RUNS, example_inputs) m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True) benchmark_model(m_ref, WARMUP, example_inputs) ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs) - print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}") - assert elapsed_time < 1.05 * ref_elapsed_time + m = torch.compile(m, mode='max-autotune', fullgraph=True) + benchmark_model(m, WARMUP, example_inputs) + elapsed_time = benchmark_model(m, RUNS, example_inputs) + + + m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True) + benchmark_model(m_bf16, WARMUP, example_inputs) + bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs) + + print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}") if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available(): + all_shapes = [ + (20, 2048, 2048), + ] + + print("_int8da_int8w_api") from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors - _bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors) + for M, N, K in all_shapes: + _bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K) + print("_int8wo_api") from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors - _bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors) + for M, N, K in all_shapes: + _bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K) + print("_int4wo_api") kwargs = {"groupsize": 32} from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors - _bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs) + for M, N, K in all_shapes: + _bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)