forked from AMReX-Combustion/PelePhysics
-
Notifications
You must be signed in to change notification settings - Fork 1
/
PeleParamsGeneric.H
118 lines (96 loc) · 2.95 KB
/
PeleParamsGeneric.H
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#ifndef PELEPARAMSGENERIC_H
#define PELEPARAMSGENERIC_H
#include <AMReX_Gpu.H>
#include "Factory.H"
namespace pele::physics {
// ParmType: Generic parameters to be copied from host to device
// Template for HostOnlyParm: Data kept only by Host
template <typename ParmType>
struct HostOnlyParm
{
};
// Forward Declaration of PeleParams
template <typename ParmType, typename BaseParmType = ParmType>
class PeleParams;
// Template for initializer/destructor functions
template <typename ParmType, typename BaseParmType = ParmType>
struct InitParm
{
static void host_initialize(PeleParams<ParmType, BaseParmType>* /*parm_in*/)
{
}
static void host_deallocate(PeleParams<ParmType, BaseParmType>* /*parm_in*/)
{
}
};
// Define generic interface
// This gets a bit complicated to handle the case of base and derived ParmType
// classes
template <typename BaseParmType>
class PeleParamsGeneric : public Factory<PeleParamsGeneric<BaseParmType>>
{
public:
static std::string base_identifier() { return "pele_params_base_generic"; }
virtual void initialize() = 0;
virtual void device_allocate() = 0;
virtual void sync_to_device() = 0;
virtual void deallocate() = 0;
virtual BaseParmType& host_parm() = 0;
virtual const BaseParmType* device_parm() = 0;
virtual HostOnlyParm<BaseParmType>& host_only_parm() = 0;
};
// Template for class that allocates/deallocates a generic ParmType
// For most use cases ParmType doesn't have inheritance to worry about
// and uses the default BaseParmType = ParmType defined in the forward
// declaration above
template <typename ParmType, typename BaseParmType>
class PeleParams : public PeleParamsGeneric<BaseParmType>
{
friend struct InitParm<ParmType, BaseParmType>;
public:
static_assert(std::is_base_of_v<BaseParmType, ParmType>);
PeleParams() = default;
~PeleParams() override = default;
void initialize() override
{
InitParm<ParmType, BaseParmType>::host_initialize(this);
device_allocate();
}
void device_allocate() override
{
if (!m_device_allocated) {
m_d_parm = (ParmType*)amrex::The_Device_Arena()->alloc(sizeof(m_h_parm));
m_device_allocated = true;
sync_to_device();
}
}
void sync_to_device() override
{
if (!m_device_allocated) {
amrex::Abort("Device params not allocated yet");
} else {
amrex::Gpu::copy(
amrex::Gpu::hostToDevice, &m_h_parm, &m_h_parm + 1, m_d_parm);
}
}
void deallocate() override
{
InitParm<ParmType, BaseParmType>::host_deallocate(this);
if (m_device_allocated) {
amrex::The_Device_Arena()->free(m_d_parm);
}
}
BaseParmType& host_parm() override { return m_h_parm; }
const BaseParmType* device_parm() override { return m_d_parm; }
HostOnlyParm<BaseParmType>& host_only_parm() override
{
return m_host_only_parm;
}
private:
HostOnlyParm<BaseParmType> m_host_only_parm;
ParmType m_h_parm;
ParmType* m_d_parm;
bool m_device_allocated{false};
};
} // namespace pele::physics
#endif