-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Test] Performance benchmarks for DGL kernels (dmlc#2582)
* add initial kernel benchmarks * finished kernel benchmarks * add desc
- Loading branch information
1 parent
12f6429
commit 362f72c
Showing
6 changed files
with
164 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import time | ||
import dgl | ||
import torch | ||
|
||
from .. import utils | ||
|
||
def calc_gflops(graph, feat_size, num_heads, time): | ||
return round(2 * graph.num_edges() * feat_size / 1000000000 / time, 2) # count both mul and add | ||
|
||
# The benchmarks include broadcasting cases. | ||
# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H) | ||
# while the edge feature shape will be (H, ), so tested operations will broadcast | ||
# along the last dimension. The total FLOP is controlled by the feat_size no | ||
# matter how many heads are there. | ||
# If num_heads = 0, it falls back to the normal element-wise operation without | ||
# broadcasting. | ||
@utils.benchmark('flops', timeout=600) | ||
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) | ||
@utils.parametrize('feat_size', [4, 32, 256]) | ||
@utils.parametrize('num_heads', [0, 1, 4]) | ||
def track_flops(graph, feat_size, num_heads): | ||
device = utils.get_bench_device() | ||
graph = utils.get_graph(graph, format='coo').to(device) | ||
if num_heads == 0: | ||
x = torch.randn(graph.num_nodes(), feat_size, device=device) | ||
else: | ||
x = torch.randn(graph.num_nodes(), num_heads, feat_size // num_heads, device=device) | ||
|
||
# dry run | ||
for i in range(3): | ||
y = dgl.ops.u_dot_v(graph, x, x) | ||
|
||
# timing | ||
accum = 0. | ||
for i in range(10): | ||
with utils.TorchOpTimer(device) as timer: | ||
y = dgl.ops.u_dot_v(graph, x, x) | ||
accum += timer.time | ||
|
||
return calc_gflops(graph, feat_size, num_heads, accum / 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import time | ||
import dgl | ||
import torch | ||
|
||
from .. import utils | ||
|
||
def calc_gflops(graph, feat_size, time): | ||
return round(graph.num_edges() * feat_size / 1000000000 / time, 2) | ||
|
||
@utils.benchmark('flops', timeout=600) | ||
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) | ||
@utils.parametrize('feat_size', [4, 32, 256]) | ||
@utils.parametrize('reducer', ['sum', 'max']) | ||
def track_flops(graph, feat_size, reducer): | ||
device = utils.get_bench_device() | ||
graph = utils.get_graph(graph, format='csc').to(device) | ||
x = torch.randn(graph.num_nodes(), feat_size, device=device) | ||
|
||
if reducer == 'sum': | ||
op = dgl.ops.copy_u_sum | ||
elif reducer == 'max': | ||
op = dgl.ops.copy_u_max | ||
else: | ||
raise ValueError('Invalid reducer', reducer) | ||
|
||
# dry run | ||
for i in range(3): | ||
y = op(graph, x) | ||
|
||
# timing | ||
accum = 0. | ||
for i in range(10): | ||
with utils.TorchOpTimer(device) as timer: | ||
y = op(graph, x) | ||
accum += timer.time | ||
|
||
return calc_gflops(graph, feat_size, accum / 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import time | ||
import dgl | ||
import torch | ||
|
||
from .. import utils | ||
|
||
def calc_gflops(graph, feat_size, num_heads, time): | ||
return round(2 * graph.num_edges() * feat_size / 1000000000 / time, 2) # count both mul and add | ||
|
||
# The benchmarks include broadcasting cases. | ||
# Given feat_size = D, num_heads = H, the node feature shape will be (H, D // H) | ||
# while the edge feature shape will be (H, ), so tested operations will broadcast | ||
# along the last dimension. The total FLOP is controlled by the feat_size no | ||
# matter how many heads are there. | ||
# If num_heads = 0, it falls back to the normal element-wise operation without | ||
# broadcasting. | ||
@utils.benchmark('flops', timeout=600) | ||
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'ogbn-proteins']) | ||
@utils.parametrize('feat_size', [4, 32, 256]) | ||
@utils.parametrize('num_heads', [0, 1, 4]) | ||
def track_flops(graph, feat_size, num_heads): | ||
device = utils.get_bench_device() | ||
graph = utils.get_graph(graph, format='csc').to(device) | ||
if num_heads == 0: | ||
x = torch.randn(graph.num_nodes(), feat_size, device=device) | ||
w = torch.randn(graph.num_edges(), feat_size, device=device) | ||
else: | ||
x = torch.randn(graph.num_nodes(), num_heads, feat_size // num_heads, device=device) | ||
w = torch.randn(graph.num_edges(), num_heads, 1, device=device) | ||
|
||
# dry run | ||
for i in range(3): | ||
y = dgl.ops.u_mul_e_sum(graph, x, w) | ||
|
||
# timing | ||
accum = 0. | ||
for i in range(10): | ||
with utils.TorchOpTimer(device) as timer: | ||
y = dgl.ops.u_mul_e_sum(graph, x, w) | ||
accum += timer.time | ||
|
||
return calc_gflops(graph, feat_size, num_heads, accum / 10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters