Skip to content

Commit a065cff

Browse files
testing
1 parent a3ea643 commit a065cff

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

src/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from tensor_array.core import tensor2 as t
1+
from tensor_array.core import Tensor
22

33
print("hello")
44

5-
t1 = t.Tensor([[1, 2, 3], [4, 5, 6]])
5+
t1 = Tensor([[1, 2, 3], [4, 5, 6]])
66
t2 = t1.clone()
77
print("tensor len", t1.__len__())
88
print(t1)
Binary file not shown.
+11-5
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
from tensor_array.layers import Layer
2+
from tensor_array.layers import Parameter
23
from tensor_array.core import Tensor
34
from typing import Any
45

56

67
class Linear(Layer):
7-
def __init__(self) -> None:
8-
self.w = t.Tensor(0)
9-
self.b = t.Tensor(0)
8+
def __init__(self, bias) -> None:
9+
super(self)
10+
self.bias_shape = bias
11+
self.b = Parameter(Tensor(shape = (bias)))
1012

11-
def __call__(self, input) -> Any:
12-
return input @ self.w + self.b
13+
def init_value(self, t):
14+
self.w = Parameter(Tensor(shape = (t.shape(-1), self.bias_shape)))
15+
16+
def calculate(self, t):
17+
return self.w @ t + self.b
18+

0 commit comments

Comments
 (0)