Skip to content

Commit

Permalink
Add Shark Benchmark
Browse files Browse the repository at this point in the history
-Introduce SharkBenchmark that bench models on regular torch, shark-py, and shark-c.
-Integrate iree-benchmark-module into Shark.
  • Loading branch information
raikonenfnu committed May 27, 2022
1 parent 18a4423 commit 91867e1
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 17 deletions.
35 changes: 35 additions & 0 deletions shark/examples/minilm_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from shark.shark_inference import SharkInference

torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("microsoft/MiniLM-L12-H384-uncased")


class MiniLMSequenceClassification(torch.nn.Module):

def __init__(self):
super().__init__()
self.model = AutoModelForSequenceClassification.from_pretrained(
"microsoft/MiniLM-L12-H384-uncased", # The pretrained model.
num_labels=
2, # The number of output labels--2 for binary classification.
output_attentions=
False, # Whether the model returns attentions weights.
output_hidden_states=
False, # Whether the model returns all hidden-states.
torchscript=True,
)

def forward(self, tokens):
return self.model.forward(tokens)[0]


test_input = torch.randint(2, (1, 128))

shark_module = SharkInference(MiniLMSequenceClassification(), (test_input,),
jit_trace=True, benchmark_mode=True)

shark_module.compile()
shark_module.forward((test_input,))
shark_module.benchmark_all((test_input,))
82 changes: 82 additions & 0 deletions shark/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
# limitations under the License.

import iree.runtime as ireert
import iree.runtime.scripts.iree_benchmark_module as benchmark_module
import iree.compiler as ireec
from iree.compiler import tf as tfc
from shark.torch_mlir_utils import get_module_name_for_asm_dump
import subprocess
import numpy as np
import os
import re

IREE_DEVICE_MAP = {
"cpu": "dylib",
Expand All @@ -27,6 +30,10 @@
"metal": "vulkan"
}

UNIT_TO_SECOND_MAP = {
"ms": 0.001,
"s": 1
}

def check_device_drivers(device):
"""Checks necessary drivers present for gpu and vulkan devices"""
Expand Down Expand Up @@ -152,6 +159,15 @@ def get_iree_compiled_module(module,

return get_iree_module(module, device, input_type, args, func_name)

def export_iree_module_to_vmfb(module, device: str, directory: str):
module_name = get_module_name_for_asm_dump(module)
flatbuffer_blob = ireec.compile_str(
str(module), target_backends=[IREE_DEVICE_MAP[device]])
filename = os.path.join(directory, module_name + ".vmfb")
with open(filename, 'wb') as f:
f.write(flatbuffer_blob)
return filename


def get_results(compiled_vm, input, config, frontend="torch"):
"""Runs a .vmfb file given inputs and config and returns output."""
Expand All @@ -171,3 +187,69 @@ def get_results(compiled_vm, input, config, frontend="torch"):
return np.copy(res)
else:
return np.copy(np.asarray(result, dtype=result.dtype))

######### Benchmark Related Tools ###########

def tensor_to_type_str(input_tensors : tuple):
"""
Input: A tuple of input tensors i.e tuple(torch.tensor)
Output: list of string that represent mlir types (i.e 1x24xf64)
# TODO: Support more than floats, and ints
"""
list_of_type = []
for input_tensor in input_tensors:
type_string = "x".join([str(dim) for dim in input_tensor.shape])
dtype_string = str(input_tensor.dtype).replace("torch.","")
regex_split = re.compile("([a-zA-Z]+)([0-9]+)")
match = regex_split.match(dtype_string)
mlir_type_string = str(match.group(1)[0])+str(match.group(2))
type_string += f"x{mlir_type_string}"
list_of_type.append(type_string)
return list_of_type

def build_benchmark_args(input_file : str, device : str, input_tensors : tuple, training=False):
"""
Inputs: input_file leading to vmfb, input_tensor to function, target device, and whether it is training or not.
Outputs: string that execute benchmark-module on target model.
"""
path = benchmark_module.__path__[0]
benchmarker_path = os.path.join(path, "..", "..", "iree-benchmark-module")
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
fn_name = "forward"
if training == True:
# TODO: Replace name of train with actual train fn name.
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--driver={IREE_DEVICE_MAP[device]}")
mlir_input_types = tensor_to_type_str(input_tensors)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
time_extractor = "| awk \'END{{print $2 $3}}\'"
benchmark_cl.append(time_extractor)
return benchmark_cl

def run_cmd(cmd):
"""
Inputs: cli command string.
"""
try:
result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True)
result_str = result.stdout.decode()
return result_str
except Exception:
sys.exit("Exiting program due to error running:", cmd)

def run_benchmark(benchmark_cl):
"""
Run benchmark command, extract result and return iteration/seconds.
Input: benchmark command.
"""
benchmark_path = benchmark_cl[0]
assert os.path.exists(benchmark_path),"Cannot find benchmark_module, Please contact SHARK maintainer on discord."
bench_result = run_cmd(' '.join(benchmark_cl))
regex_split = re.compile("([0-9]+[.]*[0-9]*)([a-zA-Z]+)")
match = regex_split.match(bench_result)
time = float(match.group(1))
unit = match.group(2)
return 1.0/(time*UNIT_TO_SECOND_MAP[unit])
31 changes: 20 additions & 11 deletions shark/shark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from shark.torch_mlir_utils import get_torch_mlir_module, run_on_refbackend
from shark.iree_utils import get_results, get_iree_compiled_module
import os
from shark.parser import shark_args
from shark.shark_runner import SharkRunner
from tqdm import tqdm
from shark.shark_runner import SharkRunner, SharkBenchmarkRunner
import time


Expand All @@ -28,11 +26,13 @@ def __init__(
device: str = None,
dynamic: bool = False,
jit_trace: bool = False,
benchmark_mode : bool = False
):
self.model = model
self.input = input
self.dynamic = dynamic
self.jit_trace = jit_trace
self.benchmark_mode = benchmark_mode

# By default it's torch frontend.
self.frontend = "pytorch"
Expand All @@ -47,14 +47,12 @@ def set_frontend(self, frontend: str):
self.frontend = frontend

def compile(self):
if self.frontend in ["pytorch", "torch"]:
self.model = get_torch_mlir_module(self.model, self.input,
self.dynamic, self.jit_trace)

iree_compilation_module, iree_config = get_iree_compiled_module(
self.model, self.device, self.frontend)

self.shark_runner = SharkRunner(iree_compilation_module, iree_config)
# Inference do not use AOT.
from_aot = False
if(self.benchmark_mode == True):
self.shark_runner = SharkBenchmarkRunner(self.model, self.input, self.dynamic, self.device, self.jit_trace, from_aot, self.frontend)
else:
self.shark_runner = SharkRunner(self.model, self.input, self.dynamic, self.device, self.jit_trace, from_aot, self.frontend)

# inputs are considered to be np.array.
def forward(self, inputs):
Expand All @@ -65,3 +63,14 @@ def forward(self, inputs):
elif self.frontend in ["tensorflow", "tf"]:
input_list = [x.numpy() for x in inputs]
return self.shark_runner.forward(input_list, self.frontend)

######### Benchmark Related Functions #########
def benchmark_mode(func):
def inner(self, *args, **kwargs):
assert self.benchmark_mode, "SharkRunner needs to be in benchmark mode to run benchmark methods."
return func(self, *args, **kwargs)
return inner

@benchmark_mode
def benchmark_all(self, inputs):
self.shark_runner.benchmark_all(inputs)
100 changes: 94 additions & 6 deletions shark/shark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from shark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
from shark.torch_mlir_utils import get_torch_mlir_module, export_module_to_mlir_file, run_on_refbackend
from shark.iree_utils import get_results, get_iree_compiled_module
from shark.iree_utils import get_results, get_iree_compiled_module, export_iree_module_to_vmfb, build_benchmark_args, run_benchmark
import os
from shark.parser import shark_args
from tqdm import tqdm
Expand All @@ -30,19 +30,43 @@ class SharkRunner:

def __init__(
self,
iree_compilation_module,
iree_config,
model,
input: tuple,
dynamic: bool = False,
device: str = None,
jit_trace: bool = False,
from_aot : bool = False,
frontend : str = "torch",
):
self.model = model
self.frontend_model = model
self.from_aot = from_aot
self.input = input
self.frontend = frontend
self.vmfb_file = None
device = device if device is not None else shark_args.device
if self.frontend in ["pytorch", "torch"]:
self.model = get_torch_mlir_module(self.model, input, dynamic,
jit_trace,
from_aot)
(
self.iree_compilation_module,
self.iree_config,
) = get_iree_compiled_module(self.model, device)

self.iree_compilation_module = iree_compilation_module
self.iree_config = iree_config
# Debugging Options:
if shark_args.save_mlir:
export_module_to_mlir_file(self.model,
shark_args.repro_dir)
if shark_args.save_mlir:
self.vmfb_file = export_iree_module_to_vmfb(self.model, device,
shark_args.repro_dir)

# All the timings and benchmarking can be done here.
def forward(self, input, frontend):
return get_results(self.iree_compilation_module, input,
self.iree_config, frontend)


class SharkMode:

def __init__(self, device="cpu"):
Expand All @@ -56,3 +80,67 @@ def __init__(self, device="cpu"):

def __del__(self):
self.guard.__exit__(None, None, None)

class SharkBenchmarkRunner(SharkRunner):
# SharkRunner derived class with Benchmarking capabilities.
def __init__(
self,
model,
input: tuple,
dynamic: bool = False,
device: str = None,
jit_trace: bool = False,
from_aot : bool = False,
frontend : str = "torch",
):
SharkRunner.__init__(self, model, input, dynamic, device, jit_trace, from_aot, frontend)
if(self.vmfb_file == None):
self.vmfb_file = export_iree_module_to_vmfb(self.model, device,
shark_args.repro_dir)
self.benchmark_cl = build_benchmark_args(self.vmfb_file, device, input, from_aot)

def benchmark_frontend(self, inputs):
if self.frontend in ["pytorch", "torch"]:
self.benchmark_torch(inputs)
elif self.frontend in ["tensorflow", "tf"]:
self.benchmark_tf(inputs)

def benchmark_torch(self, inputs):
inputs = self.input if self.from_aot else inputs
inputs = inputs[0]
for i in range(shark_args.num_warmup_iterations):
self.frontend_model.forward(inputs)

begin = time.time()
for i in range(shark_args.num_iterations):
out = self.frontend_model.forward(inputs)
if i == shark_args.num_iterations - 1:
end = time.time()
break
print(f"Torch benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}")

def benchmark_tf(self, inputs):
print(f"TF benchmark not implemented yet!")
return

def benchmark_c(self):
result = run_benchmark(self.benchmark_cl)
print(f"Shark-{self.frontend} C-benchmark:{result} iter/second")

def benchmark_python(self, inputs):
inputs = self.input if self.from_aot else inputs
input_list = [x.detach().numpy() for x in inputs]
for i in range(shark_args.num_warmup_iterations):
self.forward(input_list, self.frontend)

begin = time.time()
for i in range(shark_args.num_iterations):
out = self.forward(input_list, self.frontend)
if i == shark_args.num_iterations - 1:
end = time.time()
print(f"Shark-{self.frontend} Python-benchmark:{shark_args.num_iterations/(end-begin)} iter/second, Total Iterations:{shark_args.num_iterations}")

def benchmark_all(self, inputs):
self.benchmark_frontend(inputs)
self.benchmark_python(inputs)
self.benchmark_c()

0 comments on commit 91867e1

Please sign in to comment.