Skip to content

Commit 414d63c

Browse files
committed
Test setup moved to fixtures
1 parent 9c85562 commit 414d63c

File tree

2 files changed

+49
-37
lines changed

2 files changed

+49
-37
lines changed

test/conftest.py

+32
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import tensorflow as tf
12
from typing import Any, Dict
23
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
4+
from c3.parametermap import ParameterMap
5+
from c3.generator.generator import Generator
6+
from c3.generator.devices import Crosstalk
7+
from c3.c3objs import Quantity
38
import pytest
49

510

@@ -101,3 +106,30 @@ def get_result_qiskit() -> Dict[str, Dict[str, Any]]:
101106
"c3_qasm_perfect_simulator": perfect_counts,
102107
}
103108
return counts_dict
109+
110+
111+
@pytest.fixture()
112+
def get_xtalk_pmap() -> ParameterMap:
113+
xtalk = Crosstalk(
114+
name="crosstalk",
115+
channels=["TC1", "TC2"],
116+
crosstalk_matrix=Quantity(
117+
value=[[1, 0], [0, 1]],
118+
min_val=[[0, 0], [0, 0]],
119+
max_val=[[1, 1], [1, 1]],
120+
unit="",
121+
),
122+
)
123+
124+
gen = Generator(devices={"crosstalk": xtalk})
125+
pmap = ParameterMap(generator=gen)
126+
pmap.set_opt_map([[["crosstalk", "crosstalk_matrix"]]])
127+
return pmap
128+
129+
130+
@pytest.fixture()
131+
def get_test_signal() -> Dict:
132+
return {
133+
"TC1": {"values": tf.linspace(0, 100, 101)},
134+
"TC2": {"values": tf.linspace(100, 200, 101)},
135+
}

test/test_crosstalk.py

+17-37
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,37 @@
11
import pytest
22

33
import numpy as np
4-
import tensorflow as tf
5-
6-
from c3.generator.devices import Crosstalk
7-
from c3.generator.generator import Generator
8-
from c3.c3objs import Quantity as Qty
9-
from c3.parametermap import ParameterMap
10-
11-
12-
xtalk = Crosstalk(
13-
name="crosstalk",
14-
channels=["TC1", "TC2"],
15-
crosstalk_matrix=Qty(
16-
value=[[1, 0], [0, 1]],
17-
min_val=[[0, 0], [0, 0]],
18-
max_val=[[1, 1], [1, 1]],
19-
unit="",
20-
),
21-
)
22-
23-
signal = {
24-
"TC1": {"values": tf.linspace(0, 100, 101)},
25-
"TC2": {"values": tf.linspace(100, 200, 101)},
26-
}
27-
28-
gen = Generator(devices={"crosstalk": xtalk})
29-
pmap = ParameterMap(generator=gen)
30-
pmap.set_opt_map([[["crosstalk", "crosstalk_matrix"]]])
314

325

336
@pytest.mark.unit
34-
def test_crosstalk() -> None:
35-
new_sig = xtalk.process(signal=signal)
36-
assert new_sig == signal
7+
def test_crosstalk(get_xtalk_pmap, get_test_signal) -> None:
8+
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
9+
new_sig = xtalk.process(signal=get_test_signal)
10+
assert new_sig == get_test_signal
3711

3812

3913
@pytest.mark.unit
40-
def test_crosstalk_flip() -> None:
14+
def test_crosstalk_flip(get_xtalk_pmap, get_test_signal) -> None:
15+
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
4116
xtalk.params["crosstalk_matrix"].set_value([[0, 1], [1, 0]])
42-
new_sig = xtalk.process(signal=signal)
17+
new_sig = xtalk.process(signal=get_test_signal)
4318
assert (new_sig["TC2"]["values"].numpy() == np.linspace(0, 100, 101)).all()
4419
assert (new_sig["TC1"]["values"].numpy() == np.linspace(100, 200, 101)).all()
4520

4621

4722
@pytest.mark.unit
48-
def test_crosstalk_mix() -> None:
23+
def test_crosstalk_mix(get_xtalk_pmap, get_test_signal) -> None:
24+
xtalk = get_xtalk_pmap.generator.devices["crosstalk"]
4925
xtalk.params["crosstalk_matrix"].set_value([[0.5, 0.5], [0.5, 0.5]])
50-
new_sig = xtalk.process(signal=signal)
26+
new_sig = xtalk.process(signal=get_test_signal)
5127
assert (new_sig["TC2"]["values"].numpy() == new_sig["TC1"]["values"].numpy()).all()
5228

5329

5430
@pytest.mark.unit
55-
def test_crosstalk_set_get_parameters() -> None:
56-
pmap.set_parameters([[1, 1], [1, 1]], [[["crosstalk", "crosstalk_matrix"]]])
57-
assert (pmap.get_parameters()[0].get_value().numpy() == [[1, 1], [1, 1]]).all()
31+
def test_crosstalk_set_get_parameters(get_xtalk_pmap) -> None:
32+
get_xtalk_pmap.set_parameters(
33+
[[1, 1], [1, 1]], [[["crosstalk", "crosstalk_matrix"]]]
34+
)
35+
assert (
36+
get_xtalk_pmap.get_parameters()[0].get_value().numpy() == [[1, 1], [1, 1]]
37+
).all()

0 commit comments

Comments
 (0)