Skip to content

Commit

Permalink
New fused edge_softmax op (dmlc#3650)
Browse files Browse the repository at this point in the history
* [feature] edge softmax refact.

* delete file

* fix backward and cmake version

* fix backward

* format function

* fix setting

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* refix

* add cuda kernel for backward and rename some function

* add benchmark for edge_softmax

* fix format

* remove cuda_backwrd

* fix code format and add comment for op on CPU

* fix lint

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
ranzhejiang and VoVAllen authored Feb 11, 2022
1 parent 45ac572 commit bc8f8b0
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 9 deletions.
25 changes: 25 additions & 0 deletions benchmarks/benchmarks/kernel/bench_edgesoftmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import time
import dgl
import torch

from .. import utils

# The benchmarks for ops edge_softmax
@utils.benchmark('time', timeout=600)
@utils.parametrize('graph', ['ogbn-arxiv', 'reddit', 'cora', 'pubmed'])
@utils.parametrize('num_heads', [1, 4, 8])
def track_time(graph, num_heads):
device = utils.get_bench_device()
graph = utils.get_graph(graph).to(device)
score = torch.randn((graph.num_edges(),num_heads)).requires_grad_(True).float().to(device)

# dry run
for i in range(3):
y = dgl.ops.edge_softmax(graph, score)

# timing
with utils.Timer(device) as t:
for i in range(100):
y = dgl.ops.edge_softmax(graph, score)

return t.elapsed_secs / 100
24 changes: 17 additions & 7 deletions python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _edge_softmax_forward, _edge_softmax_backward
from ...sparse import _csrmm, _csrsum, _csrmask, _scatter_add, _update_grad_minmax_hetero
from ...heterograph_index import create_unitgraph_from_csr

Expand Down Expand Up @@ -470,10 +470,15 @@ def forward(ctx, gidx, score, eids, norm_by):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
#Note: Now _edge_softmax_forward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(score.is_cuda):
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = th.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
out = _gsddmm(gidx, 'div', score, score_sum, 'e', 'v')
else:
out = _edge_softmax_forward(gidx, score, 'copy_rhs')
ctx.backward_cache = gidx
ctx.save_for_backward(out)
return out
Expand All @@ -500,9 +505,14 @@ def backward(ctx, grad_out):
ctx.backward_cache = None
out, = ctx.saved_tensors
sds = out * grad_out
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)
#Note: Now _edge_softmax_backward op only supports CPU
#TODO(Zhejiang): We will support GPU in the future
if(out.is_cuda):
accum = gspmm(gidx, 'copy_rhs', 'sum', None, sds)

grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
grad_score = sds - gsddmm(gidx, 'mul', out, accum, 'e', 'v')
else:
grad_score = _edge_softmax_backward(gidx, out, sds)
return None, grad_score, None, None


Expand Down
61 changes: 61 additions & 0 deletions python/dgl/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,67 @@ def get_typeid_by_target(gidx, etid, target):
'dst': 2
}

def _edge_softmax_backward(gidx, out, sds):
r""" Edge_softmax backward interface.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
out : tensor
The result of Edge_softmax during forward.
sds : tensor
The result of out * gradient.
Returns
-------
The result of Edge_softmax during backward
Notes
-----
This function does not support gpu op.
"""
op = 'copy_rhs'
back_out = F.zeros_like(out)
_CAPI_DGLKernelEdge_softmax_backward(gidx, op,
to_dgl_nd(out),
to_dgl_nd(sds),
to_dgl_nd_for_write(back_out),
to_dgl_nd(None))
return back_out

def _edge_softmax_forward(gidx, e, op):
r""" Edge_softmax forward interface.
Parameters
----------
gidx : HeteroGraphIndex
The input graph index.
op : str
The binary op's name, default as ``copy_rhs``.
e : tensor or None
The feature on edges.
Returns
-------
The result of Edge_softmax during forward
Notes
-----
This function does not support gpu op.
"""
if F.ndim(e) == 1:
e = F.unsqueeze(e, -1)
expand = True
else:
expand = False
myout = F.zeros_like(e)
_CAPI_DGLKernelEdge_softmax_forward(gidx, op,
to_dgl_nd(None),
to_dgl_nd(e),
to_dgl_nd_for_write(myout))
myout = F.squeeze(myout, -1) if expand else myout
return myout

def _gspmm(gidx, op, reduce_op, u, e):
r""" Generalized Sparse Matrix Multiplication interface. It takes the result of
Expand Down
80 changes: 80 additions & 0 deletions src/array/cpu/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,86 @@ template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);

/*! \brief Edge_softmax_csr forward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_forward<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
});
});
}

/*! \brief Edge_softmax_csr backward op on Csr format. */
template <int XPU, typename IdType, int bits>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray out,
NDArray sds,
NDArray back_out) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cpu::Edge_softmax_csr_backward<IdType, DType, Op>(bcast, csr, out, sds, back_out);
});
});
}

template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);

template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);

/*! \brief Generalized SpMM on Coo format. */
template <int XPU, typename IdType, int bits>
void SpMMCoo(const std::string& op, const std::string& reduce,
Expand Down
96 changes: 96 additions & 0 deletions src/array/cpu/spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
#include <dgl/array.h>
#include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h>
#include <math.h>
#include <algorithm>
#include <limits>
#include <memory>
#include <algorithm>
#include <vector>
#include "spmm_binary_ops.h"
#if !defined(_WIN32)
#ifdef USE_AVX
Expand Down Expand Up @@ -466,6 +469,99 @@ void SpMMCmpCoo(const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat,
}
}


/*!
* \brief CPU kernel of Edge_softmax_csr_forward on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param ufeat The feature on source nodes.
* \param efeat The feature on edges.
* \param out The result of edge_softmax_forward.
*/
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_forward(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
NDArray efeat, NDArray out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* W = Op::use_rhs ? static_cast<DType*>(efeat->data) : nullptr;
const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
std::vector<DType> data_e(row_end-row_start, 0);
std::vector<IdType> num(row_end-row_start, 0);
for (int64_t k = 0; k < dim; ++k) {
DType max_v = -std::numeric_limits<DType>::infinity();
for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
data_e[j-row_start] = *rhs_off;
num[j-row_start] = eid*rhs_dim+rhs_add;
max_v = std::max<DType>(max_v, (*rhs_off));
}
DType exp_sum = 0;
for (auto& element : data_e) {
element -= max_v;
element = std::exp(element);
exp_sum += element;
}
for (int i=0; i < row_end-row_start; i++) {
out.Ptr<DType>()[num[i]] = data_e[i]/exp_sum;
}
}
}
});
}


/*!
* \brief CPU kernel of Edge_softmax_csr_backward on Csr format.
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param out The result of forward.
* \param sds The result of gradiet * out.
* \param back_out The result of edge_softmax_backward.
*/
template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_backward(const BcastOff& bcast, const CSRMatrix& csr, NDArray out,
NDArray sds, NDArray back_out) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges =
has_idx ? static_cast<IdType*>(csr.data->data) : nullptr;
const DType* W_out = Op::use_rhs ? static_cast<DType*>(out->data) : nullptr;
const DType* W_sds = Op::use_rhs ? static_cast<DType*>(sds->data) : nullptr;
const int64_t dim = bcast.out_len, rhs_dim = bcast.rhs_len;
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (int64_t k = 0; k < dim; ++k) {
DType sum_sds = 0;
for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off_sds =
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
sum_sds += (*rhs_off_sds);
}
for (IdType j = row_start; j< row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* rhs_off_out =
Op::use_rhs ? W_out + eid * rhs_dim + rhs_add : nullptr;
const DType* rhs_off_sds =
Op::use_rhs ? W_sds + eid * rhs_dim + rhs_add : nullptr;
back_out.Ptr<DType>()[eid*rhs_dim+rhs_add] = (*rhs_off_sds) - sum_sds*(*rhs_off_out);
}
}
}
});
}

} // namespace cpu
} // namespace aten
} // namespace dgl
Expand Down
Loading

0 comments on commit bc8f8b0

Please sign in to comment.