3
3
#include < pybind11/pybind11.h>
4
4
#include < pybind11/numpy.h>
5
5
#include < pybind11/operators.h>
6
+ #include < pybind11/stl.h>
6
7
7
8
using namespace tensor_array ::value;
8
9
using namespace tensor_array ::datatype;
@@ -35,9 +36,9 @@ pybind11::dtype get_py_type(const std::type_info& info)
35
36
throw std::exception ();
36
37
}
37
38
38
- pybind11::array convert_tensor_to_numpy (const Tensor& tensor )
39
+ pybind11::array convert_tensor_to_numpy (const Tensor& self )
39
40
{
40
- const TensorBase& base_tensor = tensor .get_buffer ().change_device ({tensor_array::devices::CPU, 0 });
41
+ const TensorBase& base_tensor = self .get_buffer ().change_device ({tensor_array::devices::CPU, 0 });
41
42
std::vector<pybind11::size_t > shape_vec (base_tensor.shape ().size ());
42
43
std::transform
43
44
(
@@ -54,15 +55,15 @@ pybind11::array convert_tensor_to_numpy(const Tensor& tensor)
54
55
return pybind11::array (ty1, shape_vec, base_tensor.data ());
55
56
}
56
57
57
- Tensor python_tuple_slice (const Tensor& t , pybind11::tuple tuple_slice)
58
+ Tensor python_tuple_slice (const Tensor& self , pybind11::tuple tuple_slice)
58
59
{
59
60
std::vector<Tensor::Slice> t_slices;
60
61
for (size_t i = 0 ; i < tuple_slice.size (); i++)
61
62
{
62
63
ssize_t start, stop, step;
63
64
ssize_t length;
64
65
pybind11::slice py_slice = tuple_slice[i].cast <pybind11::slice>();
65
- if (!py_slice.compute (t .get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
66
+ if (!py_slice.compute (self .get_buffer ().shape ().begin ()[i], &start, &stop, &step, &length))
66
67
throw std::runtime_error (" Invalid slice" );
67
68
t_slices.insert
68
69
(
@@ -75,17 +76,17 @@ Tensor python_tuple_slice(const Tensor& t, pybind11::tuple tuple_slice)
75
76
}
76
77
);
77
78
}
78
- return t [tensor_array::wrapper::initializer_wrapper (t_slices.begin ().operator ->(), t_slices.end ().operator ->())];
79
+ return self [tensor_array::wrapper::initializer_wrapper (t_slices.begin ().operator ->(), t_slices.end ().operator ->())];
79
80
}
80
81
81
- Tensor python_slice (const Tensor& t , pybind11::slice py_slice)
82
+ Tensor python_slice (const Tensor& self , pybind11::slice py_slice)
82
83
{
83
84
std::vector<Tensor::Slice> t_slices;
84
85
ssize_t start, stop, step;
85
86
ssize_t length;
86
- if (!py_slice.compute (t .get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
87
+ if (!py_slice.compute (self .get_buffer ().shape ().begin ()[0 ], &start, &stop, &step, &length))
87
88
throw std::runtime_error (" Invalid slice" );
88
- return t
89
+ return self
89
90
[
90
91
{
91
92
Tensor::Slice
@@ -98,25 +99,43 @@ Tensor python_slice(const Tensor& t, pybind11::slice py_slice)
98
99
];
99
100
}
100
101
101
- Tensor python_index (const Tensor& t , unsigned int i)
102
+ Tensor python_index (const Tensor& self , unsigned int i)
102
103
{
103
- return t [i];
104
+ return self [i];
104
105
}
105
106
106
- std::size_t python_len (const Tensor& t )
107
+ std::size_t python_len (const Tensor& self )
107
108
{
108
- std::initializer_list<unsigned int > shape_list = t .get_buffer ().shape ();
109
+ std::initializer_list<unsigned int > shape_list = self .get_buffer ().shape ();
109
110
return shape_list.size () != 0 ? shape_list.begin ()[0 ]: 1U ;
110
111
}
111
112
112
- pybind11::str tensor_to_string (const Tensor& t )
113
+ pybind11::str tensor_to_string (const Tensor& self )
113
114
{
114
- return pybind11::repr (convert_tensor_to_numpy (t ));
115
+ return pybind11::repr (convert_tensor_to_numpy (self ));
115
116
}
116
117
117
- Tensor tensor_cast_1 (const Tensor& t , DataType dtype)
118
+ Tensor tensor_cast_1 (const Tensor& self , DataType dtype)
118
119
{
119
- return t.tensor_cast (warp_type (dtype));
120
+ return self.tensor_cast (warp_type (dtype));
121
+ }
122
+
123
+ pybind11::tuple tensor_shape (const Tensor& self)
124
+ {
125
+ return pybind11::cast (std::vector (self.get_buffer ().shape ()));
126
+ }
127
+
128
+ Tensor tensor_copying (const Tensor& self)
129
+ {
130
+ return self;
131
+ }
132
+
133
+ Tensor py_zeros (pybind11::tuple shape_tuple, DataType dtype)
134
+ {
135
+ std::vector<unsigned int > shape_vec;
136
+ for (auto & it: shape_tuple)
137
+ shape_vec.push_back (it.cast <unsigned int >());
138
+ return TensorBase (warp_type (dtype), shape_vec);
120
139
}
121
140
122
141
PYBIND11_MODULE (tensor2, m)
@@ -136,9 +155,18 @@ PYBIND11_MODULE(tensor2, m)
136
155
.value (" U_INT_32" , U_INT_32)
137
156
.value (" U_INT_64" , U_INT_64)
138
157
.export_values ();
158
+
159
+ m.def
160
+ (
161
+ " zeros" ,
162
+ &py_zeros,
163
+ pybind11::arg (" shape" ),
164
+ pybind11::arg (" dtype" ) = S_INT_32
165
+ );
139
166
140
167
pybind11::class_<Tensor>(m, " Tensor" )
141
168
.def (pybind11::init ())
169
+ .def (pybind11::init (&tensor_copying))
142
170
.def (pybind11::init (&convert_numpy_to_tensor_base<float >))
143
171
.def (pybind11::self + pybind11::self)
144
172
.def (pybind11::self - pybind11::self)
@@ -176,11 +204,13 @@ PYBIND11_MODULE(tensor2, m)
176
204
.def (" matmul" , &matmul)
177
205
.def (" condition" , &condition)
178
206
.def (" numpy" , &convert_tensor_to_numpy)
207
+ .def (" shape" , &tensor_shape)
179
208
.def (" __getitem__" , &python_index)
180
209
.def (" __getitem__" , &python_slice)
181
210
.def (" __getitem__" , &python_tuple_slice)
182
211
.def (" __len__" , &python_len)
183
212
.def (" __matmul__" , &matmul)
184
213
.def (" __rmatmul__" , &matmul)
185
- .def (" __repr__" , &tensor_to_string);
214
+ .def (" __repr__" , &tensor_to_string)
215
+ .def (" __copy__" , &tensor_copying);
186
216
}
0 commit comments