Skip to content

Commit

Permalink
fix parameter shift rule
Browse files Browse the repository at this point in the history
  • Loading branch information
dsdsdshe committed Sep 4, 2023
1 parent 953d9a6 commit cd08444
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 98 deletions.
2 changes: 1 addition & 1 deletion ccsrc/include/simulator/vector/vector_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class VectorState {
//! Get the expectation and gradient of hamiltonian by parameter-shift rule
virtual VVT<py_qs_data_t> GetExpectationWithGradParameterShiftOneMulti(
const std::vector<std::shared_ptr<Hamiltonian<calc_type>>>& hams, const circuit_t& circ,
const parameter::ParameterResolver& pr, const MST<size_t>& p_map, int n_thread);
parameter::ParameterResolver& pr, const MST<size_t>& p_map, int n_thread);

virtual VT<VVT<py_qs_data_t>> GetExpectationWithGradParameterShiftMultiMulti(
const std::vector<std::shared_ptr<Hamiltonian<calc_type>>>& hams, const circuit_t& circ,
Expand Down
47 changes: 45 additions & 2 deletions ccsrc/include/simulator/vector/vector_state.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradMultiMulti(
template <typename qs_policy_t_>
auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
const std::vector<std::shared_ptr<Hamiltonian<calc_type>>>& hams, const circuit_t& circ,
const parameter::ParameterResolver& pr, const MST<size_t>& p_map, int n_thread) -> VVT<py_qs_data_t> {
parameter::ParameterResolver& pr, const MST<size_t>& p_map, int n_thread) -> VVT<py_qs_data_t> {
auto n_hams = hams.size();
int max_thread = 15;
if (n_thread == 0) {
Expand Down Expand Up @@ -1155,7 +1155,14 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
calc_type pr_shift = M_PI_2;
calc_type coeff = 0.5;
if (gate->id_ == GateID::CUSTOM) {
auto p_gate = static_cast<CustomGate*>(gate.get());
pr_shift = 0.001;
coeff = 0.5 / pr_shift;
}
if (gate->id_ == GateID::SWAPalpha) {
pr_shift = 0.5;
coeff = M_PI_2;
}
if (gate->id_ == GateID::FSim) {
pr_shift = 0.001;
coeff = 0.5 / pr_shift;
}
Expand All @@ -1164,18 +1171,54 @@ auto VectorState<qs_policy_t_>::GetExpectationWithGradParameterShiftOneMulti(
VT<py_qs_data_t> intrin_grad_list(p_gate->prs_.size());
for (int k = 0; k < p_gate->prs_.size(); k++) {
p_gate->prs_[k] += -pr_shift;
if (gate->id_ == GateID::U3) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp += -pr_shift / coeff;
pr.SetItem(key, tmp);
}
sim_l = *this;
sim_l.ApplyCircuit(circ, pr);
sim_rs[j - start] = sim_l;
sim_rs[j - start].ApplyHamiltonian(*hams[j]);
auto expect0 = qs_policy_t::Vdot(sim_l.qs, sim_rs[j - start].qs, dim);
p_gate->prs_[k] += 2 * pr_shift;
if (gate->id_ == GateID::U3) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp += 2 * pr_shift / coeff;
pr.SetItem(key, tmp);
}
sim_l = *this;
sim_l.ApplyCircuit(circ, pr);
sim_rs[j - start] = sim_l;
sim_rs[j - start].ApplyHamiltonian(*hams[j]);
auto expect1 = qs_policy_t::Vdot(sim_l.qs, sim_rs[j - start].qs, dim);
p_gate->prs_[k] += -pr_shift;
if (gate->id_ == GateID::U3) {
parameter::tn::Tensor coeff;
parameter::tn::Tensor tmp;
std::string key;
for (auto& [key_, v] : p_gate->prs_[k].data_) {
key = key_;
coeff = v;
tmp = pr.GetItem(key_);
}
tmp += -pr_shift / coeff;
pr.SetItem(key, tmp);
}
intrin_grad_list[k] = {coeff * std::real(expect1 - expect0), 0};
}
auto intrin_grad = tensor::Matrix(VVT<py_qs_data_t>{intrin_grad_list});
Expand Down
4 changes: 2 additions & 2 deletions mindquantum/simulator/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def get_expectation_with_grad(
batch in parallel threads. Default: ``None``.
pr_shift (bool): Whether or not to use parameter-shift rule. Only available in "mqvector" simulator.
It will be enabled automatically when circuit contains noise channel. Noted that not every gate
uses the same shift value π/2, so the gradient of parameterized custom gate will be calculated
by finite difference method with gap 0.001. Default: ``False``.
uses the same shift value π/2, so the gradient of FSim gate and parameterized custom gate will be
calculated by finite difference method with gap 0.001. Default: ``False``.
Returns:
GradOpsWrapper, a grad ops wrapper than contains information to generate this grad ops.
Expand Down
67 changes: 38 additions & 29 deletions tests/st/test_simulator/test_basic_gate_with_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def test_single_parameter_gate_expectation_with_grad(config, gate): # pylint: d
Expectation: success.
"""
virtual_qc, dtype = config
g = gate('a')
g = gate({'a': 1, 'b': 2})
dim = 2**g.n_qubits
g = g.on(list(range(g.n_qubits)))
init_state = np.random.rand(dim) + np.random.rand(dim) * 1j
Expand All @@ -161,43 +161,54 @@ def test_single_parameter_gate_expectation_with_grad(config, gate): # pylint: d
sim = Simulator(virtual_qc, g.n_qubits, dtype=dtype)
sim.set_qs(init_state)
grad_ops = sim.get_expectation_with_grad(ham, Circuit(g))
pr = np.random.rand() * 2 * np.pi
f, grad = grad_ops([pr])
pr = np.random.rand(2) * 2 * np.pi
f, grad = grad_ops(pr)
ref_f = (
init_state.T.conj()
@ g.hermitian().matrix({'a': pr})
@ g.hermitian().matrix({'a': pr[0], 'b': pr[1]})
@ ham.hamiltonian.matrix(g.n_qubits)
@ g.matrix({'a': pr})
@ g.matrix({'a': pr[0], 'b': pr[1]})
@ init_state
)
ref_grad = (
init_state.T.conj()
@ g.hermitian().matrix({'a': pr})
@ ham.hamiltonian.matrix(g.n_qubits)
@ g.diff_matrix({'a': pr})
@ init_state
).real * 2
ref_grad = []
for about_what in ('a', 'b'):
ref_grad.append(
(
init_state.T.conj()
@ g.hermitian().matrix({'a': pr[0], 'b': pr[1]})
@ ham.hamiltonian.matrix(g.n_qubits)
@ g.diff_matrix({'a': pr[0], 'b': pr[1]}, about_what)
@ init_state
).real
* 2
)
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(grad, ref_grad.real, atol=1e-4)
assert np.allclose(grad, ref_grad, atol=1e-4)

c_g = g.on(list(range(g.n_qubits)), g.n_qubits)
c_init_state = np.random.rand(2 * dim) + np.random.rand(2 * dim) * 1j
c_init_state = c_init_state / np.linalg.norm(c_init_state)
c_sim = Simulator(virtual_qc, c_g.n_qubits + 1, dtype=dtype)
c_sim.set_qs(c_init_state)
c_grad_ops = c_sim.get_expectation_with_grad(ham, Circuit(c_g))
c_pr = np.random.rand() * 2 * np.pi
c_f, c_grad = c_grad_ops([c_pr])
m = np.block([[np.eye(dim), np.zeros((dim, dim))], [np.zeros((dim, dim)), g.matrix({'a': c_pr})]])
diff_m = np.block(
[[np.zeros((dim, dim)), np.zeros((dim, dim))], [np.zeros((dim, dim)), g.diff_matrix({'a': c_pr})]]
)
c_pr = np.random.rand(2) * 2 * np.pi
c_f, c_grad = c_grad_ops(c_pr)
m = np.block([[np.eye(dim), np.zeros((dim, dim))], [np.zeros((dim, dim)), g.matrix({'a': c_pr[0], 'b': c_pr[1]})]])
c_ref_f = c_init_state.T.conj() @ m.T.conj() @ ham.hamiltonian.matrix(g.n_qubits + 1) @ m @ c_init_state
c_ref_grad = (
2 * (c_init_state.T.conj() @ m.T.conj() @ ham.hamiltonian.matrix(g.n_qubits + 1) @ diff_m @ c_init_state).real
)
assert np.allclose(c_f, c_ref_f, atol=1e-6)
assert np.allclose(c_grad, c_ref_grad, atol=1e-6)
c_ref_grad = []
for about_what in ('a', 'b'):
diff_m = np.block(
[
[np.zeros((dim, dim)), np.zeros((dim, dim))],
[np.zeros((dim, dim)), g.diff_matrix({'a': c_pr[0], 'b': c_pr[1]}, about_what)],
]
)
c_ref_grad.append(
2
* (c_init_state.T.conj() @ m.T.conj() @ ham.hamiltonian.matrix(g.n_qubits + 1) @ diff_m @ c_init_state).real
)
assert np.allclose(c_f, c_ref_f, atol=1e-5)
assert np.allclose(c_grad, c_ref_grad, atol=1e-5)


@pytest.mark.level0
Expand Down Expand Up @@ -364,7 +375,7 @@ def diff_matrix(alpha):
@ init_state
).real * 2
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(grad, ref_grad.real, atol=1e-6)
assert np.allclose(grad, ref_grad, atol=1e-6)

c_g = g.on(list(range(n)), n)
c_init_state = np.random.rand(2 * dim) + np.random.rand(2 * dim) * 1j
Expand Down Expand Up @@ -425,8 +436,6 @@ def test_u3_expectation_with_grad(config): # pylint: disable=R0914
f, grad = grad_ops(pr)
ref_f, ref_grad = ref_grad_ops(ref_pr)
ref_grad = np.array([ref_grad[0][0][1], ref_grad[0][0][2], ref_grad[0][0][0]])
print(grad)
print(ref_grad)
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(grad, ref_grad.real, atol=1e-6)

Expand All @@ -445,7 +454,7 @@ def test_u3_expectation_with_grad(config): # pylint: disable=R0914
ref_c_f, ref_c_grad = ref_c_grad_ops(ref_c_pr)
ref_c_grad = np.array([ref_c_grad[0][0][1], ref_c_grad[0][0][2], ref_c_grad[0][0][0]])
assert np.allclose(c_f, ref_c_f, atol=1e-6)
assert np.allclose(c_grad, ref_c_grad, atol=1e-6)
assert np.allclose(c_grad, ref_c_grad.real, atol=1e-6)


@pytest.mark.level0
Expand Down Expand Up @@ -516,7 +525,7 @@ def phi_diff_matrix(phi):
).real * 2
ref_grad = np.array([ref_grad_theta, ref_grad_phi])
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(grad, ref_grad.real, atol=1e-6)
assert np.allclose(grad, ref_grad, atol=1e-6)

c_g = g.on(list(range(g.n_qubits)), g.n_qubits)
c_init_state = np.random.rand(2 * dim) + np.random.rand(2 * dim) * 1j
Expand Down
93 changes: 50 additions & 43 deletions tests/st/test_simulator/test_method_of_mqmatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,32 +216,37 @@ def test_get_expectation_with_grad(config):
init_state = init_state / np.linalg.norm(init_state)
circ0 = random_circuit(3, 100)
circ1 = random_circuit(3, 100)
circ = circ0 + G.RX('a').on(0) + circ1
circ = circ0 + G.RX({'a': 1, 'b': 2}).on(0) + circ1
sim = Simulator(virtual_qc, 3, dtype=dtype)
sim.set_qs(init_state)
ham0 = Hamiltonian(QubitOperator('X0 Y1'), dtype=dtype)
ham1 = ham0.sparse(3)
ham2 = Hamiltonian(csr_matrix(ham0.hamiltonian.matrix(3)), dtype=dtype)
for ham in (ham0, ham1, ham2):
grad_ops = sim.get_expectation_with_grad(ham, circ)
pr = np.random.rand()
f, g = grad_ops([pr])
pr = np.random.rand(2)
f, g = grad_ops(pr)
ref_f = (
init_state.T.conj()
@ circ.hermitian().matrix({'a': pr})
@ circ.hermitian().matrix({'a': pr[0], 'b': pr[1]})
@ ham0.hamiltonian.matrix(3)
@ circ.matrix({'a': pr})
@ circ.matrix({'a': pr[0], 'b': pr[1]})
@ init_state
)
ref_g = (
init_state.T.conj()
@ circ.hermitian().matrix({'a': pr})
@ ham0.hamiltonian.matrix(3)
@ circ1.matrix()
@ np.kron(np.eye(4, 4), G.RX('a').diff_matrix({'a': pr}))
@ circ0.matrix()
@ init_state
).real * 2
ref_g = []
for about_what in ('a', 'b'):
ref_g.append(
(
init_state.T.conj()
@ circ.hermitian().matrix({'a': pr[0], 'b': pr[1]})
@ ham0.hamiltonian.matrix(3)
@ circ1.matrix()
@ np.kron(np.eye(4, 4), G.RX({'a': 1, 'b': 2}).diff_matrix({'a': pr[0], 'b': pr[1]}, about_what))
@ circ0.matrix()
@ init_state
).real
* 2
)
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(g, ref_g, atol=1e-6)

Expand All @@ -262,7 +267,7 @@ def test_noise_get_expectation_with_grad(virtual_qc, dtype):
init_dm = np.outer(init_state, init_state.conj())
circ0 = random_circuit(3, 100, 1.0, 0.0)
circ1 = random_circuit(3, 100, 1.0, 0.0)
circ = circ0 + G.RX('a').on(0) + circ1
circ = circ0 + G.RX({'a': 1, 'b': 2}).on(0) + circ1
circ = circ.with_noise()
ham0 = Hamiltonian(QubitOperator('X0 Y1'), dtype=dtype)
ham1 = ham0.sparse(3)
Expand All @@ -271,39 +276,41 @@ def test_noise_get_expectation_with_grad(virtual_qc, dtype):
sim = Simulator(virtual_qc, 3, dtype=dtype)
sim.set_qs(init_dm)
grad_ops = sim.get_expectation_with_grad(ham, circ)
pr = np.random.rand()
f, grad = grad_ops([pr])
sim.apply_circuit(circ, [pr])
pr = np.random.rand(2)
f, grad = grad_ops(pr)
sim.apply_circuit(circ, pr)
dm = sim.get_qs()
ref_f = np.trace(ham0.hamiltonian.matrix(3) @ dm)
dm = init_dm
for g in circ:
if g.parameterized:
dm = (
np.kron(np.eye(4, 4), g.diff_matrix({'a': pr}))
@ dm
@ np.kron(np.eye(4, 4), g.hermitian().matrix({'a': pr}))
)
elif isinstance(g, G.NoiseGate):
tmp = np.zeros((8, 8), dtype=mq.to_np_type(dtype))
for m in g.matrix():
ref_grad = []
for about_what in ('a', 'b'):
dm = init_dm
for g in circ:
if g.parameterized:
dm = (
np.kron(np.eye(4, 4), g.diff_matrix({'a': pr[0], 'b': pr[1]}, about_what))
@ dm
@ np.kron(np.eye(4, 4), g.hermitian().matrix({'a': pr[0], 'b': pr[1]}))
)
elif isinstance(g, G.NoiseGate):
tmp = np.zeros((8, 8), dtype=mq.to_np_type(dtype))
for m in g.matrix():
if g.obj_qubits[0] == 0:
big_m = np.kron(np.eye(4, 4), m)
elif g.obj_qubits[0] == 1:
big_m = np.kron(np.kron(np.eye(2, 2), m), np.eye(2, 2))
else:
big_m = np.kron(m, np.eye(4, 4))
tmp += big_m @ dm @ big_m.conj().T
dm = tmp
else:
if g.obj_qubits[0] == 0:
big_m = np.kron(np.eye(4, 4), m)
big_m = np.kron(np.eye(4, 4), g.matrix())
elif g.obj_qubits[0] == 1:
big_m = np.kron(np.kron(np.eye(2, 2), m), np.eye(2, 2))
big_m = np.kron(np.kron(np.eye(2, 2), g.matrix()), np.eye(2, 2))
else:
big_m = np.kron(m, np.eye(4, 4))
tmp += big_m @ dm @ big_m.conj().T
dm = tmp
else:
if g.obj_qubits[0] == 0:
big_m = np.kron(np.eye(4, 4), g.matrix())
elif g.obj_qubits[0] == 1:
big_m = np.kron(np.kron(np.eye(2, 2), g.matrix()), np.eye(2, 2))
else:
big_m = np.kron(g.matrix(), np.eye(4, 4))
dm = big_m @ dm @ big_m.conj().T
ref_grad = np.trace(ham0.hamiltonian.matrix(3) @ dm).real * 2
big_m = np.kron(g.matrix(), np.eye(4, 4))
dm = big_m @ dm @ big_m.conj().T
ref_grad.append(np.trace(ham0.hamiltonian.matrix(3) @ dm).real * 2)
assert np.allclose(f, ref_f, atol=1e-6)
assert np.allclose(grad, ref_grad, atol=1e-4)

Expand Down
Loading

0 comments on commit cd08444

Please sign in to comment.