Skip to content

Commit

Permalink
Divert cirq.Sampler's run_batch to run_batch_async (#6813)
Browse files Browse the repository at this point in the history
* run_batch diverts to run_batch_async

* fix typecheck
  • Loading branch information
senecameeks authored Dec 2, 2024
1 parent 429bb9b commit 495d913
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
22 changes: 4 additions & 18 deletions cirq-core/cirq/work/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ async def run_sweep_async(
"""
raise NotImplementedError

def run_batch(
async def run_batch_async(
self,
programs: Sequence['cirq.AbstractCircuit'],
params_list: Optional[Sequence['cirq.Sweepable']] = None,
repetitions: Union[int, Sequence[int]] = 1,
) -> Sequence[Sequence['cirq.Result']]:
"""Runs the supplied circuits.
"""Runs the supplied circuits asynchronously.
Each circuit provided in `programs` will pair with the optional
associated parameter sweep provided in the `params_list`, and be run
Expand Down Expand Up @@ -281,26 +281,12 @@ def run_batch(
of `params_list` or the length of `repetitions`.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
return [
self.run_sweep(circuit, params=params, repetitions=repetitions)
for circuit, params, repetitions in zip(programs, params_list, repetitions)
]

async def run_batch_async(
self,
programs: Sequence['cirq.AbstractCircuit'],
params_list: Optional[Sequence['cirq.Sweepable']] = None,
repetitions: Union[int, Sequence[int]] = 1,
) -> Sequence[Sequence['cirq.Result']]:
"""Runs the supplied circuits asynchronously.
See docs for `cirq.Sampler.run_batch`.
"""
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
return await duet.pstarmap_async(
self.run_sweep_async, zip(programs, params_list, repetitions)
)

run_batch = duet.sync(run_batch_async)

def _normalize_batch_args(
self,
programs: Sequence['cirq.AbstractCircuit'],
Expand Down
7 changes: 5 additions & 2 deletions cirq-google/cirq_google/engine/validating_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Callable, Optional, Sequence, Union

import cirq
import duet

VALIDATOR_TYPE = Callable[
[Sequence[cirq.AbstractCircuit], Sequence[cirq.Sweepable], Union[int, Sequence[int]]], None
Expand Down Expand Up @@ -64,12 +65,14 @@ def run_sweep(
self._validate_circuit([program], [params], repetitions)
return self._sampler.run_sweep(program, params, repetitions)

def run_batch(
async def run_batch_async(
self,
programs: Sequence[cirq.AbstractCircuit],
params_list: Optional[Sequence[cirq.Sweepable]] = None,
repetitions: Union[int, Sequence[int]] = 1,
) -> Sequence[Sequence[cirq.Result]]:
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
self._validate_circuit(programs, params_list, repetitions)
return self._sampler.run_batch(programs, params_list, repetitions)
return await self._sampler.run_batch_async(programs, params_list, repetitions)

run_batch = duet.sync(run_batch_async)

0 comments on commit 495d913

Please sign in to comment.