Skip to content

Commit

Permalink
fixed the convolution.c import
Browse files Browse the repository at this point in the history
use cuda 10.1
  • Loading branch information
Ubuntu committed Feb 12, 2020
1 parent 3780adc commit a587ef5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion dnn/cgoflags.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ package cudnn
//
// // default locs:
// #cgo LDFLAGS:-L/usr/local/cuda/lib64 -L/usr/local/cuda/lib
// #cgo CFLAGS: -I/usr/include/x86_64-linux-gnu -I/usr/local/cuda-9.0/targets/x86_64-linux/include -I/usr/local/cuda/include
// #cgo CFLAGS: -I/usr/include/x86_64-linux-gnu -I/usr/local/cuda-10.1/targets/x86_64-linux/include -I/usr/local/cuda/include
import "C"
24 changes: 12 additions & 12 deletions dnn/convolution.c
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
#include <cudnn_v7.h>
#include <cudnn.h>

cudnnStatus_t gocudnnNewConvolution(cudnnConvolutionDescriptor_t *retVal,
cudnnMathType_t mathType, const int groupCount,
cudnnMathType_t mathType, const int groupCount,
const int size, const int* padding,
const int* filterStrides,
const int* filterStrides,
const int* dilation,
cudnnConvolutionMode_t convolutionMode, cudnnDataType_t dataType) {

cudnnStatus_t status ;
status = cudnnCreateConvolutionDescriptor(retVal);
if (status != CUDNN_STATUS_SUCCESS) {
return status;
return status;
}

status = cudnnSetConvolutionMathType(*retVal, mathType);
if (status != CUDNN_STATUS_SUCCESS) {
return status;
}

status = cudnnSetConvolutionGroupCount(*retVal, groupCount);
status = cudnnSetConvolutionGroupCount(*retVal, groupCount);
if (status != CUDNN_STATUS_SUCCESS) {
return status;
}

int padH;
int padW;
int u;
int v;
int v;
int dilationH;
int dilationW;
switch (size) {
Expand All @@ -39,17 +39,17 @@ cudnnStatus_t gocudnnNewConvolution(cudnnConvolutionDescriptor_t *retVal,
u = filterStrides[0];
v = filterStrides[1];
dilationH = dilation[0];
dilationW = dilation[1];
dilationW = dilation[1];

status = cudnnSetConvolution2dDescriptor(*retVal,
padH, padW,
u, v,
dilationH, dilationW,
status = cudnnSetConvolution2dDescriptor(*retVal,
padH, padW,
u, v,
dilationH, dilationW,
convolutionMode, dataType);
break;
default:
status = cudnnSetConvolutionNdDescriptor(*retVal, size, padding, filterStrides, dilation, convolutionMode, dataType);
break;
}
return status;
}
}

0 comments on commit a587ef5

Please sign in to comment.