-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathblob.cc
67 lines (57 loc) · 1.99 KB
/
blob.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include "blob.h"
//#include "tensor.h"
#ifdef WITH_CUDA
#include <caffe2/core/context_gpu.h>
#endif
namespace caffe2 {
TensorCPU BlobUtil::Get() {
#ifdef WITH_CUDA
if (blob_.IsType<TensorCUDA>()) {
return TensorCPU(blob_.Get<TensorCUDA>());
}
#endif
return blob_.Get<TensorCPU>();
}
void BlobUtil::Set(const TensorCPU &value, bool force_cuda) {
#ifdef WITH_CUDA
if (force_cuda || blob_.IsType<TensorCUDA>()) {
//auto tensor = blob_.GetMutable<TensorCUDA>();
auto tensor = caffe2::BlobGetMutableTensor(blob_,DeviceType::CUDA);
tensor->CopyFrom(value);
return;
}
#endif
//auto tensor = blob_.GetMutable<TensorCPU>();
auto tensor = caffe2::BlobGetMutableTensor(&blob_,DeviceType::CPU);
tensor->ResizeLike(value);
tensor->ShareData(value);
}
template<typename T>
void BlobUtil::Set(const std::vector<int> &dim, const std::vector<T> &data, bool force_cuda) {
#ifdef WITH_CUDA
if (force_cuda || blob_.IsType<TensorCUDA>()) {
//auto tensor = blob_.GetMutable<TensorCUDA>();
auto tensor = caffe2::BlobGetMutableTensor(blob_,DeviceType::CUDA);
tensor->CopyFrom(value);
return;
}
#endif
auto value = Tensor(dim,DeviceType::CPU);
int count = 0;
for (auto& v : data) {
value.mutable_data<T>()[count] =v;
count++;
}
auto tensor = caffe2::BlobGetMutableTensor(&blob_,DeviceType::CPU);
tensor->ResizeLike(value);
tensor->ShareData(value);
}
void BlobUtil::Print(const std::string &name, int max) {
auto tensor = Get();
//I took out 11 07 2018 TensorUtil(tensor).Print(name, max);
}
//or you could just remove the below and place put the function implementation inside the header file and delete it from the cpp file
//see this https://stackoverflow.com/questions/10632251/undefined-reference-to-template-function for better explaination
template void BlobUtil::Set<float>(const std::vector<int> &, const std::vector<float> &, bool);
template void BlobUtil::Set<int>(const std::vector<int> &, const std::vector<int> &, bool);
} // namespace caffe2