Skip to content

Commit

Permalink
profile_explorer: add op-kernel correlation info (microsoft#15946)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
* Add aggregated op-kernel correlation information in profiler explorer
when running inference session.
* Add filtering feature so that we can focus on model runs of interest
(excluding warmup steps, etc.)
  • Loading branch information
mindest authored May 30, 2023
1 parent 31fc25d commit 90e8c8d
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 45 deletions.
41 changes: 40 additions & 1 deletion onnxruntime/python/tools/profile_explorer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Profile Explorer is a script that makes analyzing profile outputs generated by O

## Usage

### Basic statistics over operators and kernels

For the most up-to-date options, please run `python ./profile_explorer.py --help` to see a list of all options. Here are some of the common use cases covered:

Find the top-5 operators that are taking the most amount of time:
Expand Down Expand Up @@ -69,4 +71,41 @@ Cijk_Ailk_Bljk_HHS_BH_MT128x144x32_MI16x16x16x1_SE_1LDSB1_APM1_ABV0_ACED0_AF0EM8
void onnxruntime::rocm::_BinaryElementWiseSimple<true, false, __half, __half, __half, onnxruntime::rocm::OP_Mul<__half, __half, __half>, 512, 2>(__half const*, __half const*, __half*, onnxruntime::rocm::OP_Mul<__half, __half, __half> const&, int) 46 0.00 11 100.00 1511590
void onnxruntime::rocm::_BinaryElementWiseSimple<false, true, __half, __half, __half, onnxruntime::rocm::OP_Sub<__half, __half, __half>, 512, 2>(__half const*, __half const*, __half*, onnxruntime::rocm::OP_Sub<__half, __half, __half> const&, int) 45 0.00 11 100.00 1511635
void onnxruntime::rocm::_Fill<__half, 256, 4>(__half*, __half, int) 4 0.00 1 100.00 1511639
```
```

### Operator-kernel correlation statistics

We provide an optional argument `--mapping`/`-m` to turn on operator-kernel correlation analysis.
This feature by default groups and counts operators by their names and input info (shape and type).

The output `.csv` file covers the occurrences of operators and kernels as well as kernel duration,
in an averaged model run, so note that in some corner cases the occurrences can be decimal.
Each operator with input info is expanded to several rows of kernels it calls if there are multiple kernels.
Numbers include and are not limited to count, duration, percentage.

Following is a simple example of the output file (some entries are omitted or simplified for brevity):

| op_name | input_type_shape | op_count | kernel_dims | kernel_count | kernel_avg_dur | ... | op_pct | kernel_name |
|--------------------|--------------------------|---------:|----------------------|-------------:|---------------:|-----|-------:|----------------|
| MultiHeadAttention | float16(2x4096x8x3x40) | 5 | b256x1x1,g512x1x1 | 5 | 1734.85 | ... | 23.26 | some_ck_kernel |
| GroupNorm | float16(2x64x64x320),... | 13 | b1024x1x1,g64x1x1 | 8 | 121.52 | ... | 4.28 | norm_kernel_1 |
| GroupNorm | float16(2x64x64x320),... | 13 | b256x1x1,g64x1x1 | 5 | 124.49 | ... | 4.28 | norm_kernel_2 |
| NhwcConv | float16(2x64x64x320),... | 8 | b256x1x1,g163840x1x1 | 7 | 178.51 | ... | 4.00 | igemm_kernel_1 |
| NhwcConv | float16(2x64x64x320),... | 8 | b512x1x1,g2560x1x1 | 7 | 13.27 | ... | 4.00 | OpAdd_kernel_1 |
| NhwcConv | float16(2x64x64x320),... | 8 | b256x1x1,g65536x1x1 | 8 | 9.88 | ... | 4.00 | SubTensorOp... |
| NhwcConv | float16(2x64x64x320),... | 8 | b256x1x1,g40960x1x1 | 1 | 64.03 | ... | 4.00 | igemm_kernel_2 |
| NhwcConv | float16(2x64x64x320),... | 8 | b512x1x1,g640x1x1 | 1 | 4.97 | ... | 4.00 | OpAdd_kernel_2 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |

From the above table, we see that the operator `MultiHeadAttention` with input info `float16(2x4096x8x3x40)`
has 5 occurrences in an model run (on average), and they all call the same kernel `some_ck_kernel`. Therefore,
the kernel `some_ck_kernel` has 5 occurrences in total.

There are cases where operators with the same input info call different kernels. For example, the operator `GroupNorm`
with input info `float16(2x64x64x320),...` has 13 occurrences in an model run, but 8 of them call the kernel
`norm_kernel_1` while the other 5 call the kernel `norm_kernel_2`.

Still, often there are cases where an operator calls multiple kernels. For example, the operator `NhwcConv` with
input info `float16(2x64x64x320),...` has 8 occurrences in an model run, but 7 of them call kernels `igemm_kernel_1`,
`SubTensorOp...` and `OpAdd_kernel_1` respectively, while the other one calls kernels `igemm_kernel_2`,
`SubTensorOp...` and `OpAdd_kernel_2`, resulting in the above five rows of this operator.
190 changes: 146 additions & 44 deletions onnxruntime/python/tools/profile_explorer/profile_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import fnmatch
import json
import subprocess as sp
from collections import defaultdict

import pandas as pd

Expand All @@ -28,7 +29,9 @@ def _get_args():
help="The command to use to demangle C++ identifiers",
)
parser.add_argument(
"--shape-sensitive", action="store_true", help="Perform a shape sensitive analysis of kernel execution times"
"--shape-sensitive",
action="store_true",
help="Perform a shape sensitive analysis of kernel execution times",
)

parser.add_argument(
Expand All @@ -44,9 +47,32 @@ def _get_args():
action="extend",
help="Restrict analysis to the specified identifiers, i.e., specify a filter list. Also supports UNIX-style wildcards.",
)
parser.add_argument("--csv", help="save data to csv")
parser.add_argument("-c", "--count", type=int, default=40, help="list top N items")
parser.add_argument("-v", "--verbose", action="store_true", help="verbose")
parser.add_argument("--csv", help="Save data to csv")
parser.add_argument("-c", "--count", type=int, default=40, help="List top N items")

parser.add_argument(
"--start",
"-s",
type=int,
default=1,
help="Index of the first model run to process (starting from 0, supports negative indices). "
"Defaults to 1 to skip the first run (run 0), which is often a warmup step.",
)
parser.add_argument(
"--end",
"-e",
type=int,
default=None,
help="Index of the last model run to process (exclusive, supports negative indices). "
"Defaults to None, which means all runs starting from --start will be included.",
)
parser.add_argument(
"--mapping",
"-m",
action="store_true",
help="Whether dump op-kernel correlation",
)

args = parser.parse_args()
return args

Expand All @@ -59,20 +85,15 @@ def _shape_to_string(shape):
key = list(dict_obj.keys())[0]
value = list(dict_obj.values())[0]
if len(res) != 0:
res += "__"
res += f'{key}_{"x".join(str(v) for v in value)}'
res += ","
res += f'{key}({"x".join(str(v) for v in value)})'
return res


def _json_to_df(profile_path, filter_matcher):
def _json_to_df(data, filter_matcher):
cpu_entries = []
gpu_entries = []

with open(profile_path, encoding="utf-8") as file_obj:
data = json.load(file_obj)
if isinstance(data, dict):
data = data["traceEvents"]

most_recent_kernel_launch_event = None
num_missing_kernel_launch_events = 0
total_kernel_events = 0
Expand All @@ -96,7 +117,7 @@ def _json_to_df(profile_path, filter_matcher):

if cat != "Kernel" and not name.endswith("kernel_time"):
continue
elif name.endswith("kernel_time"):
if name.endswith("kernel_time"):
most_recent_kernel_launch_event = item

block_x = arg.get("block_x", -1)
Expand All @@ -111,7 +132,7 @@ def _json_to_df(profile_path, filter_matcher):
{
"name": name,
"duration": dur,
"dimensions": f"{block_x}_{block_y}_{block_z}_{grid_x}_{grid_y}_{grid_z}",
"dimensions": f"b{block_x}x{block_y}x{block_z},g{grid_x}x{grid_y}x{grid_z}",
"op_name": op_name,
"input_type_shape": (
_shape_to_string(most_recent_kernel_launch_event["args"]["input_type_shape"])
Expand All @@ -135,7 +156,7 @@ def _json_to_df(profile_path, filter_matcher):

if num_missing_kernel_launch_events > 0:
print(
f"WARNNG: Could not resolve shapes for {num_missing_kernel_launch_events} of {total_kernel_events} kernels."
f"WARNING: Could not resolve shapes for {num_missing_kernel_launch_events} of {total_kernel_events} kernels."
)

cpu_df = pd.DataFrame(cpu_entries)
Expand All @@ -145,12 +166,16 @@ def _json_to_df(profile_path, filter_matcher):
return cpu_df, gpu_df


def _print_cpu_top_hitters(frame, args):
def _print_top_hitters(frame, args, target="cpu"):
if len(frame) == 0:
print("No CPU entries found!")
print(f"No {target.upper()} entries found!")
return
top = args.count
group_key = ["name"]

if target.lower() == "gpu" and args.dimension_sensitive:
group_key.append("dimensions")

if args.shape_sensitive:
group_key.append("input_type_shape")

Expand All @@ -161,35 +186,69 @@ def _print_cpu_top_hitters(frame, args):
frame1 = frame1.sort_values(by="duration", ascending=False)[:top]
frame1["cumulative_pct"] = frame1["pct"].cumsum()
frame1["cumulative_dur"] = frame1["duration"].cumsum()
print("\n------ Top CPU Kernel Times ------")

if target.lower() == "gpu":
frame1["name"] = frame1["name"].apply(lambda x: _demangle(x, args.demangler))

print(f"\n------ Top {target.upper()} Kernel Times ------")
print(frame1.round(2).to_string(index=False))
if args.csv:
frame1.to_csv(f"{args.csv}_cpu_kernel_times.csv", index=False)
frame1.to_csv(f"{args.csv}_{target}_kernel_times.csv", index=False)


def _print_gpu_top_hitters(frame, args):
if len(frame) == 0:
print("No GPU entries found!")
return
top = args.count
group_key = ["name"]
if args.dimension_sensitive:
group_key.append("dimensions")
if args.shape_sensitive:
group_key.append("input_type_shape")
def _print_op_kernel_mapping_info(cpu_df, gpu_df, num_runs, csv=None):
# Count op occurrences in the selected runs
op_counts = defaultdict(int)
for op in cpu_df.T.to_dict().values():
identifiers = tuple([op["name"], op["input_type_shape"]])
op_counts[identifiers] += 1

frame2 = frame[["duration", "count"]].sum()
frame["pct"] = 100 * (frame["duration"] / frame2["duration"])
fields = [*group_key, "duration", "pct", "count"]
frame1 = frame[fields].groupby(group_key).sum().reset_index()
frame1 = frame1.sort_values(by="duration", ascending=False)[:top]
frame1["cumulative_pct"] = frame1["pct"].cumsum()
frame1["cumulative_dur"] = frame1["duration"].cumsum()
frame1["name"] = frame1["name"].apply(lambda x: _demangle(x, args.demangler))
print("\n------ Top GPU Kernel Times ------")
print(frame1.round(2).to_string(index=False))
if args.csv:
frame1.to_csv(f"{args.csv}_gpu_kernel_times.csv", index=False)
# Collect kernel stats: count/duration
stat_dict = defaultdict(lambda: defaultdict(float))
for kernel in gpu_df.T.to_dict().values():
op_name = kernel["op_name"]
if op_name is None: # Only interested in op related kernels
continue
input_type_shape = kernel["input_type_shape"]
kernel_name = kernel["name"]
dimensions = kernel["dimensions"]
identifiers = tuple([op_name, input_type_shape, kernel_name, dimensions])
stat_dict[identifiers]["count"] += 1
stat_dict[identifiers]["duration"] += kernel["duration"]

# Create the DataFrame for kernel entries with op correlation info
kernel_list = []
for identifiers, stat in stat_dict.items():
op_name, input_type_shape, kernel_name, dimensions = identifiers
op_count = op_counts.get(tuple([op_name, input_type_shape]))
if op_count is None:
continue
kernel_list.append(
{
"op_name": op_name,
"input_type_shape": input_type_shape,
"op_count": op_count / num_runs, # Average op count per run
"kernel_name": kernel_name,
"kernel_dimensions": dimensions,
"kernel_count": stat["count"] / num_runs, # Average kernel count per run
"kernel_avg_dur (us)": stat["duration"] / stat["count"],
"kernel_total_dur (us)": stat["duration"] / num_runs,
}
)

df = pd.DataFrame(kernel_list)
df["op_dur (us)"] = df.groupby(["op_name", "input_type_shape"])["kernel_total_dur (us)"].transform("sum")
df["op_avg_dur (us)"] = df["op_dur (us)"] / df["op_count"]
df = df.sort_values(
by=["op_dur (us)", "op_name", "input_type_shape", "kernel_total_dur (us)"],
ascending=False,
).reset_index(drop=True)
df["kernel_pct (%)"] = df["kernel_total_dur (us)"] / df["op_dur (us)"] * 100
df["op_pct (%)"] = df["op_dur (us)"] / df["kernel_total_dur (us)"].sum() * 100
# Move kernel_name to the end since it tends to be long
df.insert(len(df.columns) - 1, "kernel_name", df.pop("kernel_name"))
if csv is not None:
df.to_csv(f"{csv}_op_kernel_mapping.csv", index=False)


def _construct_filter_matcher(args):
Expand All @@ -212,15 +271,58 @@ def _match_item(item):
return _match_item


def _split_data_across_runs(data, start=1, end=None):
"""
Splits the traces according to model runs they belong to.
By default, we skip the first model run (run 0) and consider all subsequent runs.
"""
# Here we assume that the traces are properly ordered, so we can simplify the splitting logic.
model_run_splits = [i for i, item in enumerate(data) if item.get("name") == "model_run"]
if not model_run_splits:
print('WARNING: Could not find "model_run" event in trace. Using entire traces.')
return data
total_num_runs = len(model_run_splits)
print(f"Found {total_num_runs} model_run events in trace.")

assert -total_num_runs <= start < total_num_runs, f"Invalid start index {start}."
if start < 0:
start += total_num_runs
if end is None:
end = total_num_runs
else:
assert -total_num_runs <= end < total_num_runs, f"Invalid end index {end}."
if end < 0:
end += total_num_runs
num_runs = end - start
assert num_runs > 0, "No valid model runs are included in the split."
print(f"Analyzing {num_runs} model run(s): {start}-{end - 1}.")

# Add index 0 in case user wants to include the first model run.
model_run_splits = [0, *model_run_splits]
return data[model_run_splits[start] : model_run_splits[end]], num_runs


def _load_json(profile_path):
with open(profile_path, encoding="utf-8") as file_obj:
data = json.load(file_obj)
if isinstance(data, dict):
data = data["traceEvents"]
return data


def main():
args = _get_args()
filter_matcher = _construct_filter_matcher(args)

cpu_df, gpu_df = _json_to_df(args.input, filter_matcher)
data = _load_json(args.input)
data, num_runs = _split_data_across_runs(data, args.start, args.end)
cpu_df, gpu_df = _json_to_df(data, filter_matcher)

pd.set_option("display.max_colwidth", 120)
_print_cpu_top_hitters(cpu_df, args)
_print_gpu_top_hitters(gpu_df, args)
_print_top_hitters(cpu_df, args, target="cpu")
_print_top_hitters(gpu_df, args, target="gpu")
if args.mapping:
_print_op_kernel_mapping_info(cpu_df, gpu_df, num_runs, args.csv)


if __name__ == "__main__":
Expand Down

0 comments on commit 90e8c8d

Please sign in to comment.