-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpyreader.h
156 lines (125 loc) · 3.04 KB
/
pyreader.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#ifndef stempyreaderpy_h
#define stempyreaderpy_h
#include <iostream>
#include <memory>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <stempy/reader.h>
namespace py = pybind11;
namespace stempy {
struct PYBIND11_EXPORT DataHolder
{
DataHolder() = default;
DataHolder(const DataHolder&) = default;
DataHolder(DataHolder&&) = default;
DataHolder& operator=(const DataHolder& other)
{
if (this != &other) {
// Need to acquire the gil before deleting the python array
// or there may be a crash.
py::gil_scoped_acquire gil;
this->array = other.array;
}
return *this;
}
DataHolder& operator=(DataHolder&& other)
{
if (this != &other) {
// Need to acquire the gil before deleting the python array
// or there may be a crash.
py::gil_scoped_acquire gil;
this->array = std::move(other.array);
}
return *this;
}
~DataHolder()
{
// Need to acquire the gil before deleting the python array
// or there may be a crash.
reset();
}
const uint16_t* get()
{
if (!this->array) {
return nullptr;
}
return this->array->data();
}
void reset()
{
py::gil_scoped_acquire gil;
this->array.reset();
}
std::shared_ptr<py::array_t<uint16_t>> array;
};
struct PYBIND11_EXPORT PyBlock
{
Header header;
DataHolder data;
PyBlock() = default;
PyBlock(py::array_t<uint16_t> pyarray);
};
class PYBIND11_EXPORT PyReader
{
public:
PyReader(py::object pyDataSet, std::vector<uint32_t>& imageNumbers,
Dimensions2D scanDimensions, uint32_t blockSize,
uint32_t totalImageNum);
PyBlock read();
void reset();
class iterator;
iterator begin();
iterator end();
class iterator
{
public:
using self_type = iterator;
using value_type = PyBlock;
using reference = PyBlock&;
using pointer = PyBlock*;
using iterator_category = std::input_iterator_tag;
using difference_type = void; // Differences not allowed here
iterator(PyReader* pyreader) : m_PyReader(pyreader)
{
if (pyreader == nullptr) {
return;
}
// read data at first time
m_block = m_PyReader->read();
if (m_block.data.get() == nullptr) {
m_PyReader = nullptr;
}
}
self_type operator++()
{
m_block = m_PyReader->read();
if (!m_block.data.get()) {
this->m_PyReader = nullptr;
}
return *this;
}
reference operator*() { return m_block; }
pointer operator->() { return &m_block; }
bool operator==(const self_type& rhs)
{
return m_PyReader == rhs.m_PyReader;
}
bool operator!=(const self_type& rhs)
{
return (this->m_PyReader != rhs.m_PyReader);
}
private:
PyReader* m_PyReader = nullptr;
value_type m_block;
};
private:
py::object m_pydataset;
Dimensions2D m_scanDimensions;
std::vector<uint32_t> m_imageNumbers;
uint32_t m_currIndex = 0;
uint32_t m_imageNumInBlock;
uint32_t m_blockNumInFile;
uint32_t m_totalImageNum;
};
} // namespace stempy
#endif