Skip to content

Commit

Permalink
Add integration tests for time dependent pulse templates
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Dec 5, 2022
1 parent 775d459 commit c0d2a54
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 20 deletions.
17 changes: 16 additions & 1 deletion tests/pulses/arithmetic_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from qupulse.parameter_scope import DictScope
from qupulse.expressions import ExpressionScalar
from qupulse.pulses import MappingPT
from qupulse.pulses import MappingPT, ConstantPT, RepetitionPT
from qupulse.pulses.plotting import render
from qupulse.pulses.arithmetic_pulse_template import ArithmeticAtomicPulseTemplate, ArithmeticPulseTemplate,\
ImplicitAtomicityInArithmeticPT, UnequalDurationWarningInArithmeticPT, try_operation
from qupulse._program.waveforms import TransformingWaveform
Expand Down Expand Up @@ -585,6 +586,20 @@ def test_repr(self):
arith = ArithmeticPulseTemplate(pt, '-', scalar, identifier='id')
self.assertEqual(super(ArithmeticPulseTemplate, arith).__repr__(), repr(arith))

def test_time_dependence(self):
inner = ConstantPT(1.4, {'a': ExpressionScalar('x'), 'b': 1.1})
with self.assertRaises(TypeError):
ArithmeticPulseTemplate(RepetitionPT(inner, 3), '*', {'a': 'sin(t)', 'b': 'cos(t)'})

pc = ArithmeticPulseTemplate(inner, '*', {'a': 'sin(t)', 'b': 'cos(t)'})
prog = pc.create_program(parameters={'x': -1})
t, vals, _ = render(prog, sample_rate=10)
expected_values = {
'a': -np.sin(t),
'b': 1.1 * np.cos(t)
}
np.testing.assert_equal(expected_values, vals)


class ArithmeticUsageTests(unittest.TestCase):
def setUp(self) -> None:
Expand Down
55 changes: 36 additions & 19 deletions tests/pulses/multi_channel_pulse_template_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from unittest import mock

import numpy
import numpy as np

from qupulse.parameter_scope import DictScope
from qupulse.pulses import RepetitionPT
from qupulse.pulses import RepetitionPT, ConstantPT
from qupulse.pulses.plotting import render
from qupulse.pulses.multi_channel_pulse_template import MultiChannelWaveform, MappingPulseTemplate,\
ChannelMappingException, AtomicMultiChannelPulseTemplate, ParallelConstantChannelPulseTemplate,\
ChannelMappingException, AtomicMultiChannelPulseTemplate, ParallelChannelPulseTemplate,\
TransformingWaveform, ParallelChannelTransformation
from qupulse.pulses.parameters import ParameterConstraint, ParameterConstraintViolation, ConstantParameter
from qupulse.expressions import ExpressionScalar, Expression
Expand Down Expand Up @@ -343,14 +345,14 @@ def serialize_callback(obj) -> str:
self.assertEqual(expected_data, data)


class ParallelConstantChannelPulseTemplateTests(unittest.TestCase):
class ParallelChannelPulseTemplateTests(unittest.TestCase):
def test_init(self):
template = DummyPulseTemplate(duration='t1', defined_channels={'X', 'Y'}, parameter_names={'a', 'b'}, measurement_names={'M'})
overwritten_channels = {'Y': 'c', 'Z': 'a'}

expected_overwritten_channels = {'Y': ExpressionScalar('c'), 'Z': ExpressionScalar('a')}

pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels)
pccpt = ParallelChannelPulseTemplate(template, overwritten_channels)
self.assertIs(template, pccpt.template)
self.assertEqual(expected_overwritten_channels, pccpt.overwritten_channels)

Expand All @@ -361,15 +363,15 @@ def test_init(self):
self.assertIs(template.duration, pccpt.duration)

non_atomic_pt = RepetitionPT(template, 5)
ParallelConstantChannelPulseTemplate(non_atomic_pt, overwritten_channels)
ParallelChannelPulseTemplate(non_atomic_pt, overwritten_channels)
with self.assertRaises(TypeError):
overwritten_channels['T'] = 'a * t'
ParallelConstantChannelPulseTemplate(non_atomic_pt, overwritten_channels)
ParallelChannelPulseTemplate(non_atomic_pt, overwritten_channels)

ParallelConstantChannelPulseTemplate(template, overwritten_channels)
ParallelChannelPulseTemplate(template, overwritten_channels)

def test_missing_implementations(self):
pccpt = ParallelConstantChannelPulseTemplate(DummyPulseTemplate(), {})
pccpt = ParallelChannelPulseTemplate(DummyPulseTemplate(), {})
with self.assertRaises(NotImplementedError):
pccpt.get_serialization_data(object())

Expand All @@ -378,7 +380,7 @@ def test_integral(self):
measurement_names={'M'},
integrals={'X': ExpressionScalar('a'), 'Y': ExpressionScalar(4)})
overwritten_channels = {'Y': 'c', 'Z': 'a'}
pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels)
pccpt = ParallelChannelPulseTemplate(template, overwritten_channels)

expected_integral = {'X': ExpressionScalar('a'),
'Y': ExpressionScalar('c*t1'),
Expand All @@ -387,12 +389,12 @@ def test_integral(self):

def test_initial_values(self):
dpt = DummyPulseTemplate(initial_values={'A': 'a', 'B': 'b'})
par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'})
par = ParallelChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'})
self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.initial_values)

def test_final_values(self):
dpt = DummyPulseTemplate(final_values={'A': 'a', 'B': 'b'})
par = ParallelConstantChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'})
par = ParallelChannelPulseTemplate(dpt, {'B': 'b2', 'C': 'c'})
self.assertEqual({'A': 'a', 'B': 'b2', 'C': 'c'}, par.final_values)

def test_get_overwritten_channels_values(self):
Expand All @@ -402,7 +404,7 @@ def test_get_overwritten_channels_values(self):
channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z', 'ToNone': None}
expected_overwritten_channel_values = {'K': 1.2, 'Z': 3.4}

pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels)
pccpt = ParallelChannelPulseTemplate(template, overwritten_channels)

real_parameters = {'c': 1.2, 'a': 3.4}
self.assertEqual(expected_overwritten_channel_values, pccpt._get_overwritten_channels_values(real_parameters,
Expand All @@ -422,7 +424,7 @@ def test_internal_create_program(self):
channel_mapping=channel_mapping,
to_single_waveform=to_single_waveform,
parent_loop=parent_loop)
pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels)
pccpt = ParallelChannelPulseTemplate(template, overwritten_channels)

scope = DictScope.from_kwargs(c=1.2, a=3.4)
kwargs = {**other_kwargs, 'scope': scope, 'global_transformation': None}
Expand All @@ -449,7 +451,7 @@ def test_build_waveform(self):
measurement_names={'M'}, waveform=DummyWaveform())
overwritten_channels = {'Y': 'c', 'Z': 'a'}
channel_mapping = {'X': 'X', 'Y': 'K', 'Z': 'Z'}
pccpt = ParallelConstantChannelPulseTemplate(template, overwritten_channels)
pccpt = ParallelChannelPulseTemplate(template, overwritten_channels)

parameters = {'c': 1.2, 'a': 3.4}
expected_overwritten_channels = {'K': 1.2, 'Z': 3.4}
Expand All @@ -465,12 +467,27 @@ def test_build_waveform(self):
resulting_waveform = pccpt.build_waveform(parameters.copy(), channel_mapping.copy())
self.assertEqual(None, resulting_waveform)
self.assertEqual([(parameters, channel_mapping), (parameters, channel_mapping)], template.build_waveform_calls)

def test_time_dependence(self):
inner = ConstantPT(1.4, {'a': ExpressionScalar('x'), 'b': 1.})
with self.assertRaises(TypeError):
ParallelChannelPulseTemplate(RepetitionPT(inner, 3), {'c': 'sin(t)'})

pc = ParallelChannelPulseTemplate(inner, {'c': 'sin(t)'})
prog = pc.create_program(parameters={'x': -1})
t, vals, _ = render(prog, sample_rate=10)
expected_values = {
'a': np.broadcast_to(-1, t.shape),
'b': np.broadcast_to(1., t.shape),
'c': np.sin(t)
}
np.testing.assert_equal(expected_values, vals)


class ParallelConstantChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase):
class ParallelChannelPulseTemplateSerializationTests(SerializableTests, unittest.TestCase):
@property
def class_to_test(self):
return ParallelConstantChannelPulseTemplate
return ParallelChannelPulseTemplate

@staticmethod
def make_kwargs(*args, **kwargs):
Expand All @@ -479,9 +496,9 @@ def make_kwargs(*args, **kwargs):
'overwritten_channels': {'Y': 'c', 'Z': 'a'}
}

def assert_equal_instance_except_id(self, lhs: ParallelConstantChannelPulseTemplate, rhs: ParallelConstantChannelPulseTemplate):
self.assertIsInstance(lhs, ParallelConstantChannelPulseTemplate)
self.assertIsInstance(rhs, ParallelConstantChannelPulseTemplate)
def assert_equal_instance_except_id(self, lhs: ParallelChannelPulseTemplate, rhs: ParallelChannelPulseTemplate):
self.assertIsInstance(lhs, ParallelChannelPulseTemplate)
self.assertIsInstance(rhs, ParallelChannelPulseTemplate)
self.assertEqual(lhs.template, rhs.template)
self.assertEqual(lhs.overwritten_channels, rhs.overwritten_channels)

Expand Down

0 comments on commit c0d2a54

Please sign in to comment.