Skip to content

Commit

Permalink
try to make at::cat in mm_tree_reduction operate on contig tensors (#…
Browse files Browse the repository at this point in the history
…18816)

Summary:
Sometimes at::cat gets transposed inputs and goes on a slow path. Also, make jit_premul lstm benchmark add bias to the whole input tensor to avoid separate reduction kernels in the backward pass.
Pull Request resolved: pytorch/pytorch#18816

Differential Revision: D15013576

Pulled By: wanchaol

fbshipit-source-id: bcfa1cf44180b11b05b0f55f034707012f66281a
  • Loading branch information
Natalia Gimelshein authored and facebook-github-bot committed Apr 25, 2019
1 parent c571969 commit 3875e1b
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 6 deletions.
2 changes: 1 addition & 1 deletion benchmarks/fastrnns/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def bench_group(model_list, bench_name, bench_group, bench_args):
parser.add_argument('--group', nargs='*', default=default_groups, help='Which group to run. cnns, rnns, etc.')

args = parser.parse_args()
rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_simple',
rnns = args.rnns or ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_premul_bias', 'jit_simple',
'jit_multilayer', 'py']
cnns = args.cnns or ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit']
# TODO: Maybe add a separate section for the layernorm/dropout lstms
Expand Down
18 changes: 18 additions & 0 deletions benchmarks/fastrnns/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ def premul_lstm_cell(igates, hidden, w_hh, b_ih, b_hh):
return hy, cy


def premul_lstm_cell_no_bias(igates, hidden, w_hh, b_hh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor) -> Tuple[Tensor, Tensor]
hx, cx = hidden
gates = igates + torch.mm(hx, w_hh.t()) + b_hh

ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

ingate = torch.sigmoid(ingate)
forgetgate = torch.sigmoid(forgetgate)
cellgate = torch.tanh(cellgate)
outgate = torch.sigmoid(outgate)

cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * torch.tanh(cy)

return hy, cy


def gru_cell(input, hidden, w_ih, w_hh, b_ih, b_hh):
gi = torch.mm(input, w_ih.t()) + b_ih
gh = torch.mm(hidden, w_hh.t()) + b_hh
Expand Down
39 changes: 38 additions & 1 deletion benchmarks/fastrnns/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import namedtuple

from .cells import lstm_cell, premul_lstm_cell, flat_lstm_cell
from .cells import lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias, flat_lstm_cell


# list[list[T]] -> list[T]
Expand Down Expand Up @@ -128,6 +128,17 @@ def lstm_premul_creator(script=True, **kwargs):
backward=simple_backward)


def lstm_premul_bias_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)


def lstm_simple_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input] + [h[0] for h in hidden] + params[0]
Expand Down Expand Up @@ -386,6 +397,32 @@ def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
return dynamic_rnn


# premul: we're going to premultiply the inputs & weights, and add bias
def lstm_factory_premul_bias(premul_cell, script):
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = hidden
outputs = []
inpSize = input.size()
# add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward
# FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this
# case. Workaround with mm and views.
inpSize = input.size()
inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih
inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0)
hy, cy = hx[0], cx[0]
for seq_idx in range(len(inputs)):
hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh)
outputs += [hy]
return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))

if script:
premul_cell = torch.jit.script(premul_cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)

return dynamic_rnn


# simple: flat inputs (no tuples), no list to accumulate outputs
# useful mostly for benchmarking older JIT versions
def lstm_factory_simple(cell, script):
Expand Down
1 change: 1 addition & 0 deletions benchmarks/fastrnns/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def get_nn_runners(*names):
'aten': RNNRunner('aten', pytorch_lstm_creator, DisableCuDNN),
'jit': RNNRunner('jit', lstm_creator, DummyContext),
'jit_premul': RNNRunner('jit_premul', lstm_premul_creator, DummyContext),
'jit_premul_bias': RNNRunner('jit_premul_bias', lstm_premul_bias_creator, DummyContext),
'jit_simple': RNNRunner('jit_simple', lstm_simple_creator, DummyContext),
'jit_multilayer': RNNRunner('jit_multilayer', lstm_multilayer_creator, DummyContext),
'jit_layernorm': RNNRunner('jit_layernorm', lnlstm_creator, DummyContext),
Expand Down
36 changes: 32 additions & 4 deletions torch/csrc/jit/passes/batch_mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,21 @@ static constexpr size_t min_fusion_size = 4;

bool have_same_shape(at::TensorList inputs) {
auto expected_sizes = inputs[0].sizes();
return std::all_of(
return (std::all_of(
inputs.begin(), inputs.end(), [expected_sizes](const at::Tensor& t) {
return t.sizes() == expected_sizes;
});
}));
}

bool should_be_transposed(at::TensorList inputs) {
return (std::all_of(
inputs.begin(), inputs.end(), [](const at::Tensor& t) {
return t.stride(0) == 1 && t.stride(1) == t.size(0);
}));
}

std::vector<at::Tensor> transpose_inputs(at::TensorList inputs){
return fmap(inputs, [](const at::Tensor& i) { return i.t(); });
}

bool shape_is_fast_for_reduce(const at::Tensor& lhs, const at::Tensor& rhs) {
Expand Down Expand Up @@ -111,8 +122,25 @@ RegisterOperators mm_tree_reduction_reg(
// failing
if (have_same_shape(lhs_inputs) && have_same_shape(rhs_inputs) &&
shape_is_fast_for_reduce(lhs_inputs[0], rhs_inputs[0])) {
auto lhs = at::cat(lhs_inputs, /*dim=*/1);
auto rhs = at::cat(rhs_inputs, /*dim=*/0);
//sometimes lhs_inputs or rhs_inputs are not contiguous, and that causes at::cat to go through slow path
//view them as contiguous if possible by transposing
bool lhs_input_transposed = should_be_transposed(lhs_inputs);
bool rhs_input_transposed = should_be_transposed(rhs_inputs);
at::Tensor lhs, rhs;
if (lhs_input_transposed) {
std::vector<at::Tensor> lhs_contig_inputs = transpose_inputs(lhs_inputs);
lhs = at::cat(lhs_contig_inputs, /*dim*/0);
lhs = lhs.t();
} else {
lhs = at::cat(lhs_inputs, /*dim=*/1);
}
if (rhs_input_transposed) {
std::vector<at::Tensor> rhs_contig_inputs = transpose_inputs(rhs_inputs);
rhs = at::cat(rhs_contig_inputs, /*dim*/1);
rhs = rhs.t();
} else {
rhs = at::cat(rhs_inputs, /*dim=*/0);
}
push(stack, at::mm(lhs, rhs));
} else {
auto acc = at::mm(inputs[0], inputs[side_num_elems]);
Expand Down

0 comments on commit 3875e1b

Please sign in to comment.