Skip to content

Commit

Permalink
remove unnecessary contiguous check in volumetricconvolution (intel#1357
Browse files Browse the repository at this point in the history
)

remove unnecessary contiguous check in volumetricconvolution
  • Loading branch information
dding3 authored Jul 25, 2017
1 parent a81e6ba commit 7ca4923
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,10 @@ class VolumetricConvolution[T: ClassTag](
override def updateGradInput(input: Tensor[T], gradOutput: Tensor[T]): Tensor[T] = {
require(input.dim() == 4 || input.dim() == 5,
s"4D or 5D (batch mode) tensor expected for input, but got: ${ input.dim() }d")
require(input.isContiguous(), "input should be contiguous")
require(gradOutput.isContiguous(), "gradOutput should be contiguous")
gradInput.resizeAs(input)
fGradInput.resizeAs(fInput).zero()
if (input.dim() == 4) {
require(gradOutput.isContiguous(), "gradOutput should be contiguous")
updateGradInputFrame(gradInput, gradOutput, weightMM.transpose(1, 2), fGradInput,
kT, kW, kH,
dT, dW, dH,
Expand All @@ -266,6 +265,7 @@ class VolumetricConvolution[T: ClassTag](
val gradInputT = gradInput.select(1, t)
val gradOutputT = gradOutput.select(1, t)
val fGradInputT = fGradInput.select(1, t)
require(gradOutputT.isContiguous(), "each batch of gradOutput should be contiguous")
updateGradInputFrame(gradInputT, gradOutputT, weightMM.transpose(1, 2), fGradInputT,
kT, kW, kH,
dT, dW, dH,
Expand Down Expand Up @@ -303,7 +303,6 @@ class VolumetricConvolution[T: ClassTag](
}

override def accGradParameters(input: Tensor[T], gradOutput: Tensor[T]): Unit = {
require(input.isContiguous(), "input should be contiguous")
require(gradOutput.isContiguous(), "gradOutput should be contiguous")
if (gradWeightMM == null || gradWeightMM.storage().isEmpty) {
gradWeightMM = gradWeight.view(nOutputPlane, nInputPlane * kT * kH * kW)
Expand Down

0 comments on commit 7ca4923

Please sign in to comment.