Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow symbolic scalars in LinearDict #7003

Merged
merged 11 commits into from
Feb 4, 2025
Prev Previous commit
Next Next commit
Revert changes to Scalar, and just use TParamValComplex everywhere. A…
…dd tests.
  • Loading branch information
daxfohl committed Jan 30, 2025
commit cbc54476cf8a7618c5885b4da36167876d0b635e
48 changes: 27 additions & 21 deletions cirq-core/cirq/value/linear_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
)
from typing_extensions import Self

import numpy as np
import sympy
from cirq.value import type_alias
from cirq import protocols

if TYPE_CHECKING:
import cirq

Check warning on line 44 in cirq-core/cirq/value/linear_dict.py

View check run for this annotation

Codecov / codecov/patch

cirq-core/cirq/value/linear_dict.py#L44

Added line #L44 was not covered by tests

Scalar = type_alias.TParamValComplex
Scalar = Union[complex, np.number]
TVector = TypeVar('TVector')

TDefault = TypeVar('TDefault')
Expand All @@ -62,7 +62,7 @@
return super()._print(expr, **kwargs)


def _format_coefficient(format_spec: str, coefficient: Scalar) -> str:
def _format_coefficient(format_spec: str, coefficient: 'cirq.TParamValComplex') -> str:
if isinstance(coefficient, sympy.Basic):
printer = _SympyPrinter(format_spec)
return printer.doprint(coefficient)
Expand All @@ -82,7 +82,7 @@
return f'({real_str}+{imag_str}j)'


def _format_term(format_spec: str, vector: TVector, coefficient: Scalar) -> str:
def _format_term(format_spec: str, vector: TVector, coefficient: 'cirq.TParamValComplex') -> str:
coefficient_str = _format_coefficient(format_spec, coefficient)
if not coefficient_str:
return coefficient_str
Expand All @@ -92,7 +92,7 @@
return '+' + result


def _format_terms(terms: Iterable[Tuple[TVector, Scalar]], format_spec: str):
def _format_terms(terms: Iterable[Tuple[TVector, 'cirq.TParamValComplex']], format_spec: str):
formatted_terms = [_format_term(format_spec, vector, coeff) for vector, coeff in terms]
s = ''.join(formatted_terms)
if not s:
Expand All @@ -102,7 +102,7 @@
return s


class LinearDict(Generic[TVector], MutableMapping[TVector, Scalar]):
class LinearDict(Generic[TVector], MutableMapping[TVector, 'cirq.TParamValComplex']):
"""Represents linear combination of things.

LinearDict implements the basic linear algebraic operations of vector
Expand All @@ -119,7 +119,7 @@

def __init__(
self,
terms: Optional[Mapping[TVector, Scalar]] = None,
terms: Optional[Mapping[TVector, 'cirq.TParamValComplex']] = None,
validator: Optional[Callable[[TVector], bool]] = None,
) -> None:
"""Initializes linear combination from a collection of terms.
Expand All @@ -135,7 +135,7 @@
"""
self._has_validator = validator is not None
self._is_valid = validator or (lambda x: True)
self._terms: Dict[TVector, Scalar] = {}
self._terms: Dict[TVector, 'cirq.TParamValComplex'] = {}
if terms is not None:
self.update(terms)

Expand Down Expand Up @@ -171,25 +171,31 @@
snapshot = self.copy().clean(atol=0)
return snapshot._terms.keys()

def values(self) -> ValuesView[Scalar]:
def values(self) -> ValuesView['cirq.TParamValComplex']:
snapshot = self.copy().clean(atol=0)
return snapshot._terms.values()

def items(self) -> ItemsView[TVector, Scalar]:
def items(self) -> ItemsView[TVector, 'cirq.TParamValComplex']:
snapshot = self.copy().clean(atol=0)
return snapshot._terms.items()

# pylint: disable=function-redefined
@overload
def update(self, other: Mapping[TVector, Scalar], **kwargs: Scalar) -> None:
def update(
self, other: Mapping[TVector, 'cirq.TParamValComplex'], **kwargs: 'cirq.TParamValComplex'
) -> None:
pass

@overload
def update(self, other: Iterable[Tuple[TVector, Scalar]], **kwargs: Scalar) -> None:
def update(
self,
other: Iterable[Tuple[TVector, 'cirq.TParamValComplex']],
**kwargs: 'cirq.TParamValComplex',
) -> None:
pass

@overload
def update(self, *args: Any, **kwargs: Scalar) -> None:
def update(self, *args: Any, **kwargs: 'cirq.TParamValComplex') -> None:
pass

def update(self, *args, **kwargs):
Expand All @@ -204,11 +210,11 @@
self.clean(atol=0)

@overload
def get(self, vector: TVector) -> Scalar:
def get(self, vector: TVector) -> 'cirq.TParamValComplex':
pass

@overload
def get(self, vector: TVector, default: TDefault) -> Union[Scalar, TDefault]:
def get(self, vector: TVector, default: TDefault) -> Union['cirq.TParamValComplex', TDefault]:
pass

def get(self, vector, default=0):
Expand All @@ -221,10 +227,10 @@
def __contains__(self, vector: Any) -> bool:
return vector in self._terms and self._terms[vector] != 0

def __getitem__(self, vector: TVector) -> Scalar:
def __getitem__(self, vector: TVector) -> 'cirq.TParamValComplex':
return self._terms.get(vector, 0)

def __setitem__(self, vector: TVector, coefficient: Scalar) -> None:
def __setitem__(self, vector: TVector, coefficient: 'cirq.TParamValComplex') -> None:
self._check_vector_valid(vector)
if coefficient != 0:
self._terms[vector] = coefficient
Expand Down Expand Up @@ -272,21 +278,21 @@
factory = type(self)
return factory({v: -c for v, c in self.items()})

def __imul__(self, a: Scalar) -> Self:
def __imul__(self, a: 'cirq.TParamValComplex') -> Self:
for vector in self:
self._terms[vector] *= a
self.clean(atol=0)
return self

def __mul__(self, a: Scalar) -> Self:
def __mul__(self, a: 'cirq.TParamValComplex') -> Self:
result = self.copy()
result *= a
return result.copy()

def __rmul__(self, a: Scalar) -> Self:
def __rmul__(self, a: 'cirq.TParamValComplex') -> Self:
return self.__mul__(a)

def __truediv__(self, a: Scalar) -> Self:
def __truediv__(self, a: 'cirq.TParamValComplex') -> Self:
return self.__mul__(1 / a)

def __bool__(self) -> bool:
Expand Down
31 changes: 31 additions & 0 deletions cirq-core/cirq/value/linear_dict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,34 @@ def test_repr_pretty(terms):
def test_json_fails_with_validator():
with pytest.raises(ValueError, match='not json serializable'):
_ = cirq.to_json(cirq.LinearDict({}, validator=lambda: True))


@pytest.mark.parametrize(
'terms, names',
(
({'X': sym}, {'sym'}),
({'X': sym * sympy.Symbol('a')}, {'sym', 'a'}),
({'X': expr}, {'sym'}),
({'X': sym, 'Y': sympy.Symbol('a')}, {'sym', 'a'}),
({'X': symval}, set()),
),
)
def test_parameter_names(terms, names):
linear_dict = cirq.LinearDict(terms)
assert cirq.parameter_names(linear_dict) == names


@pytest.mark.parametrize(
'terms, expected',
(
({'X': sym}, {'X': 2}),
({'X': sym * sympy.Symbol('a')}, {'X': 6}),
({'X': expr}, {'X': -4 - 6j}),
({'X': sym, 'Y': sympy.Symbol('a')}, {'X': 2, 'Y': 3}),
({'X': symval}, {'X': symvalresolved}),
),
)
def test_resolve_parameters(terms, expected):
linear_dict = cirq.LinearDict(terms)
expected_dict = cirq.LinearDict(expected)
assert cirq.resolve_parameters(linear_dict, {'sym': 2, 'a': 3}) == expected_dict