Skip to content

Commit

Permalink
Update PhyRes-LSTM model file
Browse files Browse the repository at this point in the history
Update PhyRes-LSTM model file
  • Loading branch information
SQY2021-Sub authored Aug 24, 2024
1 parent fa88606 commit d8127ba
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 0 deletions.
57 changes: 57 additions & 0 deletions model/ResNetBlock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch.nn as nn

# 定义ResNet块
class ResNetBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResNetBlock, self).__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(out_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm1d(out_channels)
self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x += shortcut
x = self.relu(x)
return x

# 定义模型
class ResNetLSTMModel(nn.Module):
def __init__(self, input_size, resnet_channels, lstm_hidden_size, output_size):
super(ResNetLSTMModel, self).__init__()
self.resnet_x1 = ResNetBlock(input_size, resnet_channels)
self.resnet_x2 = ResNetBlock(input_size, resnet_channels)
self.lstm_x1 = nn.LSTM(resnet_channels, lstm_hidden_size, batch_first=True)
self.lstm_x2 = nn.LSTM(resnet_channels, lstm_hidden_size, batch_first=True)
self.fc_lstm = nn.Linear(lstm_hidden_size, 2) # LSTM的输出连接到全连接层,输出2个值
self.fc_x1_y1 = nn.Linear(3, 1) # x1输出和y1通过全连接层
self.fc_x2_y2 = nn.Linear(3, 1) # x2输出和y2通过全连接层

def forward(self, x1, x2, y1, y2):
x1 = self.resnet_x1(x1)
x2 = self.resnet_x2(x2)
x1 = x1.permute(0, 2, 1) # 调整输入形状为 (batch_size, sequence_length, channels)
x2 = x2.permute(0, 2, 1) # 调整输入形状为 (batch_size, sequence_length, channels)
x1, _ = self.lstm_x1(x1)
x2, _ = self.lstm_x2(x2)
x1 = x1[:, -1, :] # 取最后一个时间步的输出
x2 = x2[:, -1, :] # 取最后一个时间步的输出

lstm_output_x1 = self.fc_lstm(x1)
lstm_output_x2 = self.fc_lstm(x2)

x1_pred_input = torch.cat((lstm_output_x1, y1), dim=1)
x2_pred_input = torch.cat((lstm_output_x2, y2), dim=1)

x1_pred = self.fc_x1_y1(x1_pred_input)
x2_pred = self.fc_x2_y2(x2_pred_input)

return torch.cat((x1_pred, x2_pred), dim=1)
Binary file added model/__pycache__/ResNetBlock.cpython-37.pyc
Binary file not shown.
Binary file not shown.
99 changes: 99 additions & 0 deletions model/physics_driven_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np
import matplotlib.pyplot as plt
def Physics_driven_block(Vi, R_load, D1, D2, D_phi):
# Define parameters
Ls = 80e-6 # Transformer leakage inductor
C2 = 100e-6 # DC Output Link Capacitance
R_load = 15 # Load resistance
R_T = 0.085 # AC Inductor Resistance
Vi = 100 # DC Input Voltage
n = 1 # Transformer Turns Ratio
f = 20000 # Switching Frequency
T = 1 / f # Period
T_hs = T / 2 # Half period

D1 = 0.2212 # Duty cycle for primary side
D2 = 0.2212 # Duty cycle for secondary side
D_phi = 0.3957 # Phase shift between primary and secondary

# Define time step and simulation time
dt = 1e-7
t_end = 0.5 # Simulation time
t = np.arange(0, t_end, dt) # Time vector

# Initialize primary side bridge voltage Vab
Vab = np.zeros_like(t)
# Initialize secondary side bridge voltage Vcd
Vcd = np.zeros_like(t)

# Initialize v_C2 and i_L
v_C2 = np.zeros_like(t)
i_L = np.zeros_like(t)

# Record primary side bridge voltage Vab waveform
for i in range(len(t)):
t_mod = t[i] % T # Get the current time point within the period
if D1 * T_hs <= t_mod < T_hs:
Vab[i] = 1
elif (1 + D1) * T_hs <= t_mod < T:
Vab[i] = -1

if (D1 + D_phi) * T_hs <= t_mod < (D1 + D_phi + 1 - D2) * T_hs:
Vcd[i] = 1
elif (t_mod < (D1 + D_phi - D2) * T_hs or ((1 + D1 + D_phi) * T_hs <= t_mod < T)):
Vcd[i] = -1

# 4th order Runge-Kutta method for v_C2 and i_L
if i > 1:
k1_vC2 = (n * Vcd[i - 1] * i_L[i - 1] - v_C2[i - 1] / R_load) / C2
k1_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * v_C2[i - 1] - R_T * i_L[i - 1]) / Ls

k2_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k1_iL * dt / 2) - (v_C2[i - 1] + k1_vC2 * dt / 2) / R_load) / C2
k2_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k1_vC2 * dt / 2) - R_T * (
i_L[i - 1] + k1_iL * dt / 2)) / Ls

k3_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k2_iL * dt / 2) - (v_C2[i - 1] + k2_vC2 * dt / 2) / R_load) / C2
k3_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k2_vC2 * dt / 2) - R_T * (
i_L[i - 1] + k2_iL * dt / 2)) / Ls

k4_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k3_iL * dt) - (v_C2[i - 1] + k3_vC2 * dt) / R_load) / C2
k4_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k3_vC2 * dt) - R_T * (
i_L[i - 1] + k3_iL * dt)) / Ls

v_C2[i] = v_C2[i - 1] + (k1_vC2 + 2 * k2_vC2 + 2 * k3_vC2 + k4_vC2) * dt / 6
i_L[i] = i_L[i - 1] + (k1_iL + 2 * k2_iL + 2 * k3_iL + k4_iL) * dt / 6

tp = np.arange(1, len(t) // 2 + 1)

# Record only waveform change points
record_t = []
record_v_C2 = []
record_i_L = []

# Record initial values
record_t.append(t[0])
record_v_C2.append(v_C2[0])
record_i_L.append(i_L[0])

# Record the points where Vab and Vcd change
for i in tp[1:]:
if Vab[i] != Vab[i - 1] or Vcd[i] != Vcd[i - 1]:
record_t.append(t[i])
record_v_C2.append(v_C2[i])
record_i_L.append(i_L[i])

record_t = np.array(record_t)
record_v_C2 = np.array(record_v_C2)
record_i_L = np.array(record_i_L)

# 绘制record_t, record_v_C2, record_i_L
plt.figure()
plt.plot(record_t, record_v_C2, 'b', label='v_C2')
plt.plot(record_t, record_i_L, 'r', label='i_L')
plt.xlabel('time')
plt.ylabel('Amplitude')
plt.legend()
plt.title('Recorded Waveform Data')
plt.show()

return record_t, record_v_C2, record_i_L

0 comments on commit d8127ba

Please sign in to comment.