-
Notifications
You must be signed in to change notification settings - Fork 35
/
optim.go
65 lines (56 loc) · 1.67 KB
/
optim.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package gotorch
// #cgo CFLAGS: -I ${SRCDIR}
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch -Wl,-rpath ${SRCDIR}/cgotorch -lcgotorch
// #cgo LDFLAGS: -L ${SRCDIR}/cgotorch/libtorch/lib -Wl,-rpath ${SRCDIR}/cgotorch/libtorch/lib -lc10 -ltorch -ltorch_cpu
// #include "cgotorch/cgotorch.h"
import "C"
import (
"runtime"
"unsafe"
)
// Optimizer struct
type Optimizer struct {
Opt *C.Optimizer
}
// SGD creates a SGD Optimizer
func SGD(lr, momentum, dampening, weightDecay float64, nesterov bool) Optimizer {
nt := 0
if nesterov {
nt = 1
}
sgd := C.SGD(C.double(lr), C.double(momentum), C.double(dampening),
C.double(weightDecay), C.int64_t(nt))
runtime.SetFinalizer(&sgd, func(p *C.Optimizer) { C.Optimizer_Close(*p) })
return Optimizer{&sgd}
}
// Adam creates an Adam Optimizer
func Adam(lr, beta1, beta2, weightDecay float64) Optimizer {
adam := C.Adam(C.double(lr), C.double(beta1), C.double(beta2), C.double(weightDecay))
runtime.SetFinalizer(&adam, func(p *C.Optimizer) { C.Optimizer_Close(*p) })
return Optimizer{&adam}
}
// AddParameters adds parameters
func (opt Optimizer) AddParameters(tensors []Tensor) {
CT := []C.Tensor{}
for _, t := range tensors {
CT = append(CT, C.Tensor(*t.T))
}
p := (*C.Tensor)(unsafe.Pointer(&CT[0]))
C.Optimizer_AddParameters(*opt.Opt, p, C.int64_t(len(CT)))
}
// ZeroGrad reset gradients to zero
func (opt Optimizer) ZeroGrad() {
C.Optimizer_ZeroGrad(*opt.Opt)
}
// Step updates parameters
func (opt Optimizer) Step() {
C.Optimizer_Step(*opt.Opt)
}
// SetLR sets learning rate
func (opt Optimizer) SetLR(lr float64) {
C.Optimizer_SetLR(*opt.Opt, C.double(lr))
}
// Close the optimizer
func (opt Optimizer) Close() {
C.Optimizer_Close(*opt.Opt)
}