Skip to content

Commit ed2d7a2

Browse files
committed
More sx manager additions:
- When subscribing, create typed subscriptions based on the subscription params. - Add ``get_by_id()`` and ``get_by_label()`` methods to the ``SubscriptionManager`` class. - Allow event to be passed to logs subscriptions for easier processing when the messages are received. - Add ability to count handler calls on both the manager and on a per-subscription basis - Add tests for subscription manager; tighten up typing. - Fix some types: - ``process_log`` expects a ``LogReceipt``, not a hex. - ``FormattedEthSubscriptionResponse`` should also include ``HexBytes`` for pending tx hashes (full_transactions=False). - Create live integration tests for ``eth_subscribe`` for "newHeads", "newPendingTransactions", "logs". Syncing still has to be stubbed out.
1 parent eebcd84 commit ed2d7a2

File tree

13 files changed

+907
-313
lines changed

13 files changed

+907
-313
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import itertools
2+
import pytest
3+
from unittest.mock import (
4+
AsyncMock,
5+
)
6+
7+
import pytest_asyncio
8+
9+
from web3 import (
10+
AsyncWeb3,
11+
PersistentConnectionProvider,
12+
)
13+
from web3.exceptions import (
14+
Web3ValueError,
15+
)
16+
from web3.providers.persistent.subscription_manager import (
17+
SubscriptionManager,
18+
)
19+
from web3.utils.subscriptions import (
20+
LogsSubscription,
21+
NewHeadsSubscription,
22+
PendingTxSubscription,
23+
)
24+
25+
26+
class MockProvider(PersistentConnectionProvider):
27+
socket_recv = AsyncMock()
28+
socket_send = AsyncMock()
29+
30+
31+
@pytest_asyncio.fixture
32+
async def subscription_manager():
33+
countr = itertools.count()
34+
_w3 = AsyncWeb3(MockProvider())
35+
_w3.eth._subscribe = AsyncMock()
36+
_w3.eth._subscribe.side_effect = lambda *_: f"0x{str(next(countr))}"
37+
_w3.eth._unsubscribe = AsyncMock()
38+
_w3.eth._unsubscribe.return_value = True
39+
yield SubscriptionManager(_w3)
40+
41+
42+
@pytest.mark.asyncio
43+
async def test_subscription_manager_raises_for_sx_with_the_same_label(
44+
subscription_manager,
45+
):
46+
sx1 = NewHeadsSubscription(label="foo")
47+
await subscription_manager.subscribe(sx1)
48+
49+
with pytest.raises(
50+
Web3ValueError,
51+
match="Subscription label already exists. Subscriptions must have unique "
52+
"labels.\n label: foo",
53+
):
54+
sx2 = LogsSubscription(label="foo")
55+
await subscription_manager.subscribe(sx2)
56+
57+
# make sure the subscription was subscribed to and not added to the manager
58+
assert subscription_manager.subscriptions == [sx1]
59+
assert subscription_manager._subscriptions_by_label == {"foo": sx1}
60+
assert subscription_manager._subscriptions_by_id == {"0x0": sx1}
61+
62+
63+
@pytest.mark.asyncio
64+
async def test_subscription_manager_get_by_id(subscription_manager):
65+
sx = NewHeadsSubscription(label="foo")
66+
await subscription_manager.subscribe(sx)
67+
assert subscription_manager.get_by_id("0x0") == sx
68+
assert subscription_manager.get_by_id("0x1") is None
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_subscription_manager_get_by_label(subscription_manager):
73+
sx = NewHeadsSubscription(label="foo")
74+
await subscription_manager.subscribe(sx)
75+
assert subscription_manager.get_by_label("foo") == sx
76+
assert subscription_manager.get_by_label("bar") is None
77+
78+
79+
@pytest.mark.asyncio
80+
async def test_unsubscribe_one_by_one_clears_all_subscriptions(
81+
subscription_manager,
82+
):
83+
sx1 = NewHeadsSubscription(label="foo")
84+
sx2 = PendingTxSubscription(label="bar")
85+
await subscription_manager.subscribe(sx1)
86+
await subscription_manager.subscribe(sx2)
87+
88+
await subscription_manager.unsubscribe(sx1)
89+
assert subscription_manager.subscriptions == [sx2]
90+
91+
await subscription_manager.unsubscribe(sx2)
92+
assert subscription_manager.subscriptions == []
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_unsubscribe_all_clears_all_subscriptions(subscription_manager):
97+
sx1 = NewHeadsSubscription(label="foo")
98+
sx2 = PendingTxSubscription(label="bar")
99+
await subscription_manager.subscribe([sx1, sx2])
100+
101+
await subscription_manager.unsubscribe_all()
102+
assert subscription_manager.subscriptions == []
103+
assert subscription_manager._subscriptions_by_id == {}
104+
assert subscription_manager._subscriptions_by_label == {}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
from unittest.mock import (
3+
Mock,
4+
)
5+
6+
from web3.utils.subscriptions import (
7+
LogsSubscription,
8+
NewHeadsSubscription,
9+
PendingTxSubscription,
10+
SyncingSubscription,
11+
)
12+
13+
14+
@pytest.fixture
15+
def handler():
16+
pass
17+
18+
19+
def test_new_heads_subscription_properties(handler):
20+
new_heads_subscription = NewHeadsSubscription(
21+
handler=handler, label="new heads label"
22+
)
23+
assert new_heads_subscription._handler is handler
24+
assert new_heads_subscription.label == "new heads label"
25+
assert new_heads_subscription.subscription_params == ("newHeads",)
26+
27+
28+
def test_pending_tx_subscription_properties_default(handler):
29+
pending_tx_subscription = PendingTxSubscription(
30+
handler=handler, label="pending tx label"
31+
)
32+
assert pending_tx_subscription._handler is handler
33+
assert pending_tx_subscription.label == "pending tx label"
34+
assert pending_tx_subscription.subscription_params == (
35+
"newPendingTransactions",
36+
False,
37+
)
38+
39+
40+
def test_pending_tx_subscription_properties_full_transactions(handler):
41+
pending_tx_subscription = PendingTxSubscription(
42+
full_transactions=True, handler=handler, label="pending tx label"
43+
)
44+
assert pending_tx_subscription._handler is handler
45+
assert pending_tx_subscription.label == "pending tx label"
46+
assert pending_tx_subscription.subscription_params == (
47+
"newPendingTransactions",
48+
True,
49+
)
50+
51+
52+
def test_logs_subscription_properties_default(handler):
53+
logs_subscription = LogsSubscription(handler=handler, label="logs label")
54+
assert logs_subscription._handler is handler
55+
assert logs_subscription.label == "logs label"
56+
assert logs_subscription.subscription_params == ("logs", {})
57+
assert logs_subscription.address is None
58+
assert logs_subscription.topics is None
59+
60+
61+
def test_logs_subscription_properties(handler):
62+
address = "0x1234567890123456789012345678901234567890"
63+
topics = [
64+
"0x0000000000000000000000000000000000000000000000000000000000000001",
65+
"0x0000000000000000000000000000000000000000000000000000000000000002",
66+
]
67+
event = Mock()
68+
logs_subscription = LogsSubscription(
69+
address=address, topics=topics, handler=handler, event=event, label="logs label"
70+
)
71+
assert logs_subscription._handler is handler
72+
assert logs_subscription.label == "logs label"
73+
assert logs_subscription.subscription_params == (
74+
"logs",
75+
{"address": address, "topics": topics},
76+
)
77+
assert logs_subscription.address == address
78+
assert logs_subscription.topics == topics
79+
assert logs_subscription.event is event
80+
81+
82+
def test_syncing_subscription_properties(handler):
83+
syncing_subscription = SyncingSubscription(handler=handler, label="syncing label")
84+
assert syncing_subscription._handler is handler
85+
assert syncing_subscription.label == "syncing label"
86+
assert syncing_subscription.subscription_params == ("syncing",)

web3/_utils/module_testing/eth_module.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,9 +1991,8 @@ async def test_async_eth_get_logs_with_logs_none_topic_args(
19911991
# Test with None overflowing
19921992
filter_params: FilterParams = {
19931993
"fromBlock": BlockNumber(0),
1994-
"topics": [None, None, None],
1994+
"topics": [None, None, None, None],
19951995
}
1996-
19971996
result = await async_w3.eth.get_logs(filter_params)
19981997
assert len(result) == 0
19991998

web3/_utils/module_testing/module_testing_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def assert_contains_log(
106106
emitter_contract_address: ChecksumAddress,
107107
txn_hash_with_log: HexStr,
108108
) -> None:
109-
assert len(result) == 1
110109
log_entry = result[0]
111110
assert log_entry["blockNumber"] == block_with_txn_with_log["number"]
112111
assert log_entry["blockHash"] == block_with_txn_with_log["hash"]

0 commit comments

Comments
 (0)