Skip to content

Commit 1b65798

Browse files
committed
Fixed device signatures, updated tests.
1 parent 8224add commit 1b65798

6 files changed

+89
-50
lines changed

c3/generator/devices.py

+64-32
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tempfile
33
import warnings
44
import hjson
5-
from typing import Callable, Dict, Any
5+
from typing import Callable, Dict, Any, List
66
import tensorflow as tf
77
import numpy as np
88
from c3.signal.pulse import Envelope, Carrier
@@ -26,7 +26,7 @@ class Device(C3obj):
2626
2727
Parameters
2828
----------
29-
resolution: np.float64
29+
resolution: float
3030
Number of samples per second this device operates at.
3131
"""
3232

@@ -69,32 +69,32 @@ def asdict(self) -> Dict[str, Any]:
6969
def __str__(self) -> str:
7070
return hjson.dumps(self.asdict(), default=hjson_encode)
7171

72-
def calc_slice_num(self, t_start: np.float64, t_end: np.float64) -> None:
72+
def calc_slice_num(self, t_start: float = 0.0, t_end: float = 0.0) -> None:
7373
"""
7474
Effective number of time slices given start, end and resolution.
7575
7676
Parameters
7777
----------
78-
t_start: np.float64
78+
t_start: float
7979
Starting time for this device.
80-
t_end: np.float64
80+
t_end: float
8181
End time for this device.
8282
"""
8383
res = self.resolution
8484
self.slice_num = int(np.abs(t_start - t_end) * res)
8585
# return self.slice_num
8686

8787
def create_ts(
88-
self, t_start: np.float64, t_end: np.float64, centered: bool = True
88+
self, t_start: float = 0, t_end: float = 0, centered: bool = True
8989
) -> tf.constant:
9090
"""
9191
Compute time samples.
9292
9393
Parameters
9494
----------
95-
t_start: np.float64
95+
t_start: float
9696
Starting time for this device.
97-
t_end: np.float64
97+
t_end: float
9898
End time for this device.
9999
centered: boolean
100100
Sample in the middle of an interval, otherwise at the beginning.
@@ -121,6 +121,31 @@ def create_ts(
121121
ts = tf.linspace(t_start, t_end, num)
122122
return ts
123123

124+
def process(
125+
self, instr: Instruction, chan: str, signals: List[Dict[str, Any]]
126+
) -> Dict[str, Any]:
127+
"""To be implemented by inheriting class.
128+
129+
Parameters
130+
----------
131+
instr : Instruction
132+
Information about the instruction or gate to be excecuted.
133+
chan : str
134+
Identifier of the drive line
135+
signals : List[Dict[str, Any]]
136+
List of potentially multiple input signals to this device.
137+
138+
Returns
139+
-------
140+
Dict[str, Any]
141+
Output signal of this device.
142+
143+
Raises
144+
------
145+
NotImplementedError
146+
"""
147+
raise NotImplementedError()
148+
124149

125150
@dev_reg_deco
126151
class Readout(Device):
@@ -176,7 +201,7 @@ def __init__(self, **props):
176201
self.outputs = props.pop("outputs", 1)
177202

178203
def process(
179-
self, instr: Instruction, chan: str, mixed_signal: Dict[str, Any]
204+
self, instr: Instruction, chan: str, mixed_signal: List[Dict[str, Any]]
180205
) -> Dict[str, Any]:
181206
"""Transform signal from value of V to Hz.
182207
@@ -191,8 +216,8 @@ def process(
191216
Waveform as control amplitudes
192217
"""
193218
v2hz = self.params["V_to_Hz"].get_value()
194-
self.signal["values"] = mixed_signal["values"] * v2hz
195-
self.signal["ts"] = mixed_signal["ts"]
219+
self.signal["values"] = mixed_signal[0]["values"] * v2hz
220+
self.signal["ts"] = mixed_signal[0]["ts"]
196221
return self.signal
197222

198223

@@ -202,7 +227,7 @@ class Crosstalk(Device):
202227
Device to phenomenologically include crosstalk in the model by explicitly mixing
203228
drive lines.
204229
205-
Parameters
230+
Parameters^
206231
----------
207232
208233
crosstalk_matrix: tf.constant
@@ -234,7 +259,9 @@ def __init__(self, **props):
234259
self.outputs = props.pop("outputs", 1)
235260
self.params["crosstalk_matrix"] = props.pop("crosstalk_matrix", None)
236261

237-
def process(self, signal: Dict[str, Any]) -> Dict[str, Any]:
262+
def process(
263+
self, instr: Instruction, chan: str, signals: List[Dict[str, Any]]
264+
) -> Dict[str, Any]:
238265
"""
239266
Mix channels in the input signal according to a crosstalk matrix.
240267
@@ -258,11 +285,11 @@ def process(self, signal: Dict[str, Any]) -> Dict[str, Any]:
258285
259286
"""
260287
xtalk = self.params["crosstalk_matrix"]
261-
signals = [signal[ch]["values"] for ch in self.crossed_channels]
262-
crossed_signals = xtalk.get_value() @ signals
288+
signal = [signals[0][ch]["values"] for ch in self.crossed_channels]
289+
crossed_signals = xtalk.get_value() @ signal
263290
for indx, ch in enumerate(self.crossed_channels):
264-
signal[ch]["values"] = crossed_signals[indx]
265-
return signal
291+
signals[0][ch]["values"] = crossed_signals[indx]
292+
return signals[0]
266293

267294

268295
@dev_reg_deco
@@ -277,7 +304,7 @@ def __init__(self, **props):
277304
self.sampling_method = props.pop("sampling_method", "nearest")
278305

279306
def process(
280-
self, instr: Instruction, chan: str, awg_signal: Dict[str, Any]
307+
self, instr: Instruction, chan: str, awg_signal: List[Dict[str, Any]]
281308
) -> Dict[str, Any]:
282309
"""Resample the awg values to higher resolution.
283310
@@ -297,12 +324,12 @@ def process(
297324
Inphase and Quadrature compontent of the upsampled signal.
298325
"""
299326
ts = self.create_ts(instr.t_start, instr.t_end, centered=True)
300-
old_dim = awg_signal["inphase"].shape[0]
327+
old_dim = awg_signal[0]["inphase"].shape[0]
301328
new_dim = ts.shape[0]
302329
# TODO add following zeros
303330
inphase = tf.reshape(
304331
tf.image.resize(
305-
tf.reshape(awg_signal["inphase"], shape=[1, old_dim, 1]),
332+
tf.reshape(awg_signal[0]["inphase"], shape=[1, old_dim, 1]),
306333
size=[1, new_dim],
307334
method=self.sampling_method,
308335
),
@@ -311,7 +338,7 @@ def process(
311338
inphase = tf.cast(inphase, tf.float64)
312339
quadrature = tf.reshape(
313340
tf.image.resize(
314-
tf.reshape(awg_signal["quadrature"], shape=[1, old_dim, 1]),
341+
tf.reshape(awg_signal[0]["quadrature"], shape=[1, old_dim, 1]),
315342
size=[1, new_dim],
316343
method=self.sampling_method,
317344
),
@@ -335,7 +362,7 @@ def __init__(self, **props):
335362
# super().__init__(**props)
336363

337364
def process(
338-
self, instr: Instruction, chan: str, Hz_signal: Dict[str, Any]
365+
self, instr: Instruction, chan: str, Hz_signal: List[Dict[str, Any]]
339366
) -> Dict[str, Any]:
340367
"""Apply a filter function to the signal."""
341368
raise Exception("C3:ERROR Not yet implemented.")
@@ -560,9 +587,11 @@ def process(self, instr, chan, iq_signal):
560587
offset = tf.exp(-((-1 - cen) ** 2) / (2 * sigma * sigma))
561588
# TODO make sure ratio of risetime and resolution is an integer
562589
risefun = gauss - offset
563-
inphase = self.convolve(iq_signal["inphase"], risefun / tf.reduce_sum(risefun))
590+
inphase = self.convolve(
591+
iq_signal[0]["inphase"], risefun / tf.reduce_sum(risefun)
592+
)
564593
quadrature = self.convolve(
565-
iq_signal["quadrature"], risefun / tf.reduce_sum(risefun)
594+
iq_signal[0]["quadrature"], risefun / tf.reduce_sum(risefun)
566595
)
567596
self.signal = {
568597
"inphase": inphase,
@@ -602,7 +631,7 @@ def process(self, instr, chan, iq_signal):
602631
Bandwidth limited IQ signal.
603632
604633
"""
605-
res_diff = (iq_signal["ts"][1] - iq_signal["ts"][0]) / self.resolution - 1
634+
res_diff = (iq_signal[0]["ts"][1] - iq_signal[0]["ts"][0]) / self.resolution - 1
606635
if res_diff > 1e-8:
607636
raise Exception(
608637
"C3:Error:Actual time resolution differs from desired by {res_diff:1.3g}."
@@ -621,17 +650,17 @@ def process(self, instr, chan, iq_signal):
621650
offset = tf.exp(-((-1 - cen) ** 2) / (2 * sigma * sigma))
622651

623652
risefun = gauss - offset
624-
inphase = tf_convolve(iq_signal["inphase"], risefun / tf.reduce_sum(risefun))
653+
inphase = tf_convolve(iq_signal[0]["inphase"], risefun / tf.reduce_sum(risefun))
625654
quadrature = tf_convolve(
626-
iq_signal["quadrature"], risefun / tf.reduce_sum(risefun)
655+
iq_signal[0]["quadrature"], risefun / tf.reduce_sum(risefun)
627656
)
628657

629658
inphase = tf.math.real(inphase)
630659
quadrature = tf.math.real(quadrature)
631660
self.signal = {
632661
"inphase": inphase,
633662
"quadrature": quadrature,
634-
"ts": iq_signal["ts"],
663+
"ts": iq_signal[0]["ts"],
635664
}
636665
return self.signal
637666

@@ -845,7 +874,7 @@ def __init__(self, **props):
845874
self.inputs = props.pop("inputs", 2)
846875
self.outputs = props.pop("outputs", 1)
847876

848-
def process(self, instr: Instruction, chan: str, in1: dict, in2: dict):
877+
def process(self, instr: Instruction, chan: str, inputs: List[Dict[str, Any]]):
849878
"""Combine signal from AWG and LO.
850879
851880
Parameters
@@ -860,6 +889,7 @@ def process(self, instr: Instruction, chan: str, in1: dict, in2: dict):
860889
dict
861890
Mixed signal.
862891
"""
892+
in1, in2 = inputs
863893
i1 = in1["inphase"]
864894
q1 = in1["quadrature"]
865895
i2 = in2["inphase"]
@@ -1004,7 +1034,9 @@ def __init__(self, **props):
10041034
self.freq_noise = props.pop("freq_noise", 0)
10051035
self.amp_noise = props.pop("amp_noise", 0)
10061036

1007-
def process(self, instr: Instruction, chan: str) -> dict:
1037+
def process(
1038+
self, instr: Instruction, chan: str, signal: List[Dict[str, Any]]
1039+
) -> dict:
10081040
# TODO check somewhere that there is only 1 carrier per instruction
10091041
ts = self.create_ts(instr.t_start, instr.t_end, centered=True)
10101042
dt = ts[1] - ts[0]
@@ -1099,7 +1131,7 @@ def asdict(self) -> dict:
10991131
# TODO create DC function
11001132

11011133
# TODO make AWG take offset from the previous point
1102-
def create_IQ(self, instr: Instruction, chan: str) -> dict:
1134+
def create_IQ(self, instr: Instruction, chan: str, inputs) -> dict:
11031135
"""
11041136
Construct the in-phase (I) and quadrature (Q) components of the signal.
11051137
These are universal to either experiment or simulation.
@@ -1137,7 +1169,7 @@ def create_IQ(self, instr: Instruction, chan: str) -> dict:
11371169
}
11381170
return self.signal[chan]
11391171

1140-
def create_IQ_pwc(self, instr: Instruction, chan: str) -> dict:
1172+
def create_IQ_pwc(self, instr: Instruction, chan: str, inputs) -> dict:
11411173
"""
11421174
Construct the in-phase (I) and quadrature (Q) components of the signal.
11431175
These are universal to either experiment or simulation.

c3/generator/generator.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111

1212
from typing import List, Callable, Dict
1313
import hjson
14-
import numpy as np
1514
import tensorflow as tf
1615
from c3.c3objs import hjson_decode, hjson_encode
1716
from c3.signal.gates import Instruction
1817
from c3.generator.devices import devices as dev_lib
18+
from c3.generator.devices import Device
1919

2020

2121
class Generator:
@@ -37,10 +37,10 @@ def __init__(
3737
self,
3838
devices: dict = None,
3939
chains: dict = None,
40-
resolution: np.float64 = 0.0,
40+
resolution: float = 0.0,
4141
callback: Callable = None,
4242
):
43-
self.devices = {}
43+
self.devices: Dict[str, Device] = {}
4444
if devices:
4545
self.devices = devices
4646
self.chains = {}
@@ -199,7 +199,7 @@ def generate_signals(self, instr: Instruction) -> dict:
199199

200200
# calculate the output and store it in the stack
201201
dev = self.devices[dev_id]
202-
output = dev.process(instr, chan, *inputs)
202+
output = dev.process(instr, chan, inputs)
203203
signal_stack[chan][dev_id] = output
204204

205205
# remove inputs if they are not needed anymore
@@ -219,5 +219,7 @@ def generate_signals(self, instr: Instruction) -> dict:
219219
# Hack to use crosstalk. Will be generalized to a post-processing module.
220220
# TODO: Rework of the signal generation for larger chips, similar to qiskit
221221
if "crosstalk" in self.devices:
222-
gen_signal = self.devices["crosstalk"].process(signal=gen_signal)
222+
gen_signal = self.devices["crosstalk"].process(
223+
None, None, signals=[gen_signal]
224+
)
223225
return gen_signal

test/test_crosstalk.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
@pytest.mark.unit
77
def test_crosstalk(get_xtalk_pmap, get_test_signal) -> None:
88
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
9-
new_sig = xtalk.process(signal=get_test_signal)
9+
new_sig = xtalk.process(None, None, signals=[get_test_signal])
1010
assert new_sig == get_test_signal
1111

1212

1313
@pytest.mark.unit
1414
def test_crosstalk_flip(get_xtalk_pmap, get_test_signal) -> None:
1515
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
1616
xtalk.params["crosstalk_matrix"].set_value([[0, 1], [1, 0]])
17-
new_sig = xtalk.process(signal=get_test_signal)
17+
new_sig = xtalk.process(None, None, signals=[get_test_signal])
1818
assert (new_sig["TC2"]["values"].numpy() == np.linspace(0, 100, 101)).all()
1919
assert (new_sig["TC1"]["values"].numpy() == np.linspace(100, 200, 101)).all()
2020

@@ -23,7 +23,7 @@ def test_crosstalk_flip(get_xtalk_pmap, get_test_signal) -> None:
2323
def test_crosstalk_mix(get_xtalk_pmap, get_test_signal) -> None:
2424
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
2525
xtalk.params["crosstalk_matrix"].set_value([[0.5, 0.5], [0.5, 0.5]])
26-
new_sig = xtalk.process(signal=get_test_signal)
26+
new_sig = xtalk.process(None, None, signals=[get_test_signal])
2727
assert (new_sig["TC2"]["values"].numpy() == new_sig["TC1"]["values"].numpy()).all()
2828

2929

test/test_envelope.py

+3
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ def test_envelope_netzero() -> None:
4747
shape.numpy()
4848
== np.array([1.0, 1.0, 1.0, 1.0, 1.0, 0.0, -1.0, -1.0, -1.0, -1.0, -1.0, -0.0])
4949
)
50+
51+
52+
test_envelope_netzero()

0 commit comments

Comments
 (0)