Skip to content

Commit

Permalink
add request limiter when polling for results (#6774)
Browse files Browse the repository at this point in the history
* add limiter

* throttle in-flight requests

* review comments

* typecheck

* fix

* review comments

* lint
  • Loading branch information
senecameeks authored Oct 21, 2024
1 parent 351a08e commit 81f66b9
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
50 changes: 46 additions & 4 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,24 @@
"""Abstract base class for things sampling quantum circuits."""

import collections
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from itertools import islice
from typing import (
Dict,
FrozenSet,
Iterator,
List,
Optional,
Sequence,
Tuple,
TypeVar,
TYPE_CHECKING,
Union,
)

import duet
import pandas as pd


from cirq import ops, protocols, study, value
from cirq.work.observable_measurement import (
measure_observables,
Expand All @@ -30,10 +43,17 @@
if TYPE_CHECKING:
import cirq

T = TypeVar('T')


class Sampler(metaclass=value.ABCMetaImplementAnyOneOf):
"""Something capable of sampling quantum circuits. Simulator or hardware."""

# Users have a rate limit of 1000 QPM for read/write requests to
# the Quantum Engine. 1000/60 ~= 16 QPS. So requests are sent
# in chunks of size 16 per second.
CHUNK_SIZE: int = 16

def run(
self,
program: 'cirq.AbstractCircuit',
Expand Down Expand Up @@ -294,9 +314,26 @@ async def run_batch_async(
See docs for `cirq.Sampler.run_batch`.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
return await duet.pstarmap_async(
self.run_sweep_async, zip(programs, params_list, repetitions)
)
if len(programs) <= self.CHUNK_SIZE:
return await duet.pstarmap_async(
self.run_sweep_async, zip(programs, params_list, repetitions)
)

results = []
for program_chunk, params_chunk, reps_chunk in zip(
_chunked(programs, self.CHUNK_SIZE),
_chunked(params_list, self.CHUNK_SIZE),
_chunked(repetitions, self.CHUNK_SIZE),
):
# Run_sweep_async for the current chunk
await duet.sleep(1) # Delay for 1 second between chunk
results.extend(
await duet.pstarmap_async(
self.run_sweep_async, zip(program_chunk, params_chunk, reps_chunk)
)
)

return results

def _normalize_batch_args(
self,
Expand Down Expand Up @@ -449,3 +486,8 @@ def _get_measurement_shapes(
)
num_instances[key] += 1
return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()}


def _chunked(iterable: Sequence[T], n: int) -> Iterator[tuple[T, ...]]:
it = iter(iterable)
return iter(lambda: tuple(islice(it, n)), ())
50 changes: 50 additions & 0 deletions cirq-core/cirq/work/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for cirq.Sampler."""
from typing import Sequence
from unittest import mock

import pytest

Expand Down Expand Up @@ -266,6 +267,55 @@ def test_sampler_run_batch_bad_input_lengths():
)


@mock.patch('duet.pstarmap_async')
@pytest.mark.parametrize('call_count', [1, 2, 3])
@duet.sync
async def test_run_batch_async_sends_circuits_in_chunks(spy, call_count):
class AsyncSampler(cirq.Sampler):
CHUNK_SIZE = 3

async def run_sweep_async(self, _, params, __: int = 1):
pass # pragma: no cover

sampler = AsyncSampler()
a = cirq.LineQubit(0)
circuit_list = [cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))] * (
sampler.CHUNK_SIZE * call_count
)
param_list = [cirq.Points('t', [0.3, 0.7])] * (sampler.CHUNK_SIZE * call_count)

await sampler.run_batch_async(circuit_list, params_list=param_list)

assert spy.call_count == call_count


@pytest.mark.parametrize('call_count', [1, 2, 3])
@duet.sync
async def test_run_batch_async_runs_runs_sequentially(call_count):
a = cirq.LineQubit(0)
finished = []
circuit1 = cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
circuit2 = cirq.Circuit(cirq.Y(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
params1 = cirq.Points('t', [0.3, 0.7])
params2 = cirq.Points('t', [0.4, 0.6])

class AsyncSampler(cirq.Sampler):
CHUNK_SIZE = 1

async def run_sweep_async(self, _, params, __: int = 1):
if params == params1:
await duet.sleep(0.001)

finished.append(params)

sampler = AsyncSampler()
circuit_list = [circuit1, circuit2] * call_count
param_list = [params1, params2] * call_count
await sampler.run_batch_async(circuit_list, params_list=param_list)

assert finished == param_list


def test_sampler_simple_sample_expectation_values():
a = cirq.LineQubit(0)
sampler = cirq.Simulator()
Expand Down

0 comments on commit 81f66b9

Please sign in to comment.