diff --git a/CMakeLists.txt b/CMakeLists.txt index 39bf8dd1e..9fd344a97 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,10 @@ set(SOURCE_FILES_NO_MAIN src/simulators/detection_simulator.cc src/simulators/error_fuser.cc src/simulators/frame_simulator.cc + src/simulators/measure_record_batch.cc + src/simulators/measure_record_batch_writer.cc + src/simulators/measure_record.cc + src/simulators/measure_record_writer.cc src/simulators/tableau_simulator.cc src/simulators/vector_simulator.cc src/stabilizers/pauli_string.cc @@ -60,6 +64,7 @@ set(TEST_FILES src/main_helper.test.cc src/probability_util.test.cc src/simd/bit_ref.test.cc + src/simd/monotonic_buffer.test.cc src/simd/simd_bit_table.test.cc src/simd/simd_bits.test.cc src/simd/simd_bits_range_ref.test.cc @@ -69,6 +74,10 @@ set(TEST_FILES src/simulators/detection_simulator.test.cc src/simulators/error_fuser.test.cc src/simulators/frame_simulator.test.cc + src/simulators/measure_record.test.cc + src/simulators/measure_record_batch.test.cc + src/simulators/measure_record_batch_writer.test.cc + src/simulators/measure_record_writer.test.cc src/simulators/tableau_simulator.test.cc src/simulators/vector_simulator.test.cc src/stabilizers/pauli_string.test.cc @@ -80,6 +89,7 @@ set(BENCHMARK_FILES src/benchmark_util.perf.cc src/circuit/circuit.perf.cc src/circuit/gate_data.perf.cc + src/main.perf.cc src/probability_util.perf.cc src/simd/simd_bit_table.perf.cc src/simd/simd_bits.perf.cc diff --git a/README.md b/README.md index 1a19114c0..5f20d51ab 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Stim -Stim is a fast simulator for non-adaptive quantum stabilizer circuits. +Stim is a fast simulator for quantum stabilizer circuits. Stim is based on the stabilizer tableau representation introduced in [Scott Aaronson et al's CHP simulator](https://arxiv.org/abs/quant-ph/0406196). Stim makes three key improvements over CHP. @@ -28,90 +28,13 @@ Pauli string multiplication is a key bottleneck operation when updating a stabil Tracking Pauli frames can also benefit from vectorization, by combining them into batches and computing thousands of samples at a time. -# Usage (python) - -Stim can be installed into a python 3 environment using pip: - -```bash -pip install stim -``` - -Once stim is installed, you can `import stim` and use it. -There are two supported use cases: interactive usage and high speed sampling. - -You can use the Tableau simulator in an interactive fashion: - -```python -import stim - -s = stim.TableauSimulator() - -# Create a GHZ state. -s.h(0) -s.cnot(0, 1) -s.cnot(0, 2) - -# Measure the GHZ state. -print(s.measure_many(0, 1, 2)) # [False, False, False] or [True, True, True] -``` - -Alternatively, you can compile a circuit and then begin generating samples from it: - -```python -import stim - -# Create a circuit that measures a large GHZ state. -c = stim.Circuit() -c.append_operation("H", [0]) -for k in range(1, 30): - c.append_operation("CNOT", [0, k]) -c.append_operation("M", range(30)) - -# Compile the circuit into a high performance sampler. -sampler = c.compile_sampler() - -# Collect a batch of samples. -# Note: the ideal batch size, in terms of speed per sample, is roughly 1024. -# Smaller batches are slower because they are not sufficiently vectorized. -# Bigger batches are slower because they use more memory. -batch = sampler.sample(1024) -print(type(batch)) # numpy.ndarray -print(batch.dtype) # numpy.uint8 -print(batch.shape) # (1024, 30) -print(batch) -# Prints something like: -# [[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] -# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] -# [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] -# ... -# [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] -# [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] -# [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] -# [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]] -``` +# Building -The circuit can also include noise: - -```python -import stim -import numpy as np - -c = stim.Circuit(""" - X_ERROR(0.1) 0 - Y_ERROR(0.2) 1 - Z_ERROR(0.3) 2 - DEPOLARIZE1(0.4) 3 - DEPOLARIZE2(0.5) 4 5 - M 0 1 2 3 4 5 -""") -batch = c.compile_sampler().sample(2**20) -print(np.mean(batch, axis=0).round(3)) -# Prints something like: -# [0.1 0.2 0. 0.267 0.267 0.266] -``` +See the [developer documentation](dev/README.md). -You can also sample detection events using `stim.Circuit.compile_detector_sampler`. +# Usage (python) +See the [python documentation](glue/python/README.md). # Usage (command line) @@ -281,7 +204,7 @@ Only one mode can be specified. Detection event sampling mode. Outputs whether or not measurement sets specified by `DETECTOR` instructions have been flipped by noise. Assumes (does not verify) that all `DETECTOR` instructions corresponding to measurement sets with deterministic parity. - See also `--prepend_observables`, `--append_observables`. + See also `--append_observables`. If an integer argument is specified, run that many shots of the circuit. - `--detector_hypergraph`: Detector graph creation mode. @@ -307,11 +230,6 @@ Not all modifiers apply to all modes. In addition to outputting the values of detectors, output the values of logical observables built up using `OBSERVABLE_INCLUDE` instructions. Put these observables' values into the detection event output as if they were additional detectors at the end of the circuit. -- `--prepend_observables`: - Requires detection event sampling mode. - In addition to outputting the values of detectors, output the values of logical observables - built up using `OBSERVABLE_INCLUDE` instructions. - Put these observables' values into the detection event output as if they were additional detectors at the start of the circuit - `--in=FILEPATH`: Specifies a file to read a circuit from. If not specified, the `stdin` pipe is used. @@ -579,140 +497,3 @@ Not all modifiers apply to all modes. - `TICK`: Optional command indicating the end of a layer of gates. May be ignored, may force processing of internally queued operations and flushing of queued measurement results. - `REPEAT N { ... }`: Repeats the instructions in its body N times. - - -# Building - -### CMake Build - -```bash -cmake . -make stim -# ./out/stim -``` - -To control the vectorization (e.g. this is done for testing), -use `cmake . -DSIMD_WIDTH=256` (implying `-mavx2`) -or `cmake . -DSIMD_WIDTH=128` (implying `-msse2`) -or `cmake . -DSIMD_WIDTH=64` (implying no machine architecture flag). -If `SIMD_WIDTH` is not specified, `-march=native` is used. - -### Bazel Build - -```bash -bazel build stim -# bazel run stim -``` - -### Manual Build - -```bash -find src | grep "\\.cc" | grep -v "\\.\(test\|perf\|pybind\)\\.cc" | xargs g++ -pthread -std=c++11 -O3 -march=native -# ./a.out -``` - -### Python Package Build - -Environment requirements: - -```bash -pip install -y pybind11 cibuildwheel -``` - -Build source distribution (fallback for missing binary wheels): - -```bash -python setup.py sdist -``` - -Output in `dist` directory. - -Build manylinux binary distributions (takes 30+ minutes): - -```bash -python -m cibuildwheel --output-dir wheelhouse --platform=linux -``` - -Output in `wheelhouse` directory. - -Build `stimcirq` package: - -```bash -cd glue/cirq -python setup.py sdist -``` - -Output in `glue/cirq/dist` directory. - -# Testing - -### Run tests using CMAKE - -Unit testing with CMAKE requires GTest to be installed on your system and discoverable by CMake. -Follow the ["Standalone CMake Project" from the GTest README](https://github.com/google/googletest/tree/master/googletest). - -Run tests with address and memory sanitization, but without optimizations: - -```bash -cmake . -make stim_test -./out/stim_test -``` - -To force AVX vectorization, SSE vectorization, or no vectorization -pass `-DSIMD_WIDTH=256` or `-DSIMD_WIDTH=128` or -DSIMD_WIDTH=64` to the `cmake` command. - -Run tests with optimizations without sanitization: - -```bash -cmake . -make stim_test_o3 -./out/stim_test_o3 -``` - -### Run tests using Bazel - -Run tests with whatever settings Bazel feels like using: - -```bash -bazel :stim_test -``` - -### Run python binding tests - -In a fresh virtual environment: - -```bash -pip install -e . -pip install -y numpy pytest -python -m pytest src -``` - -# Benchmarking - -```bash -cmake . -make stim_benchmark -./out/stim_benchmark -``` - -This will output results like: - -``` -[....................*....................] 460 ns (vs 450 ns) ( 21 GBits/s) simd_bits_randomize_10K -[...................*|....................] 24 ns (vs 20 ns) (400 GBits/s) simd_bits_xor_10K -[....................|>>>>*...............] 3.6 ns (vs 4.0 ns) (270 GBits/s) simd_bits_not_zero_100K -[....................*....................] 5.8 ms (vs 6.0 ms) ( 17 GBits/s) simd_bit_table_inplace_square_transpose_diam10K -[...............*<<<<|....................] 8.1 ms (vs 5.0 ms) ( 12 GOpQubits/s) FrameSimulator_depolarize1_100Kqubits_1Ksamples_per1000 -[....................*....................] 5.3 ms (vs 5.0 ms) ( 18 GOpQubits/s) FrameSimulator_depolarize2_100Kqubits_1Ksamples_per1000 -``` - -The bars on the left show how fast each task is running compared to baseline expectations (on my dev machine). -Each tick away from the center `|` is 1 decibel slower or faster (i.e. each `<` or `>` represents a factor of `1.26`). - -Basically, if you see `[......*<<<<<<<<<<<<<|....................]` then something is *seriously* wrong, because the -code is running 25x slower than expected. - -The benchmark binary supports a `--only=BENCHMARK_NAME` filter flag. -Multiple filters can be specified by separating them with commas `--only=A,B`. -Ending a filter with a `*` turns it into a prefix filter `--only=sim_*`. diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 000000000..90b5fb60a --- /dev/null +++ b/dev/README.md @@ -0,0 +1,163 @@ +# Stim Developer Documentation + +This is documentation for programmers working with stim, e.g. how to build it. +These notes generally assume you are on a Linux system. + +## Build stim command line tool + +### CMake Build + +```bash +cmake . +make stim +# ./out/stim +``` + +To control the vectorization (e.g. this is done for testing), +use `cmake . -DSIMD_WIDTH=256` (implying `-mavx2`) +or `cmake . -DSIMD_WIDTH=128` (implying `-msse2`) +or `cmake . -DSIMD_WIDTH=64` (implying no machine architecture flag). +If `SIMD_WIDTH` is not specified, `-march=native` is used. + +### Bazel Build + +```bash +bazel build stim +# bazel run stim +``` + +### Manual Build + +```bash +find src | grep "\\.cc" | grep -v "\\.\(test\|perf\|pybind\)\\.cc" | xargs g++ -pthread -std=c++11 -O3 -march=native +# ./a.out +``` + +# Profile stim command line tool + +```bash +find src | grep "\\.cc" | grep -v "\\.\(test\|perf\|pybind\)\\.cc" | xargs g++ -pthread -std=c++11 -O3 -march=native -g -fno-omit-frame-pointer +sudo perf record -g ./a.out # [ADD STIM FLAGS FOR THE CASE YOU WANT TO PROFILE] +sudo perf report +``` + +# Run benchmarks + +```bash +cmake . +make stim_benchmark +./out/stim_benchmark +``` + +This will output results like: + +``` +[....................*....................] 460 ns (vs 450 ns) ( 21 GBits/s) simd_bits_randomize_10K +[...................*|....................] 24 ns (vs 20 ns) (400 GBits/s) simd_bits_xor_10K +[....................|>>>>*...............] 3.6 ns (vs 4.0 ns) (270 GBits/s) simd_bits_not_zero_100K +[....................*....................] 5.8 ms (vs 6.0 ms) ( 17 GBits/s) simd_bit_table_inplace_square_transpose_diam10K +[...............*<<<<|....................] 8.1 ms (vs 5.0 ms) ( 12 GOpQubits/s) FrameSimulator_depolarize1_100Kqubits_1Ksamples_per1000 +[....................*....................] 5.3 ms (vs 5.0 ms) ( 18 GOpQubits/s) FrameSimulator_depolarize2_100Kqubits_1Ksamples_per1000 +``` + +The bars on the left show how fast each task is running compared to baseline expectations (on my dev machine). +Each tick away from the center `|` is 1 decibel slower or faster (i.e. each `<` or `>` represents a factor of `1.26`). + +Basically, if you see `[......*<<<<<<<<<<<<<|....................]` then something is *seriously* wrong, because the +code is running 25x slower than expected. + +The benchmark binary supports a `--only=BENCHMARK_NAME` filter flag. +Multiple filters can be specified by separating them with commas `--only=A,B`. +Ending a filter with a `*` turns it into a prefix filter `--only=sim_*`. + +# Build stim python package + +Ensure python environment dependencies are present: + +```bash +pip install -y pybind11 +``` + +Create a source distribution: + +```bash +python setup.py sdist +``` + +Output is in the `dist` directory, and can be uploaded using `twine`. + +```bash +twine upload --username="${PROD_TWINE_USERNAME}" --password="${PROD_TWINE_PASSWORD}" dist/[CREATED_FILE_GOES_HERE] +``` + +# Build stimcirq python package + +Create a source distribution: + +```bash +cd glue/cirq +python setup.py sdist +cd ../.. +``` + +Output is in the `glue/cirq/dist` directory, and can be uploaded using `twine`. + +```bash +twine upload --username="${PROD_TWINE_USERNAME}" --password="${PROD_TWINE_PASSWORD}" glue/cirq/dist/[CREATED_FILE_GOES_HERE] +``` + +# Testing + +### Run C++ tests using CMAKE + +Unit testing with CMAKE requires GTest to be installed on your system and discoverable by CMake. +Follow the ["Standalone CMake Project" from the GTest README](https://github.com/google/googletest/tree/master/googletest). + +Run tests with address and memory sanitization, but without optimizations: + +```bash +cmake . +make stim_test +./out/stim_test +``` + +To force AVX vectorization, SSE vectorization, or no vectorization +pass `-DSIMD_WIDTH=256` or `-DSIMD_WIDTH=128` or -DSIMD_WIDTH=64` to the `cmake` command. + +Run tests with optimizations without sanitization: + +```bash +cmake . +make stim_test_o3 +./out/stim_test_o3 +``` + +### Run C++ tests using Bazel + +Run tests with whatever settings Bazel feels like using: + +```bash +bazel :stim_test +``` + +### Run stim python package tests + +In a clean virtual environment: + +```bash +pip install pytest doctest +pip install -e . +python -m pytest src +python -c "import stim; import doctest; doctest.testmod(stim)" +``` + +### Run stimcirq python package tests + +In a clean virtual environment: + +```bash +pip install pytest doctest +pip install -e glue/cirq +python -m pytest glue/cirq +python -c "import stimcirq; import doctest; doctest.testmod(stimcirq)" +``` diff --git a/glue/cirq/setup.py b/glue/cirq/setup.py index bd79040fe..4b410f79a 100644 --- a/glue/cirq/setup.py +++ b/glue/cirq/setup.py @@ -30,6 +30,6 @@ long_description_content_type='text/markdown', python_requires='>=3.6.0', data_files=['README.md'], - install_requires=['stim', 'cirq~=0.10.0'], + install_requires=['stim', 'cirq'], tests_require=['pytest', 'python3-distutils'], ) diff --git a/glue/cirq/stimcirq/_stim_sampler_test.py b/glue/cirq/stimcirq/_stim_sampler_test.py index c80f30952..8eebff25b 100644 --- a/glue/cirq/stimcirq/_stim_sampler_test.py +++ b/glue/cirq/stimcirq/_stim_sampler_test.py @@ -113,7 +113,6 @@ def test_more_unitary_gate_conversions(): assert ( str(c).strip() == """ -# Circuit [num_qubits=2, num_measurements=2] H 0 CX 0 1 M 0 1 @@ -314,7 +313,6 @@ def qubits(self) -> Tuple['cirq.Qid', ...]: stim_circuit = stimcirq.cirq_circuit_to_stim_circuit(cirq_circuit) assert str(stim_circuit).strip() == """ -# Circuit [num_qubits=3, num_measurements=5] M 0 1 2 DETECTOR rec[-2] M !1 1 diff --git a/glue/python/README.md b/glue/python/README.md index 1f3b9a1f1..9f8b31fb0 100644 --- a/glue/python/README.md +++ b/glue/python/README.md @@ -1,8 +1,8 @@ # Stim -Stim is a fast simulator for non-adaptive quantum stabilizer circuits. +Stim is a fast simulator for quantum stabilizer circuits. -API reference: https://github.com/quantumlib/stim/wiki +API references are available on the stim github wiki: https://github.com/quantumlib/stim/wiki Stim can be installed into a python 3 environment using pip: diff --git a/src/circuit/circuit.cc b/src/circuit/circuit.cc index 2e803400e..6784f89ef 100644 --- a/src/circuit/circuit.cc +++ b/src/circuit/circuit.cc @@ -28,9 +28,23 @@ enum READ_CONDITION { READ_UNTIL_END_OF_FILE, }; +/// Concatenates the second pointer range's data into the first. +/// Typically, the two ranges are contiguous and so this only requires advancing the end of the destination region. +/// In cases where that doesn't occur, space is created in the given monotonic buffer to store the result and both +/// the start and end of the destination range move. +void fuse_data(PointerRange &dst, PointerRange src, MonotonicBuffer &buf) { + if (dst.ptr_end != src.ptr_start) { + buf.ensure_available(src.size() + dst.size()); + dst = buf.take_copy(dst); + src = buf.take_copy(src); + } + assert(dst.ptr_end == src.ptr_start); + dst.ptr_end = src.ptr_end; +} + DetectorsAndObservables::DetectorsAndObservables(const Circuit &circuit) { size_t tick = 0; - auto resolve_into = [&](const Operation &op, std::vector &out) { + auto resolve_into = [&](const Operation &op, const std::function &func) { for (auto qb : op.target_data.targets) { auto dt = qb ^ TARGET_RECORD_BIT; if (!dt) { @@ -39,63 +53,62 @@ DetectorsAndObservables::DetectorsAndObservables(const Circuit &circuit) { if (dt > tick) { throw std::out_of_range("Referred to a measurement result before the beginning of time."); } - out.push_back(tick - dt); + func(tick - dt); } }; - for (const auto &p : circuit.operations) { + circuit.for_each_operation([&](const Operation &p) { if (p.gate->flags & GATE_PRODUCES_RESULTS) { tick += p.target_data.targets.size(); } else if (p.gate->id == gate_name_to_id("DETECTOR")) { - size_t n = jagged_data.vec.size(); - resolve_into(p, jagged_data.vec); - detectors.push_back(jagged_data.tail_view(n)); + resolve_into(p, [&](uint32_t k) { + jagged_detector_data.append_tail(k); + }); + detectors.push_back(jagged_detector_data.commit_tail()); } else if (p.gate->id == gate_name_to_id("OBSERVABLE_INCLUDE")) { size_t obs = (size_t)p.target_data.arg; if (obs != p.target_data.arg) { throw std::out_of_range("Observable index must be an integer."); } while (observables.size() <= obs) { - observables.push_back({}); + observables.emplace_back(); } - resolve_into(p, observables[obs]); + resolve_into(p, [&](uint32_t k) { + observables[obs].push_back(k); + }); } - } + }); } -Circuit::Circuit() : jagged_target_data(), operations(), num_qubits(0), num_measurements(0) { +Circuit::Circuit() : jag_targets(), operations(), blocks() { } Circuit::Circuit(const Circuit &circuit) - : jagged_target_data(), - operations(circuit.operations), - num_qubits(circuit.num_qubits), - num_measurements(circuit.num_measurements) { - jagged_target_data.vec = circuit.jagged_target_data.vec; + : jag_targets(circuit.jag_targets.total_allocated()), operations(circuit.operations), blocks(circuit.blocks) { + // Keep local copy of operation data. for (auto &op : operations) { - op.target_data.targets.vec_ptr = &jagged_target_data.vec; + op.target_data.targets = jag_targets.take_copy(op.target_data.targets); } } Circuit::Circuit(Circuit &&circuit) noexcept - : jagged_target_data(), + : jag_targets(circuit.jag_targets.total_allocated()), operations(std::move(circuit.operations)), - num_qubits(circuit.num_qubits), - num_measurements(circuit.num_measurements) { - jagged_target_data.vec = std::move(circuit.jagged_target_data.vec); + blocks(circuit.blocks) { + // Keep local copy of operation data. for (auto &op : operations) { - op.target_data.targets.vec_ptr = &jagged_target_data.vec; + op.target_data.targets = jag_targets.take_copy(op.target_data.targets); } } Circuit &Circuit::operator=(const Circuit &circuit) { if (&circuit != this) { - num_qubits = circuit.num_qubits; - num_measurements = circuit.num_measurements; - operations = circuit.operations; - jagged_target_data.vec = circuit.jagged_target_data.vec; + blocks = circuit.blocks; + + // Keep local copy of operation data. + jag_targets = MonotonicBuffer(circuit.jag_targets.total_allocated()); for (auto &op : operations) { - op.target_data.targets.vec_ptr = &jagged_target_data.vec; + op.target_data.targets = jag_targets.take_copy(op.target_data.targets); } } return *this; @@ -103,12 +116,13 @@ Circuit &Circuit::operator=(const Circuit &circuit) { Circuit &Circuit::operator=(Circuit &&circuit) noexcept { if (&circuit != this) { - num_qubits = circuit.num_qubits; - num_measurements = circuit.num_measurements; operations = std::move(circuit.operations); - jagged_target_data.vec = std::move(circuit.jagged_target_data.vec); + blocks = std::move(circuit.blocks); + + // Keep local copy of operation data. + jag_targets = MonotonicBuffer(circuit.jag_targets.total_allocated()); for (auto &op : operations) { - op.target_data.targets.vec_ptr = &jagged_target_data.vec; + op.target_data.targets = jag_targets.take_copy(op.target_data.targets); } } return *this; @@ -141,15 +155,13 @@ bool OperationData::operator!=(const OperationData &other) const { } bool Circuit::operator==(const Circuit &other) const { - return num_qubits == other.num_qubits && num_measurements == other.num_measurements && - operations == other.operations; + return operations == other.operations && blocks == other.blocks; } bool Circuit::operator!=(const Circuit &other) const { return !(*this == other); } bool Circuit::approx_equals(const Circuit &other, double atol) const { - if (num_qubits != other.num_qubits || num_measurements != other.num_measurements || - operations.size() != other.operations.size()) { + if (operations.size() != other.operations.size() || blocks.size() != other.blocks.size()) { return false; } for (size_t k = 0; k < operations.size(); k++) { @@ -157,6 +169,11 @@ bool Circuit::approx_equals(const Circuit &other, double atol) const { return false; } } + for (size_t k = 0; k < blocks.size(); k++) { + if (!blocks[k].approx_equals(other.blocks[k], atol)) { + return false; + } + } return true; } @@ -259,8 +276,7 @@ bool read_until_next_line_arg(int &c, SOURCE read_char) { template inline void read_raw_qubit_target_into(int &c, SOURCE read_char, Circuit &circuit) { uint32_t q = read_uint24_t(c, read_char); - circuit.jagged_target_data.vec.push_back(q); - circuit.num_qubits = std::max(circuit.num_qubits, (size_t)q + 1); + circuit.jag_targets.append_tail(q); } template @@ -274,7 +290,7 @@ inline void read_record_target_into(int &c, SOURCE read_char, Circuit &circuit) throw std::out_of_range("Expected a record argument like 'rec[-1]'."); } c = read_char(); - circuit.jagged_target_data.vec.push_back(lookback | TARGET_RECORD_BIT); + circuit.jag_targets.append_tail(lookback | TARGET_RECORD_BIT); } template @@ -313,22 +329,19 @@ inline void read_pauli_targets_into(int &c, SOURCE read_char, Circuit &circuit) throw std::out_of_range("Unexpected space after Pauli before target qubit index."); } size_t q = read_uint24_t(c, read_char); - circuit.jagged_target_data.vec.push_back(q | m); - circuit.num_qubits = std::max(circuit.num_qubits, (size_t)q + 1); + circuit.jag_targets.append_tail(q | m); } } template inline void read_result_targets_into(int &c, SOURCE read_char, const Gate &gate, Circuit &circuit) { while (read_until_next_line_arg(c, read_char)) { - uint32_t flipped = c == '!' ? uint32_t{1} << 31 : 0; - if (flipped) { + uint32_t flipped_flag = c == '!' ? uint32_t{1} << 31 : 0; + if (flipped_flag) { c = read_char(); } uint32_t q = read_uint24_t(c, read_char); - circuit.num_qubits = std::max(circuit.num_qubits, (size_t)q + 1); - circuit.jagged_target_data.vec.push_back(q ^ flipped); - circuit.num_measurements++; + circuit.jag_targets.append_tail(q ^ flipped_flag); } } @@ -359,14 +372,13 @@ void read_past_dead_space_between_commands(int &c, SOURCE read_char) { template void circuit_read_single_operation(Circuit &circuit, char lead_char, SOURCE read_char) { - int c = lead_char; + int c = (int)lead_char; const auto &gate = read_gate_name(c, read_char); double val = 0; if (gate.flags & GATE_TAKES_PARENS_ARGUMENT) { read_past_within_line_whitespace(c, read_char); val = read_parens_argument(c, gate, read_char); } - size_t offset = circuit.jagged_target_data.vec.size(); if (!(gate.flags & (GATE_IS_BLOCK | GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_PRODUCES_RESULTS | GATE_TARGETS_PAULI_STRING | GATE_CAN_TARGET_MEASUREMENT_RECORD))) { read_raw_qubit_targets_into(c, read_char, circuit); @@ -380,17 +392,17 @@ void circuit_read_single_operation(Circuit &circuit, char lead_char, SOURCE read read_pauli_targets_into(c, read_char, circuit); } else { while (read_until_next_line_arg(c, read_char)) { - circuit.jagged_target_data.vec.push_back(read_uint24_t(c, read_char)); + circuit.jag_targets.append_tail(read_uint24_t(c, read_char)); } } - if (c != '{' && (gate.flags & GATE_IS_BLOCK && c != '{')) { + if (c != '{' && (gate.flags & GATE_IS_BLOCK)) { throw std::out_of_range("Missing '{' at start of " + std::string(gate.name) + " block."); } if (c == '{' && !(gate.flags & GATE_IS_BLOCK)) { throw std::out_of_range("Unexpected '{' after non-block command " + std::string(gate.name) + "."); } - auto view = circuit.jagged_target_data.tail_view(offset); + auto view = circuit.jag_targets.commit_tail(); if (gate.flags & GATE_TARGETS_PAIRS) { if (view.size() & 1) { throw std::out_of_range( @@ -425,41 +437,39 @@ void circuit_read_operations(Circuit &circuit, SOURCE read_char, READ_CONDITION } return; } - size_t s = ops.size(); circuit_read_single_operation(circuit, c, read_char); + Operation &new_op = ops.back(); - if (ops[s].gate->id == gate_name_to_id("REPEAT")) { - if (ops[s].target_data.targets.size() != 1) { + if (new_op.gate->id == gate_name_to_id("REPEAT")) { + if (new_op.target_data.targets.size() != 1) { throw std::out_of_range("Invalid instruction. Expected one repetition arg like `REPEAT 100 {`."); } - size_t rep_count = circuit.jagged_target_data.vec.back(); - circuit.jagged_target_data.vec.pop_back(); - ops.pop_back(); + uint32_t rep_count = new_op.target_data.targets[0]; + uint32_t block_id = circuit.blocks.size(); if (rep_count == 0) { throw std::out_of_range("Repeating 0 times is not supported."); } - size_t ops_start = ops.size(); - size_t num_measure_start = circuit.num_measurements; - circuit.fusion_barrier(); - circuit_read_operations(circuit, read_char, READ_UNTIL_END_OF_BLOCK); - size_t ops_end = ops.size(); - circuit.num_measurements += (circuit.num_measurements - num_measure_start) * (rep_count - 1); - while (rep_count > 1) { - ops.insert(ops.end(), ops.data() + ops_start, ops.data() + ops_end); - rep_count--; - } - circuit.fusion_barrier(); + + // Read block. + circuit.blocks.emplace_back(); + circuit_read_operations(circuit.blocks.back(), read_char, READ_UNTIL_END_OF_BLOCK); + + // Rewrite target data to reference the parsed block. + circuit.jag_targets.ensure_available(2); + circuit.jag_targets.append_tail(block_id); + circuit.jag_targets.append_tail(rep_count); + new_op.target_data.targets = circuit.jag_targets.commit_tail(); } - while (s > circuit.min_safe_fusion_index && ops[s - 1].can_fuse(ops[s])) { - ops[s - 1].target_data.targets.length += ops[s].target_data.targets.length; + + // Fuse operations. + while (ops.size() > 1 && ops[ops.size() - 2].can_fuse(new_op)) { + fuse_data(ops[ops.size() - 2].target_data.targets, new_op.target_data.targets, circuit.jag_targets); ops.pop_back(); - s--; } } while (read_condition != READ_AS_LITTLE_AS_POSSIBLE); } -bool Circuit::append_from_text(const char *text) { - size_t before = operations.size(); +void Circuit::append_from_text(const char *text) { size_t k = 0; circuit_read_operations( *this, @@ -467,70 +477,28 @@ bool Circuit::append_from_text(const char *text) { return text[k] != 0 ? text[k++] : EOF; }, READ_UNTIL_END_OF_FILE); - return operations.size() > before; -} - -void Circuit::update_metadata_for_manually_appended_operation() { - const auto &op = operations.back(); - const auto &gate = *op.gate; - const auto &vec = op.target_data.targets; - if (gate.flags & GATE_PRODUCES_RESULTS) { - num_measurements += vec.size(); - } - for (auto q : vec) { - if (!(q & TARGET_RECORD_BIT)) { - num_qubits = std::max(num_qubits, (size_t)((q & TARGET_VALUE_MASK) + 1)); - } - } -} - -void Circuit::append_circuit(const Circuit &circuit, size_t repetitions) { - if (!repetitions) { - return; - } - auto original_size = operations.size(); - - if (&circuit == this) { - num_measurements *= repetitions + 1; - do { - operations.insert(operations.end(), operations.begin(), operations.begin() + original_size); - } while (--repetitions); - return; - } - - fusion_barrier(); - for (const auto &op : circuit.operations) { - append_operation(op); - } - auto single_rep_end = operations.end(); - while (--repetitions) { - operations.insert(operations.end(), operations.begin() + original_size, single_rep_end); - } - fusion_barrier(); } void Circuit::append_operation(const Operation &operation) { operations.push_back( - {operation.gate, {operation.target_data.arg, jagged_target_data.inserted(operation.target_data.targets)}}); - update_metadata_for_manually_appended_operation(); + {operation.gate, {operation.target_data.arg, jag_targets.take_copy(operation.target_data.targets)}}); } void Circuit::append_op(const std::string &gate_name, const std::vector &vec, double arg) { const auto &gate = GATE_DATA.at(gate_name); - append_operation(gate, vec.data(), vec.size(), arg); + append_operation(gate, vec, arg); } -void Circuit::append_operation( - const Gate &gate, const uint32_t *targets_start, size_t num_targets, double arg) { +void Circuit::append_operation(const Gate &gate, ConstPointerRange targets, double arg) { if (gate.flags & GATE_TARGETS_PAIRS) { - if (num_targets & 1) { + if (targets.size() & 1) { throw std::out_of_range( "Two qubit gate " + std::string(gate.name) + " requires have an even number of targets."); } - for (size_t k = 0; k < num_targets; k += 2) { - if (targets_start[k] == targets_start[k + 1]) { + for (size_t k = 0; k < targets.size(); k += 2) { + if (targets[k] == targets[k + 1]) { throw std::out_of_range( - "Interacting a target with itself " + std::to_string(targets_start[k] & TARGET_VALUE_MASK) + + "Interacting a target with itself " + std::to_string(targets[k] & TARGET_VALUE_MASK) + " using gate " + std::string(gate.name) + "."); } } @@ -550,8 +518,10 @@ void Circuit::append_operation( if (gate.flags & (GATE_ONLY_TARGETS_MEASUREMENT_RECORD | GATE_CAN_TARGET_MEASUREMENT_RECORD)) { valid_target_mask |= TARGET_RECORD_BIT; } - for (size_t k = 0; k < num_targets; k++) { - auto q = targets_start[k]; + if (gate.flags & GATE_IS_BLOCK) { + throw std::out_of_range("Can't append a block as an operation."); + } + for (uint32_t q : targets) { if (q != (q & valid_target_mask)) { throw std::out_of_range( "Target " + std::to_string(q & TARGET_VALUE_MASK) + " has invalid flags " + @@ -559,32 +529,24 @@ void Circuit::append_operation( } } - if (!(gate.flags & GATE_IS_NOT_FUSABLE) && operations.size() > min_safe_fusion_index && - operations.back().gate->id == gate.id && operations.back().target_data.arg == arg) { - // Don't double count measurements when doing incremental update. - if (gate.flags & GATE_PRODUCES_RESULTS) { - num_measurements -= operations.back().target_data.targets.size(); - } + auto added = jag_targets.take_copy(targets); + Operation to_add = {&gate, {arg, added}}; + if (!operations.empty() && operations.back().can_fuse(to_add)) { // Extend targets of last gate. - jagged_target_data.vec.insert(jagged_target_data.vec.end(), targets_start, targets_start + num_targets); - operations.back().target_data.targets.length += num_targets; + fuse_data(operations.back().target_data.targets, to_add.target_data.targets, jag_targets); } else { // Add a fresh new operation with its own target data. - operations.push_back({&gate, {arg, jagged_target_data.inserted(targets_start, num_targets)}}); + operations.push_back(to_add); } - // Update num_measurements and num_qubits appropriately. - update_metadata_for_manually_appended_operation(); } -bool Circuit::append_from_file(FILE *file, bool stop_asap) { - size_t before = operations.size(); +void Circuit::append_from_file(FILE *file, bool stop_asap) { circuit_read_operations( *this, [&]() { return getc(file); }, stop_asap ? READ_AS_LITTLE_AS_POSSIBLE : READ_UNTIL_END_OF_FILE); - return operations.size() > before; } std::ostream &operator<<(std::ostream &out, const Operation &op) { @@ -617,20 +579,38 @@ std::ostream &operator<<(std::ostream &out, const Operation &op) { return out; } -std::ostream &operator<<(std::ostream &out, const Circuit &c) { - out << "# Circuit [num_qubits=" << c.num_qubits << ", num_measurements=" << c.num_measurements << "]"; +void print_circuit(std::ostream &out, const Circuit &c, const std::string &indentation) { + bool first = true; for (const auto &op : c.operations) { - out << "\n" << op; + if (first) { + first = false; + } else { + out << "\n"; + } + + // Recurse on repeat blocks. + if (op.gate && op.gate->id == gate_name_to_id("REPEAT")) { + if (op.target_data.targets.size() == 2 && op.target_data.targets[0] < c.blocks.size()) { + out << indentation << "REPEAT " << op.target_data.targets[1] << " {\n"; + print_circuit(out, c.blocks[op.target_data.targets[0]], indentation + " "); + out << "\n" << indentation << "}"; + continue; + } + } + + out << indentation << op; } +} + +std::ostream &operator<<(std::ostream &out, const Circuit &c) { + print_circuit(out, c, ""); return out; } void Circuit::clear() { - num_qubits = 0; - num_measurements = 0; - jagged_target_data.vec.clear(); + jag_targets.clear(); operations.clear(); - min_safe_fusion_index = 0; + blocks.clear(); } Circuit Circuit::operator+(const Circuit &other) const { @@ -639,24 +619,55 @@ Circuit Circuit::operator+(const Circuit &other) const { return result; } Circuit Circuit::operator*(size_t repetitions) const { - Circuit result = *this; - result *= repetitions; + if (repetitions == 0) { + return Circuit(); + } + if (repetitions == 1) { + return *this; + } + // If the entire circuit is a repeat block, just adjust its repeat count. + if (operations.size() == 1 && operations[0].gate->id == gate_name_to_id("REPEAT")) { + uint64_t old_reps = operations[0].target_data.targets[1]; + uint64_t new_reps = old_reps * repetitions; + // Don't create an overflowed repeat count. + if (new_reps == (new_reps & TARGET_VALUE_MASK)) { + Circuit copy = *this; + copy.operations[0].target_data.targets[1] *= repetitions; + return copy; + } + } + Circuit result; + result.blocks.push_back(*this); + result.jag_targets.append_tail(0); + result.jag_targets.append_tail(repetitions); + result.operations.push_back({&GATE_DATA.at("REPEAT"), {0, result.jag_targets.commit_tail()}}); return result; } -void Circuit::fusion_barrier() { - min_safe_fusion_index = operations.size(); -} - Circuit &Circuit::operator+=(const Circuit &other) { - append_circuit(other, 1); + if (&other == this) { + operations.insert(operations.end(), operations.begin(), operations.end()); + return *this; + } + + size_t block_offset = blocks.size(); + blocks.insert(blocks.end(), other.blocks.begin(), other.blocks.end()); + for (const auto &op : other.operations) { + assert(op.gate != nullptr); + append_operation(op); + if (op.gate->id == gate_name_to_id("REPEAT")) { + assert(op.target_data.targets.size() == 2); + operations.back().target_data.targets[0] += block_offset; + } + } + return *this; } Circuit &Circuit::operator*=(size_t repetitions) { if (repetitions == 0) { clear(); } else { - append_circuit(*this, repetitions - 1); + *this = *this * repetitions; } return *this; } @@ -686,37 +697,117 @@ Circuit Circuit::from_text(const char *text) { } DetectorsAndObservables::DetectorsAndObservables(DetectorsAndObservables &&other) noexcept - : jagged_data(), detectors(std::move(other.detectors)), observables(std::move(other.observables)) { - jagged_data.vec = std::move(other.jagged_data.vec); - for (auto &e : detectors) { - e.vec_ptr = &jagged_data.vec; + : jagged_detector_data(other.jagged_detector_data.total_allocated()), + detectors(std::move(other.detectors)), + observables(std::move(other.observables)) { + // Keep a local copy of the detector data. + for (PointerRange &e : detectors) { + e = jagged_detector_data.take_copy(e); } } DetectorsAndObservables &DetectorsAndObservables::operator=(DetectorsAndObservables &&other) noexcept { - jagged_data.vec = std::move(other.jagged_data.vec); observables = std::move(other.observables); detectors = std::move(other.detectors); - for (auto &e : detectors) { - e.vec_ptr = &jagged_data.vec; + + // Keep a local copy of the detector data. + jagged_detector_data = MonotonicBuffer(other.jagged_detector_data.total_allocated()); + for (PointerRange &e : detectors) { + e = jagged_detector_data.take_copy(e); } + return *this; } DetectorsAndObservables::DetectorsAndObservables(const DetectorsAndObservables &other) - : jagged_data(), detectors(other.detectors), observables(other.observables) { - jagged_data.vec = other.jagged_data.vec; - for (auto &e : detectors) { - e.vec_ptr = &jagged_data.vec; + : jagged_detector_data(other.jagged_detector_data.total_allocated()), + detectors(other.detectors), + observables(other.observables) { + // Keep a local copy of the detector data. + for (PointerRange &e : detectors) { + e = jagged_detector_data.take_copy(e); } } DetectorsAndObservables &DetectorsAndObservables::operator=(const DetectorsAndObservables &other) { - jagged_data.vec = other.jagged_data.vec; + if (this == &other) { + return *this; + } + observables = other.observables; detectors = other.detectors; - for (auto &e : detectors) { - e.vec_ptr = &jagged_data.vec; + + // Keep a local copy of the detector data. + jagged_detector_data = MonotonicBuffer(other.jagged_detector_data.total_allocated()); + for (PointerRange &e : detectors) { + e = jagged_detector_data.take_copy(e); } + return *this; } + +size_t Circuit::count_qubits() const { + size_t n = 0; + for (const auto &block : blocks) { + n = std::max(n, block.count_qubits()); + } + for (const auto &op : operations) { + if (op.gate->flags & GATE_IS_BLOCK) { + // Handled in block case. + continue; + } + for (uint32_t t : op.target_data.targets) { + if (!(t & TARGET_RECORD_BIT)) { + n = std::max(n, (t & TARGET_VALUE_MASK) + size_t{1}); + } + } + } + return n; +} + +size_t Circuit::max_lookback() const { + size_t n = 0; + for (const auto &block : blocks) { + n = std::max(n, block.max_lookback()); + } + for (const auto &op : operations) { + if (op.gate->flags & (GATE_CAN_TARGET_MEASUREMENT_RECORD | GATE_ONLY_TARGETS_MEASUREMENT_RECORD)) { + for (uint32_t t : op.target_data.targets) { + if (t & TARGET_RECORD_BIT) { + n = std::max(n, size_t{t & TARGET_VALUE_MASK}); + } + } + } + } + return n; +} + +uint64_t Circuit::count_measurements() const { + uint64_t n = 0; + for (const auto &op : operations) { + assert(op.gate != nullptr); + if (op.gate->id == gate_name_to_id("REPEAT")) { + assert(op.target_data.targets.size() == 2); + assert(op.target_data.targets[0] < blocks.size()); + n += blocks[op.target_data.targets[0]].count_measurements() * op.target_data.targets[1]; + } else if (op.gate->flags & GATE_PRODUCES_RESULTS) { + n += op.target_data.targets.size(); + } + } + return n; +} + +uint64_t Circuit::count_detectors_and_observables() const { + uint64_t n = 0; + for (const auto &op : operations) { + assert(op.gate != nullptr); + if (op.gate->id == gate_name_to_id("REPEAT")) { + assert(op.target_data.targets.size() == 2); + assert(op.target_data.targets[0] < blocks.size()); + n += blocks[op.target_data.targets[0]].count_detectors_and_observables() * op.target_data.targets[1]; + } else if (op.gate->id == gate_name_to_id("DETECTOR") || op.gate->id == gate_name_to_id("OBSERVABLE_INCLUDE")) { + n++; + } + } + return n; +} diff --git a/src/circuit/circuit.h b/src/circuit/circuit.h index ba0b56dd1..af3c389f3 100644 --- a/src/circuit/circuit.h +++ b/src/circuit/circuit.h @@ -24,7 +24,7 @@ #include #include -#include "../simd/vector_view.h" +#include "../simd/monotonic_buffer.h" #include "gate_data.h" #define TARGET_VALUE_MASK ((uint32_t{1} << 24) - uint32_t{1}) @@ -50,7 +50,7 @@ enum SampleFormat { /// Transposed binary format. /// /// For each measurement: - /// For each group of 8 shots (padded with 0s if needed): + /// For each group of 64 shots (padded with 0s if needed): /// Output bit packed bytes (least significant bit of first byte has first shot) SAMPLE_FORMAT_PTB64, /// Human readable compressed format. @@ -83,7 +83,7 @@ struct OperationData { /// The bottom 24 bits of each item always refer to a qubit index. /// The top 8 bits are used for additional data such as /// Pauli basis, record lookback, and measurement inversion. - VectorView targets; + PointerRange targets; bool operator==(const OperationData &other) const; bool operator!=(const OperationData &other) const; @@ -111,17 +111,16 @@ struct Operation { /// A description of a quantum computation. struct Circuit { - /// Variable-sized operation data is stored as views into this single contiguous array. - /// Appending operations will append their target data into this vector, and the operation will reference it. - /// This decreases memory fragmentation and the number of allocations during parsing. - JaggedDataArena jagged_target_data; + /// Backing data store for variable-sized target data referenced by operations. + MonotonicBuffer jag_targets; /// Operations in the circuit, from earliest to latest. std::vector operations; - /// One more than the maximum qubit index seen in the circuit (so far). - size_t num_qubits; - /// The total number of measurement results the circuit (so far) will produce. - size_t num_measurements; - size_t min_safe_fusion_index = 0; + std::vector blocks; + + size_t count_qubits() const; + uint64_t count_measurements() const; + uint64_t count_detectors_and_observables() const; + size_t max_lookback() const; /// Constructs an empty circuit. Circuit(); @@ -152,35 +151,23 @@ struct Circuit { /// interactive (repl) mode, where measurements should produce results immediately instead of only after the /// circuit is entirely specified. *This has significantly worse performance. It prevents measurement /// batching.* - /// - /// Returns: - /// true: Operations were read from the file. - /// false: The file has ended, and no operations were read. - bool append_from_file(FILE *file, bool stop_asap); + void append_from_file(FILE *file, bool stop_asap); /// Grows the circuit using operations from a string. /// /// Note: operations are automatically fused. - /// - /// Returns: - /// true: Operations were read from the string. - /// false: The string contained no operations. - bool append_from_text(const char *text); + void append_from_text(const char *text); Circuit operator+(const Circuit &other) const; Circuit operator*(size_t repetitions) const; Circuit &operator+=(const Circuit &other); Circuit &operator*=(size_t repetitions); - /// Appends a circuit to the end of this one. - void append_circuit(const Circuit &circuit, size_t repetitions); /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. void append_operation(const Operation &operation); /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. - void append_op( - const std::string &gate_name, const std::vector &vec, double arg = 0); + void append_op(const std::string &gate_name, const std::vector &vec, double arg = 0); /// Safely adds an operation at the end of the circuit, copying its data into the circuit's jagged data as needed. - void append_operation( - const Gate &gate, const uint32_t *targets_start, size_t num_targets, double arg); + void append_operation(const Gate &gate, ConstPointerRange targets, double arg); /// Resets the circuit back to an empty circuit. void clear(); @@ -194,16 +181,48 @@ struct Circuit { /// Approximate equality. bool approx_equals(const Circuit &other, double atol) const; - /// Updates metadata (e.g. num_qubits) to account for an operation appended via non-standard means. - void update_metadata_for_manually_appended_operation(); - - void fusion_barrier(); + template + void for_each_operation(const CALLBACK &callback) const { + for (const auto &op : operations) { + assert(op.gate != nullptr); + if (op.gate->id == gate_name_to_id("REPEAT")) { + assert(op.target_data.targets.size() == 2); + assert(op.target_data.targets[0] < blocks.size()); + size_t repeats = op.target_data.targets[1]; + const auto &block = blocks[op.target_data.targets[0]]; + for (size_t k = 0; k < repeats; k++) { + block.for_each_operation(callback); + } + } else { + callback(op); + } + } + } + + template + void for_each_operation_reverse(const CALLBACK &callback) const { + for (size_t p = operations.size(); p-- > 0;) { + const auto &op = operations[p]; + assert(op.gate != nullptr); + if (op.gate->id == gate_name_to_id("REPEAT")) { + assert(op.target_data.targets.size() == 2); + assert(op.target_data.targets[0] < blocks.size()); + size_t repeats = op.target_data.targets[1]; + const auto &block = blocks[op.target_data.targets[0]]; + for (size_t k = 0; k < repeats; k++) { + block.for_each_operation_reverse(callback); + } + } else { + callback(op); + } + } + } }; /// Lists sets of measurements that have deterministic parity under noiseless execution from a circuit. struct DetectorsAndObservables { - JaggedDataArena jagged_data; - std::vector> detectors; + MonotonicBuffer jagged_detector_data; + std::vector> detectors; std::vector> observables; DetectorsAndObservables(const Circuit &circuit); diff --git a/src/circuit/circuit.perf.cc b/src/circuit/circuit.perf.cc index cb832e948..7d71a7cd6 100644 --- a/src/circuit/circuit.perf.cc +++ b/src/circuit/circuit.perf.cc @@ -25,7 +25,7 @@ CNOT 4 5 6 7 M 1 2 3 4 5 6 7 8 9 10 11 )input"); }).goal_nanos(950); - if (c.num_qubits == 0) { + if (c.count_qubits() == 0) { std::cerr << "impossible"; } } @@ -41,7 +41,7 @@ BENCHMARK(circuit_parse_sparse) { benchmark_go([&]() { c = Circuit::from_text(text.data()); }).goal_micros(150); - if (c.num_qubits == 0) { + if (c.count_qubits() == 0) { std::cerr << "impossible"; } } diff --git a/src/circuit/circuit.pybind.cc b/src/circuit/circuit.pybind.cc index c685b1140..bca3975a2 100644 --- a/src/circuit/circuit.pybind.cc +++ b/src/circuit/circuit.pybind.cc @@ -12,17 +12,22 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "circuit.h" +#include "circuit.pybind.h" #include "../py/base.pybind.h" #include "../py/compiled_detector_sampler.pybind.h" #include "../py/compiled_measurement_sampler.pybind.h" -#include "circuit.pybind.h" + +std::string circuit_repr(const Circuit &self) { + if (self.operations.empty()) { + return "stim.Circuit()"; + } + return "stim.Circuit('''\n" + self.str() + "\n''')"; +} void pybind_circuit(pybind11::module &m) { pybind11::class_( - m, - "Circuit", + m, "Circuit", R"DOC( A mutable stabilizer circuit. @@ -65,8 +70,8 @@ void pybind_circuit(pybind11::module &m) { ... M 1 ... ''') )DOC") - .def_readonly("num_measurements", &Circuit::num_measurements, R"DOC( - The number of measurement bits produced when sampling from the circuit. + .def_property_readonly("num_measurements", &Circuit::count_measurements, R"DOC( + Counts the number of measurement bits produced when sampling from the circuit. Examples: >>> import stim @@ -77,8 +82,8 @@ void pybind_circuit(pybind11::module &m) { >>> c.num_measurements 3 )DOC") - .def_readonly("num_qubits", &Circuit::num_qubits, R"DOC( - The number of qubits used when simulating the circuit. + .def_property_readonly("num_qubits", &Circuit::count_qubits, R"DOC( + Counts the number of qubits used when simulating the circuit. Examples: >>> import stim @@ -142,8 +147,8 @@ void pybind_circuit(pybind11::module &m) { ... Y 1 2 ... ''') >>> c.clear() - >>> print(c) - # Circuit [num_qubits=0, num_measurements=0] + >>> c + stim.Circuit() )DOC") .def("__iadd__", &Circuit::operator+=, pybind11::arg("second"), R"DOC( Appends a circuit into the receiving circuit (mutating it). @@ -159,7 +164,6 @@ void pybind_circuit(pybind11::module &m) { ... ''') >>> c1 += c2 >>> print(c1) - # Circuit [num_qubits=3, num_measurements=3] X 0 Y 1 2 M 0 1 2 @@ -187,7 +191,6 @@ void pybind_circuit(pybind11::module &m) { ... M 0 1 2 ... ''') >>> print(c1 + c2) - # Circuit [num_qubits=3, num_measurements=3] X 0 Y 1 2 M 0 1 2 @@ -203,7 +206,6 @@ void pybind_circuit(pybind11::module &m) { ... ''') >>> c *= 3 >>> print(c) - # Circuit [num_qubits=3, num_measurements=0] X 0 Y 1 2 X 0 @@ -221,7 +223,6 @@ void pybind_circuit(pybind11::module &m) { ... Y 1 2 ... ''') >>> print(c * 3) - # Circuit [num_qubits=3, num_measurements=0] X 0 Y 1 2 X 0 @@ -241,13 +242,10 @@ void pybind_circuit(pybind11::module &m) { ... Y 1 2 ... ''') >>> print(3 * c) - # Circuit [num_qubits=3, num_measurements=0] - X 0 - Y 1 2 - X 0 - Y 1 2 - X 0 - Y 1 2 + REPEAT 3 { + X 0 + Y 1 2 + } )DOC") .def( "append_operation", &Circuit::append_op, R"DOC( @@ -263,7 +261,6 @@ void pybind_circuit(pybind11::module &m) { >>> c.append_operation("X_ERROR", [0], 0.125) >>> c.append_operation("CORRELATED_ERROR", [stim.target_x(0), stim.target_y(2)], 0.25) >>> print(c) - # Circuit [num_qubits=3, num_measurements=2] X 0 H 0 1 M 0 !1 @@ -296,7 +293,6 @@ void pybind_circuit(pybind11::module &m) { ... CNOT rec[-1] 1 ... ''') >>> print(c) - # Circuit [num_qubits=3, num_measurements=1] H 0 CX 0 2 M 2 @@ -307,7 +303,5 @@ void pybind_circuit(pybind11::module &m) { )DOC", pybind11::arg("stim_program_text")) .def("__str__", &Circuit::str) - .def("__repr__", [](const Circuit &self) { - return "stim.Circuit(\"\"\"\n" + self.str() + "\n\"\"\")"; - }); + .def("__repr__", &circuit_repr); } diff --git a/src/circuit/circuit.pybind.h b/src/circuit/circuit.pybind.h index 64a253562..657e0460e 100644 --- a/src/circuit/circuit.pybind.h +++ b/src/circuit/circuit.pybind.h @@ -17,6 +17,9 @@ #include +#include "circuit.h" + void pybind_circuit(pybind11::module &m); +std::string circuit_repr(const Circuit &self); #endif \ No newline at end of file diff --git a/src/circuit/circuit.test.cc b/src/circuit/circuit.test.cc index a0c7e922a..3f481c1d1 100644 --- a/src/circuit/circuit.test.cc +++ b/src/circuit/circuit.test.cc @@ -27,7 +27,7 @@ OpDat OpDat::flipped(size_t target) { } OpDat::operator OperationData() { - return {0, {&targets, 0, targets.size()}}; + return {0, targets}; } TEST(circuit, from_text) { @@ -114,33 +114,35 @@ TEST(circuit, from_text) { expected.append_op("M", {0, 0 | TARGET_INVERTED_BIT, 1, 1 | TARGET_INVERTED_BIT}); ASSERT_EQ(f("M 0 !0 1 !1"), expected); + // Measurement fusion. expected.clear(); expected.append_op("H", {0}); expected.append_op("M", {0, 1, 2}); expected.append_op("SWAP", {0, 1}); expected.append_op("M", {0, 10}); ASSERT_EQ( - f("# Measurement fusion\n" - "H 0\n" - "M 0\n" - "M 1\n" - "M 2\n" - "SWAP 0 1\n" - "M 0\n" - "M 10\n"), + f(R"CIRCUIT( + H 0 + M 0 + M 1 + M 2 + SWAP 0 1 + M 0 + M 10 + )CIRCUIT"), expected); expected.clear(); expected.append_op("X", {0}); - expected.append_op("Y", {1, 2}); - expected.fusion_barrier(); - expected.append_op("Y", {1, 2}); + expected += Circuit::from_text("Y 1 2") * 2; ASSERT_EQ( - f("X 0\n" - "REPEAT 2 {\n" - " Y 1\n" - " Y 2 #####\n" - "} #####"), + f(R"CIRCUIT( + X 0 + REPEAT 2 { + Y 1 + Y 2 #####" + } #####" + )CIRCUIT"), expected); expected.clear(); @@ -150,25 +152,16 @@ TEST(circuit, from_text) { expected.append_op("DETECTOR", {6 | TARGET_RECORD_BIT}); ASSERT_EQ(f("DETECTOR rec[-6]"), expected); - expected.clear(); - expected.append_op("M", {0}); - expected.fusion_barrier(); - expected.append_op("M", {1, 2, 3}); - expected.fusion_barrier(); - expected.append_op("M", {1, 2, 3}); - expected.fusion_barrier(); - expected.append_op("M", {1, 2, 3}); - expected.fusion_barrier(); - expected.append_op("M", {1, 2, 3}); - expected.fusion_barrier(); - expected.append_op("M", {1, 2, 3}); - ASSERT_EQ( + Circuit parsed = f("M 0\n" "REPEAT 5 {\n" " M 1 2\n" " M 3\n" - "} #####"), - expected); + "} #####"); + ASSERT_EQ(parsed.operations.size(), 2); + ASSERT_EQ(parsed.blocks.size(), 1); + ASSERT_EQ(parsed.blocks[0].operations.size(), 1); + ASSERT_EQ(parsed.blocks[0].operations[0].target_data.targets.size(), 3); expected.clear(); expected.append_op( @@ -194,23 +187,20 @@ TEST(circuit, append_circuit) { Circuit expected; expected.append_op("X", {0, 1}); expected.append_op("M", {0, 1, 2, 4}); - expected.fusion_barrier(); expected.append_op("M", {7}); Circuit actual = c1; actual += c2; + ASSERT_EQ(actual.operations.size(), 3); + actual = Circuit::from_text(actual.str().data()); + ASSERT_EQ(actual.operations.size(), 2); ASSERT_EQ(actual, expected); actual *= 4; - for (size_t k = 0; k < 3; k++) { - expected.append_op("X", {0, 1}); - expected.append_op("M", {0, 1, 2, 4}); - expected.fusion_barrier(); - expected.append_op("M", {7}); - } - ASSERT_EQ(actual, expected); - ASSERT_EQ(actual.jagged_target_data.vec.size(), 7); - ASSERT_EQ(expected.jagged_target_data.vec.size(), 7 * 4); + ASSERT_EQ(actual.str(), R"CIRCUIT(REPEAT 4 { + X 0 1 + M 0 1 2 4 7 +})CIRCUIT"); } TEST(circuit, append_op_fuse) { @@ -257,8 +247,7 @@ TEST(circuit, str) { 29 | TARGET_PAULI_X_BIT | TARGET_PAULI_Z_BIT, }, 0.25); - ASSERT_EQ(c.str(), R"circuit(# Circuit [num_qubits=30, num_measurements=3] -TICK + ASSERT_EQ(c.str(), R"circuit(TICK CX 2 3 rec[-5] 3 M 1 3 2 DETECTOR rec[-7] @@ -272,6 +261,7 @@ TEST(circuit, append_op_validation) { ASSERT_THROW({ c.append_op("CNOT", {0}); }, std::out_of_range); c.append_op("CNOT", {0, 1}); + ASSERT_THROW({ c.append_op("REPEAT", {100}); }, std::out_of_range); ASSERT_THROW({ c.append_op("X", {0 | TARGET_PAULI_X_BIT}); }, std::out_of_range); ASSERT_THROW({ c.append_op("X", {0 | TARGET_PAULI_Z_BIT}); }, std::out_of_range); ASSERT_THROW({ c.append_op("X", {0 | TARGET_INVERTED_BIT}); }, std::out_of_range); @@ -306,10 +296,315 @@ TEST(circuit, classical_controls) { expected.append_op("CY", {2 | TARGET_RECORD_BIT, 1}); expected.append_op("CZ", {4 | TARGET_RECORD_BIT, 1}); ASSERT_EQ( - Circuit::from_text(R"circuit(# Circuit [num_qubits=5, num_measurements=0] -ZCX 0 1 + Circuit::from_text(R"circuit(ZCX 0 1 ZCX rec[-1] 1 ZCY rec[-2] 1 ZCZ rec[-4] 1)circuit"), expected); } + +TEST(circuit, for_each_operation) { + Circuit c; + c.append_from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 2 { + X 1 + REPEAT 3 { + Y 2 + } + } + )CIRCUIT"); + + Circuit flat; + auto f = [&](const char *gate, const std::vector &targets) { + flat.append_operation({&GATE_DATA.at(gate), {0, flat.jag_targets.take_copy(targets)}}); + }; + f("H", {0}); + f("M", {0, 1}); + f("X", {1}); + f("Y", {2}); + f("Y", {2}); + f("Y", {2}); + f("X", {1}); + f("Y", {2}); + f("Y", {2}); + f("Y", {2}); + + std::vector ops; + c.for_each_operation([&](const Operation &op) { + ops.push_back(op); + }); + ASSERT_EQ(ops, flat.operations); +} + +TEST(circuit, for_each_operation_reverse) { + Circuit c; + c.append_from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 2 { + X 1 + REPEAT 3 { + Y 2 + } + } + )CIRCUIT"); + + Circuit flat; + auto f = [&](const char *gate, const std::vector &targets) { + flat.append_operation({&GATE_DATA.at(gate), {0, flat.jag_targets.take_copy(targets)}}); + }; + f("Y", {2}); + f("Y", {2}); + f("Y", {2}); + f("X", {1}); + f("Y", {2}); + f("Y", {2}); + f("Y", {2}); + f("X", {1}); + f("M", {0, 1}); + f("H", {0}); + + std::vector ops; + c.for_each_operation_reverse([&](const Operation &op) { + ops.push_back(op); + }); + ASSERT_EQ(ops, flat.operations); +} + +TEST(circuit, count_qubits) { + ASSERT_EQ(Circuit().count_qubits(), 0); + + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 2 { + X 1 + REPEAT 3 { + Y 2 + M 2 + } + } + )CIRCUIT") + .count_qubits(), + 3); + + // Ensure not unrolling to compute. + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + X 1 + REPEAT 999999 { + Y 2 + M 2 + } + } + } + } + } + )CIRCUIT") + .count_qubits(), + 3); +} + +TEST(circuit, count_detectors_and_observables) { + ASSERT_EQ(Circuit().count_detectors_and_observables(), 0); + + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + M 0 1 2 + DETECTOR rec[-1] + OBSERVABLE_INCLUDE(5) rec[-1] + )CIRCUIT") + .count_detectors_and_observables(), + 2); + + // Ensure not unrolling to compute. + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + M 0 1 + REPEAT 1000 { + REPEAT 1000 { + REPEAT 1000 { + REPEAT 1000 { + DETECTOR rec[-1] + } + } + } + } + )CIRCUIT") + .count_detectors_and_observables(), + 1000000000000ULL); +} + +TEST(circuit, max_lookback) { + ASSERT_EQ(Circuit().max_lookback(), 0); + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + M 0 1 2 3 4 5 6 + )CIRCUIT") + .max_lookback(), + 0); + + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + M 0 1 2 3 4 5 6 + REPEAT 2 { + CNOT rec[-4] 0 + REPEAT 3 { + CNOT rec[-1] 0 + } + } + )CIRCUIT") + .max_lookback(), + 4); + + // Ensure not unrolling to compute. + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + M 0 1 2 3 4 5 + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + CNOT rec[-5] 0 + } + } + } + } + } + )CIRCUIT") + .max_lookback(), + 5); +} + +TEST(circuit, count_measurements) { + ASSERT_EQ(Circuit().count_measurements(), 0); + + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 2 { + X 1 + REPEAT 3 { + Y 2 + M 2 + } + } + )CIRCUIT") + .count_measurements(), + 8); + + // Ensure not unrolling to compute. + ASSERT_EQ( + Circuit::from_text(R"CIRCUIT( + REPEAT 999999 { + REPEAT 999999 { + REPEAT 999999 { + M 0 + } + } + } + )CIRCUIT") + .count_measurements(), + 999999ULL * 999999ULL * 999999ULL); +} + +TEST(circuit, preserves_repetition_blocks) { + Circuit c = Circuit::from_text(R"CIRCUIT( + H 0 + M 0 1 + REPEAT 2 { + X 1 + REPEAT 3 { + Y 2 + M 2 + X 0 + } + } + )CIRCUIT"); + ASSERT_EQ(c.operations.size(), 3); + ASSERT_EQ(c.blocks.size(), 1); + ASSERT_EQ(c.blocks[0].operations.size(), 2); + ASSERT_EQ(c.blocks[0].blocks.size(), 1); + ASSERT_EQ(c.blocks[0].blocks[0].operations.size(), 3); + ASSERT_EQ(c.blocks[0].blocks[0].blocks.size(), 0); +} + +TEST(circuit, multiplication_repeats) { + Circuit c = Circuit::from_text(R"CIRCUIT( + H 0 + M 0 1 + )CIRCUIT"); + ASSERT_EQ((c * 2).str(), R"CIRCUIT(REPEAT 2 { + H 0 + M 0 1 +})CIRCUIT"); + + ASSERT_EQ(c * 0, Circuit()); + ASSERT_EQ(c * 1, c); + Circuit copy = c; + c *= 1; + ASSERT_EQ(c, copy); + c *= 0; + ASSERT_EQ(c, Circuit()); +} + +TEST(circuit, self_addition) { + Circuit c = Circuit::from_text(R"CIRCUIT( + X 0 + )CIRCUIT"); + c += c; + ASSERT_EQ(c.operations.size(), 2); + ASSERT_EQ(c.blocks.size(), 0); + ASSERT_EQ(c.operations[0], c.operations[1]); + + c = Circuit::from_text(R"CIRCUIT( + X 0 + REPEAT 2 { + Y 0 + } + )CIRCUIT"); + c += c; + ASSERT_EQ(c.operations.size(), 4); + ASSERT_EQ(c.blocks.size(), 1); + ASSERT_EQ(c.operations[0], c.operations[2]); + ASSERT_EQ(c.operations[1], c.operations[3]); +} + +TEST(circuit, addition_shares_blocks) { + Circuit c1 = Circuit::from_text(R"CIRCUIT( + X 0 + REPEAT 2 { + X 1 + } + )CIRCUIT"); + Circuit c2 = Circuit::from_text(R"CIRCUIT( + X 2 + REPEAT 2 { + X 3 + } + )CIRCUIT"); + Circuit c3 = Circuit::from_text(R"CIRCUIT( + X 0 + REPEAT 2 { + X 1 + } + X 2 + REPEAT 2 { + X 3 + } + )CIRCUIT"); + ASSERT_EQ(c1 + c2, c3); + c1 += c2; + ASSERT_EQ(c1, c3); +} diff --git a/src/circuit/circuit_pybind_test.py b/src/circuit/circuit_pybind_test.py index c6899d9d7..0b8f1fc6e 100644 --- a/src/circuit/circuit_pybind_test.py +++ b/src/circuit/circuit_pybind_test.py @@ -20,15 +20,12 @@ def test_circuit_init_num_measurements_num_qubits(): c = stim.Circuit() assert c.num_qubits == c.num_measurements == 0 - assert str(c).strip() == """ -# Circuit [num_qubits=0, num_measurements=0] - """.strip() + assert str(c).strip() == "" c.append_operation("X", [3]) assert c.num_qubits == 4 assert c.num_measurements == 0 assert str(c).strip() == """ -# Circuit [num_qubits=4, num_measurements=0] X 3 """.strip() @@ -36,7 +33,6 @@ def test_circuit_init_num_measurements_num_qubits(): assert c.num_qubits == 4 assert c.num_measurements == 1 assert str(c).strip() == """ -# Circuit [num_qubits=4, num_measurements=1] X 3 M 0 """.strip() @@ -74,7 +70,6 @@ def test_circuit_append_operation(): c.append_operation("DETECTOR", [stim.target_rec(-1)]) c.append_operation("OBSERVABLE_INCLUDE", [stim.target_rec(-1), stim.target_rec(-2)], 5) assert str(c).strip() == """ -# Circuit [num_qubits=4, num_measurements=2] X 0 1 2 3 CX 0 1 M 0 !1 @@ -93,7 +88,6 @@ def test_circuit_iadd(): c2.append_operation("M", [4]) c += c2 assert str(c).strip() == """ -# Circuit [num_qubits=5, num_measurements=1] X 1 2 Y 3 M 4 @@ -101,7 +95,6 @@ def test_circuit_iadd(): c += c assert str(c).strip() == """ -# Circuit [num_qubits=5, num_measurements=2] X 1 2 Y 3 M 4 @@ -118,14 +111,12 @@ def test_circuit_add(): c2.append_operation("Y", [3]) c2.append_operation("M", [4]) assert str(c + c2).strip() == """ - # Circuit [num_qubits=5, num_measurements=1] X 1 2 Y 3 M 4 """.strip() assert str(c2 + c2).strip() == """ -# Circuit [num_qubits=5, num_measurements=2] Y 3 M 4 Y 3 @@ -137,22 +128,23 @@ def test_circuit_mul(): c = stim.Circuit() c.append_operation("Y", [3]) c.append_operation("M", [4]) - expected = """ -# Circuit [num_qubits=5, num_measurements=2] -Y 3 -M 4 -Y 3 -M 4 + assert str(c * 2) == str(2 * c) == """ +REPEAT 2 { + Y 3 + M 4 +} + """.strip() + assert str((c * 2) * 3) == """ +REPEAT 6 { + Y 3 + M 4 +} """.strip() - assert str(c * 2) == str(2 * c) == expected expected = """ -# Circuit [num_qubits=5, num_measurements=3] -Y 3 -M 4 -Y 3 -M 4 -Y 3 -M 4 +REPEAT 3 { + Y 3 + M 4 +} """.strip() assert str(c * 3) == str(3 * c) == expected c *= 3 @@ -160,7 +152,7 @@ def test_circuit_mul(): c *= 1 assert str(c) == expected c *= 0 - assert str(c) == "# Circuit [num_qubits=0, num_measurements=0]" + assert str(c) == "" def test_circuit_repr(): @@ -169,11 +161,10 @@ def test_circuit_repr(): M 0 """) r = repr(v) - assert r == '''stim.Circuit(""" -# Circuit [num_qubits=1, num_measurements=1] + assert r == """stim.Circuit(''' X 0 M 0 -""")''' +''')""" assert eval(r, {'stim': stim}) == v @@ -208,44 +199,45 @@ def test_circuit_compile_sampler(): c = stim.Circuit() s = c.compile_sampler() c.append_operation("M", [0]) - assert str(s) == """ -# reference sample: -# Circuit [num_qubits=0, num_measurements=0] - """.strip() + print(repr(s)) + assert repr(s) == "stim.CompiledMeasurementSampler(stim.Circuit())" s = c.compile_sampler() - assert str(s) == """ -# reference sample: 0 -# Circuit [num_qubits=1, num_measurements=1] + assert repr(s) == """ +stim.CompiledMeasurementSampler(stim.Circuit(''' M 0 +''')) """.strip() c.append_operation("H", [0, 1, 2, 3, 4]) c.append_operation("M", [0, 1, 2, 3, 4]) s = c.compile_sampler() - assert str(s) == """ -# reference sample: 000000 -# Circuit [num_qubits=5, num_measurements=6] + r = repr(s) + assert r == """ +stim.CompiledMeasurementSampler(stim.Circuit(''' M 0 H 0 1 2 3 4 M 0 1 2 3 4 +''')) """.strip() == str(stim.CompiledMeasurementSampler(c)) + # Check that expression can be evaluated. + _ = eval(r, {"stim": stim}) + def test_circuit_compile_detector_sampler(): c = stim.Circuit() s = c.compile_detector_sampler() c.append_operation("M", [0]) - assert str(s) == """ -# num_detectors: 0 -# num_observables: 0 -# Circuit [num_qubits=0, num_measurements=0] - """.strip() + assert repr(s) == "stim.CompiledDetectorSampler(stim.Circuit())" c.append_operation("DETECTOR", [stim.target_rec(-1)]) s = c.compile_detector_sampler() - assert str(s) == """ -# num_detectors: 1 -# num_observables: 0 -# Circuit [num_qubits=1, num_measurements=1] + r = repr(s) + assert r == """ +stim.CompiledDetectorSampler(stim.Circuit(''' M 0 DETECTOR rec[-1] +''')) """.strip() + + # Check that expression can be evaluated. + _ = eval(r, {"stim": stim}) diff --git a/src/circuit/gate_data.cc b/src/circuit/gate_data.cc index f9b1b2c5d..5873c26c2 100644 --- a/src/circuit/gate_data.cc +++ b/src/circuit/gate_data.cc @@ -381,7 +381,7 @@ extern const GateDataMap GATE_DATA( &TableauSimulator::I, &FrameSimulator::I, &ErrorFuser::I, - GATE_IS_BLOCK, + (GateFlags)(GATE_IS_BLOCK | GATE_IS_NOT_FUSABLE), {}, {}, }, @@ -469,7 +469,7 @@ Gate::Gate( : name(name), tableau_simulator_function(tableau_simulator_function), frame_simulator_function(frame_simulator_function), - hit_simulator_function(hit_simulator_function), + reverse_error_fuser_function(hit_simulator_function), flags(flags), unitary_data(unitary_data), tableau_data(tableau_data), diff --git a/src/circuit/gate_data.h b/src/circuit/gate_data.h index e1a8d262f..8d31ee42b 100644 --- a/src/circuit/gate_data.h +++ b/src/circuit/gate_data.h @@ -118,7 +118,7 @@ struct Gate { const char *name; void (TableauSimulator::*tableau_simulator_function)(const OperationData &); void (FrameSimulator::*frame_simulator_function)(const OperationData &); - void (ErrorFuser::*hit_simulator_function)(const OperationData &); + void (ErrorFuser::*reverse_error_fuser_function)(const OperationData &); GateFlags flags; TruncatedArray, 4>, 4> unitary_data; TruncatedArray tableau_data; diff --git a/src/main.perf.cc b/src/main.perf.cc new file mode 100644 index 000000000..37012ac55 --- /dev/null +++ b/src/main.perf.cc @@ -0,0 +1,164 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "benchmark_util.h" +#include "main_helper.h" +#include "simulators/detection_simulator.h" + +Circuit make_rep_code(uint32_t distance, uint32_t rounds) { + Circuit round_ops; + for (uint32_t k = 0; k < distance - 1; k++) { + round_ops.append_op("CNOT", {2 * k, 2 * k + 1}); + } + for (uint32_t k = 0; k < distance - 1; k++) { + round_ops.append_op("DEPOLARIZE2", {2 * k, 2 * k + 1}, 0.001); + } + for (uint32_t k = 1; k < distance; k++) { + round_ops.append_op("CNOT", {2 * k, 2 * k - 1}); + } + for (uint32_t k = 1; k < distance; k++) { + round_ops.append_op("DEPOLARIZE2", {2 * k, 2 * k - 1}, 0.001); + } + for (uint32_t k = 0; k < distance - 1; k++) { + round_ops.append_op("X_ERROR", {2 * k + 1}, 0.001); + } + for (uint32_t k = 0; k < distance - 1; k++) { + round_ops.append_op("MR", {2 * k + 1}); + } + Circuit detectors; + for (uint32_t k = 1; k < distance; k++) { + detectors.append_op("DETECTOR", {k | TARGET_RECORD_BIT, (k + distance - 1) | TARGET_RECORD_BIT}); + } + + Circuit result = round_ops + (round_ops + detectors) * (rounds - 1); + for (uint32_t k = 0; k < distance; k++) { + result.append_op("X_ERROR", {2 * k}, 0.001); + } + for (uint32_t k = 0; k < distance; k++) { + result.append_op("M", {2 * k}); + } + for (uint32_t k = 1; k < distance; k++) { + result.append_op( + "DETECTOR", {k | TARGET_RECORD_BIT, (k + 1) | TARGET_RECORD_BIT, (k + distance) | TARGET_RECORD_BIT}); + } + result.append_op("OBSERVABLE_INCLUDE", {1 | TARGET_RECORD_BIT}, 0); + return result; +} + +BENCHMARK(main_sample1_tableau_rep_d1000_r100) { + size_t distance = 1000; + size_t rounds = 100; + auto circuit = make_rep_code(distance, rounds); + FILE *in = tmpfile(); + FILE *out = tmpfile(); + fprintf(in, "%s", circuit.str().data()); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + benchmark_go([&]() { + rewind(in); + rewind(out); + TableauSimulator::sample_stream(in, out, SAMPLE_FORMAT_B8, false, rng); + }) + .goal_millis(30) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample1_pauliframe_b8_rep_d1000_r100) { + size_t distance = 1000; + size_t rounds = 100; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(0); + benchmark_go([&]() { + rewind(out); + FrameSimulator::sample_out(circuit, ref, 1, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(16) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample1_detectors_b8_rep_d1000_r100) { + size_t distance = 1000; + size_t rounds = 100; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(circuit.count_measurements()); + benchmark_go([&]() { + rewind(out); + detector_samples_out(circuit, 1, false, true, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(20) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample256_pauliframe_b8_rep_d1000_r100) { + size_t distance = 1000; + size_t rounds = 100; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(0); + benchmark_go([&]() { + rewind(out); + FrameSimulator::sample_out(circuit, ref, 256, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(20) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample256_pauliframe_b8_rep_d1000_r1000_stream) { + size_t distance = 1000; + size_t rounds = 1000; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(0); + benchmark_go([&]() { + rewind(out); + FrameSimulator::sample_out(circuit, ref, 256, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(360) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample256_detectors_b8_rep_d1000_r100) { + size_t distance = 1000; + size_t rounds = 100; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(0); + benchmark_go([&]() { + rewind(out); + detector_samples_out(circuit, 256, false, true, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(25) + .show_rate("Samples", circuit.count_measurements()); +} + +BENCHMARK(main_sample256_detectors_b8_rep_d1000_r1000_stream) { + size_t distance = 1000; + size_t rounds = 1000; + auto circuit = make_rep_code(distance, rounds); + FILE *out = tmpfile(); + std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) + simd_bits ref(0); + benchmark_go([&]() { + rewind(out); + detector_samples_out(circuit, 256, false, true, out, SAMPLE_FORMAT_B8, rng); + }) + .goal_millis(360) + .show_rate("Samples", circuit.count_measurements()); +} diff --git a/src/main_helper.cc b/src/main_helper.cc index b291f3127..0ad433d18 100644 --- a/src/main_helper.cc +++ b/src/main_helper.cc @@ -40,17 +40,17 @@ static std::vector sample_mode_known_arguments{ "--sample", "--frame0", "--out_format", "--out", "--in", }; static std::vector detect_mode_known_arguments{ - "--detect", "--append_observables", "--prepend_observables", "--out_format", "--out", "--in", + "--detect", "--prepend_observables", "--append_observables", "--out_format", "--out", "--in", }; static std::vector detector_hypergraph_mode_known_arguments{ - "--detector_hypergraph", "--out", "--in", + "--detector_hypergraph", + "--out", + "--in", }; static std::vector repl_mode_known_arguments{ "--repl", }; -static std::vector format_names{ - "01", "b8", "ptb64", "hits", "r8", "dets" -}; +static std::vector format_names{"01", "b8", "ptb64", "hits", "r8", "dets"}; static std::vector format_values{ SAMPLE_FORMAT_01, SAMPLE_FORMAT_B8, SAMPLE_FORMAT_PTB64, SAMPLE_FORMAT_HITS, SAMPLE_FORMAT_R8, SAMPLE_FORMAT_DETS, }; @@ -126,7 +126,7 @@ M 0 1 2 std::mt19937_64 rng = externally_seeded_rng(); if (mode_interactive) { check_for_unknown_arguments(repl_mode_known_arguments, "--repl", argc, argv); - TableauSimulator::sample_stream(in, out, mode_interactive, rng); + TableauSimulator::sample_stream(in, out, SAMPLE_FORMAT_01, mode_interactive, rng); return EXIT_SUCCESS; } if (mode_sampling) { @@ -138,13 +138,13 @@ M 0 1 2 if (num_shots == 0) { return EXIT_SUCCESS; } - if (num_shots == 1 && out_format == SAMPLE_FORMAT_01 && !frame0) { - TableauSimulator::sample_stream(in, out, false, rng); + if (num_shots == 1 && !frame0) { + TableauSimulator::sample_stream(in, out, out_format, false, rng); return EXIT_SUCCESS; } auto circuit = Circuit::from_file(in); - simd_bits ref(circuit.num_measurements); + simd_bits ref(0); if (!frame0) { ref = TableauSimulator::reference_sample_circuit(circuit); } @@ -166,7 +166,6 @@ M 0 1 2 } auto circuit = Circuit::from_file(in); - simd_bits ref(circuit.num_measurements); detector_samples_out(circuit, num_shots, prepend_observables, append_observables, out, out_format, rng); return EXIT_SUCCESS; } diff --git a/src/probability_util.h b/src/probability_util.h index c5e991f40..838852022 100644 --- a/src/probability_util.h +++ b/src/probability_util.h @@ -48,7 +48,7 @@ struct RareErrorIterator { } template - inline static void for_samples(double p, const VectorView &vals, std::mt19937_64 &rng, BODY body) { + inline static void for_samples(double p, const PointerRange &vals, std::mt19937_64 &rng, BODY body) { RareErrorIterator skipper((float)p); while (true) { size_t s = skipper.next(rng); diff --git a/src/probability_util.perf.cc b/src/probability_util.perf.cc index 55d32da54..ead419713 100644 --- a/src/probability_util.perf.cc +++ b/src/probability_util.perf.cc @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "benchmark_util.h" - #include "probability_util.h" + +#include "benchmark_util.h" #include "simd/simd_bits.h" BENCHMARK(biased_random_1024_0point1percent) { @@ -22,9 +22,11 @@ BENCHMARK(biased_random_1024_0point1percent) { float p = 0.001; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(70).show_rate("bits", n); + }) + .goal_nanos(70) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_0point01percent) { @@ -32,9 +34,11 @@ BENCHMARK(biased_random_1024_0point01percent) { float p = 0.0001; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(35).show_rate("bits", n); + }) + .goal_nanos(35) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_1percent) { @@ -42,9 +46,11 @@ BENCHMARK(biased_random_1024_1percent) { float p = 0.01; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(250).show_rate("bits", n); + }) + .goal_nanos(250) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_40percent) { @@ -52,9 +58,11 @@ BENCHMARK(biased_random_1024_40percent) { float p = 0.4; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(420).show_rate("bits", n); + }) + .goal_nanos(420) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_50percent) { @@ -62,9 +70,11 @@ BENCHMARK(biased_random_1024_50percent) { float p = 0.5; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(40).show_rate("bits", n); + }) + .goal_nanos(40) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_90percent) { @@ -72,9 +82,11 @@ BENCHMARK(biased_random_1024_90percent) { float p = 0.9; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(450).show_rate("bits", n); + }) + .goal_nanos(450) + .show_rate("bits", n); } BENCHMARK(biased_random_1024_99percent) { @@ -82,7 +94,9 @@ BENCHMARK(biased_random_1024_99percent) { float p = 0.99; size_t n = 1024; simd_bits data(n); - benchmark_go([&](){ + benchmark_go([&]() { biased_randomize_bits(p, data.u64, data.u64 + data.num_u64_padded(), rng); - }).goal_nanos(260).show_rate("bits", n); + }) + .goal_nanos(260) + .show_rate("bits", n); } diff --git a/src/probability_util.test.cc b/src/probability_util.test.cc index 4b7dabd6e..071909dce 100644 --- a/src/probability_util.test.cc +++ b/src/probability_util.test.cc @@ -56,7 +56,8 @@ TEST(probability_util, biased_random) { float dev = sqrtf(p * (1 - p) * n); float min_expected = n * p - dev * 5; float max_expected = n * p + dev * 5; - ASSERT_TRUE( min_expected >= 0 && max_expected <= n) << min_expected << ", " << max_expected; - EXPECT_TRUE(min_expected <= t && t <= max_expected) << min_expected/n << " < " << t/(float)n << " < " << max_expected/n << " for p=" << p; + ASSERT_TRUE(min_expected >= 0 && max_expected <= n) << min_expected << ", " << max_expected; + EXPECT_TRUE(min_expected <= t && t <= max_expected) + << min_expected / n << " < " << t / (float)n << " < " << max_expected / n << " for p=" << p; } } \ No newline at end of file diff --git a/src/py/compiled_detector_sampler.pybind.cc b/src/py/compiled_detector_sampler.pybind.cc index 2deb69d56..f4fe5c18c 100644 --- a/src/py/compiled_detector_sampler.pybind.cc +++ b/src/py/compiled_detector_sampler.pybind.cc @@ -14,6 +14,7 @@ #include "compiled_detector_sampler.pybind.h" +#include "../circuit/circuit.pybind.h" #include "../simulators/detection_simulator.h" #include "../simulators/frame_simulator.h" #include "../simulators/tableau_simulator.h" @@ -40,9 +41,14 @@ pybind11::array_t CompiledDetectorSampler::sample( } size_t n = dets_obs.detectors.size() + dets_obs.observables.size() * (prepend_observables + append_observables); - return pybind11::array_t(pybind11::buffer_info( - bytes.data(), sizeof(uint8_t), pybind11::format_descriptor::value, 2, {num_shots, n}, - {(long long)sample.num_minor_bits_padded(), (long long)1}, true)); + + void *ptr = bytes.data(); + ssize_t itemsize = sizeof(uint8_t); + std::vector shape{(ssize_t)num_shots, (ssize_t)n}; + std::vector stride{(ssize_t)sample.num_minor_bits_padded(), 1}; + const std::string &format = pybind11::format_descriptor::value; + bool readonly = true; + return pybind11::array_t(pybind11::buffer_info(ptr, itemsize, format, 2, shape, stride, readonly)); } pybind11::array_t CompiledDetectorSampler::sample_bit_packed( @@ -51,16 +57,21 @@ pybind11::array_t CompiledDetectorSampler::sample_bit_packed( detector_samples(circuit, dets_obs, num_shots, prepend_observables, append_observables, PYBIND_SHARED_RNG()) .transposed(); size_t n = dets_obs.detectors.size() + dets_obs.observables.size() * (prepend_observables + append_observables); - return pybind11::array_t(pybind11::buffer_info( - sample.data.u8, sizeof(uint8_t), pybind11::format_descriptor::value, 2, {num_shots, (n + 7) / 8}, - {(long long)sample.num_minor_u8_padded(), (long long)1}, true)); + + void *ptr = sample.data.u8; + ssize_t itemsize = sizeof(uint8_t); + std::vector shape{(ssize_t)num_shots, (ssize_t)(n + 7) / 8}; + std::vector stride{(ssize_t)sample.num_minor_u8_padded(), 1}; + const std::string &format = pybind11::format_descriptor::value; + bool readonly = true; + return pybind11::array_t(pybind11::buffer_info(ptr, itemsize, format, 2, shape, stride, readonly)); } -std::string CompiledDetectorSampler::str() const { +std::string CompiledDetectorSampler::repr() const { std::stringstream result; - result << "# num_detectors: " << dets_obs.detectors.size() << "\n"; - result << "# num_observables: " << dets_obs.observables.size() << "\n"; - result << circuit; + result << "stim.CompiledDetectorSampler("; + result << circuit_repr(circuit); + result << ")"; return result.str(); } @@ -79,7 +90,7 @@ void pybind_compiled_detector_sampler(pybind11::module &m) { shots: The number of times to sample every detector in the circuit. prepend_observables: Defaults to false. When set, observables are included with the detectors and are placed at the start of the results. - prepend_observables: Defaults to false. When set, observables are included with the detectors and are + append_observables: Defaults to false. When set, observables are included with the detectors and are placed at the end of the results. Returns: @@ -100,7 +111,7 @@ void pybind_compiled_detector_sampler(pybind11::module &m) { shots: The number of times to sample every detector in the circuit. prepend_observables: Defaults to false. When set, observables are included with the detectors and are placed at the start of the results. - prepend_observables: Defaults to false. When set, observables are included with the detectors and are + append_observables: Defaults to false. When set, observables are included with the detectors and are placed at the end of the results. Returns: @@ -110,5 +121,5 @@ void pybind_compiled_detector_sampler(pybind11::module &m) { )DOC", pybind11::arg("shots"), pybind11::kw_only(), pybind11::arg("prepend_observables") = false, pybind11::arg("append_observables") = false) - .def("__str__", &CompiledDetectorSampler::str); + .def("__repr__", &CompiledDetectorSampler::repr); } diff --git a/src/py/compiled_detector_sampler.pybind.h b/src/py/compiled_detector_sampler.pybind.h index 473cf93bb..f51f70b11 100644 --- a/src/py/compiled_detector_sampler.pybind.h +++ b/src/py/compiled_detector_sampler.pybind.h @@ -30,7 +30,7 @@ struct CompiledDetectorSampler { CompiledDetectorSampler(Circuit circuit); pybind11::array_t sample(size_t num_shots, bool prepend_observables, bool append_observables); pybind11::array_t sample_bit_packed(size_t num_shots, bool prepend_observables, bool append_observables); - std::string str() const; + std::string repr() const; }; #endif diff --git a/src/py/compiled_measurement_sampler.pybind.cc b/src/py/compiled_measurement_sampler.pybind.cc index 178766bb7..abfe36961 100644 --- a/src/py/compiled_measurement_sampler.pybind.cc +++ b/src/py/compiled_measurement_sampler.pybind.cc @@ -14,6 +14,7 @@ #include "compiled_measurement_sampler.pybind.h" +#include "../circuit/circuit.pybind.h" #include "../simulators/detection_simulator.h" #include "../simulators/frame_simulator.h" #include "../simulators/tableau_simulator.h" @@ -37,27 +38,32 @@ pybind11::array_t CompiledMeasurementSampler::sample(size_t num_samples } } - return pybind11::array_t(pybind11::buffer_info( - bytes.data(), sizeof(uint8_t), pybind11::format_descriptor::value, 2, - {num_samples, circuit.num_measurements}, {(long long)sample.num_minor_bits_padded(), (long long)1}, true)); + void *ptr = bytes.data(); + ssize_t itemsize = sizeof(uint8_t); + std::vector shape{(ssize_t)num_samples, (ssize_t)circuit.count_measurements()}; + std::vector stride{(ssize_t)sample.num_minor_bits_padded(), 1}; + const std::string &format = pybind11::format_descriptor::value; + bool readonly = true; + return pybind11::array_t(pybind11::buffer_info(ptr, itemsize, format, 2, shape, stride, readonly)); } pybind11::array_t CompiledMeasurementSampler::sample_bit_packed(size_t num_samples) { auto sample = FrameSimulator::sample(circuit, ref, num_samples, PYBIND_SHARED_RNG()); - return pybind11::array_t(pybind11::buffer_info( - sample.data.u8, sizeof(uint8_t), pybind11::format_descriptor::value, 2, - {num_samples, (circuit.num_measurements + 7) / 8}, {(long long)sample.num_minor_u8_padded(), (long long)1}, - true)); + + void *ptr = sample.data.u8; + ssize_t itemsize = sizeof(uint8_t); + std::vector shape{(ssize_t)num_samples, (ssize_t)(circuit.count_measurements() + 7) / 8}; + std::vector stride{(ssize_t)sample.num_minor_u8_padded(), 1}; + const std::string &format = pybind11::format_descriptor::value; + bool readonly = true; + return pybind11::array_t(pybind11::buffer_info(ptr, itemsize, format, 2, shape, stride, readonly)); } -std::string CompiledMeasurementSampler::str() const { +std::string CompiledMeasurementSampler::repr() const { std::stringstream result; - result << "# reference sample: "; - for (size_t k = 0; k < circuit.num_measurements; k++) { - result << "01"[ref[k]]; - } - result << "\n"; - result << circuit; + result << "stim.CompiledMeasurementSampler("; + result << circuit_repr(circuit); + result << ")"; return result.str(); } @@ -109,5 +115,5 @@ void pybind_compiled_measurement_sampler(pybind11::module &m) { The bit for measurement `m` in shot `s` is at `result[s, (m // 8)] & 2**(m % 8)`. )DOC", pybind11::arg("shots")) - .def("__str__", &CompiledMeasurementSampler::str); + .def("__repr__", &CompiledMeasurementSampler::repr); } diff --git a/src/py/compiled_measurement_sampler.pybind.h b/src/py/compiled_measurement_sampler.pybind.h index 86bc14c54..cf8fb8180 100644 --- a/src/py/compiled_measurement_sampler.pybind.h +++ b/src/py/compiled_measurement_sampler.pybind.h @@ -30,7 +30,7 @@ struct CompiledMeasurementSampler { CompiledMeasurementSampler(Circuit circuit); pybind11::array_t sample(size_t num_samples); pybind11::array_t sample_bit_packed(size_t num_samples); - std::string str() const; + std::string repr() const; }; #endif diff --git a/src/simd/monotonic_buffer.h b/src/simd/monotonic_buffer.h new file mode 100644 index 000000000..2ba5f9a3f --- /dev/null +++ b/src/simd/monotonic_buffer.h @@ -0,0 +1,164 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MONOTONIC_BUFFER_H +#define MONOTONIC_BUFFER_H + +#include +#include +#include +#include +#include +#include +#include + +#include "pointer_range.h" + +/// A memory resource that can efficiently incrementally accumulate data. +/// +/// There are three important types of "region" in play: the tail region, the current region, and old regions. +/// +/// The tail is for contiguous data being added incrementally into the buffer. +/// When the tail grows beyond the currently available storage, more memory is allocated and the tail is is copied into +/// the new memory so that it can stay contiguous. At any time, the tail can be discarded or committed. Discarding the +/// tail allows the memory it was covering to be re-used when writing the next tail. Committing the tail permanently +/// preserves that data (until the monotonic buffer is cleared or deconstructed) and also guarantees it will no longer +/// move so pointers to it can be stored. +/// +/// The current region is a contiguous chunk of memory that the tail is being written into. +/// When the tail grows beyond this region and triggers an allocation, the current region is relabelled as an old region +/// and the newly allocated memory is now the current region. Each subsequent current region will be at least double the +/// size of the previous one. +/// +/// The old regions are memory that has been finalized, and will be stored until the buffer is cleared or deconstructed. +template +struct MonotonicBuffer { + /// Contiguous memory that is being appended to, but has not yet been committed. + PointerRange tail; + /// The current contiguous memory region with a mix of committed, staged, and unused memory. + PointerRange cur; + /// Old contiguous memory regions that have been committed and now need to be kept. + std::vector> old_areas; + + /// Constructs an empty monotonic buffer. + MonotonicBuffer() = default; + /// Constructs an empty monotonic buffer with initial capacity for its current region. + MonotonicBuffer(size_t reserve) { + ensure_available(reserve); + } + ~MonotonicBuffer() { + for (auto &v : old_areas) { + free(v.ptr_start); + } + if (cur.ptr_start) { + free(cur.ptr_start); + } + old_areas.clear(); + cur.ptr_start = cur.ptr_end = tail.ptr_start = tail.ptr_end = nullptr; + } + MonotonicBuffer(MonotonicBuffer &&other) noexcept + : tail(other.tail), cur(other.cur), old_areas(std::move(other.old_areas)) { + other.cur.ptr_start = nullptr; + other.cur.ptr_end = nullptr; + other.tail.ptr_start = nullptr; + other.tail.ptr_end = nullptr; + } + MonotonicBuffer(const MonotonicBuffer &other) = delete; + MonotonicBuffer &operator=(MonotonicBuffer &&other) noexcept { + (*this).~MonotonicBuffer(); + new (this) MonotonicBuffer(std::move(other)); + return *this; + } + + /// Invalidates all previous data and resets the class into a clean state. + /// + /// Happens to keep the current contiguous memory region and free old regions. + void clear() { + for (auto &v : old_areas) { + free(v.ptr_start); + } + old_areas.clear(); + tail.ptr_end = tail.ptr_start = cur.ptr_start; + } + + /// Returns the size of memory allocated and held by this monotonic buffer (in units of sizeof(T)). + size_t total_allocated() const { + size_t result = cur.size(); + for (auto &old : old_areas) { + result += old.size(); + } + return result; + } + + /// Appends and commits data. + /// Requires the tail to be empty, to avoid bugs where previously staged data is committed. + PointerRange take_copy(ConstPointerRange data) { + assert(tail.size() == 0); + append_tail(data); + return commit_tail(); + } + + /// Adds a staged data item. + void append_tail(T item) { + ensure_available(1); + *tail.ptr_end = item; + tail.ptr_end++; + } + + /// Adds staged data. + void append_tail(ConstPointerRange data) { + ensure_available(data.size()); + std::copy(data.begin(), data.end(), tail.ptr_end); + tail.ptr_end += data.size(); + } + + /// Throws away staged data, so its memory can be re-used. + void discard_tail() { + tail.ptr_end = tail.ptr_start; + } + + /// Changes staged data into committed data that will be kept until the buffer is cleared or deconstructed. + PointerRange commit_tail() { + PointerRange result(tail); + tail.ptr_start = tail.ptr_end; + return result; + } + + /// Ensures it is possible to stage at least `min_required` more items without more reallocations. + void ensure_available(size_t min_required) { + size_t available = cur.ptr_end - tail.ptr_end; + if (available >= min_required) { + return; + } + + size_t alloc_count = std::max(min_required, cur.size() << 1); + if (cur.ptr_start != nullptr) { + old_areas.push_back(cur); + } + cur.ptr_start = (T *)malloc(alloc_count * sizeof(T)); + cur.ptr_end = cur.ptr_start + alloc_count; + + // Staged data is not complete yet; keep it contiguous by copying it to the new larger memory region. + size_t tail_size = tail.size(); + if (tail_size) { + std::move(tail.ptr_start, tail.ptr_end, cur.ptr_start); + } + + tail = {cur.ptr_start, cur.ptr_start + tail_size}; + } +}; + +#endif diff --git a/src/simd/monotonic_buffer.test.cc b/src/simd/monotonic_buffer.test.cc new file mode 100644 index 000000000..543aa3f22 --- /dev/null +++ b/src/simd/monotonic_buffer.test.cc @@ -0,0 +1,51 @@ +// Copyright 2021 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "monotonic_buffer.h" + +#include + +TEST(pointer_range, equality) { + int data[100]{}; + PointerRange r1{&data[0], &data[10]}; + PointerRange r2{&data[10], &data[20]}; + PointerRange r4{&data[30], &data[50]}; + ASSERT_TRUE(r1 == r2); + ASSERT_FALSE(r1 != r2); + ASSERT_EQ(r1, r2); + ASSERT_NE(r1, r4); + r2[0] = 1; + ASSERT_EQ(data[10], 1); + ASSERT_NE(r1, r2); + ASSERT_TRUE(r1 != r2); + ASSERT_FALSE(r1 == r2); + ASSERT_NE(r1, r4); + r2[0] = 0; + ASSERT_EQ(r1, r2); + r2[6] = 1; + ASSERT_NE(r1, r2); +} + +TEST(monotonic_buffer, x) { + MonotonicBuffer buf; + for (size_t k = 0; k < 100; k++) { + buf.append_tail(k); + } + + PointerRange rng = buf.commit_tail(); + ASSERT_EQ(rng.size(), 100); + for (size_t k = 0; k < 100; k++) { + ASSERT_EQ(rng[k], k); + } +} diff --git a/src/simd/pointer_range.h b/src/simd/pointer_range.h new file mode 100644 index 000000000..6a3f54245 --- /dev/null +++ b/src/simd/pointer_range.h @@ -0,0 +1,174 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef POINTER_RANGE_H +#define POINTER_RANGE_H + +#include +#include +#include +#include +#include +#include + +/// Delineates mutable memory using an inclusive start pointer and exclusive end pointer. +template +struct PointerRange { + T *ptr_start; + T *ptr_end; + PointerRange() : ptr_start(nullptr), ptr_end(nullptr) { + } + PointerRange(T *begin, T *end) : ptr_start(begin), ptr_end(end) { + } + // Implicit conversions. + PointerRange(std::vector &items) : ptr_start(items.data()), ptr_end(items.data() + items.size()) { + } + + size_t size() const { + return ptr_end - ptr_start; + } + const T *begin() const { + return ptr_start; + } + const T *end() const { + return ptr_end; + } + T *begin() { + return ptr_start; + } + T *end() { + return ptr_end; + } + const T &operator[](size_t index) const { + return ptr_start[index]; + } + T &operator[](size_t index) { + return ptr_start[index]; + } + + bool operator==(const PointerRange &other) const { + size_t n = size(); + if (n != other.size()) { + return false; + } + for (size_t k = 0; k < n; k++) { + if (ptr_start[k] != other[k]) { + return false; + } + } + return true; + } + bool operator!=(const PointerRange &other) const { + return !(*this == other); + } + + std::string str() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + /// Lexicographic ordering. + bool operator<(const PointerRange &other) const { + auto n = std::min(size(), other.size()); + for (size_t k = 0; k < n; k++) { + if ((*this)[k] != other[k]) { + return (*this)[k] < other[k]; + } + } + return size() < other.size(); + } +}; + +/// Delineates readable memory using an inclusive start pointer and exclusive end pointer. +template +struct ConstPointerRange { + const T *ptr_start; + const T *ptr_end; + + ConstPointerRange() : ptr_start(nullptr), ptr_end(nullptr) { + } + ConstPointerRange(const T *begin, const T *end) : ptr_start(begin), ptr_end(end) { + } + // Implicit conversions. + ConstPointerRange(PointerRange items) : ptr_start(items.ptr_start), ptr_end(items.ptr_end) { + } + ConstPointerRange(const std::vector &items) : ptr_start(items.data()), ptr_end(items.data() + items.size()) { + } + + size_t size() const { + return ptr_end - ptr_start; + } + const T *begin() const { + return ptr_start; + } + const T *end() const { + return ptr_end; + } + const T &operator[](size_t index) const { + return ptr_start[index]; + } + + bool operator==(const ConstPointerRange &other) const { + size_t n = size(); + if (n != other.size()) { + return false; + } + for (size_t k = 0; k < n; k++) { + if (ptr_start[k] != other[k]) { + return false; + } + } + return true; + } + bool operator!=(const ConstPointerRange &other) const { + return !(*this == other); + } + + std::string str() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } + + /// Lexicographic ordering. + bool operator<(const ConstPointerRange &other) const { + auto n = std::min(size(), other.size()); + for (size_t k = 0; k < n; k++) { + if ((*this)[k] != other[k]) { + return (*this)[k] < other[k]; + } + } + return size() < other.size(); + } +}; + +template +std::ostream &operator<<(std::ostream &out, ConstPointerRange v) { + out << "PointerRange{"; + bool first = true; + for (auto &e : v) { + if (!first) { + out << ", "; + } + first = false; + out << e; + } + out << "}"; + return out; +} + +#endif diff --git a/src/simd/simd_bit_table.cc b/src/simd/simd_bit_table.cc index 0ab495a37..335d28de1 100644 --- a/src/simd/simd_bit_table.cc +++ b/src/simd/simd_bit_table.cc @@ -148,6 +148,14 @@ simd_bit_table simd_bit_table::transposed() const { return result; } +simd_bit_table simd_bit_table::slice_maj(size_t maj_start_bit, size_t maj_stop_bit) const { + simd_bit_table result(maj_stop_bit - maj_start_bit, num_minor_bits_padded()); + for (size_t k = maj_start_bit; k < maj_stop_bit; k++) { + result[k - maj_start_bit] = (*this)[k]; + } + return result; +} + void simd_bit_table::transpose_into(simd_bit_table &out) const { assert(out.num_simd_words_minor == num_simd_words_major); assert(out.num_simd_words_major == num_simd_words_minor); @@ -205,3 +213,12 @@ std::string simd_bit_table::str(size_t n) const { } return out.str(); } + +simd_bit_table simd_bit_table::random( + size_t num_randomized_major_bits, size_t num_randomized_minor_bits, std::mt19937_64 &rng) { + simd_bit_table result(num_randomized_major_bits, num_randomized_minor_bits); + for (size_t maj = 0; maj < num_randomized_major_bits; maj++) { + result[maj].randomize(num_randomized_minor_bits, rng); + } + return result; +} diff --git a/src/simd/simd_bit_table.h b/src/simd/simd_bit_table.h index a08a25831..92dd85aa2 100644 --- a/src/simd/simd_bit_table.h +++ b/src/simd/simd_bit_table.h @@ -26,6 +26,9 @@ struct simd_bit_table { /// Creates zero initialized table. simd_bit_table(size_t min_bits_major, size_t min_bits_minor); + /// Creates a randomly initialized table. + static simd_bit_table random( + size_t num_randomized_major_bits, size_t num_randomized_minor_bits, std::mt19937_64 &rng); /// Creates a square table with 1s down the diagonal. static simd_bit_table identity(size_t n); /// Concatenates tables together to form a larger table. @@ -55,8 +58,11 @@ struct simd_bit_table { void do_square_transpose(); /// Transposes the table out of place into a target location. void transpose_into(simd_bit_table &out) const; + void transpose_into(simd_bit_table &out, size_t major_start_bit, size_t min_major_length_bits) const; /// Transposes the table out of place. simd_bit_table transposed() const; + /// Returns a subset of the table. + simd_bit_table slice_maj(size_t maj_start_bit, size_t maj_stop_bit) const; /// Sets all bits in the table to zero. void clear(); diff --git a/src/simd/simd_bit_table.test.cc b/src/simd/simd_bit_table.test.cc index d7e142ec7..d9cc062ba 100644 --- a/src/simd/simd_bit_table.test.cc +++ b/src/simd/simd_bit_table.test.cc @@ -16,6 +16,8 @@ #include +#include "../test_util.test.h" + TEST(bit_mat, creation) { simd_bit_table a(3, 3); ASSERT_EQ( @@ -167,6 +169,24 @@ TEST(bit_mat, transposed) { ASSERT_EQ(trans2, m); } +TEST(bit_mat, random) { + auto t = simd_bit_table::random(100, 90, SHARED_TEST_RNG()); + ASSERT_NE(t[99], simd_bits(90)); + ASSERT_EQ(t[100], simd_bits(90)); + t = t.transposed(); + ASSERT_NE(t[89], simd_bits(100)); + ASSERT_EQ(t[90], simd_bits(100)); + ASSERT_NE(simd_bit_table::random(10, 10, SHARED_TEST_RNG()), simd_bit_table::random(10, 10, SHARED_TEST_RNG())); +} + +TEST(bit_mat, slice_maj) { + auto m = simd_bit_table::random(100, 64, SHARED_TEST_RNG()); + auto s = m.slice_maj(5, 15); + ASSERT_EQ(s[0], m[5]); + ASSERT_EQ(s[9], m[14]); + ASSERT_FALSE(s[10].not_zero()); +} + TEST(bit_mat, from_quadrants) { simd_bit_table t(2, 2); simd_bit_table z(2, 2); diff --git a/src/simd/simd_bits.cc b/src/simd/simd_bits.cc index e5d170789..8a10c6033 100644 --- a/src/simd/simd_bits.cc +++ b/src/simd/simd_bits.cc @@ -134,6 +134,16 @@ simd_bits &simd_bits::operator^=(const simd_bits_range_ref other) { return *this; } +simd_bits &simd_bits::operator&=(const simd_bits_range_ref other) { + simd_bits_range_ref(*this) &= other; + return *this; +} + +simd_bits &simd_bits::operator|=(const simd_bits_range_ref other) { + simd_bits_range_ref(*this) |= other; + return *this; +} + bool simd_bits::not_zero() const { return simd_bits_range_ref(*this).not_zero(); } diff --git a/src/simd/simd_bits.h b/src/simd/simd_bits.h index a9417d0e9..dd9b9e7d5 100644 --- a/src/simd/simd_bits.h +++ b/src/simd/simd_bits.h @@ -57,6 +57,9 @@ struct simd_bits { simd_bits &operator=(simd_bits &&other) noexcept; // Xor assignment. simd_bits &operator^=(const simd_bits_range_ref other); + // Mask assignment. + simd_bits &operator&=(const simd_bits_range_ref other); + simd_bits &operator|=(const simd_bits_range_ref other); // Swap assignment. simd_bits &swap_with(simd_bits_range_ref other); diff --git a/src/simd/simd_bits.test.cc b/src/simd/simd_bits.test.cc index 635f7bdba..cb6f14ea4 100644 --- a/src/simd/simd_bits.test.cc +++ b/src/simd/simd_bits.test.cc @@ -235,3 +235,31 @@ TEST(simd_bits, invert_bits) { ASSERT_EQ(r[k], k != 5); } } + +TEST(simd_bits, mask_assignment_and) { + simd_bits a(4); + simd_bits b(4); + a[2] = true; + a[3] = true; + b[1] = true; + b[3] = true; + b &= a; + simd_bits expected(4); + expected[3] = true; + ASSERT_EQ(b, expected); +} + +TEST(simd_bits, mask_assignment_or) { + simd_bits a(4); + simd_bits b(4); + a[2] = true; + a[3] = true; + b[1] = true; + b[3] = true; + b |= a; + simd_bits expected(4); + expected[1] = true; + expected[2] = true; + expected[3] = true; + ASSERT_EQ(b, expected); +} diff --git a/src/simd/simd_bits_range_ref.cc b/src/simd/simd_bits_range_ref.cc index 651ce0221..d03c1a2c0 100644 --- a/src/simd/simd_bits_range_ref.cc +++ b/src/simd/simd_bits_range_ref.cc @@ -30,6 +30,20 @@ simd_bits_range_ref simd_bits_range_ref::operator^=(const simd_bits_range_ref ot return *this; } +simd_bits_range_ref simd_bits_range_ref::operator|=(const simd_bits_range_ref other) { + for_each_word(other, [](simd_word &w0, simd_word &w1) { + w0 |= w1; + }); + return *this; +} + +simd_bits_range_ref simd_bits_range_ref::operator&=(const simd_bits_range_ref other) { + for_each_word(other, [](simd_word &w0, simd_word &w1) { + w0 &= w1; + }); + return *this; +} + simd_bits_range_ref simd_bits_range_ref::operator=(const simd_bits_range_ref other) { memcpy(ptr_simd, other.ptr_simd, num_u8_padded()); return *this; diff --git a/src/simd/simd_bits_range_ref.h b/src/simd/simd_bits_range_ref.h index 91a73f468..0585f6c9e 100644 --- a/src/simd/simd_bits_range_ref.h +++ b/src/simd/simd_bits_range_ref.h @@ -47,6 +47,9 @@ struct simd_bits_range_ref { other); // NOLINT(cppcoreguidelines-c-copy-assignment-signature,misc-unconventional-assign-operator) /// Xor assignment. simd_bits_range_ref operator^=(const simd_bits_range_ref other); + /// Mask assignment. + simd_bits_range_ref operator&=(const simd_bits_range_ref other); + simd_bits_range_ref operator|=(const simd_bits_range_ref other); /// Swap assignment. void swap_with(simd_bits_range_ref other); diff --git a/src/simd/simd_bits_range_ref.test.cc b/src/simd/simd_bits_range_ref.test.cc index a37742834..628d394c9 100644 --- a/src/simd/simd_bits_range_ref.test.cc +++ b/src/simd/simd_bits_range_ref.test.cc @@ -217,6 +217,8 @@ TEST(simd_bits_range_ref, for_each_set_bit) { ref[5] = true; ref[101] = true; std::vector hits; - ref.for_each_set_bit([&](size_t k) { hits.push_back(k); }); + ref.for_each_set_bit([&](size_t k) { + hits.push_back(k); + }); ASSERT_EQ(hits, (std::vector{5, 101})); } diff --git a/src/simd/sparse_xor_vec.h b/src/simd/sparse_xor_vec.h index c7b61a065..6d28f95cb 100644 --- a/src/simd/sparse_xor_vec.h +++ b/src/simd/sparse_xor_vec.h @@ -17,25 +17,28 @@ #ifndef SPARSE_XOR_TABLE_H #define SPARSE_XOR_TABLE_H +#include +#include #include +#include #include #include -#include "vector_view.h" +#include "monotonic_buffer.h" -/// Merge sorts the elements of two sorted buffers into an output buffer while cancelling out duplicate items. +/// Merges the elements of two sorted buffers into an output buffer while cancelling out duplicate items. /// -/// \param p1: Pointer to the first input buffer. -/// \param n1: Number of items in the first input buffer. -/// \param p2: Pointer to the second input buffer. -/// \param n2: Number of items in the second input buffer. -/// \param out: Pointer to the output buffer. The output buffer must have a size of at least n1+n2. -/// \return: The (exclusive) end pointer of the written part of the output buffer. +/// \param sorted_in1: Pointer range covering the first sorted list. +/// \param sorted_in2: Pointer range covering the second sorted list. +/// \param out: Where to write the output. Must have size of at least sorted_in1.size() + sorted_in2.size(). +/// \return: A pointer to the end of the output (one past the last place written). template -inline T *xor_merge_sorted_items_into(const T *p1, size_t n1, const T *p2, size_t n2, T *out) { +inline T *xor_merge_sort(ConstPointerRange sorted_in1, ConstPointerRange sorted_in2, T *out) { // Interleave sorted src and dst into a sorted work buffer. - auto *end1 = p1 + n1; - auto *end2 = p2 + n2; + const T *p1 = sorted_in1.ptr_start; + const T *p2 = sorted_in2.ptr_start; + const T *end1 = sorted_in1.ptr_end; + const T *end2 = sorted_in2.ptr_end; while (p1 != end1) { if (p2 == end2 || *p1 < *p2) { *out++ = *p1++; @@ -53,60 +56,63 @@ inline T *xor_merge_sorted_items_into(const T *p1, size_t n1, const T *p2, size_ return out; } -// HACK: this should be templated, but it's not in order to have compatibility with C++11. -static std::vector _shared_buf; - -template -inline void vector_tail_view_xor_in_place(VectorView &buf, const T *p2, size_t n2) { - size_t max = buf.size() + n2; - if (_shared_buf.size() < max) { - _shared_buf.resize(2 * max); - } - auto end = xor_merge_sorted_items_into(buf.begin(), buf.size(), p2, n2, _shared_buf.data()); - buf.length = end - _shared_buf.data(); - buf.vec_ptr->resize(buf.offset); - buf.vec_ptr->insert(buf.vec_ptr->end(), _shared_buf.data(), _shared_buf.data() + buf.length); -} - -template -inline void xor_into_vector_tail_view(VectorView &buf, const T *p1, size_t n1, const T *p2, size_t n2) { - buf.vec_ptr->resize(buf.offset + n1 + n2); - auto end = xor_merge_sorted_items_into(p1, n1, p2, n2, buf.begin()); - buf.length = end - buf.begin(); - buf.vec_ptr->resize(buf.offset + buf.length); +template +inline void xor_merge_sort_temp_buffer_callback( + ConstPointerRange sorted_items_1, ConstPointerRange sorted_items_2, CALLBACK handler) { + constexpr size_t STACK_SIZE = 64; + T data[STACK_SIZE]; + size_t max = sorted_items_1.size() + sorted_items_2.size(); + T *begin = max > STACK_SIZE ? new T[max] : &data[0]; + T *end = xor_merge_sort(sorted_items_1, sorted_items_2, begin); + handler(ConstPointerRange(begin, end)); + if (max > STACK_SIZE) { + delete[] begin; + } } /// A sparse set of integers that supports efficient xoring (computing the symmetric difference). template struct SparseXorVec { - private: - inline SparseXorVec xor_helper(const T *src_ptr, size_t src_size) const { - SparseXorVec result; - result.vec.resize(size() + src_size); - auto n = xor_merge_sorted_items_into(begin(), size(), src_ptr, src_size, result.begin()) - result.begin(); - result.vec.resize(n); - return result; + public: + // Sorted list of entries. + std::vector sorted_items; + + SparseXorVec() = default; + SparseXorVec(std::vector &&vec) : sorted_items(std::move(vec)) { } - public: - std::vector vec; + void set_to_xor_merge_sort(ConstPointerRange sorted_items1, ConstPointerRange sorted_items2) { + sorted_items.resize(sorted_items1.size() + sorted_items2.size()); + auto written = xor_merge_sort(sorted_items, sorted_items1, sorted_items2); + sorted_items.resize(written.size()); + } - inline void inplace_xor_helper(const T *src_ptr, size_t src_size) { - VectorView view{&vec, 0, vec.size()}; - vector_tail_view_xor_in_place(view, src_ptr, src_size); + void xor_sorted_items(ConstPointerRange sorted) { + xor_merge_sort_temp_buffer_callback(range(), sorted, [&](ConstPointerRange result) { + sorted_items.clear(); + sorted_items.insert(sorted_items.end(), result.begin(), result.end()); + }); } - SparseXorVec &operator^=(const T &other) { - inplace_xor_helper(&other, 1); - return *this; + + void clear() { + sorted_items.clear(); } - SparseXorVec &operator^=(const SparseXorVec &other) { - inplace_xor_helper(other.begin(), other.size()); + void xor_item(const T &item) { + xor_sorted_items({&item, &item + 1}); + } + + SparseXorVec &operator^=(const SparseXorVec &other) { + xor_sorted_items(other.range()); return *this; } - SparseXorVec operator^(const SparseXorVec &other) const { - return xor_helper(other.begin(), other.size()); + SparseXorVec operator^(const SparseXorVec &other) const { + SparseXorVec result; + result.sorted_items.resize(size() + other.size()); + auto n = xor_merge_sort(range(), other.range(), result.begin()) - result.begin(); + result.sorted_items.resize(n); + return result; } SparseXorVec operator^(const T &other) const { @@ -114,35 +120,39 @@ struct SparseXorVec { } bool operator<(const SparseXorVec &other) const { - return view() < other.view(); + return range() < other.range(); } inline size_t size() const { - return vec.size(); + return sorted_items.size(); } inline T *begin() { - return vec.data(); + return sorted_items.data(); } inline T *end() { - return vec.data() + size(); + return sorted_items.data() + size(); } inline const T *begin() const { - return vec.data(); + return sorted_items.data(); } inline const T *end() const { - return vec.data() + size(); + return sorted_items.data() + size(); } bool operator==(const SparseXorVec &other) const { - return vec == other.vec; + return sorted_items == other.sorted_items; } bool operator!=(const SparseXorVec &other) const { - return vec != other.vec; + return sorted_items != other.sorted_items; + } + + ConstPointerRange range() const { + return {begin(), end()}; } std::string str() const { @@ -150,11 +160,6 @@ struct SparseXorVec { ss << *this; return ss.str(); } - - const VectorView view() const { - // Temporarily remove const correctness but then immediately restore it. - return VectorView{(std::vector *)&vec, 0, vec.size()}; - } }; template diff --git a/src/simd/sparse_xor_vec.perf.cc b/src/simd/sparse_xor_vec.perf.cc index 23f605d7e..0e952bca6 100644 --- a/src/simd/sparse_xor_vec.perf.cc +++ b/src/simd/sparse_xor_vec.perf.cc @@ -22,11 +22,11 @@ BENCHMARK(SparseXorTable_SmallRowXor_1000) { size_t n = 1000; std::vector> table(n); for (uint32_t k = 0; k < n; k++) { - table[k] ^= k; - table[k] ^= k + 1; - table[k] ^= k + 4; - table[k] ^= k + 8; - table[k] ^= k + 15; + table[k].xor_item(k); + table[k].xor_item(k + 1); + table[k].xor_item(k + 4); + table[k].xor_item(k + 8); + table[k].xor_item(k + 15); } benchmark_go([&]() { diff --git a/src/simd/sparse_xor_vec.test.cc b/src/simd/sparse_xor_vec.test.cc index 9455fd9b2..789f161a3 100644 --- a/src/simd/sparse_xor_vec.test.cc +++ b/src/simd/sparse_xor_vec.test.cc @@ -23,62 +23,62 @@ TEST(sparse_xor_table, inplace_xor) { SparseXorVec v1; SparseXorVec v2; - v1 ^= 1; - v1 ^= 3; - v2 ^= 2; - v2 ^= 3; - ASSERT_EQ(v1.vec, (std::vector{1, 3})); - ASSERT_EQ(v2.vec, (std::vector{2, 3})); + v1.xor_item(1); + v1.xor_item(3); + v2.xor_item(2); + v2.xor_item(3); + ASSERT_EQ(v1.sorted_items, (std::vector{1, 3})); + ASSERT_EQ(v2.sorted_items, (std::vector{2, 3})); v1 ^= v2; - ASSERT_EQ(v1.vec, (std::vector{1, 2})); - ASSERT_EQ(v2.vec, (std::vector{2, 3})); + ASSERT_EQ(v1.sorted_items, (std::vector{1, 2})); + ASSERT_EQ(v2.sorted_items, (std::vector{2, 3})); } TEST(sparse_xor_table, grow) { SparseXorVec v1; SparseXorVec v2; - v1 ^= 1; - v1 ^= 3; - v1 ^= 6; - v2 ^= 2; - v2 ^= 3; - v2 ^= 4; - v2 ^= 5; + v1.xor_item(1); + v1.xor_item(3); + v1.xor_item(6); + v2.xor_item(2); + v2.xor_item(3); + v2.xor_item(4); + v2.xor_item(5); v1 ^= v2; - ASSERT_EQ(v1.vec, (std::vector{1, 2, 4, 5, 6})); + ASSERT_EQ(v1.sorted_items, (std::vector{1, 2, 4, 5, 6})); } TEST(sparse_xor_table, historical_failure_case) { SparseXorVec v1; SparseXorVec v2; - v1 ^= 1; - v1 ^= 2; - v1 ^= 3; - v1 ^= 6; - v1 ^= 9; - v2 ^= 2; - v1 ^= 2; - ASSERT_EQ(v1.vec, (std::vector{1, 3, 6, 9})); - ASSERT_EQ(v2.vec, (std::vector{2})); + v1.xor_item(1); + v1.xor_item(2); + v1.xor_item(3); + v1.xor_item(6); + v1.xor_item(9); + v2.xor_item(2); + v1.xor_item(2); + ASSERT_EQ(v1.sorted_items, (std::vector{1, 3, 6, 9})); + ASSERT_EQ(v2.sorted_items, (std::vector{2})); } TEST(sparse_xor_table, comparison) { SparseXorVec v1; - v1 ^= 1; - v1 ^= 3; + v1.xor_item(1); + v1.xor_item(3); SparseXorVec v2; - v2 ^= 1; + v2.xor_item(1); ASSERT_TRUE(v1 != v2); ASSERT_TRUE(!(v1 == v2)); ASSERT_TRUE(v2 < v1); ASSERT_TRUE(!(v1 < v2)); - v2 ^= 4; + v2.xor_item(4); ASSERT_TRUE(v1 != v2); ASSERT_TRUE(!(v1 == v2)); ASSERT_TRUE(!(v2 < v1)); ASSERT_TRUE(v1 < v2); - v2 ^= 4; - v2 ^= 3; + v2.xor_item(4); + v2.xor_item(3); ASSERT_TRUE(v1 == v2); ASSERT_TRUE(!(v1 != v2)); ASSERT_TRUE(!(v2 < v1)); @@ -88,10 +88,10 @@ TEST(sparse_xor_table, comparison) { TEST(sparse_xor_table, str) { SparseXorVec v; ASSERT_EQ(v.str(), "SparseXorVec{}"); - v ^= 5; + v.xor_item(5); ASSERT_EQ(v.str(), "SparseXorVec{5}"); - v ^= 2; + v.xor_item(2); ASSERT_EQ(v.str(), "SparseXorVec{2, 5}"); - v ^= 5000; + v.xor_item(5000); ASSERT_EQ(v.str(), "SparseXorVec{2, 5, 5000}"); } diff --git a/src/simd/vector_view.h b/src/simd/vector_view.h deleted file mode 100644 index 8b52ae811..000000000 --- a/src/simd/vector_view.h +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright 2021 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef VECTOR_VIEW_H -#define VECTOR_VIEW_H - -#include -#include - -/// A pointer to contiguous data inside a std::vector. -/// -/// This is used to avoid copies and allocations in contexts where there is a lot of jagged data. The data is aggregated -/// into one vector, and then referenced using views into that vector. For example, it is used by Circuit to store -/// operation target data. -/// -/// A notable distinction between a vector view and a raw pointer with size information is that the vector view does not -/// become invalid when appending to the underlying vector (e.g. due to it reallocating its backing buffer). -template -struct VectorView { - std::vector *vec_ptr; - size_t offset; - size_t length; - - inline size_t size() const { - return length; - } - - inline T operator[](size_t k) const { - return (*vec_ptr)[offset + k]; - } - - inline T &operator[](size_t k) { - return (*vec_ptr)[offset + k]; - } - - T *begin() { - return vec_ptr->data() + offset; - } - - T *end() { - return vec_ptr->data() + offset + length; - } - - const T *begin() const { - return vec_ptr->data() + offset; - } - - const T *end() const { - return vec_ptr->data() + offset + length; - } - - bool operator==(const VectorView &other) const { - if (length != other.length) { - return false; - } - for (size_t k = 0; k < length; k++) { - if ((*this)[k] != other[k]) { - return false; - } - } - return true; - } - - bool operator!=(const VectorView &other) const { - return !(*this == other); - } - - /// Lexicographic ordering. - bool operator<(const VectorView &other) const { - auto n = std::min(size(), other.size()); - for (size_t k = 0; k < n; k++) { - if ((*this)[k] != other[k]) { - return (*this)[k] < other[k]; - } - } - return size() < other.size(); - } - - std::string str() const { - std::stringstream ss; - ss << *this; - return ss.str(); - } -}; - -template -std::ostream &operator<<(std::ostream &out, const VectorView &v) { - out << "VectorView{"; - bool first = true; - for (auto &e : v) { - if (!first) { - out << ", "; - } - first = false; - out << e; - } - out << "}"; - return out; -} - -template -struct JaggedDataArena { - std::vector vec; - JaggedDataArena() : vec() { - } - JaggedDataArena(const JaggedDataArena &) = delete; - JaggedDataArena(JaggedDataArena &&) noexcept = delete; - JaggedDataArena &operator=(JaggedDataArena &&) noexcept = delete; - JaggedDataArena &operator=(const JaggedDataArena &) = delete; - - VectorView view(size_t start, size_t size) { - return {&vec, start, size}; - } - VectorView tail_view(size_t start) { - return view(start, vec.size() - start); - } - VectorView inserted(const T *data, size_t size) { - size_t n = vec.size(); - vec.insert(vec.end(), data, data + size); - return tail_view(n); - } - VectorView inserted(const std::vector &items) { - return inserted(items.data(), items.size()); - } - VectorView inserted(const VectorView &items) { - if (items.vec_ptr == &vec) { - return items; - } - return inserted(items.begin(), items.size()); - } -}; - -#endif diff --git a/src/simulators/detection_simulator.cc b/src/simulators/detection_simulator.cc index 7690489a5..f961a5794 100644 --- a/src/simulators/detection_simulator.cc +++ b/src/simulators/detection_simulator.cc @@ -31,21 +31,21 @@ simd_bit_table detector_samples( auto num_detectors = det_obs.detectors.size(); auto num_obs = det_obs.observables.size(); - size_t num_results = num_detectors + num_obs * ((int)prepend_observables + (int)append_observables); + size_t num_results = num_detectors + num_obs * (prepend_observables + append_observables); simd_bit_table result(num_results, num_shots); // Xor together measurement samples to form detector samples. size_t offset = 0; if (prepend_observables) { - for (auto obs : det_obs.observables) { + for (const auto &obs : det_obs.observables) { xor_measurement_set_into_result(obs, frame_samples, result, offset); } } - for (auto det : det_obs.detectors) { + for (const auto &det : det_obs.detectors) { xor_measurement_set_into_result(det, frame_samples, result, offset); } if (append_observables) { - for (auto obs : det_obs.observables) { + for (const auto &obs : det_obs.observables) { xor_measurement_set_into_result(obs, frame_samples, result, offset); } } @@ -59,7 +59,56 @@ simd_bit_table detector_samples( circuit, DetectorsAndObservables(circuit), num_shots, prepend_observables, append_observables, rng); } -void detector_samples_out( +void detector_sample_out_helper_stream( + const Circuit &circuit, FrameSimulator &sim, size_t num_samples, bool append_observables, FILE *out, + SampleFormat format) { + MeasureRecordBatchWriter writer(out, num_samples, format); + std::vector observables; + sim.reset_all(); + writer.begin_result_type('D'); + simd_bit_table detector_buffer(1024, num_samples); + size_t buffered_detectors = 0; + circuit.for_each_operation([&](const Operation &op) { + if (op.gate->id == gate_name_to_id("DETECTOR")) { + simd_bits_range_ref result = detector_buffer[buffered_detectors]; + for (auto t : op.target_data.targets) { + assert(t & TARGET_RECORD_BIT); + result ^= sim.m_record.lookback(t ^ TARGET_RECORD_BIT); + } + buffered_detectors++; + if (buffered_detectors == 1024) { + writer.batch_write_bytes(detector_buffer, 1024 >> 6); + buffered_detectors = 0; + } + } else if (op.gate->id == gate_name_to_id("OBSERVABLE_INCLUDE")) { + if (append_observables) { + size_t id = (size_t)op.target_data.arg; + while (observables.size() <= id) { + observables.emplace_back(num_samples); + } + simd_bits_range_ref result = observables[id]; + + for (auto t : op.target_data.targets) { + assert(t & TARGET_RECORD_BIT); + result ^= sim.m_record.lookback(t ^ TARGET_RECORD_BIT); + } + } + } else { + (sim.*op.gate->frame_simulator_function)(op.target_data); + sim.m_record.mark_all_as_written(); + } + }); + for (size_t k = 0; k < buffered_detectors; k++) { + writer.batch_write_bit(detector_buffer[k]); + } + writer.begin_result_type('L'); + for (const auto &result : observables) { + writer.batch_write_bit(result); + } + writer.write_end(); +} + +void detector_samples_out_in_memory( const Circuit &circuit, size_t num_shots, bool prepend_observables, bool append_observables, FILE *out, SampleFormat format, std::mt19937_64 &rng) { if (prepend_observables && append_observables) { @@ -86,15 +135,38 @@ void detector_samples_out( ct = 0; } + auto table = detector_samples(circuit, det_obs, num_shots, prepend_observables, append_observables, rng); + write_table_data(out, num_shots, num_sample_locations, simd_bits(0), table, format, c1, c2, ct); +} + +void detector_sample_out_helper( + const Circuit &circuit, FrameSimulator &sim, size_t num_shots, bool prepend_observables, bool append_observables, + FILE *out, SampleFormat format, std::mt19937_64 &rng) { + uint64_t approx_mem_usage = std::max(num_shots, size_t{256}) * + std::max(circuit.count_measurements(), circuit.count_detectors_and_observables()); + if (!prepend_observables && approx_mem_usage > SWITCH_TO_STREAMING_MEASUREMENT_THRESHOLD) { + detector_sample_out_helper_stream(circuit, sim, num_shots, append_observables, out, format); + } else { + detector_samples_out_in_memory(circuit, num_shots, prepend_observables, append_observables, out, format, rng); + } +} + +void detector_samples_out( + const Circuit &circuit, size_t num_shots, bool prepend_observables, bool append_observables, FILE *out, + SampleFormat format, std::mt19937_64 &rng) { constexpr size_t GOOD_BLOCK_SIZE = 1024; - simd_bits reference_sample(num_sample_locations); - while (num_shots > GOOD_BLOCK_SIZE) { - auto table = detector_samples(circuit, det_obs, GOOD_BLOCK_SIZE, prepend_observables, append_observables, rng); - write_table_data(out, GOOD_BLOCK_SIZE, num_sample_locations, reference_sample, table, format, c1, c2, ct); - num_shots -= GOOD_BLOCK_SIZE; + size_t num_qubits = circuit.count_qubits(); + size_t max_lookback = circuit.max_lookback(); + if (num_shots >= GOOD_BLOCK_SIZE) { + auto sim = FrameSimulator(num_qubits, GOOD_BLOCK_SIZE, max_lookback, rng); + while (num_shots > GOOD_BLOCK_SIZE) { + detector_sample_out_helper( + circuit, sim, GOOD_BLOCK_SIZE, prepend_observables, append_observables, out, format, rng); + num_shots -= GOOD_BLOCK_SIZE; + } } if (num_shots) { - auto table = detector_samples(circuit, det_obs, num_shots, prepend_observables, append_observables, rng); - write_table_data(out, num_shots, num_sample_locations, reference_sample, table, format, c1, c2, ct); + auto sim = FrameSimulator(num_qubits, num_shots, max_lookback, rng); + detector_sample_out_helper(circuit, sim, num_shots, prepend_observables, append_observables, out, format, rng); } } diff --git a/src/simulators/error_fuser.cc b/src/simulators/error_fuser.cc index c1d7aebc8..202ae1d77 100644 --- a/src/simulators/error_fuser.cc +++ b/src/simulators/error_fuser.cc @@ -25,8 +25,8 @@ void ErrorFuser::R(const OperationData &dat) { for (size_t k = dat.targets.size(); k-- > 0;) { auto q = dat.targets[k]; - xs[q].vec.clear(); - zs[q].vec.clear(); + xs[q].clear(); + zs[q].clear(); } } @@ -34,11 +34,10 @@ void ErrorFuser::M(const OperationData &dat) { for (size_t k = dat.targets.size(); k-- > 0;) { auto q = dat.targets[k] & TARGET_VALUE_MASK; scheduled_measurement_time++; - auto view = jagged_detector_sets.inserted(measurement_to_detectors[scheduled_measurement_time]); - std::sort(view.begin(), view.end()); - zs[q].inplace_xor_helper(view.begin(), view.size()); - jagged_detector_sets.vec.resize(view.offset); - measurement_to_detectors.erase(scheduled_measurement_time); + + std::vector &d = measurement_to_detectors[scheduled_measurement_time]; + std::sort(d.begin(), d.end()); + zs[q].xor_sorted_items(d); } } @@ -50,7 +49,7 @@ void ErrorFuser::MR(const OperationData &dat) { void ErrorFuser::H_XZ(const OperationData &dat) { for (size_t k = dat.targets.size(); k-- > 0;) { auto q = dat.targets[k]; - std::swap(xs[q].vec, zs[q].vec); + std::swap(xs[q], zs[q]); } } @@ -141,17 +140,21 @@ void ErrorFuser::ZCX(const OperationData &dat) { void ErrorFuser::feedback(uint32_t record_control, size_t target, bool x, bool z) { uint32_t time = scheduled_measurement_time + (record_control & ~TARGET_RECORD_BIT); - auto &out_view = measurement_to_detectors[time]; - std::sort(out_view.begin(), out_view.end()); - VectorView view{&out_view, 0, out_view.size()}; + std::vector &dst = measurement_to_detectors[time]; + + // Temporarily move map's vector data into a SparseXorVec for manipulation. + std::sort(dst.begin(), dst.end()); + SparseXorVec tmp(std::move(dst)); + if (x) { - auto &x_vec = xs[target]; - vector_tail_view_xor_in_place(view, x_vec.begin(), x_vec.size()); + tmp ^= xs[target]; } if (z) { - auto &z_vec = zs[target]; - vector_tail_view_xor_in_place(view, z_vec.begin(), z_vec.size()); + tmp ^= zs[target]; } + + // Move data back into the map. + dst = std::move(tmp.sorted_items); } void ErrorFuser::single_cx(uint32_t c, uint32_t t) { @@ -214,8 +217,8 @@ void ErrorFuser::SWAP(const OperationData &dat) { for (size_t k = dat.targets.size() - 2; k + 2 != 0; k -= 2) { auto a = dat.targets[k]; auto b = dat.targets[k + 1]; - std::swap(xs[a].vec, xs[b].vec); - std::swap(zs[a].vec, zs[b].vec); + std::swap(xs[a], xs[b]); + std::swap(zs[a], zs[b]); } } @@ -227,8 +230,8 @@ void ErrorFuser::ISWAP(const OperationData &dat) { zs[a] ^= xs[b]; zs[b] ^= xs[a]; zs[b] ^= xs[b]; - std::swap(xs[a].vec, xs[b].vec); - std::swap(zs[a].vec, zs[b].vec); + std::swap(xs[a], xs[b]); + std::swap(zs[a], zs[b]); } } @@ -254,49 +257,50 @@ ErrorFuser::ErrorFuser(size_t num_qubits) : xs(num_qubits), zs(num_qubits) { } void ErrorFuser::run_circuit(const Circuit &circuit) { - for (size_t k = circuit.operations.size(); k-- > 0;) { - const auto &op = circuit.operations[k]; - (this->*op.gate->hit_simulator_function)(op.target_data); - } - uint32_t detector_id_root = UINT32_MAX - num_found_detectors + 1; - for (auto &t : jagged_detector_sets.vec) { - if (t > (UINT32_MAX >> 2)) { - t -= detector_id_root; - t |= TARGET_PAULI_X_BIT; - } - } + circuit.for_each_operation_reverse([&](const Operation &op) { + (this->*op.gate->reverse_error_fuser_function)(op.target_data); + }); } void ErrorFuser::X_ERROR(const OperationData &dat) { for (auto q : dat.targets) { - independent_error_1(dat.arg, zs[q]); + add_error(dat.arg, zs[q]); } } void ErrorFuser::Y_ERROR(const OperationData &dat) { for (auto q : dat.targets) { - independent_error_2(dat.arg, xs[q], zs[q]); + add_xored_error(dat.arg, xs[q], zs[q]); } } void ErrorFuser::Z_ERROR(const OperationData &dat) { for (auto q : dat.targets) { - independent_error_1(dat.arg, xs[q]); + add_error(dat.arg, xs[q]); } } +template +inline void inplace_xor_tail(MonotonicBuffer &dst, const SparseXorVec &src) { + ConstPointerRange in1 = dst.tail; + ConstPointerRange in2 = src.range(); + xor_merge_sort_temp_buffer_callback(in1, in2, [&](ConstPointerRange result) { + dst.discard_tail(); + dst.append_tail(result); + }); +} + void ErrorFuser::CORRELATED_ERROR(const OperationData &dat) { - VectorView tail = jagged_detector_sets.tail_view(jagged_detector_sets.vec.size()); for (auto qp : dat.targets) { auto q = qp & TARGET_VALUE_MASK; if (qp & TARGET_PAULI_Z_BIT) { - vector_tail_view_xor_in_place(tail, xs[q].begin(), xs[q].size()); + inplace_xor_tail(jag_flip_data, xs[q]); } if (qp & TARGET_PAULI_X_BIT) { - vector_tail_view_xor_in_place(tail, zs[q].begin(), zs[q].size()); + inplace_xor_tail(jag_flip_data, zs[q]); } } - independent_error_placed_tail(dat.arg, tail); + add_error_in_sorted_jagged_tail(dat.arg); } void ErrorFuser::DEPOLARIZE1(const OperationData &dat) { @@ -306,9 +310,9 @@ void ErrorFuser::DEPOLARIZE1(const OperationData &dat) { } double p = 0.5 - 0.5 * sqrt(1 - (4 * dat.arg) / 3); for (auto q : dat.targets) { - independent_error_1(p, xs[q]); - independent_error_1(p, zs[q]); - independent_error_2(p, xs[q], zs[q]); + add_error(p, xs[q]); + add_error(p, zs[q]); + add_xored_error(p, xs[q], zs[q]); } } @@ -330,23 +334,23 @@ void ErrorFuser::DEPOLARIZE2(const OperationData &dat) { auto y2 = x2 ^ z2; // Isolated errors. - independent_error_1(p, x1); - independent_error_1(p, y1); - independent_error_1(p, z1); - independent_error_1(p, x2); - independent_error_1(p, y2); - independent_error_1(p, z2); + add_error(p, x1); + add_error(p, y1); + add_error(p, z1); + add_error(p, x2); + add_error(p, y2); + add_error(p, z2); // Paired errors. - independent_error_2(p, x1, x2); - independent_error_2(p, y1, x2); - independent_error_2(p, z1, x2); - independent_error_2(p, x1, y2); - independent_error_2(p, y1, y2); - independent_error_2(p, z1, y2); - independent_error_2(p, x1, z2); - independent_error_2(p, y1, z2); - independent_error_2(p, z1, z2); + add_xored_error(p, x1, x2); + add_xored_error(p, y1, x2); + add_xored_error(p, z1, x2); + add_xored_error(p, x1, y2); + add_xored_error(p, y1, y2); + add_xored_error(p, z1, y2); + add_xored_error(p, x1, z2); + add_xored_error(p, y1, z2); + add_xored_error(p, z1, z2); } } @@ -356,17 +360,18 @@ void ErrorFuser::ELSE_CORRELATED_ERROR(const OperationData &dat) { } void ErrorFuser::convert_circuit_out(const Circuit &circuit, FILE *out) { - ErrorFuser fuser(circuit.num_qubits); + ErrorFuser fuser(circuit.count_qubits()); fuser.run_circuit(circuit); std::stringstream ss; + uint32_t detector_id_root = UINT32_MAX - fuser.num_found_detectors + 1; for (const auto &kv : fuser.error_class_probabilities) { ss.str(""); ss << std::setprecision(std::numeric_limits::digits10 + 1) << kv.second; fprintf(out, "error(%s)", ss.str().data()); for (auto e : kv.first) { - if (e & TARGET_PAULI_X_BIT) { - fprintf(out, " D%lld", (long long)(e - TARGET_PAULI_X_BIT)); + if (e > (UINT32_MAX >> 2)) { + fprintf(out, " D%lld", (long long)(e - detector_id_root)); } else { fprintf(out, " L%lld", (long long)e); } @@ -375,29 +380,29 @@ void ErrorFuser::convert_circuit_out(const Circuit &circuit, FILE *out) { } } -void ErrorFuser::independent_error_1(double probability, const SparseXorVec &d1) { - independent_error_1(probability, d1.begin(), d1.size()); +void ErrorFuser::add_error(double probability, const SparseXorVec &flipped) { + jag_flip_data.append_tail(flipped.range()); + add_error_in_sorted_jagged_tail(probability); } -void ErrorFuser::independent_error_1(double probability, const uint32_t *begin, size_t size) { - independent_error_placed_tail(probability, jagged_detector_sets.inserted(begin, size)); +void ErrorFuser::add_xored_error( + double probability, const SparseXorVec &flipped1, const SparseXorVec &flipped2) { + jag_flip_data.ensure_available(flipped1.size() + flipped2.size()); + jag_flip_data.tail.ptr_end = + xor_merge_sort(flipped1.range(), flipped2.range(), jag_flip_data.tail.ptr_end); + add_error_in_sorted_jagged_tail(probability); } -void ErrorFuser::independent_error_placed_tail(double probability, VectorView detector_set) { - if (detector_set.size()) { - if (error_class_probabilities.find(detector_set) != error_class_probabilities.end()) { - auto &p = error_class_probabilities[detector_set]; +void ErrorFuser::add_error_in_sorted_jagged_tail(double probability) { + auto flipped = jag_flip_data.tail; + if (flipped.size()) { + if (error_class_probabilities.find(flipped) != error_class_probabilities.end()) { + auto &p = error_class_probabilities[flipped]; p = p * (1 - probability) + (1 - p) * probability; - jagged_detector_sets.vec.resize(jagged_detector_sets.vec.size() - detector_set.size()); + jag_flip_data.discard_tail(); } else { - error_class_probabilities[detector_set] = probability; + error_class_probabilities[flipped] = probability; + jag_flip_data.commit_tail(); } } } - -void ErrorFuser::independent_error_2( - double probability, const SparseXorVec &d1, const SparseXorVec &d2) { - auto view = jagged_detector_sets.tail_view(jagged_detector_sets.vec.size()); - xor_into_vector_tail_view(view, d1.begin(), d1.size(), d2.begin(), d2.size()); - independent_error_placed_tail(probability, view); -} diff --git a/src/simulators/error_fuser.h b/src/simulators/error_fuser.h index 033dc8760..059764b13 100644 --- a/src/simulators/error_fuser.h +++ b/src/simulators/error_fuser.h @@ -23,6 +23,7 @@ #include #include "../circuit/circuit.h" +#include "../simd/monotonic_buffer.h" #include "../simd/sparse_xor_vec.h" struct ErrorFuser { @@ -36,10 +37,9 @@ struct ErrorFuser { size_t scheduled_measurement_time = 0; /// The final result. Independent probabilities of flipping various sets of detectors. - /// - /// The backing data for the vector views is in the `jagged_data` field. - std::map, double> error_class_probabilities; - JaggedDataArena jagged_detector_sets; + std::map, double> error_class_probabilities; + /// Backing datastore for values in error_class_probabilities. + MonotonicBuffer jag_flip_data; ErrorFuser(size_t num_qubits); @@ -81,10 +81,10 @@ struct ErrorFuser { void ISWAP(const OperationData &dat); private: - void independent_error_1(double probability, const SparseXorVec &detector_set); - void independent_error_1(double probability, const uint32_t *begin, size_t size); - void independent_error_2(double probability, const SparseXorVec &d1, const SparseXorVec &d2); - void independent_error_placed_tail(double probability, VectorView detector_set); + void add_error(double probability, const SparseXorVec &data); + void add_xored_error( + double probability, const SparseXorVec &flipped1, const SparseXorVec &flipped2); + void add_error_in_sorted_jagged_tail(double probability); void single_cx(uint32_t c, uint32_t t); void single_cy(uint32_t c, uint32_t t); void single_cz(uint32_t c, uint32_t t); diff --git a/src/simulators/error_fuser.test.cc b/src/simulators/error_fuser.test.cc index a798aa0e2..e37f08476 100644 --- a/src/simulators/error_fuser.test.cc +++ b/src/simulators/error_fuser.test.cc @@ -26,7 +26,7 @@ std::string convert(const char *text) { ErrorFuser::convert_circuit_out(Circuit::from_text(text), f); rewind(f); std::string s; - while(true) { + while (true) { int c = getc(f); if (c == EOF) { break; @@ -42,7 +42,8 @@ TEST(ErrorFuser, convert_circuit) { X_ERROR(0.25) 3 M 3 DETECTOR rec[-1] - )circuit"), R"graph(error(0.25) D0 + )circuit"), + R"graph(error(0.25) D0 )graph"); ASSERT_EQ( @@ -50,7 +51,8 @@ TEST(ErrorFuser, convert_circuit) { Y_ERROR(0.25) 3 M 3 DETECTOR rec[-1] - )circuit"), R"graph(error(0.25) D0 + )circuit"), + R"graph(error(0.25) D0 )graph"); ASSERT_EQ( @@ -58,14 +60,16 @@ TEST(ErrorFuser, convert_circuit) { Z_ERROR(0.25) 3 M 3 DETECTOR rec[-1] - )circuit"), R"graph()graph"); + )circuit"), + R"graph()graph"); ASSERT_EQ( convert(R"circuit( DEPOLARIZE1(0.25) 3 M 3 DETECTOR rec[-1] - )circuit"), R"graph(error(0.1666666666666666574) D0 + )circuit"), + R"graph(error(0.1666666666666666574) D0 )graph"); ASSERT_EQ( @@ -75,17 +79,20 @@ TEST(ErrorFuser, convert_circuit) { M 0 1 OBSERVABLE_INCLUDE(3) rec[-1] DETECTOR rec[-2] - )circuit"), R"graph(error(0.125) L3 + )circuit"), + R"graph(error(0.125) L3 error(0.25) D0 )graph"); - ASSERT_EQ(convert(R"circuit( + ASSERT_EQ( + convert(R"circuit( X_ERROR(0.25) 0 X_ERROR(0.125) 1 M 0 1 OBSERVABLE_INCLUDE(3) rec[-1] DETECTOR rec[-2] - )circuit"), R"graph(error(0.125) L3 + )circuit"), + R"graph(error(0.125) L3 error(0.25) D0 )graph"); @@ -96,7 +103,8 @@ error(0.25) D0 M 5 DETECTOR rec[-1] DETECTOR rec[-2] - )circuit"), R"graph(error(0.07182558071116235121) D0 + )circuit"), + R"graph(error(0.07182558071116235121) D0 error(0.07182558071116235121) D0 D1 error(0.07182558071116235121) D1 )graph"); @@ -113,7 +121,8 @@ error(0.07182558071116235121) D1 DETECTOR rec[-2] DETECTOR rec[-3] DETECTOR rec[-4] - )circuit"), R"graph(error(0.01901372644820353841) D0 + )circuit"), + R"graph(error(0.01901372644820353841) D0 error(0.01901372644820353841) D0 D1 error(0.01901372644820353841) D0 D1 D2 error(0.01901372644820353841) D0 D1 D2 D3 @@ -132,32 +141,32 @@ error(0.01901372644820353841) D3 } TEST(ErrorFuser, unitary_gates_match_frame_simulator) { - FrameSimulator f(16, 16, 0, SHARED_TEST_RNG()); + FrameSimulator f(16, 16, SIZE_MAX, SHARED_TEST_RNG()); ErrorFuser e(16); for (size_t q = 0; q < 16; q++) { if (q & 1) { - e.xs[q] ^= 0; + e.xs[q].xor_item(0); f.x_table[q][0] = true; } if (q & 2) { - e.xs[q] ^= 1; + e.xs[q].xor_item(1); f.x_table[q][1] = true; } if (q & 4) { - e.zs[q] ^= 0; + e.zs[q].xor_item(0); f.z_table[q][0] = true; } if (q & 8) { - e.zs[q] ^= 1; + e.zs[q].xor_item(1); f.z_table[q][1] = true; } } std::vector data{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - OperationData targets = {0, {&data, 0, data.size()}}; + OperationData targets = {0, data}; for (const auto &gate : GATE_DATA.gates()) { if (gate.flags & GATE_IS_UNITARY) { - (e.*gate.hit_simulator_function)(targets); + (e.*gate.reverse_error_fuser_function)(targets); (f.*gate.frame_simulator_function)(targets); for (size_t q = 0; q < 16; q++) { bool xs[2]{}; @@ -188,7 +197,8 @@ TEST(ErrorFuser, reversed_operation_order) { M 0 1 DETECTOR rec[-2] DETECTOR rec[-1] - )circuit"), R"graph(error(0.25) D1 + )circuit"), + R"graph(error(0.25) D1 )graph"); } @@ -200,7 +210,8 @@ TEST(ErrorFuser, classical_error_propagation) { CNOT rec[-1] 1 M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); ASSERT_EQ( @@ -212,7 +223,8 @@ TEST(ErrorFuser, classical_error_propagation) { H 1 M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); ASSERT_EQ( @@ -224,7 +236,8 @@ TEST(ErrorFuser, classical_error_propagation) { H 1 M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); ASSERT_EQ( @@ -234,7 +247,8 @@ TEST(ErrorFuser, classical_error_propagation) { CY rec[-1] 1 M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); ASSERT_EQ( @@ -244,7 +258,8 @@ TEST(ErrorFuser, classical_error_propagation) { XCZ 1 rec[-1] M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); ASSERT_EQ( @@ -254,6 +269,7 @@ TEST(ErrorFuser, classical_error_propagation) { YCZ 1 rec[-1] M 1 DETECTOR rec[-1] - )circuit"), R"graph(error(0.125) D0 + )circuit"), + R"graph(error(0.125) D0 )graph"); } diff --git a/src/simulators/frame_simulator.cc b/src/simulators/frame_simulator.cc index 0bd74e57f..e5b3f2a72 100644 --- a/src/simulators/frame_simulator.cc +++ b/src/simulators/frame_simulator.cc @@ -14,6 +14,7 @@ #include "frame_simulator.h" +#include #include #include "../circuit/gate_data.h" @@ -35,171 +36,42 @@ inline void for_each_target_pair(FrameSimulator &sim, const OperationData &targe } } -FrameSimulator::FrameSimulator(size_t num_qubits, size_t num_samples, size_t num_measurements, std::mt19937_64 &rng) +FrameSimulator::FrameSimulator(size_t num_qubits, size_t batch_size, size_t max_lookback, std::mt19937_64 &rng) : num_qubits(num_qubits), - num_samples_raw(num_samples), - num_measurements_raw(num_measurements), + batch_size(batch_size), num_recorded_measurements(0), - x_table(num_qubits, num_samples), - z_table(num_qubits, num_samples), - m_table(num_measurements, num_samples), - rng_buffer(num_samples), - last_correlated_error_occurred(num_samples), + x_table(num_qubits, batch_size), + z_table(num_qubits, batch_size), + m_record(batch_size, max_lookback), + rng_buffer(batch_size), + last_correlated_error_occurred(batch_size), rng(rng) { } -simd_bit_table transposed_vs_ref( - size_t num_samples_raw, const simd_bit_table &table, const simd_bits &reference_sample) { - auto result = table.transposed(); - for (size_t s = 0; s < num_samples_raw; s++) { - result[s] ^= reference_sample; - } - return result; -} - -void write_table_data( - FILE *out, - size_t num_shots_raw, - size_t num_sample_locations_raw, - const simd_bits &reference_sample, - const simd_bit_table &table, - SampleFormat format, - char dets_prefix_1, - char dets_prefix_2, - size_t dets_prefix_transition) { - if (format == SAMPLE_FORMAT_01) { - auto result = transposed_vs_ref(num_shots_raw, table, reference_sample); - for (size_t s = 0; s < num_shots_raw; s++) { - for (size_t k = 0; k < num_sample_locations_raw; k++) { - putc('0' + result[s][k], out); - } - putc('\n', out); - } - return; - } - - if (format == SAMPLE_FORMAT_B8) { - auto result = transposed_vs_ref(num_shots_raw, table, reference_sample); - auto n = (num_sample_locations_raw + 7) >> 3; - for (size_t s = 0; s < num_shots_raw; s++) { - fwrite(result[s].u8, 1, n, out); - } - return; - } - - if (format == SAMPLE_FORMAT_PTB64) { - auto f64 = num_shots_raw >> 6; - for (size_t s = 0; s < f64; s++) { - for (size_t m = 0; m < num_sample_locations_raw; m++) { - uint64_t v = table[m].u64[s]; - if (reference_sample[m]) { - v = ~v; - } - fwrite(&v, 1, 64 >> 3, out); - } - } - if (num_shots_raw & 63) { - uint64_t mask = (uint64_t{1} << (num_shots_raw & 63)) - 1ULL; - for (size_t m = 0; m < num_sample_locations_raw; m++) { - uint64_t v = table[m].u64[f64]; - if (reference_sample[m]) { - v = ~v; - } - v &= mask; - fwrite(&v, 1, 64 >> 3, out); - } - } - return; - } - - if (format == SAMPLE_FORMAT_HITS) { - auto result = transposed_vs_ref(num_shots_raw, table, reference_sample); - for (size_t s = 0; s < num_shots_raw; s++) { - bool rest = false; - result[s].for_each_set_bit([&](size_t k) { - if (rest) { - putc(',', out); - } - rest = true; - fprintf(out, "%lld", (unsigned long long)(k)); - }); - putc('\n', out); - } - return; - } - - if (format == SAMPLE_FORMAT_DETS) { - auto result = transposed_vs_ref(num_shots_raw, table, reference_sample); - for (size_t s = 0; s < num_shots_raw; s++) { - fprintf(out, "shot"); - result[s].for_each_set_bit([&](size_t k) { - if (k < dets_prefix_transition) { - fprintf(out, " %c%lld", dets_prefix_1, (unsigned long long)k); - } else { - fprintf(out, " %c%lld", dets_prefix_2, (unsigned long long)k - dets_prefix_transition); - } - }); - putc('\n', out); - } - return; - } - - if (format == SAMPLE_FORMAT_R8) { - auto result = transposed_vs_ref(num_shots_raw, table, reference_sample); - for (size_t s = 0; s < num_shots_raw; s++) { - size_t prev = 0; - auto write_gap = [&](size_t k) { - size_t gap = k - prev; - while (gap >= 0xFF) { - gap -= 0xFF; - putc(0xFF, out); - } - putc((char)gap, out); - prev = k + 1; - }; - result[s].for_each_set_bit(write_gap); - - // Always encode a trailing 1 just past the end of the measurement results. - write_gap(num_sample_locations_raw); - } - return; - } - - throw std::out_of_range("Unrecognized output format."); -} - -void FrameSimulator::write_measurements(FILE *out, const simd_bits &reference_sample, SampleFormat format) const { - write_table_data(out, num_samples_raw, num_measurements_raw, reference_sample, m_table, format, 'M', 'M', 0); -} - simd_bits_range_ref FrameSimulator::measurement_record_ref(uint32_t encoded_target) { - uint8_t b = encoded_target ^ TARGET_RECORD_BIT; - if (b == 0 || b > num_recorded_measurements) { - throw std::out_of_range("Referred to a measurement record before the beginning of time."); - } - return m_table[num_recorded_measurements - b]; + assert(encoded_target & TARGET_RECORD_BIT); + return m_record.lookback(encoded_target ^ TARGET_RECORD_BIT); } void FrameSimulator::reset_all() { num_recorded_measurements = 0; x_table.clear(); z_table.data.randomize(z_table.data.num_bits_padded(), rng); - m_table.clear(); + m_record.clear(); } void FrameSimulator::reset_all_and_run(const Circuit &circuit) { - assert(circuit.num_measurements == num_measurements_raw); reset_all(); - for (const auto &op : circuit.operations) { + circuit.for_each_operation([&](const Operation &op) { (this->*op.gate->frame_simulator_function)(op.target_data); - } + }); } void FrameSimulator::measure(const OperationData &target_data) { for (auto q : target_data.targets) { q &= TARGET_VALUE_MASK; // Flipping is ignored because it is accounted for in the reference sample. z_table[q].randomize(z_table[q].num_bits_padded(), rng); - m_table[num_recorded_measurements] = x_table[q]; + m_record.record_result(x_table[q]); num_recorded_measurements++; } } @@ -215,7 +87,7 @@ void FrameSimulator::measure_reset(const OperationData &target_data) { // Note: Caution when implementing this. Can't group the resets. because the same qubit target may appear twice. for (auto q : target_data.targets) { q &= TARGET_VALUE_MASK; // Flipping is ignored because it is accounted for in the reference sample. - m_table[num_recorded_measurements] = x_table[q]; + m_record.record_result(x_table[q]); x_table[q].clear(); z_table[q].randomize(z_table[q].num_bits_padded(), rng); num_recorded_measurements++; @@ -226,7 +98,7 @@ void FrameSimulator::I(const OperationData &target_data) { } PauliString FrameSimulator::get_frame(size_t sample_index) const { - assert(sample_index < num_samples_raw); + assert(sample_index < batch_size); PauliString result(num_qubits); for (size_t q = 0; q < num_qubits; q++) { result.xs[q] = x_table[q][sample_index]; @@ -236,7 +108,7 @@ PauliString FrameSimulator::get_frame(size_t sample_index) const { } void FrameSimulator::set_frame(size_t sample_index, const PauliStringRef &new_frame) { - assert(sample_index < num_samples_raw); + assert(sample_index < batch_size); assert(new_frame.num_qubits == num_qubits); for (size_t q = 0; q < num_qubits; q++) { x_table[q][sample_index] = new_frame.xs[q]; @@ -409,10 +281,10 @@ void FrameSimulator::YCZ(const OperationData &target_data) { void FrameSimulator::DEPOLARIZE1(const OperationData &target_data) { const auto &targets = target_data.targets; - RareErrorIterator::for_samples(target_data.arg, targets.size() * num_samples_raw, rng, [&](size_t s) { + RareErrorIterator::for_samples(target_data.arg, targets.size() * batch_size, rng, [&](size_t s) { auto p = 1 + (rng() % 3); - auto target_index = s / num_samples_raw; - auto sample_index = s % num_samples_raw; + auto target_index = s / batch_size; + auto sample_index = s % batch_size; auto t = targets[target_index]; x_table[t][sample_index] ^= p & 1; z_table[t][sample_index] ^= p & 2; @@ -422,11 +294,11 @@ void FrameSimulator::DEPOLARIZE1(const OperationData &target_data) { void FrameSimulator::DEPOLARIZE2(const OperationData &target_data) { const auto &targets = target_data.targets; assert(!(targets.size() & 1)); - auto n = (targets.size() * num_samples_raw) >> 1; + auto n = (targets.size() * batch_size) >> 1; RareErrorIterator::for_samples(target_data.arg, n, rng, [&](size_t s) { auto p = 1 + (rng() % 15); - auto target_index = (s / num_samples_raw) << 1; - auto sample_index = s % num_samples_raw; + auto target_index = (s / batch_size) << 1; + auto sample_index = s % batch_size; size_t t1 = targets[target_index]; size_t t2 = targets[target_index + 1]; x_table[t1][sample_index] ^= p & 1; @@ -438,9 +310,9 @@ void FrameSimulator::DEPOLARIZE2(const OperationData &target_data) { void FrameSimulator::X_ERROR(const OperationData &target_data) { const auto &targets = target_data.targets; - RareErrorIterator::for_samples(target_data.arg, targets.size() * num_samples_raw, rng, [&](size_t s) { - auto target_index = s / num_samples_raw; - auto sample_index = s % num_samples_raw; + RareErrorIterator::for_samples(target_data.arg, targets.size() * batch_size, rng, [&](size_t s) { + auto target_index = s / batch_size; + auto sample_index = s % batch_size; auto t = targets[target_index]; x_table[t][sample_index] ^= true; }); @@ -448,9 +320,9 @@ void FrameSimulator::X_ERROR(const OperationData &target_data) { void FrameSimulator::Y_ERROR(const OperationData &target_data) { const auto &targets = target_data.targets; - RareErrorIterator::for_samples(target_data.arg, targets.size() * num_samples_raw, rng, [&](size_t s) { - auto target_index = s / num_samples_raw; - auto sample_index = s % num_samples_raw; + RareErrorIterator::for_samples(target_data.arg, targets.size() * batch_size, rng, [&](size_t s) { + auto target_index = s / batch_size; + auto sample_index = s % batch_size; auto t = targets[target_index]; x_table[t][sample_index] ^= true; z_table[t][sample_index] ^= true; @@ -459,9 +331,9 @@ void FrameSimulator::Y_ERROR(const OperationData &target_data) { void FrameSimulator::Z_ERROR(const OperationData &target_data) { const auto &targets = target_data.targets; - RareErrorIterator::for_samples(target_data.arg, targets.size() * num_samples_raw, rng, [&](size_t s) { - auto target_index = s / num_samples_raw; - auto sample_index = s % num_samples_raw; + RareErrorIterator::for_samples(target_data.arg, targets.size() * batch_size, rng, [&](size_t s) { + auto target_index = s / batch_size; + auto sample_index = s % batch_size; auto t = targets[target_index]; z_table[t][sample_index] ^= true; }); @@ -469,9 +341,9 @@ void FrameSimulator::Z_ERROR(const OperationData &target_data) { simd_bit_table FrameSimulator::sample_flipped_measurements( const Circuit &circuit, size_t num_samples, std::mt19937_64 &rng) { - FrameSimulator sim(circuit.num_qubits, num_samples, circuit.num_measurements, rng); + FrameSimulator sim(circuit.count_qubits(), num_samples, SIZE_MAX, rng); sim.reset_all_and_run(circuit); - return sim.m_table; + return sim.m_record.storage; } simd_bit_table FrameSimulator::sample( @@ -487,9 +359,9 @@ void FrameSimulator::CORRELATED_ERROR(const OperationData &target_data) { void FrameSimulator::ELSE_CORRELATED_ERROR(const OperationData &target_data) { // Sample error locations. - biased_randomize_bits(target_data.arg, rng_buffer.u64, rng_buffer.u64 + ((num_samples_raw + 63) >> 6), rng); - if (num_samples_raw & 63) { - rng_buffer.u64[num_samples_raw >> 6] &= (uint64_t{1} << (num_samples_raw & 63)) - 1; + biased_randomize_bits(target_data.arg, rng_buffer.u64, rng_buffer.u64 + ((batch_size + 63) >> 6), rng); + if (batch_size & 63) { + rng_buffer.u64[batch_size >> 6] &= (uint64_t{1} << (batch_size & 63)) - 1; } // Omit locations blocked by prev error, while updating prev error mask. simd_bits_range_ref{rng_buffer}.for_each_word(last_correlated_error_occurred, [](simd_word &buf, simd_word &prev) { @@ -512,21 +384,44 @@ void FrameSimulator::ELSE_CORRELATED_ERROR(const OperationData &target_data) { } } +void sample_out_helper( + const Circuit &circuit, FrameSimulator &sim, simd_bits_range_ref ref_sample, size_t num_shots, FILE *out, + SampleFormat format) { + sim.reset_all(); + + if (std::max(num_shots, size_t{256}) * circuit.count_measurements() > SWITCH_TO_STREAMING_MEASUREMENT_THRESHOLD) { + // Results getting quite large. Stream them (with buffering to disk) instead of trying to store them all. + MeasureRecordBatchWriter writer(out, num_shots, format); + circuit.for_each_operation([&](const Operation &op) { + (sim.*op.gate->frame_simulator_function)(op.target_data); + sim.m_record.intermediate_write_unwritten_results_to(writer, ref_sample); + }); + sim.m_record.final_write_unwritten_results_to(writer, ref_sample); + } else { + // Small case. Just do everything in memory. + circuit.for_each_operation([&](const Operation &op) { + (sim.*op.gate->frame_simulator_function)(op.target_data); + }); + write_table_data( + out, num_shots, circuit.count_measurements(), ref_sample, sim.m_record.storage, format, 'M', 'M', 0); + } +} + void FrameSimulator::sample_out( - const Circuit &circuit, const simd_bits &reference_sample, size_t num_samples, FILE *out, SampleFormat format, + const Circuit &circuit, const simd_bits &reference_sample, size_t num_shots, FILE *out, SampleFormat format, std::mt19937_64 &rng) { constexpr size_t GOOD_BLOCK_SIZE = 1024; - if (num_samples >= GOOD_BLOCK_SIZE) { - auto sim = FrameSimulator(circuit.num_qubits, GOOD_BLOCK_SIZE, circuit.num_measurements, rng); - while (num_samples > GOOD_BLOCK_SIZE) { - sim.reset_all_and_run(circuit); - sim.write_measurements(out, reference_sample, format); - num_samples -= GOOD_BLOCK_SIZE; + size_t num_qubits = circuit.count_qubits(); + size_t max_lookback = circuit.max_lookback(); + if (num_shots >= GOOD_BLOCK_SIZE) { + auto sim = FrameSimulator(num_qubits, GOOD_BLOCK_SIZE, max_lookback, rng); + while (num_shots > GOOD_BLOCK_SIZE) { + sample_out_helper(circuit, sim, reference_sample, GOOD_BLOCK_SIZE, out, format); + num_shots -= GOOD_BLOCK_SIZE; } } - if (num_samples) { - auto sim = FrameSimulator(circuit.num_qubits, num_samples, circuit.num_measurements, rng); - sim.reset_all_and_run(circuit); - sim.write_measurements(out, reference_sample, format); + if (num_shots) { + auto sim = FrameSimulator(num_qubits, num_shots, max_lookback, rng); + sample_out_helper(circuit, sim, reference_sample, num_shots, out, format); } } diff --git a/src/simulators/frame_simulator.h b/src/simulators/frame_simulator.h index 39ac9cdec..7316df5e9 100644 --- a/src/simulators/frame_simulator.h +++ b/src/simulators/frame_simulator.h @@ -17,11 +17,14 @@ #ifndef SIM_FRAME_H #define SIM_FRAME_H +#define SWITCH_TO_STREAMING_MEASUREMENT_THRESHOLD 100000000 + #include #include "../circuit/circuit.h" #include "../simd/simd_bit_table.h" #include "../stabilizers/pauli_string.h" +#include "measure_record_batch.h" /// A Pauli Frame simulator that computes many samples simultaneously. /// @@ -30,23 +33,22 @@ /// This requires a set of reference measurements to diff against. struct FrameSimulator { size_t num_qubits; - size_t num_samples_raw; - size_t num_measurements_raw; + size_t batch_size; size_t num_recorded_measurements; simd_bit_table x_table; simd_bit_table z_table; - simd_bit_table m_table; + MeasureRecordBatch m_record; simd_bits rng_buffer; simd_bits last_correlated_error_occurred; std::mt19937_64 &rng; - FrameSimulator(size_t num_qubits, size_t num_samples, size_t num_measurements, std::mt19937_64 &rng); + FrameSimulator(size_t num_qubits, size_t batch_size, size_t max_lookback, std::mt19937_64 &rng); static simd_bit_table sample_flipped_measurements(const Circuit &circuit, size_t num_samples, std::mt19937_64 &rng); static simd_bit_table sample( const Circuit &circuit, const simd_bits &reference_sample, size_t num_samples, std::mt19937_64 &rng); static void sample_out( - const Circuit &circuit, const simd_bits &reference_sample, size_t num_samples, FILE *out, SampleFormat format, + const Circuit &circuit, const simd_bits &reference_sample, size_t num_shots, FILE *out, SampleFormat format, std::mt19937_64 &rng); PauliString get_frame(size_t sample_index) const; @@ -55,8 +57,6 @@ struct FrameSimulator { void reset_all_and_run(const Circuit &circuit); void reset_all(); - void write_measurements(FILE *out, const simd_bits &reference_sample, SampleFormat format) const; - void measure(const OperationData &target_data); void reset(const OperationData &target_data); void measure_reset(const OperationData &target_data); @@ -91,15 +91,4 @@ struct FrameSimulator { void single_cy(uint32_t c, uint32_t t); }; -void write_table_data( - FILE *out, - size_t num_shots_raw, - size_t num_sample_locations_raw, - const simd_bits &reference_sample, - const simd_bit_table &table, - SampleFormat format, - char dets_prefix_1, - char dets_prefix_2, - size_t dets_prefix_transition); - #endif diff --git a/src/simulators/frame_simulator.perf.cc b/src/simulators/frame_simulator.perf.cc index ef9ba9e64..7996837b9 100644 --- a/src/simulators/frame_simulator.perf.cc +++ b/src/simulators/frame_simulator.perf.cc @@ -21,13 +21,13 @@ BENCHMARK(FrameSimulator_depolarize1_100Kqubits_1Ksamples_per1000) { size_t num_samples = 1000; float probability = 0.001f; std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) - FrameSimulator sim(num_qubits, num_samples, 1, rng); + FrameSimulator sim(num_qubits, num_samples, SIZE_MAX, rng); std::vector targets; for (size_t k = 0; k < num_qubits; k++) { targets.push_back(k); } - OperationData op_data{probability, {&targets, 0, targets.size()}}; + OperationData op_data{probability, targets}; op_data.arg = probability; benchmark_go([&]() { sim.DEPOLARIZE1(op_data); @@ -41,13 +41,13 @@ BENCHMARK(FrameSimulator_depolarize2_100Kqubits_1Ksamples_per1000) { size_t num_samples = 1000; float probability = 0.001f; std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) - FrameSimulator sim(num_qubits, num_samples, 1, rng); + FrameSimulator sim(num_qubits, num_samples, SIZE_MAX, rng); std::vector targets; for (size_t k = 0; k < num_qubits; k++) { targets.push_back(k); } - OperationData op_data{probability, {&targets, 0, targets.size()}}; + OperationData op_data{probability, targets}; op_data.arg = probability; benchmark_go([&]() { @@ -61,13 +61,13 @@ BENCHMARK(FrameSimulator_hadamard_100Kqubits_1Ksamples) { size_t num_qubits = 100 * 1000; size_t num_samples = 1000; std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) - FrameSimulator sim(num_qubits, num_samples, 1, rng); + FrameSimulator sim(num_qubits, num_samples, SIZE_MAX, rng); std::vector targets; for (size_t k = 0; k < num_qubits; k++) { targets.push_back(k); } - OperationData op_data{0, {&targets, 0, targets.size()}}; + OperationData op_data{0, targets}; benchmark_go([&]() { sim.H_XZ(op_data); @@ -80,13 +80,13 @@ BENCHMARK(FrameSimulator_CX_100Kqubits_1Ksamples) { size_t num_qubits = 100 * 1000; size_t num_samples = 1000; std::mt19937_64 rng(0); // NOLINT(cert-msc51-cpp) - FrameSimulator sim(num_qubits, num_samples, 1, rng); + FrameSimulator sim(num_qubits, num_samples, SIZE_MAX, rng); std::vector targets; for (size_t k = 0; k < num_qubits; k++) { targets.push_back(k); } - OperationData op_data{0, {&targets, 0, targets.size()}}; + OperationData op_data{0, targets}; benchmark_go([&]() { sim.ZCX(op_data); diff --git a/src/simulators/frame_simulator.test.cc b/src/simulators/frame_simulator.test.cc index 5a83b9edf..2b7e93f98 100644 --- a/src/simulators/frame_simulator.test.cc +++ b/src/simulators/frame_simulator.test.cc @@ -21,6 +21,18 @@ #include "../test_util.test.h" #include "tableau_simulator.h" +static std::string rewind_read_all(FILE *f) { + rewind(f); + std::string result; + while (true) { + int c = getc(f); + if (c == EOF) { + return result; + } + result.push_back((char)c); + } +} + TEST(FrameSimulator, get_set_frame) { FrameSimulator sim(6, 4, 999, SHARED_TEST_RNG()); ASSERT_EQ(sim.get_frame(0), PauliString::from_str("______")); @@ -52,21 +64,15 @@ bool is_bulk_frame_operation_consistent_with_tableau(const Gate &gate) { size_t num_qubits = 500; size_t num_samples = 1000; - size_t num_measurements = 10; - FrameSimulator sim(num_qubits, num_samples, num_measurements, SHARED_TEST_RNG()); + size_t max_lookback = 10; + FrameSimulator sim(num_qubits, num_samples, max_lookback, SHARED_TEST_RNG()); size_t num_targets = tableau.num_qubits; assert(num_targets == 1 || num_targets == 2); std::vector targets{101, 403, 202, 100}; while (targets.size() > num_targets) { targets.pop_back(); } - OperationData op_data{ - 0, - { - &targets, - 0, - targets.size(), - }}; + OperationData op_data{0, targets}; for (size_t k = 7; k < num_samples; k += 101) { PauliString test_value = PauliString::random(num_qubits, SHARED_TEST_RNG()); PauliStringRef test_value_ref(test_value); @@ -95,30 +101,28 @@ TEST(FrameSimulator, bulk_operations_consistent_with_tableau_data) { } } -#define EXPECT_SAMPLES_POSSIBLE(program) EXPECT_TRUE(is_sim_frame_consistent_with_sim_tableau(program)) << program - bool is_output_possible_promising_no_bare_resets(const Circuit &circuit, const simd_bits_range_ref output) { - auto tableau_sim = TableauSimulator(circuit.num_qubits, SHARED_TEST_RNG()); + auto tableau_sim = TableauSimulator(circuit.count_qubits(), SHARED_TEST_RNG()); size_t out_p = 0; - for (const auto &op : circuit.operations) { + bool pass = true; + circuit.for_each_operation([&](const Operation &op) { if (op.gate->name == std::string("M")) { for (auto qf : op.target_data.targets) { tableau_sim.sign_bias = output[out_p] ? -1 : +1; tableau_sim.measure(OpDat(qf)); - if (output[out_p] != tableau_sim.measurement_record.back()) { - return false; + if (output[out_p] != tableau_sim.measurement_record.storage.back()) { + pass = false; } out_p++; } } else { (tableau_sim.*op.gate->tableau_simulator_function)(op.target_data); } - } - - return true; + }); + return pass; } -TEST(PauliFrameSimulation, test_util_is_output_possible) { +TEST(FrameSimulator, test_util_is_output_possible) { auto circuit = Circuit::from_text( "H 0\n" "CNOT 0 1\n" @@ -145,8 +149,8 @@ bool is_sim_frame_consistent_with_sim_tableau(const char *program_text) { simd_bits_range_ref sample = samples[k]; if (!is_output_possible_promising_no_bare_resets(circuit, sample)) { std::cerr << "Impossible output: "; - for (size_t k = 0; k < circuit.num_measurements; k++) { - std::cerr << '0' + sample[k]; + for (size_t k2 = 0; k2 < circuit.count_measurements(); k2++) { + std::cerr << '0' + sample[k2]; } std::cerr << "\n"; return false; @@ -155,7 +159,9 @@ bool is_sim_frame_consistent_with_sim_tableau(const char *program_text) { return true; } -TEST(PauliFrameSimulation, consistency) { +#define EXPECT_SAMPLES_POSSIBLE(program) EXPECT_TRUE(is_sim_frame_consistent_with_sim_tableau(program)) << program + +TEST(FrameSimulator, consistency) { EXPECT_SAMPLES_POSSIBLE( "H 0\n" "CNOT 0 1\n" @@ -275,7 +281,7 @@ TEST(PauliFrameSimulation, consistency) { "M 6"); } -TEST(PauliFrameSimulation, sample_out) { +TEST(FrameSimulator, sample_out) { auto circuit = Circuit::from_text( "X 0\n" "M 1\n" @@ -290,16 +296,7 @@ TEST(PauliFrameSimulation, sample_out) { FILE *tmp = tmpfile(); FrameSimulator::sample_out(circuit, ref, 5, tmp, SAMPLE_FORMAT_01, SHARED_TEST_RNG()); - rewind(tmp); - std::stringstream ss; - while (true) { - auto i = getc(tmp); - if (i == EOF) { - break; - } - ss << (char)i; - } - ASSERT_EQ(ss.str(), "0100\n0100\n0100\n0100\n0100\n"); + ASSERT_EQ(rewind_read_all(tmp), "0100\n0100\n0100\n0100\n0100\n"); tmp = tmpfile(); FrameSimulator::sample_out(circuit, ref, 5, tmp, SAMPLE_FORMAT_B8, SHARED_TEST_RNG()); @@ -325,7 +322,7 @@ TEST(PauliFrameSimulation, sample_out) { ASSERT_EQ(getc(tmp), EOF); } -TEST(PauliFrameSimulation, big_circuit_measurements) { +TEST(FrameSimulator, big_circuit_measurements) { Circuit circuit; for (uint32_t k = 0; k < 1250; k += 3) { circuit.append_op("X", {k}); @@ -366,7 +363,7 @@ TEST(PauliFrameSimulation, big_circuit_measurements) { ASSERT_EQ(getc(tmp), EOF); } -TEST(PauliFrameSimulation, run_length_measurement_formats) { +TEST(FrameSimulator, run_length_measurement_formats) { Circuit circuit; circuit.append_op("X", {100, 500, 501, 551, 1200}); for (uint32_t k = 0; k < 1250; k++) { @@ -376,17 +373,13 @@ TEST(PauliFrameSimulation, run_length_measurement_formats) { FILE *tmp = tmpfile(); FrameSimulator::sample_out(circuit, ref, 3, tmp, SAMPLE_FORMAT_HITS, SHARED_TEST_RNG()); - rewind(tmp); - for (char c : "100,500,501,551,1200\n100,500,501,551,1200\n100,500,501,551,1200\n") { - ASSERT_EQ(getc(tmp), c == '\0' ? EOF : c); - } + ASSERT_EQ(rewind_read_all(tmp), "100,500,501,551,1200\n100,500,501,551,1200\n100,500,501,551,1200\n"); tmp = tmpfile(); FrameSimulator::sample_out(circuit, ref, 3, tmp, SAMPLE_FORMAT_DETS, SHARED_TEST_RNG()); - rewind(tmp); - for (char c : "shot M100 M500 M501 M551 M1200\nshot M100 M500 M501 M551 M1200\nshot M100 M500 M501 M551 M1200\n") { - ASSERT_EQ(getc(tmp), c == '\0' ? EOF : c); - } + ASSERT_EQ( + rewind_read_all(tmp), + "shot M100 M500 M501 M551 M1200\nshot M100 M500 M501 M551 M1200\nshot M100 M500 M501 M551 M1200\n"); tmp = tmpfile(); FrameSimulator::sample_out(circuit, ref, 3, tmp, SAMPLE_FORMAT_R8, SHARED_TEST_RNG()); @@ -405,7 +398,7 @@ TEST(PauliFrameSimulation, run_length_measurement_formats) { ASSERT_EQ(getc(tmp), EOF); } -TEST(PauliFrameSimulation, big_circuit_random_measurements) { +TEST(FrameSimulator, big_circuit_random_measurements) { Circuit circuit; for (uint32_t k = 0; k < 270; k++) { circuit.append_op("H_XZ", {k}); @@ -746,3 +739,31 @@ TEST(FrameSimulator, classical_controls) { ref, 1, SHARED_TEST_RNG())[0], expected); } + +TEST(FrameSimulator, record_gets_trimmed) { + FrameSimulator sim(100, 1024, 5, SHARED_TEST_RNG()); + Circuit c = Circuit::from_text("M 0 1 2 3 4 5 6 7 8 9"); + MeasureRecordBatchWriter b(tmpfile(), 1024, SAMPLE_FORMAT_B8); + for (size_t k = 0; k < 1000; k++) { + sim.measure(c.operations[0].target_data); + sim.m_record.intermediate_write_unwritten_results_to(b, simd_bits(0)); + ASSERT_LT(sim.m_record.storage.num_major_bits_padded(), 2500); + } +} + +TEST(FrameSimulator, stream_huge_case) { + FILE *tmp = tmpfile(); + FrameSimulator::sample_out( + Circuit::from_text(R"CIRCUIT( + X_ERROR(1) 2 + REPEAT 100000 { + M 0 1 2 3 + } + )CIRCUIT"), + simd_bits(0), 256, tmp, SAMPLE_FORMAT_B8, SHARED_TEST_RNG()); + rewind(tmp); + for (size_t k = 0; k < 256 * 100000 * 4 / 8; k++) { + ASSERT_EQ(getc(tmp), 0x44); + } + ASSERT_EQ(getc(tmp), EOF); +} diff --git a/src/simulators/measure_record.cc b/src/simulators/measure_record.cc new file mode 100644 index 000000000..2ac9aa7d3 --- /dev/null +++ b/src/simulators/measure_record.cc @@ -0,0 +1,53 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record.h" + +#include + +#include "measure_record_writer.h" + +MeasureRecord::MeasureRecord(size_t max_lookback) : max_lookback(max_lookback), unwritten(0) { +} + +void MeasureRecord::write_unwritten_results_to(MeasureRecordWriter &writer) { + size_t n = storage.size(); + for (size_t k = n - unwritten; k < n; k++) { + writer.write_bit(storage[k]); + } + unwritten = 0; + if ((storage.size() >> 1) > max_lookback) { + storage.erase(storage.begin(), storage.end() - max_lookback); + } +} + +bool MeasureRecord::lookback(size_t lookback) const { + if (lookback > storage.size()) { + throw std::out_of_range("Referred to a measurement record before the beginning of time."); + } + if (lookback == 0) { + throw std::out_of_range("Lookback must be non-zero."); + } + if (lookback > max_lookback) { + throw std::out_of_range("Referred to a measurement record past the lookback limit."); + } + return *(storage.end() - lookback); +} + +void MeasureRecord::record_result(bool result) { + storage.push_back(result); + unwritten++; +} diff --git a/src/simulators/measure_record.h b/src/simulators/measure_record.h new file mode 100644 index 000000000..3153d1705 --- /dev/null +++ b/src/simulators/measure_record.h @@ -0,0 +1,52 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RECORD_STORAGE_H +#define RECORD_STORAGE_H + +#include +#include + +#include "measure_record_writer.h" + +/// Stores a historical record of measurement results that can be looked up and written to the external world. +/// +/// Results that have been written and are further back than `max_lookback` may be discarded from memory. +struct MeasureRecord { + /// How far back into the measurement record a circuit being simulated may look. + /// Results younger than this cannot be discarded. + size_t max_lookback; + /// How many results have been recorded but not yet written to the external world. + /// Results younger than this cannot be discarded. + size_t unwritten; + /// The actual recorded results. + std::vector storage; + /// Creates an empty measurement record. + MeasureRecord(size_t max_lookback = SIZE_MAX); + /// Forces all unwritten results to be written via the given writer. + /// + /// After the results are written, older measurements now eligible to be discarded may be removed from memory. + void write_unwritten_results_to(MeasureRecordWriter &writer); + /// Returns a measurement result from the record. + /// + /// Args: + /// lookback: How far back the measurement is. lookback=1 is the latest measurement, 2 the second latest, etc. + bool lookback(size_t lookback) const; + /// Appends a measurement to the record. + void record_result(bool result); +}; + +#endif diff --git a/src/simulators/measure_record.test.cc b/src/simulators/measure_record.test.cc new file mode 100644 index 000000000..4591be676 --- /dev/null +++ b/src/simulators/measure_record.test.cc @@ -0,0 +1,45 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record.h" + +#include "gtest/gtest.h" + +#include "../test_util.test.h" + +TEST(MeasureRecord, basic_usage) { + MeasureRecord r(20); + r.record_result(true); + ASSERT_EQ(r.lookback(1), true); + r.record_result(false); + ASSERT_EQ(r.lookback(1), false); + ASSERT_EQ(r.lookback(2), true); + for (size_t k = 0; k < 50; k++) { + r.record_result(true); + r.record_result(false); + } + ASSERT_EQ(r.storage.size(), 102); + + FILE *tmp = tmpfile(); + r.write_unwritten_results_to(*MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_01)); + rewind(tmp); + for (size_t k = 0; k < 102; k++) { + ASSERT_EQ(getc(tmp), '0' + (~k & 1)); + } + ASSERT_EQ(getc(tmp), EOF); + + ASSERT_LE(r.storage.size(), 40); +} diff --git a/src/simulators/measure_record_batch.cc b/src/simulators/measure_record_batch.cc new file mode 100644 index 000000000..752fbb6c9 --- /dev/null +++ b/src/simulators/measure_record_batch.cc @@ -0,0 +1,107 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_batch.h" + +#include + +#include "measure_record_batch_writer.h" + +MeasureRecordBatch::MeasureRecordBatch(size_t num_shots, size_t max_lookback) + : max_lookback(max_lookback), unwritten(0), stored(0), written(0), shot_mask(num_shots), storage(1, num_shots) { + for (size_t k = 0; k < num_shots; k++) { + shot_mask[k] = true; + } +} + +void MeasureRecordBatch::record_result(simd_bits_range_ref result) { + if (stored >= storage.num_major_bits_padded()) { + simd_bit_table new_storage(storage.num_major_bits_padded() * 2, storage.num_minor_bits_padded()); + new_storage.data.word_range_ref(0, storage.data.num_simd_words) = storage.data; + storage = std::move(new_storage); + } + storage[stored] = result; + storage[stored] &= shot_mask; + stored++; + unwritten++; +} + +simd_bits_range_ref MeasureRecordBatch::lookback(size_t lookback) const { + if (lookback > stored) { + throw std::out_of_range("Referred to a measurement record before the beginning of time."); + } + if (lookback == 0) { + throw std::out_of_range("Lookback must be non-zero."); + } + if (lookback > max_lookback) { + throw std::out_of_range("Referred to a measurement record past the lookback limit."); + } + return storage[stored - lookback]; +} + +void MeasureRecordBatch::mark_all_as_written() { + unwritten = 0; + size_t m = max_lookback; + if ((stored >> 1) > m) { + memcpy(storage.data.u8, storage[stored - m].u8, m * storage.num_minor_u8_padded()); + stored = m; + } +} + +void MeasureRecordBatch::intermediate_write_unwritten_results_to( + MeasureRecordBatchWriter &writer, simd_bits_range_ref ref_sample) { + while (unwritten >= 1024) { + auto slice = storage.slice_maj(stored - unwritten, stored - unwritten + 1024); + for (size_t k = 0; k < 1024; k++) { + size_t j = written + k; + if (j < ref_sample.num_bits_padded() && ref_sample[j]) { + slice[k] ^= shot_mask; + } + } + writer.batch_write_bytes(slice, 1024 >> 6); + unwritten -= 1024; + written += 1024; + } + + size_t m = std::max(max_lookback, unwritten); + if ((stored >> 1) > m) { + memcpy(storage.data.u8, storage[stored - m].u8, m * storage.num_minor_u8_padded()); + stored = m; + } +} + +void MeasureRecordBatch::final_write_unwritten_results_to( + MeasureRecordBatchWriter &writer, simd_bits_range_ref ref_sample) { + size_t n = stored; + for (size_t k = n - unwritten; k < n; k++) { + bool invert = written < ref_sample.num_bits_padded() && ref_sample[written]; + if (invert) { + storage[k] ^= shot_mask; + } + writer.batch_write_bit(storage[k]); + if (invert) { + storage[k] ^= shot_mask; + } + written++; + } + unwritten = 0; + writer.write_end(); +} + +void MeasureRecordBatch::clear() { + stored = 0; + unwritten = 0; +} diff --git a/src/simulators/measure_record_batch.h b/src/simulators/measure_record_batch.h new file mode 100644 index 000000000..53d1245fe --- /dev/null +++ b/src/simulators/measure_record_batch.h @@ -0,0 +1,67 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RECORD_BATCH_STORAGE_H +#define RECORD_BATCH_STORAGE_H + +#include "measure_record_batch_writer.h" + +/// Stores a record of multiple measurement streams that can be looked up and written to the external world. +/// +/// Results that have been written and are further back than `max_lookback` may be discarded from memory. +struct MeasureRecordBatch { + /// How far back into the measurement record a circuit being simulated may look. + /// Results younger than this cannot be discarded. + size_t max_lookback; + /// How many results have been recorded but not yet written to the external world. + /// Results younger than this cannot be discarded. + size_t unwritten; + /// How many results are currently stored (from each separate stream). + size_t stored; + /// How many results have been written to the external world. + size_t written; + /// For performance reasons, measurement data given to store may include non-zero values past the data corresponding + /// to the number of expected shots. AND-ing the data with this mask fixes the problem. + simd_bits shot_mask; + /// The 2-dimensional block of bits storing the measurement results from each separate measurement stream. + /// Major index is measurement index, minor index is shot index. + simd_bit_table storage; + + /// Constructs an empty MeasureRecordBatch configured for the given max_lookback and number of shots. + MeasureRecordBatch(size_t num_shots, size_t max_lookback); + + /// Allows measurements older than max_lookback to be discarded, even though they weren't written out. + /// + /// E.g. this is used during detection event sampling, when what is written is derived detection events. + void mark_all_as_written(); + /// Hints that measurements can be written to the given writer. + /// + /// For performance reasons, they may not be written until a large enough block has been accumulated. + void intermediate_write_unwritten_results_to(MeasureRecordBatchWriter &writer, simd_bits_range_ref ref_sample); + /// Forces measurements to be written to the given writer, and to tell the writer the measurements are ending. + void final_write_unwritten_results_to(MeasureRecordBatchWriter &writer, simd_bits_range_ref ref_sample); + /// Looks up a historical batch measurement. + /// + /// Returns: + /// A reference into the storage table, with the bit at offset k corresponding to the measurement from stream k. + simd_bits_range_ref lookback(size_t lookback) const; + /// Appends a batch measurement result into storage. + void record_result(simd_bits_range_ref result); + /// Resets the record to an empty state. + void clear(); +}; + +#endif diff --git a/src/simulators/measure_record_batch.test.cc b/src/simulators/measure_record_batch.test.cc new file mode 100644 index 000000000..f7e0971ef --- /dev/null +++ b/src/simulators/measure_record_batch.test.cc @@ -0,0 +1,74 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_batch.h" + +#include "gtest/gtest.h" + +#include "../test_util.test.h" + +TEST(MeasureRecordBatch, basic_usage) { + simd_bits s0(5); + simd_bits s1(5); + s0[0] = true; + s1[1] = true; + s0[2] = true; + s1[3] = true; + s0[4] = true; + MeasureRecordBatch r(5, 20); + ASSERT_EQ(r.stored, 0); + r.record_result(s0); + ASSERT_EQ(r.stored, 1); + ASSERT_EQ(r.lookback(1), s0); + r.record_result(s1); + ASSERT_EQ(r.stored, 2); + ASSERT_EQ(r.lookback(1), s1); + ASSERT_EQ(r.lookback(2), s0); + + for (size_t k = 0; k < 50; k++) { + r.record_result(s0); + r.record_result(s1); + } + ASSERT_EQ(r.unwritten, 102); + ASSERT_EQ(r.stored, 102); + FILE *tmp = tmpfile(); + MeasureRecordBatchWriter w(tmp, 5, SAMPLE_FORMAT_01); + r.intermediate_write_unwritten_results_to(w, simd_bits(0)); + ASSERT_EQ(r.unwritten, 102); + + for (size_t k = 0; k < 500; k++) { + r.record_result(s0); + r.record_result(s1); + } + ASSERT_EQ(r.unwritten, 1102); + ASSERT_EQ(r.stored, 1102); + r.intermediate_write_unwritten_results_to(w, simd_bits(0)); + ASSERT_LT(r.unwritten, 100); + ASSERT_LT(r.stored, 100); + r.final_write_unwritten_results_to(w, simd_bits(0)); + ASSERT_EQ(r.unwritten, 0); + ASSERT_LT(r.stored, 100); + + rewind(tmp); + for (size_t s = 0; s < 5; s++) { + simd_bits sk = (s & 1) ? s1 : s0; + for (size_t k = 0; k < 1102; k++) { + ASSERT_EQ(getc(tmp), '0' + ((s + k + 1) & 1)); + } + ASSERT_EQ(getc(tmp), '\n'); + } + ASSERT_EQ(getc(tmp), EOF); +} diff --git a/src/simulators/measure_record_batch_writer.cc b/src/simulators/measure_record_batch_writer.cc new file mode 100644 index 000000000..2d3e554eb --- /dev/null +++ b/src/simulators/measure_record_batch_writer.cc @@ -0,0 +1,104 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_batch_writer.h" + +#include + +#include "measure_record_batch.h" + +MeasureRecordBatchWriter::MeasureRecordBatchWriter(FILE *out, size_t num_shots, SampleFormat output_format) + : output_format(output_format), out(out) { + auto f = output_format; + auto s = num_shots; + if (output_format == SAMPLE_FORMAT_PTB64) { + f = SAMPLE_FORMAT_B8; + s += 63; + s /= 64; + } + if (s) { + writers.push_back(MeasureRecordWriter::make(out, f)); + } + for (size_t k = 1; k < s; k++) { + FILE *file = tmpfile(); + writers.push_back(MeasureRecordWriter::make(file, f)); + temporary_files.push_back(file); + } +} + +MeasureRecordBatchWriter::~MeasureRecordBatchWriter() { + for (auto &e : temporary_files) { + fclose(e); + } + temporary_files.clear(); +} + +void MeasureRecordBatchWriter::begin_result_type(char result_type) { + for (auto &e : writers) { + e->begin_result_type(result_type); + } +} + +void MeasureRecordBatchWriter::batch_write_bit(simd_bits_range_ref bits) { + if (output_format == SAMPLE_FORMAT_PTB64) { + uint8_t *p = bits.u8; + for (auto &writer : writers) { + uint8_t *n = p + 8; + writer->write_bytes({p, n}); + p = n; + } + } else { + for (size_t k = 0; k < writers.size(); k++) { + writers[k]->write_bit(bits[k]); + } + } +} + +void MeasureRecordBatchWriter::batch_write_bytes(const simd_bit_table &table, size_t num_major_u64) { + if (output_format == SAMPLE_FORMAT_PTB64) { + for (size_t k = 0; k < writers.size(); k++) { + for (size_t w = 0; w < num_major_u64; w++) { + uint8_t *p = table.data.u8 + (k * 8) + table.num_minor_u8_padded() * w; + writers[k]->write_bytes({p, p + 8}); + } + } + } else { + auto transposed = table.transposed(); + for (size_t k = 0; k < writers.size(); k++) { + uint8_t *p = transposed[k].u8; + writers[k]->write_bytes({p, p + num_major_u64 * 8}); + } + } +} + +void MeasureRecordBatchWriter::write_end() { + for (auto &writer : writers) { + writer->write_end(); + } + + for (FILE *file : temporary_files) { + rewind(file); + while (true) { + int c = getc(file); + if (c == EOF) { + break; + } + putc(c, out); + } + fclose(file); + } + temporary_files.clear(); +} diff --git a/src/simulators/measure_record_batch_writer.h b/src/simulators/measure_record_batch_writer.h new file mode 100644 index 000000000..c59d01d14 --- /dev/null +++ b/src/simulators/measure_record_batch_writer.h @@ -0,0 +1,60 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RECORD_BATCH_WRITER_H +#define RECORD_BATCH_WRITER_H + +#include "../simd/simd_bit_table.h" +#include "measure_record_writer.h" + +/// Handles buffering and writing multiple measurement data streams that ultimately need to be concatenated. +struct MeasureRecordBatchWriter { + SampleFormat output_format; + FILE *out; + /// Temporary files used to hold data that will eventually be concatenated onto the main stream. + std::vector temporary_files; + /// The individual writers for each incoming stream of measurement results. + /// The first writer will go directly to `out`, whereas the others go into temporary files. + std::vector> writers; + + MeasureRecordBatchWriter(FILE *out, size_t num_shots, SampleFormat output_format); + /// Cleans up temporary files. + ~MeasureRecordBatchWriter(); + /// See MeasureRecordWriter::begin_result_type. + void begin_result_type(char result_type); + + /// Writes a separate measurement result to each MeasureRecordWriter. + /// + /// Args: + /// bits: The measurement results. The bit at offset k is the bit for the writer at offset k. + void batch_write_bit(simd_bits_range_ref bits); + /// Writes multiple separate measurement results to each MeasureRecordWriter. + /// + /// This method can be called after calling `batch_write_bit`, but for performance reasons it is recommended to not + /// do this since it can result in the individual writers doing extra work due to not being on byte boundaries. + /// + /// Args: + /// table: The measurement results. + /// The bits at minor offset k, from major offset 0 to major offset 64*num_major_u64, are the bits for the + /// writer at offset k. + /// num_major_u64: The number of measurement results (divided by 64) for each writer. The actual number of + /// results is required to be a multiple of 64 for performance reasons. + void batch_write_bytes(const simd_bit_table &table, size_t num_major_u64); + /// Tells each writer to finish up, then concatenates all of their data into the `out` stream and cleans up. + void write_end(); +}; + +#endif diff --git a/src/simulators/measure_record_batch_writer.test.cc b/src/simulators/measure_record_batch_writer.test.cc new file mode 100644 index 000000000..2e0e6a88c --- /dev/null +++ b/src/simulators/measure_record_batch_writer.test.cc @@ -0,0 +1,41 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_batch_writer.h" + +#include "gtest/gtest.h" + +#include "../test_util.test.h" + +TEST(MeasureRecordBatchWriter, basic_usage) { + FILE *tmp = tmpfile(); + MeasureRecordBatchWriter w(tmp, 5, SAMPLE_FORMAT_01); + simd_bits v(5); + v[1] = true; + w.batch_write_bit(v); + w.write_end(); + rewind(tmp); + ASSERT_EQ(getc(tmp), '0'); + ASSERT_EQ(getc(tmp), '\n'); + ASSERT_EQ(getc(tmp), '1'); + ASSERT_EQ(getc(tmp), '\n'); + ASSERT_EQ(getc(tmp), '0'); + ASSERT_EQ(getc(tmp), '\n'); + ASSERT_EQ(getc(tmp), '0'); + ASSERT_EQ(getc(tmp), '\n'); + ASSERT_EQ(getc(tmp), '0'); + ASSERT_EQ(getc(tmp), '\n'); +} diff --git a/src/simulators/measure_record_writer.cc b/src/simulators/measure_record_writer.cc new file mode 100644 index 000000000..0f6039d7d --- /dev/null +++ b/src/simulators/measure_record_writer.cc @@ -0,0 +1,262 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_writer.h" + +#include + +std::unique_ptr MeasureRecordWriter::make(FILE *out, SampleFormat output_format) { + switch (output_format) { + case SAMPLE_FORMAT_01: + return std::unique_ptr(new MeasureRecordWriterFormat01(out)); + case SAMPLE_FORMAT_B8: + return std::unique_ptr(new MeasureRecordWriterFormatB8(out)); + case SAMPLE_FORMAT_DETS: + return std::unique_ptr(new MeasureRecordFormatDets(out)); + case SAMPLE_FORMAT_HITS: + return std::unique_ptr(new MeasureRecordWriterFormatHits(out)); + case SAMPLE_FORMAT_PTB64: + throw std::invalid_argument("SAMPLE_FORMAT_PTB64 incompatible with SingleMeasurementRecord"); + case SAMPLE_FORMAT_R8: + return std::unique_ptr(new MeasureRecordFormatR8(out)); + default: + throw std::invalid_argument("Sample format not recognized by SingleMeasurementRecord"); + } +} + +void MeasureRecordWriter::begin_result_type(char result_type) { +} + +void MeasureRecordWriter::write_bytes(ConstPointerRange data) { + for (uint8_t b : data) { + for (size_t k = 0; k < 8; k++) { + write_bit((b >> k) & 1); + } + } +} + +MeasureRecordWriterFormat01::MeasureRecordWriterFormat01(FILE *out) : out(out) { +} + +void MeasureRecordWriterFormat01::write_bit(bool b) { + putc('0' + b, out); +} + +void MeasureRecordWriterFormat01::write_end() { + putc('\n', out); +} + +MeasureRecordWriterFormatB8::MeasureRecordWriterFormatB8(FILE *out) : out(out) { +} + +void MeasureRecordWriterFormatB8::write_bytes(ConstPointerRange data) { + if (count == 0) { + fwrite(data.ptr_start, sizeof(uint8_t), data.ptr_end - data.ptr_start, out); + } else { + MeasureRecordWriter::write_bytes(data); + } +} + +void MeasureRecordWriterFormatB8::write_bit(bool b) { + payload |= uint8_t{b} << count; + count++; + if (count == 8) { + putc(payload, out); + count = 0; + payload = 0; + } +} + +void MeasureRecordWriterFormatB8::write_end() { + if (count > 0) { + putc(payload, out); + count = 0; + payload = 0; + } +} + +MeasureRecordWriterFormatHits::MeasureRecordWriterFormatHits(FILE *out) : out(out) { +} + +void MeasureRecordWriterFormatHits::write_bytes(ConstPointerRange data) { + for (uint8_t b : data) { + if (!b) { + position += 8; + } else { + for (size_t k = 0; k < 8; k++) { + write_bit((b >> k) & 1); + } + } + } +} + +void MeasureRecordWriterFormatHits::write_bit(bool b) { + if (b) { + if (first) { + first = false; + } else { + putc(',', out); + } + fprintf(out, "%lld", (unsigned long long)(position)); + } + position++; +} + +void MeasureRecordWriterFormatHits::write_end() { + putc('\n', out); + position = 0; + first = true; +} + +MeasureRecordFormatR8::MeasureRecordFormatR8(FILE *out) : out(out) { +} + +void MeasureRecordFormatR8::write_bytes(ConstPointerRange data) { + for (uint8_t b : data) { + if (!b) { + run_length += 8; + if (run_length >= 0xFF) { + putc(0xFF, out); + run_length -= 0xFF; + } + } else { + for (size_t k = 0; k < 8; k++) { + write_bit((b >> k) & 1); + } + } + } +} + +void MeasureRecordFormatR8::write_bit(bool b) { + if (b) { + putc(run_length, out); + run_length = 0; + } else { + run_length++; + if (run_length == 255) { + putc(run_length, out); + run_length = 0; + } + } +} + +void MeasureRecordFormatR8::write_end() { + putc(run_length, out); + run_length = 0; +} + +MeasureRecordFormatDets::MeasureRecordFormatDets(FILE *out) : out(out) { + fprintf(out, "shot"); +} + +void MeasureRecordFormatDets::begin_result_type(char new_result_type) { + result_type = new_result_type; + position = 0; +} + +void MeasureRecordFormatDets::write_bytes(ConstPointerRange data) { + for (uint8_t b : data) { + if (!b) { + position += 8; + } else { + for (size_t k = 0; k < 8; k++) { + write_bit((b >> k) & 1); + } + } + } +} + +void MeasureRecordFormatDets::write_bit(bool b) { + if (b) { + putc(' ', out); + putc(result_type, out); + fprintf(out, "%lld", (unsigned long long)(position)); + } + position++; +} + +void MeasureRecordFormatDets::write_end() { + putc('\n', out); + position = 0; +} + +simd_bit_table transposed_vs_ref( + size_t num_samples_raw, const simd_bit_table &table, const simd_bits &reference_sample) { + auto result = table.transposed(); + for (size_t s = 0; s < num_samples_raw; s++) { + result[s].word_range_ref(0, reference_sample.num_simd_words) ^= reference_sample; + } + return result; +} + +void write_table_data( + FILE *out, size_t num_shots, size_t num_measurements, const simd_bits &reference_sample, + const simd_bit_table &table, SampleFormat format, char dets_prefix_1, char dets_prefix_2, + size_t dets_prefix_transition) { + if (format == SAMPLE_FORMAT_PTB64) { + auto f64 = num_shots >> 6; + for (size_t s = 0; s < f64; s++) { + for (size_t m = 0; m < num_measurements; m++) { + uint64_t v = table[m].u64[s]; + if (m < reference_sample.num_bits_padded() && reference_sample[m]) { + v = ~v; + } + fwrite(&v, 1, 64 >> 3, out); + } + } + if (num_shots & 63) { + uint64_t mask = (uint64_t{1} << (num_shots & 63)) - 1ULL; + for (size_t m = 0; m < num_measurements; m++) { + uint64_t v = table[m].u64[f64]; + if (m < reference_sample.num_bits_padded() && reference_sample[m]) { + v = ~v; + } + v &= mask; + fwrite(&v, 1, 64 >> 3, out); + } + } + return; + } else { + auto result = transposed_vs_ref(num_shots, table, reference_sample); + if (dets_prefix_transition == 0) { + dets_prefix_transition = num_measurements; + dets_prefix_1 = dets_prefix_2; + } else if (dets_prefix_1 == dets_prefix_2 || dets_prefix_transition >= num_measurements) { + dets_prefix_transition = num_measurements; + } + for (size_t shot = 0; shot < num_shots; shot++) { + auto w = MeasureRecordWriter::make(out, format); + + w->begin_result_type(dets_prefix_1); + size_t n8 = dets_prefix_transition >> 3; + uint8_t *p = result[shot].u8; + w->write_bytes({p, p + n8}); + size_t m = n8 << 3; + while (m < dets_prefix_transition) { + w->write_bit(result[shot][m]); + m++; + } + + w->begin_result_type(dets_prefix_2); + while (m < num_measurements) { + w->write_bit(result[shot][m]); + m++; + } + + w->write_end(); + } + } +} diff --git a/src/simulators/measure_record_writer.h b/src/simulators/measure_record_writer.h new file mode 100644 index 000000000..2684ec404 --- /dev/null +++ b/src/simulators/measure_record_writer.h @@ -0,0 +1,103 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RECORD_WRITER_H +#define RECORD_WRITER_H + +#include + +#include "../circuit/circuit.h" +#include "../simd/pointer_range.h" +#include "../simd/simd_bit_table.h" + +/// Handles writing measurement data to the outside world. +/// +/// Child classes implement the various output formats. +struct MeasureRecordWriter { + /// Creates a MeasureRecordWriter that writes the given format into the given FILE*. + static std::unique_ptr make(FILE *out, SampleFormat output_format); + /// Writes (or buffers) one measurement result. + virtual void write_bit(bool b) = 0; + /// Writes (or buffers) multiple measurement results. + virtual void write_bytes(ConstPointerRange data); + /// Flushes all buffered measurement results and writes any end-of-result markers that are needed (e.g. a newline). + virtual void write_end() = 0; + /// Used to control the DETS format prefix character (M for measurement, D for detector, L for logical observable). + /// + /// Setting this is understood to reset the "result index" back to 0 so that e.g. listing logical observables after + /// detectors results in the first logical observable being L0 instead of L[number-of-detectors]. + virtual void begin_result_type(char result_type); +}; + +struct MeasureRecordWriterFormat01 : MeasureRecordWriter { + FILE *out; + MeasureRecordWriterFormat01(FILE *out); + void write_bit(bool b) override; + void write_end() override; +}; + +struct MeasureRecordWriterFormatB8 : MeasureRecordWriter { + FILE *out; + uint8_t payload = 0; + uint8_t count = 0; + MeasureRecordWriterFormatB8(FILE *out); + void write_bytes(ConstPointerRange data) override; + void write_bit(bool b) override; + void write_end() override; +}; + +struct MeasureRecordWriterFormatHits : MeasureRecordWriter { + FILE *out; + uint64_t position = 0; + bool first = true; + + MeasureRecordWriterFormatHits(FILE *out); + void write_bytes(ConstPointerRange data) override; + void write_bit(bool b) override; + void write_end() override; +}; + +struct MeasureRecordFormatR8 : MeasureRecordWriter { + FILE *out; + uint16_t run_length = 0; + + MeasureRecordFormatR8(FILE *out); + void write_bytes(ConstPointerRange data) override; + void write_bit(bool b) override; + void write_end() override; +}; + +struct MeasureRecordFormatDets : MeasureRecordWriter { + FILE *out; + uint64_t position = 0; + char result_type = 'M'; + + MeasureRecordFormatDets(FILE *out); + void begin_result_type(char result_type) override; + void write_bytes(ConstPointerRange data) override; + void write_bit(bool b) override; + void write_end() override; +}; + +simd_bit_table transposed_vs_ref( + size_t num_samples_raw, const simd_bit_table &table, const simd_bits &reference_sample); + +void write_table_data( + FILE *out, size_t num_shots, size_t num_measurements, const simd_bits &reference_sample, + const simd_bit_table &table, SampleFormat format, char dets_prefix_1, char dets_prefix_2, + size_t dets_prefix_transition); + +#endif diff --git a/src/simulators/measure_record_writer.test.cc b/src/simulators/measure_record_writer.test.cc new file mode 100644 index 000000000..1646ca419 --- /dev/null +++ b/src/simulators/measure_record_writer.test.cc @@ -0,0 +1,356 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "measure_record_writer.h" + +#include "gtest/gtest.h" + +#include "../test_util.test.h" + +static std::string rewind_read_all(FILE *f) { + rewind(f); + std::string result; + while (true) { + int c = getc(f); + if (c == EOF) { + return result; + } + result.push_back((char)c); + } + fclose(f); +} + +TEST(MeasureRecordWriter, Format01) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_01); + uint8_t bytes[]{0xF8}; + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(false); + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(true); + writer->write_end(); + ASSERT_EQ(rewind_read_all(tmp), "000111110000111111\n"); +} + +TEST(MeasureRecordWriter, FormatB8) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_B8); + uint8_t bytes[]{0xF8}; + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(false); + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(true); + writer->write_end(); + auto s = rewind_read_all(tmp); + ASSERT_EQ(s.size(), 3); + ASSERT_EQ(s[0], (char)0xF8); + ASSERT_EQ(s[1], (char)0xF0); + ASSERT_EQ(s[2], (char)0x03); +} + +TEST(MeasureRecordWriter, FormatHits) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_HITS); + uint8_t bytes[]{0xF8}; + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(false); + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(true); + writer->write_end(); + ASSERT_EQ(rewind_read_all(tmp), "3,4,5,6,7,12,13,14,15,16,17\n"); +} + +TEST(MeasureRecordWriter, FormatDets) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_DETS); + uint8_t bytes[]{0xF8}; + writer->begin_result_type('D'); + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(false); + writer->write_bytes({bytes, bytes + 1}); + writer->begin_result_type('L'); + writer->write_bit(false); + writer->write_bit(true); + writer->write_end(); + ASSERT_EQ(rewind_read_all(tmp), "shot D3 D4 D5 D6 D7 D12 D13 D14 D15 D16 L1\n"); +} + +TEST(MeasureRecordWriter, FormatR8) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_R8); + uint8_t bytes[]{0xF8}; + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(false); + writer->write_bytes({bytes, bytes + 1}); + writer->write_bit(true); + writer->write_end(); + auto s = rewind_read_all(tmp); + ASSERT_EQ(s.size(), 12); + ASSERT_EQ(s[0], (char)3); + ASSERT_EQ(s[1], (char)0); + ASSERT_EQ(s[2], (char)0); + ASSERT_EQ(s[3], (char)0); + ASSERT_EQ(s[4], (char)0); + ASSERT_EQ(s[5], (char)4); + ASSERT_EQ(s[6], (char)0); + ASSERT_EQ(s[7], (char)0); + ASSERT_EQ(s[8], (char)0); + ASSERT_EQ(s[9], (char)0); + ASSERT_EQ(s[10], (char)0); + ASSERT_EQ(s[11], (char)0); +} + +TEST(MeasureRecordWriter, FormatR8_LongGap) { + FILE *tmp = tmpfile(); + auto writer = MeasureRecordWriter::make(tmp, SAMPLE_FORMAT_R8); + uint8_t bytes[]{0, 0, 0, 0, 0, 0, 0, 0}; + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bytes({bytes, bytes + 8}); + writer->write_bit(true); + writer->write_bytes({bytes, bytes + 4}); + writer->write_end(); + auto s = rewind_read_all(tmp); + ASSERT_EQ(s.size(), 4); + ASSERT_EQ(s[0], (char)255); + ASSERT_EQ(s[1], (char)255); + ASSERT_EQ(s[2], (char)2); + ASSERT_EQ(s[3], (char)32); +} + +TEST(MeasureRecordWriter, write_table_data_small) { + simd_bit_table results(4, 5); + simd_bits ref_sample(0); + results[1][0] ^= 1; + results[1][1] ^= 1; + results[1][2] ^= 1; + results[1][3] ^= 1; + results[1][4] ^= 1; + + FILE *tmp; + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_01, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "0100\n0100\n0100\n0100\n0100\n"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_HITS, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "1\n1\n1\n1\n1\n"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_DETS, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "shot M1\nshot M1\nshot M1\nshot M1\nshot M1\n"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_DETS, 'D', 'L', 1); + ASSERT_EQ(rewind_read_all(tmp), "shot L0\nshot L0\nshot L0\nshot L0\nshot L0\n"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_R8, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "\x01\x02\x01\x02\x01\x02\x01\x02\x01\x02"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_B8, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "\x02\x02\x02\x02\x02"); + + tmp = tmpfile(); + write_table_data(tmp, 5, 4, ref_sample, results, SAMPLE_FORMAT_PTB64, 'M', 'M', 0); + ASSERT_EQ( + rewind_read_all(tmp), std::string( + "\0\0\0\0\0\0\0\0" + "\x1F\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0", + 8 * 4)); +} + +TEST(MeasureRecordWriter, write_table_data_large) { + simd_bit_table results(100, 2); + simd_bits ref_sample(100); + ref_sample[2] ^= true; + ref_sample[3] ^= true; + ref_sample[5] ^= true; + ref_sample[7] ^= true; + ref_sample[11] ^= true; + results[7][1] ^= true; + + FILE *tmp; + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_01, 'M', 'M', 0); + ASSERT_EQ( + rewind_read_all(tmp), + "0011010100" + "0100000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000\n" + "0011010000" + "0100000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000" + "0000000000\n"); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_HITS, 'M', 'M', 0); + ASSERT_EQ(rewind_read_all(tmp), "2,3,5,7,11\n2,3,5,11\n"); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_DETS, 'D', 'L', 5); + ASSERT_EQ(rewind_read_all(tmp), "shot D2 D3 L0 L2 L6\nshot D2 D3 L0 L6\n"); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_DETS, 'D', 'L', 90); + ASSERT_EQ(rewind_read_all(tmp), "shot D2 D3 D5 D7 D11\nshot D2 D3 D5 D11\n"); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_R8, 'M', 'M', 0); + ASSERT_EQ( + rewind_read_all(tmp), std::string( + "\x02\x00\x01\x01\x03\x58" + "\x02\x00\x01\x05\x58", + 11)); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_B8, 'M', 'M', 0); + ASSERT_EQ( + rewind_read_all(tmp), std::string( + "\xAC\x08\0\0\0\0\0\0\0\0\0\0\0" + "\x2C\x08\0\0\0\0\0\0\0\0\0\0\0", + 26)); + + tmp = tmpfile(); + write_table_data(tmp, 2, 100, ref_sample, results, SAMPLE_FORMAT_PTB64, 'M', 'M', 0); + auto actual = rewind_read_all(tmp); + ASSERT_EQ( + rewind_read_all(tmp), std::string( + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\x03\0\0\0\0\0\0\0" + "\x03\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\x03\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\x01\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\x03\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0" + "\0\0\0\0\0\0\0\0", + 8 * 100)); +} diff --git a/src/simulators/tableau_simulator.cc b/src/simulators/tableau_simulator.cc index 21172750a..447b46364 100644 --- a/src/simulators/tableau_simulator.cc +++ b/src/simulators/tableau_simulator.cc @@ -17,11 +17,11 @@ #include "../circuit/gate_data.h" #include "../probability_util.h" -TableauSimulator::TableauSimulator(size_t num_qubits, std::mt19937_64 &rng, int8_t sign_bias) +TableauSimulator::TableauSimulator(size_t num_qubits, std::mt19937_64 &rng, int8_t sign_bias, MeasureRecord record) : inv_state(Tableau::identity(num_qubits)), rng(rng), sign_bias(sign_bias), - measurement_record(), + measurement_record(record), last_correlated_error_occurred(false) { } @@ -38,7 +38,7 @@ void TableauSimulator::measure(const OperationData &target_data) { auto q = qf & TARGET_VALUE_MASK; bool flipped = qf & TARGET_INVERTED_BIT; bool b = inv_state.zs.signs[q] ^ flipped; - measurement_record.push_back(b); + measurement_record.record_result(b); } } @@ -53,7 +53,7 @@ void TableauSimulator::measure_reset(const OperationData &target_data) { auto q = qf & TARGET_VALUE_MASK; bool flipped = qf & TARGET_INVERTED_BIT; bool b = inv_state.zs.signs[q] ^ flipped; - measurement_record.push_back(b); + measurement_record.record_result(b); inv_state.zs.signs[q] = false; } } @@ -141,11 +141,8 @@ void TableauSimulator::SQRT_Y_DAG(const OperationData &target_data) { } bool TableauSimulator::read_measurement_record(uint32_t encoded_target) const { - uint8_t b = encoded_target ^ TARGET_RECORD_BIT; - if (b == 0 || b > measurement_record.size()) { - throw std::out_of_range("Referred to a measurement record before the beginning of time."); - } - return measurement_record[measurement_record.size() - b]; + assert(encoded_target & TARGET_RECORD_BIT); + return measurement_record.lookback(encoded_target ^ TARGET_RECORD_BIT); } void TableauSimulator::single_cx(uint32_t c, uint32_t t) { @@ -382,15 +379,15 @@ void TableauSimulator::Z(const OperationData &target_data) { } simd_bits TableauSimulator::sample_circuit(const Circuit &circuit, std::mt19937_64 &rng, int8_t sign_bias) { - TableauSimulator sim(circuit.num_qubits, rng, sign_bias); - for (const auto &op : circuit.operations) { + TableauSimulator sim(circuit.count_qubits(), rng, sign_bias); + circuit.for_each_operation([&](const Operation &op) { (sim.*op.gate->tableau_simulator_function)(op.target_data); - } + }); - assert(sim.measurement_record.size() == circuit.num_measurements); - simd_bits result(circuit.num_measurements); - for (size_t k = 0; k < circuit.num_measurements; k++) { - result[k] = sim.measurement_record[k]; + const std::vector &v = sim.measurement_record.storage; + simd_bits result(v.size()); + for (size_t k = 0; k < v.size(); k++) { + result[k] = v[k]; } return result; } @@ -402,29 +399,28 @@ void TableauSimulator::ensure_large_enough_for_qubits(size_t num_qubits) { inv_state.expand(num_qubits); } -void TableauSimulator::sample_stream(FILE *in, FILE *out, bool newline_after_measurements, std::mt19937_64 &rng) { - Circuit unprocessed; +void TableauSimulator::sample_stream(FILE *in, FILE *out, SampleFormat format, bool interactive, std::mt19937_64 &rng) { TableauSimulator sim(1, rng); - size_t reported = 0; - while (unprocessed.append_from_file(in, newline_after_measurements)) { - sim.ensure_large_enough_for_qubits(unprocessed.num_qubits); + auto writer = MeasureRecordWriter::make(out, format); + Circuit unprocessed; + while (true) { + unprocessed.clear(); + unprocessed.append_from_file(in, true); + if (unprocessed.operations.empty()) { + break; + } + sim.ensure_large_enough_for_qubits(unprocessed.count_qubits()); - for (const auto &op : unprocessed.operations) { + unprocessed.for_each_operation([&](const Operation &op) { (sim.*op.gate->tableau_simulator_function)(op.target_data); - while (reported < sim.measurement_record.size()) { - putc('0' + sim.measurement_record[reported++], out); - } - if (newline_after_measurements && (op.gate->flags & GATE_PRODUCES_RESULTS)) { + sim.measurement_record.write_unwritten_results_to(*writer); + if (interactive && (op.gate->flags & GATE_PRODUCES_RESULTS)) { putc('\n', out); fflush(out); } - } - - unprocessed.clear(); - } - if (!newline_after_measurements) { - putc('\n', out); + }); } + writer->write_end(); } VectorSimulator TableauSimulator::to_vector_sim() const { @@ -491,15 +487,21 @@ void TableauSimulator::collapse_qubit(size_t target, TableauTransposedRaii &tran } } -simd_bits TableauSimulator::reference_sample_circuit(const Circuit &circuit) { - Circuit filtered; - std::vector deterministic_operations{}; +Circuit aliased_noiseless_subset(const Circuit &circuit) { + // HACK: result has pointers into `circuit`! + Circuit result; for (const auto &op : circuit.operations) { if (!(op.gate->flags & GATE_IS_NOISE)) { - filtered.append_operation(op); + result.operations.push_back(op); } } + for (const auto &block : circuit.blocks) { + result.blocks.push_back(aliased_noiseless_subset(block)); + } + return result; +} +simd_bits TableauSimulator::reference_sample_circuit(const Circuit &circuit) { std::mt19937_64 irrelevant_rng(0); - return TableauSimulator::sample_circuit(filtered, irrelevant_rng, +1); + return TableauSimulator::sample_circuit(aliased_noiseless_subset(circuit), irrelevant_rng, +1); } diff --git a/src/simulators/tableau_simulator.h b/src/simulators/tableau_simulator.h index aa7ef113a..6f87f078e 100644 --- a/src/simulators/tableau_simulator.h +++ b/src/simulators/tableau_simulator.h @@ -27,27 +27,30 @@ #include "../circuit/circuit.h" #include "../stabilizers/tableau.h" #include "../stabilizers/tableau_transposed_raii.h" +#include "measure_record.h" #include "vector_simulator.h" struct TableauSimulator { Tableau inv_state; std::mt19937_64 &rng; int8_t sign_bias; - std::vector measurement_record; + MeasureRecord measurement_record; bool last_correlated_error_occurred; /// Args: /// num_qubits: The initial number of qubits in the simulator state. /// rng: The random number generator to use for random operations. /// sign_bias: 0 means collapse randomly, -1 means collapse towards True, +1 means collapse towards False. - explicit TableauSimulator(size_t num_qubits, std::mt19937_64 &rng, int8_t sign_bias = 0); + /// record: Measurement record configuration. + explicit TableauSimulator( + size_t num_qubits, std::mt19937_64 &rng, int8_t sign_bias = 0, MeasureRecord record = MeasureRecord()); /// Samples the given circuit in a deterministic fashion. /// /// Discards all noisy operations, and biases all collapse events towards +Z instead of randomly +Z/-Z. static simd_bits reference_sample_circuit(const Circuit &circuit); static simd_bits sample_circuit(const Circuit &circuit, std::mt19937_64 &rng, int8_t sign_bias = 0); - static void sample_stream(FILE *in, FILE *out, bool newline_after_measurements, std::mt19937_64 &rng); + static void sample_stream(FILE *in, FILE *out, SampleFormat format, bool interactive, std::mt19937_64 &rng); /// Expands the internal state of the simulator (if needed) to ensure the given qubit exists. /// diff --git a/src/simulators/tableau_simulator.perf.cc b/src/simulators/tableau_simulator.perf.cc index 54460fb1b..b2960a4bd 100644 --- a/src/simulators/tableau_simulator.perf.cc +++ b/src/simulators/tableau_simulator.perf.cc @@ -15,7 +15,6 @@ #include "tableau_simulator.h" #include "../benchmark_util.h" -#include "../circuit/gate_data.h" BENCHMARK(TableauSimulator_CX_10Kqubits) { size_t num_qubits = 10 * 1000; @@ -26,7 +25,7 @@ BENCHMARK(TableauSimulator_CX_10Kqubits) { for (size_t k = 0; k < num_qubits; k++) { targets.push_back(k); } - OperationData op_data{0, {&targets, 0, targets.size()}}; + OperationData op_data{0, targets}; benchmark_go([&]() { sim.ZCX(op_data); diff --git a/src/simulators/tableau_simulator.pybind.cc b/src/simulators/tableau_simulator.pybind.cc index ec033fb61..4c83deb98 100644 --- a/src/simulators/tableau_simulator.pybind.cc +++ b/src/simulators/tableau_simulator.pybind.cc @@ -22,10 +22,8 @@ struct TempViewableData { std::vector targets; TempViewableData(std::vector targets) : targets(std::move(targets)) { } - operator OperationData() const { - // Temporarily remove const correctness but then immediately restore it. - VectorView v{(std::vector *)&targets, 0, targets.size()}; - return {0, v}; + operator OperationData() { + return {0, targets}; } }; @@ -61,8 +59,7 @@ TempViewableData args_to_target_pairs(TableauSimulator &self, const pybind11::ar void pybind_tableau_simulator(pybind11::module &m) { pybind11::class_( - m, - "TableauSimulator", + m, "TableauSimulator", R"DOC( A quantum stabilizer circuit simulator whose internal state is an inverse stabilizer tableau. @@ -133,7 +130,7 @@ void pybind_tableau_simulator(pybind11::module &m) { .def( "current_measurement_record", [](TableauSimulator &self) { - return self.measurement_record; + return self.measurement_record.storage; }, R"DOC( Returns a copy of the record of all measurements performed by the simulator. @@ -160,10 +157,10 @@ void pybind_tableau_simulator(pybind11::module &m) { .def( "do", [](TableauSimulator &self, const Circuit &circuit) { - self.ensure_large_enough_for_qubits(circuit.num_qubits); - for (const auto &op : circuit.operations) { + self.ensure_large_enough_for_qubits(circuit.count_qubits()); + circuit.for_each_operation([&](const Operation &op) { (self.*op.gate->tableau_simulator_function)(op.target_data); - } + }); }, pybind11::arg("circuit"), R"DOC( @@ -460,7 +457,7 @@ void pybind_tableau_simulator(pybind11::module &m) { .def( "ycz", [](TableauSimulator &self, pybind11::args args) { - self.YCZ(args_to_target_pairs(self, args)); + self.YCZ(args_to_target_pairs(self, args)); }, R"DOC( Applies a Y-controlled Z gate to the simulator's state. @@ -485,7 +482,7 @@ void pybind_tableau_simulator(pybind11::module &m) { "measure", [](TableauSimulator &self, uint32_t target) { self.measure(TempViewableData({target})); - return (bool)self.measurement_record.back(); + return (bool)self.measurement_record.storage.back(); }, pybind11::arg("target"), R"DOC( @@ -508,7 +505,7 @@ void pybind_tableau_simulator(pybind11::module &m) { [](TableauSimulator &self, pybind11::args args) { auto converted_args = args_to_targets(self, args); self.measure(converted_args); - auto e = self.measurement_record.end(); + auto e = self.measurement_record.storage.end(); return std::vector(e - converted_args.targets.size(), e); }, R"DOC( diff --git a/src/simulators/tableau_simulator.test.cc b/src/simulators/tableau_simulator.test.cc index e5d8423b2..e0d417391 100644 --- a/src/simulators/tableau_simulator.test.cc +++ b/src/simulators/tableau_simulator.test.cc @@ -22,11 +22,11 @@ TEST(TableauSimulator, identity) { auto s = TableauSimulator(1, SHARED_TEST_RNG()); - ASSERT_EQ(s.measurement_record, (std::vector{})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{})); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record, (std::vector{false})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{false})); s.measure(OpDat::flipped(0)); - ASSERT_EQ(s.measurement_record, (std::vector{false, true})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{false, true})); } TEST(TableauSimulator, bit_flip) { @@ -38,15 +38,15 @@ TEST(TableauSimulator, bit_flip) { s.measure(OpDat(0)); s.X(OpDat(0)); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record, (std::vector{true, false})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{true, false})); } TEST(TableauSimulator, identity2) { auto s = TableauSimulator(2, SHARED_TEST_RNG()); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record, (std::vector{false})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{false})); s.measure(OpDat(1)); - ASSERT_EQ(s.measurement_record, (std::vector{false, false})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{false, false})); } TEST(TableauSimulator, bit_flip_2) { @@ -56,9 +56,9 @@ TEST(TableauSimulator, bit_flip_2) { s.SQRT_Z(OpDat(0)); s.H_XZ(OpDat(0)); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record, (std::vector{true})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{true})); s.measure(OpDat(1)); - ASSERT_EQ(s.measurement_record, (std::vector{true, false})); + ASSERT_EQ(s.measurement_record.storage, (std::vector{true, false})); } TEST(TableauSimulator, epr) { @@ -71,7 +71,7 @@ TEST(TableauSimulator, epr) { ASSERT_EQ(s.is_deterministic(0), true); ASSERT_EQ(s.is_deterministic(1), true); s.measure(OpDat(1)); - ASSERT_EQ(s.measurement_record[0], s.measurement_record[1]); + ASSERT_EQ(s.measurement_record.storage[0], s.measurement_record.storage[1]); } TEST(TableauSimulator, big_determinism) { @@ -92,7 +92,7 @@ TEST(TableauSimulator, phase_kickback_consume_s_state) { s.ZCX(OpDat({0, 1})); ASSERT_EQ(s.is_deterministic(1), false); s.measure(OpDat(1)); - auto v1 = s.measurement_record.back(); + auto v1 = s.measurement_record.storage.back(); if (v1) { s.SQRT_Z(OpDat(0)); s.SQRT_Z(OpDat(0)); @@ -101,7 +101,7 @@ TEST(TableauSimulator, phase_kickback_consume_s_state) { s.H_XZ(OpDat(0)); ASSERT_EQ(s.is_deterministic(0), true); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record.back(), true); + ASSERT_EQ(s.measurement_record.storage.back(), true); } } @@ -126,12 +126,12 @@ TEST(TableauSimulator, phase_kickback_preserve_s_state) { s.H_XZ(OpDat(0)); ASSERT_EQ(s.is_deterministic(0), true); s.measure(OpDat(0)); - ASSERT_EQ(s.measurement_record.back(), true); + ASSERT_EQ(s.measurement_record.storage.back(), true); s.SQRT_Z(OpDat(1)); s.H_XZ(OpDat(1)); ASSERT_EQ(s.is_deterministic(1), true); s.measure(OpDat(1)); - ASSERT_EQ(s.measurement_record.back(), true); + ASSERT_EQ(s.measurement_record.storage.back(), true); } TEST(TableauSimulator, kickback_vs_stabilizer) { @@ -179,7 +179,7 @@ TEST(TableauSimulator, s_state_distillation_low_depth) { sim.H_XZ(OpDat(anc)); ASSERT_EQ(sim.is_deterministic(anc), false); sim.measure(OpDat(anc)); - bool v = sim.measurement_record.back(); + bool v = sim.measurement_record.storage.back(); if (v) { sim.X(OpDat(anc)); } @@ -191,7 +191,7 @@ TEST(TableauSimulator, s_state_distillation_low_depth) { sim.SQRT_Z(OpDat(k)); sim.H_XZ(OpDat(k)); sim.measure(OpDat(k)); - qubit_measurements.push_back(sim.measurement_record.back()); + qubit_measurements.push_back(sim.measurement_record.storage.back()); } bool sum = false; @@ -209,7 +209,7 @@ TEST(TableauSimulator, s_state_distillation_low_depth) { sim.H_XZ(OpDat(7)); ASSERT_EQ(sim.is_deterministic(7), true); sim.measure(OpDat(7)); - ASSERT_EQ(sim.measurement_record.back(), false); + ASSERT_EQ(sim.measurement_record.storage.back(), false); for (const auto &c : checks) { bool r = false; @@ -255,7 +255,7 @@ TEST(TableauSimulator, s_state_distillation_low_space) { sim.H_XZ(OpDat(anc)); ASSERT_EQ(sim.is_deterministic(anc), false); sim.measure(OpDat(anc)); - bool v = sim.measurement_record.back(); + bool v = sim.measurement_record.storage.back(); if (v) { for (const auto &k : phasor) { sim.X(OpDat(k)); @@ -267,13 +267,13 @@ TEST(TableauSimulator, s_state_distillation_low_space) { for (size_t k = 0; k < 3; k++) { ASSERT_EQ(sim.is_deterministic(k), true); sim.measure(OpDat(k)); - ASSERT_EQ(sim.measurement_record.back(), false); + ASSERT_EQ(sim.measurement_record.storage.back(), false); } sim.SQRT_Z(OpDat(3)); sim.H_XZ(OpDat(3)); ASSERT_EQ(sim.is_deterministic(3), true); sim.measure(OpDat(3)); - ASSERT_EQ(sim.measurement_record.back(), true); + ASSERT_EQ(sim.measurement_record.storage.back(), true); } } @@ -355,16 +355,16 @@ TEST(TableauSimulator, to_vector_sim) { ASSERT_TRUE(sim_tab.to_vector_sim().approximate_equals(sim_vec, true)); } -bool vec_sim_corroborates_measurement_process( - const TableauSimulator &sim, const std::vector &measurement_targets) { - TableauSimulator sim_tab = sim; +bool vec_sim_corroborates_measurement_process(const Tableau &state, const std::vector &measurement_targets) { + TableauSimulator sim_tab(2, SHARED_TEST_RNG()); + sim_tab.inv_state = state; auto vec_sim = sim_tab.to_vector_sim(); sim_tab.measure(OpDat(measurement_targets)); PauliString buf(sim_tab.inv_state.num_qubits); size_t k = 0; for (auto t : measurement_targets) { buf.zs[t] = true; - buf.sign = sim_tab.measurement_record[k++]; + buf.sign = sim_tab.measurement_record.storage[k++]; float f = vec_sim.project(buf); if (fabs(f - 0.5) > 1e-4 && fabsf(f - 1) > 1e-4) { return false; @@ -376,26 +376,23 @@ bool vec_sim_corroborates_measurement_process( TEST(TableauSimulator, measurement_vs_vector_sim) { for (size_t k = 0; k < 10; k++) { - TableauSimulator sim_tab(2, SHARED_TEST_RNG()); - sim_tab.inv_state = Tableau::random(2, SHARED_TEST_RNG()); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {1})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 1})); + Tableau state = Tableau::random(2, SHARED_TEST_RNG()); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {1})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1})); } for (size_t k = 0; k < 10; k++) { - TableauSimulator sim_tab(4, SHARED_TEST_RNG()); - sim_tab.inv_state = Tableau::random(4, SHARED_TEST_RNG()); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 1})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {2, 1})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 1, 2, 3})); + Tableau state = Tableau::random(4, SHARED_TEST_RNG()); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {2, 1})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3})); } { - TableauSimulator sim_tab(12, SHARED_TEST_RNG()); - sim_tab.inv_state = Tableau::random(12, SHARED_TEST_RNG()); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 1, 2, 3})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 10, 11})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {11, 5, 7})); - ASSERT_TRUE(vec_sim_corroborates_measurement_process(sim_tab, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); + Tableau state = Tableau::random(12, SHARED_TEST_RNG()); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 10, 11})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {11, 5, 7})); + ASSERT_TRUE(vec_sim_corroborates_measurement_process(state, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); } } diff --git a/src/stabilizers/pauli_string.pybind.cc b/src/stabilizers/pauli_string.pybind.cc index 9b80adadc..fa0865935 100644 --- a/src/stabilizers/pauli_string.pybind.cc +++ b/src/stabilizers/pauli_string.pybind.cc @@ -330,8 +330,7 @@ void pybind_pauli_string(pybind11::module &m) { int z = self.zs[u]; return (x ^ z) | (z << 1); }, - pybind11::arg("index"), - GET_ITEM_DOC) + pybind11::arg("index"), GET_ITEM_DOC) .def( "__getitem__", [](const PauliString &self, pybind11::slice slice) { @@ -347,11 +346,9 @@ void pybind_pauli_string(pybind11::module &m) { return "_XZY"[self.xs[j] + self.zs[j] * 2]; }); }, - pybind11::arg("slice"), - GET_ITEM_DOC) + pybind11::arg("slice"), GET_ITEM_DOC) .def( - pybind11::init(&PauliString::from_str), - pybind11::arg("text"), + pybind11::init(&PauliString::from_str), pybind11::arg("text"), R"DOC( Creates a stim.PauliString from a text string. diff --git a/src/stabilizers/tableau.perf.cc b/src/stabilizers/tableau.perf.cc index 2eaf9a250..571bfa7b7 100644 --- a/src/stabilizers/tableau.perf.cc +++ b/src/stabilizers/tableau.perf.cc @@ -57,5 +57,5 @@ BENCHMARK(tableau_cnot_10Kqubits) { Tableau t(n); benchmark_go([&]() { t.prepend_ZCX(5, 20); - }).goal_nanos(120); + }).goal_nanos(220); } diff --git a/src/stabilizers/tableau.test.cc b/src/stabilizers/tableau.test.cc index 5042192af..eab6e6f48 100644 --- a/src/stabilizers/tableau.test.cc +++ b/src/stabilizers/tableau.test.cc @@ -702,10 +702,10 @@ TEST(tableau, raised_to) { Tableau s = GATE_DATA.at("S").tableau(); Tableau z = GATE_DATA.at("Z").tableau(); Tableau s_dag = GATE_DATA.at("S_DAG").tableau(); - ASSERT_EQ(s.raised_to(4*-437829 + 0), Tableau(1)); - ASSERT_EQ(s.raised_to(4*-437829 + 1), s); - ASSERT_EQ(s.raised_to(4*-437829 + 2), z); - ASSERT_EQ(s.raised_to(4*-437829 + 3), s_dag); + ASSERT_EQ(s.raised_to(4 * -437829 + 0), Tableau(1)); + ASSERT_EQ(s.raised_to(4 * -437829 + 1), s); + ASSERT_EQ(s.raised_to(4 * -437829 + 2), z); + ASSERT_EQ(s.raised_to(4 * -437829 + 3), s_dag); ASSERT_EQ(s.raised_to(-5), s_dag); ASSERT_EQ(s.raised_to(-4), Tableau(1)); ASSERT_EQ(s.raised_to(-3), s); @@ -719,10 +719,10 @@ TEST(tableau, raised_to) { ASSERT_EQ(s.raised_to(5), s); ASSERT_EQ(s.raised_to(6), z); ASSERT_EQ(s.raised_to(7), s_dag); - ASSERT_EQ(s.raised_to(4*437829 + 0), Tableau(1)); - ASSERT_EQ(s.raised_to(4*437829 + 1), s); - ASSERT_EQ(s.raised_to(4*437829 + 2), z); - ASSERT_EQ(s.raised_to(4*437829 + 3), s_dag); + ASSERT_EQ(s.raised_to(4 * 437829 + 0), Tableau(1)); + ASSERT_EQ(s.raised_to(4 * 437829 + 1), s); + ASSERT_EQ(s.raised_to(4 * 437829 + 2), z); + ASSERT_EQ(s.raised_to(4 * 437829 + 3), s_dag); Tableau p15(3); p15.inplace_scatter_append(GATE_DATA.at("SQRT_X").tableau(), {0}); @@ -733,7 +733,7 @@ TEST(tableau, raised_to) { ASSERT_NE(p15.raised_to(k), Tableau(3)); } ASSERT_EQ(p15.raised_to(15), Tableau(3)); - ASSERT_EQ(p15.raised_to(15*47321 + 4), p15.raised_to(4)); - ASSERT_EQ(p15.raised_to(15*47321 + 1), p15); - ASSERT_EQ(p15.raised_to(15*-47321 + 1), p15); + ASSERT_EQ(p15.raised_to(15 * 47321 + 4), p15.raised_to(4)); + ASSERT_EQ(p15.raised_to(15 * 47321 + 1), p15); + ASSERT_EQ(p15.raised_to(15 * -47321 + 1), p15); }