forked from SQY2021/PETL
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
4 changed files
with
156 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
# 定义ResNet块 | ||
class ResNetBlock(nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super(ResNetBlock, self).__init__() | ||
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) | ||
self.bn1 = nn.BatchNorm1d(out_channels) | ||
self.relu = nn.ReLU() | ||
self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) | ||
self.bn2 = nn.BatchNorm1d(out_channels) | ||
self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1) | ||
|
||
def forward(self, x): | ||
shortcut = self.shortcut(x) | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.conv2(x) | ||
x = self.bn2(x) | ||
x += shortcut | ||
x = self.relu(x) | ||
return x | ||
|
||
# 定义模型 | ||
class ResNetLSTMModel(nn.Module): | ||
def __init__(self, input_size, resnet_channels, lstm_hidden_size, output_size): | ||
super(ResNetLSTMModel, self).__init__() | ||
self.resnet_x1 = ResNetBlock(input_size, resnet_channels) | ||
self.resnet_x2 = ResNetBlock(input_size, resnet_channels) | ||
self.lstm_x1 = nn.LSTM(resnet_channels, lstm_hidden_size, batch_first=True) | ||
self.lstm_x2 = nn.LSTM(resnet_channels, lstm_hidden_size, batch_first=True) | ||
self.fc_lstm = nn.Linear(lstm_hidden_size, 2) # LSTM的输出连接到全连接层,输出2个值 | ||
self.fc_x1_y1 = nn.Linear(3, 1) # x1输出和y1通过全连接层 | ||
self.fc_x2_y2 = nn.Linear(3, 1) # x2输出和y2通过全连接层 | ||
|
||
def forward(self, x1, x2, y1, y2): | ||
x1 = self.resnet_x1(x1) | ||
x2 = self.resnet_x2(x2) | ||
x1 = x1.permute(0, 2, 1) # 调整输入形状为 (batch_size, sequence_length, channels) | ||
x2 = x2.permute(0, 2, 1) # 调整输入形状为 (batch_size, sequence_length, channels) | ||
x1, _ = self.lstm_x1(x1) | ||
x2, _ = self.lstm_x2(x2) | ||
x1 = x1[:, -1, :] # 取最后一个时间步的输出 | ||
x2 = x2[:, -1, :] # 取最后一个时间步的输出 | ||
|
||
lstm_output_x1 = self.fc_lstm(x1) | ||
lstm_output_x2 = self.fc_lstm(x2) | ||
|
||
x1_pred_input = torch.cat((lstm_output_x1, y1), dim=1) | ||
x2_pred_input = torch.cat((lstm_output_x2, y2), dim=1) | ||
|
||
x1_pred = self.fc_x1_y1(x1_pred_input) | ||
x2_pred = self.fc_x2_y2(x2_pred_input) | ||
|
||
return torch.cat((x1_pred, x2_pred), dim=1) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
def Physics_driven_block(Vi, R_load, D1, D2, D_phi): | ||
# Define parameters | ||
Ls = 80e-6 # Transformer leakage inductor | ||
C2 = 100e-6 # DC Output Link Capacitance | ||
R_load = 15 # Load resistance | ||
R_T = 0.085 # AC Inductor Resistance | ||
Vi = 100 # DC Input Voltage | ||
n = 1 # Transformer Turns Ratio | ||
f = 20000 # Switching Frequency | ||
T = 1 / f # Period | ||
T_hs = T / 2 # Half period | ||
|
||
D1 = 0.2212 # Duty cycle for primary side | ||
D2 = 0.2212 # Duty cycle for secondary side | ||
D_phi = 0.3957 # Phase shift between primary and secondary | ||
|
||
# Define time step and simulation time | ||
dt = 1e-7 | ||
t_end = 0.5 # Simulation time | ||
t = np.arange(0, t_end, dt) # Time vector | ||
|
||
# Initialize primary side bridge voltage Vab | ||
Vab = np.zeros_like(t) | ||
# Initialize secondary side bridge voltage Vcd | ||
Vcd = np.zeros_like(t) | ||
|
||
# Initialize v_C2 and i_L | ||
v_C2 = np.zeros_like(t) | ||
i_L = np.zeros_like(t) | ||
|
||
# Record primary side bridge voltage Vab waveform | ||
for i in range(len(t)): | ||
t_mod = t[i] % T # Get the current time point within the period | ||
if D1 * T_hs <= t_mod < T_hs: | ||
Vab[i] = 1 | ||
elif (1 + D1) * T_hs <= t_mod < T: | ||
Vab[i] = -1 | ||
|
||
if (D1 + D_phi) * T_hs <= t_mod < (D1 + D_phi + 1 - D2) * T_hs: | ||
Vcd[i] = 1 | ||
elif (t_mod < (D1 + D_phi - D2) * T_hs or ((1 + D1 + D_phi) * T_hs <= t_mod < T)): | ||
Vcd[i] = -1 | ||
|
||
# 4th order Runge-Kutta method for v_C2 and i_L | ||
if i > 1: | ||
k1_vC2 = (n * Vcd[i - 1] * i_L[i - 1] - v_C2[i - 1] / R_load) / C2 | ||
k1_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * v_C2[i - 1] - R_T * i_L[i - 1]) / Ls | ||
|
||
k2_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k1_iL * dt / 2) - (v_C2[i - 1] + k1_vC2 * dt / 2) / R_load) / C2 | ||
k2_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k1_vC2 * dt / 2) - R_T * ( | ||
i_L[i - 1] + k1_iL * dt / 2)) / Ls | ||
|
||
k3_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k2_iL * dt / 2) - (v_C2[i - 1] + k2_vC2 * dt / 2) / R_load) / C2 | ||
k3_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k2_vC2 * dt / 2) - R_T * ( | ||
i_L[i - 1] + k2_iL * dt / 2)) / Ls | ||
|
||
k4_vC2 = (n * Vcd[i - 1] * (i_L[i - 1] + k3_iL * dt) - (v_C2[i - 1] + k3_vC2 * dt) / R_load) / C2 | ||
k4_iL = (Vab[i - 1] * Vi - n * Vcd[i - 1] * (v_C2[i - 1] + k3_vC2 * dt) - R_T * ( | ||
i_L[i - 1] + k3_iL * dt)) / Ls | ||
|
||
v_C2[i] = v_C2[i - 1] + (k1_vC2 + 2 * k2_vC2 + 2 * k3_vC2 + k4_vC2) * dt / 6 | ||
i_L[i] = i_L[i - 1] + (k1_iL + 2 * k2_iL + 2 * k3_iL + k4_iL) * dt / 6 | ||
|
||
tp = np.arange(1, len(t) // 2 + 1) | ||
|
||
# Record only waveform change points | ||
record_t = [] | ||
record_v_C2 = [] | ||
record_i_L = [] | ||
|
||
# Record initial values | ||
record_t.append(t[0]) | ||
record_v_C2.append(v_C2[0]) | ||
record_i_L.append(i_L[0]) | ||
|
||
# Record the points where Vab and Vcd change | ||
for i in tp[1:]: | ||
if Vab[i] != Vab[i - 1] or Vcd[i] != Vcd[i - 1]: | ||
record_t.append(t[i]) | ||
record_v_C2.append(v_C2[i]) | ||
record_i_L.append(i_L[i]) | ||
|
||
record_t = np.array(record_t) | ||
record_v_C2 = np.array(record_v_C2) | ||
record_i_L = np.array(record_i_L) | ||
|
||
# 绘制record_t, record_v_C2, record_i_L | ||
plt.figure() | ||
plt.plot(record_t, record_v_C2, 'b', label='v_C2') | ||
plt.plot(record_t, record_i_L, 'r', label='i_L') | ||
plt.xlabel('time') | ||
plt.ylabel('Amplitude') | ||
plt.legend() | ||
plt.title('Recorded Waveform Data') | ||
plt.show() | ||
|
||
return record_t, record_v_C2, record_i_L |