Skip to content

Commit

Permalink
simplify multi pr case in pr_shift
Browse files Browse the repository at this point in the history
  • Loading branch information
dsdsdshe committed Nov 17, 2023
1 parent bfe4d61 commit d3f7bbd
Showing 1 changed file with 14 additions and 30 deletions.
44 changes: 14 additions & 30 deletions ccsrc/include/simulator/vector/vector_state.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,7 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
if (gate->GradRequired()) {
auto p_gate = static_cast<Parameterizable*>(gate.get());
std::shared_ptr<BasicGate> tmp_gate;
bool is_multi_pr = false;
switch (gate->id_) {
case (GateID::RX): {
tmp_gate = CONVERT_GATE(RXGate, p_gate);
Expand Down Expand Up @@ -1397,10 +1398,12 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
}
case (GateID::U3): {
tmp_gate = CONVERT_GATE(U3, p_gate);
is_multi_pr = true;
break;
}
case (GateID::FSim): {
tmp_gate = CONVERT_GATE(FSim, p_gate);
is_multi_pr = true;
break;
}
case (GateID::CUSTOM): {
Expand Down Expand Up @@ -1444,16 +1447,10 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
VT<py_qs_data_t> intrin_grad_list(tmp_p_gate->prs_.size());
for (int k = 0; k < tmp_p_gate->prs_.size(); k++) {
tmp_p_gate->prs_[k] += -pr_shift;
if (gate->id_ == GateID::U3 || gate->id_ == GateID::FSim) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : tmp_p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp += -pr_shift / coeff;
if (is_multi_pr) {
std::string key = tmp_p_gate->prs_[k].data_.begin()->first;
parameter::tn::Tensor tmp = pr.GetItem(key);
tmp += -pr_shift / tmp_p_gate->prs_[k].data_.begin()->second;
tmp_pr.SetItem(key, tmp);
}
sim_l = *this;
Expand All @@ -1464,16 +1461,10 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
sim_rs[j - start].ApplyHamiltonian(*hams[j]);
auto expect0 = qs_policy_t::Vdot(sim_l.qs, sim_rs[j - start].qs, dim);
tmp_p_gate->prs_[k] += 2 * pr_shift;
if (gate->id_ == GateID::U3 || gate->id_ == GateID::FSim) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : tmp_p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp += pr_shift / coeff;
if (is_multi_pr) {
std::string key = tmp_p_gate->prs_[k].data_.begin()->first;
parameter::tn::Tensor tmp = pr.GetItem(key);
tmp += pr_shift / tmp_p_gate->prs_[k].data_.begin()->second;
tmp_pr.SetItem(key, tmp);
}
sim_l = *this;
Expand All @@ -1484,16 +1475,9 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
sim_rs[j - start].ApplyHamiltonian(*hams[j]);
auto expect1 = qs_policy_t::Vdot(sim_l.qs, sim_rs[j - start].qs, dim);
tmp_p_gate->prs_[k] += -pr_shift;
if (gate->id_ == GateID::U3 || gate->id_ == GateID::FSim) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : tmp_p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp_pr.SetItem(key, tmp);
if (is_multi_pr) {
std::string key = tmp_p_gate->prs_[k].data_.begin()->first;
tmp_pr.SetItem(key, pr.GetItem(key));
}
intrin_grad_list[k] = {coeff * std::real(expect1 - expect0), 0};
}
Expand Down

0 comments on commit d3f7bbd

Please sign in to comment.