forked from tczhangzhi/pytorch-parallel
-
Notifications
You must be signed in to change notification settings - Fork 0
/
grad_check.py
37 lines (29 loc) · 1.12 KB
/
grad_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import division
from __future__ import print_function
import argparse
import torch
from torch.autograd import Variable, gradcheck
parser = argparse.ArgumentParser()
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
parser.add_argument('-b', '--batch-size', type=int, default=3)
parser.add_argument('-f', '--feature-size', type=int, default=17)
parser.add_argument('-o', '--output-size', type=int, default=3)
parser.add_argument('-c', '--cuda', action='store_true')
options = parser.parse_args()
if options.example == 'py':
from python.dense import DenseFunction
elif options.example == 'cpp':
from cpp.dense import DenseFunction
else:
from cuda.dense import DenseFunction
options.cuda = True
X = torch.randn(options.batch_size, options.feature_size)
W = torch.randn(options.output_size, options.feature_size)
b = torch.randn(options.output_size)
variables = [X, W, b]
for i, var in enumerate(variables):
if options.cuda:
var = var.cuda()
variables[i] = Variable(var.double(), requires_grad=True)
if gradcheck(DenseFunction.apply, variables, eps=1e-6, atol=1e-4):
print('Ok')