Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ venv
*.env
*.local

.vscode

api/test
endpoints*.json
endpoints
70 changes: 42 additions & 28 deletions api/support/update_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,54 @@
import json
import logging
import os
from collections.abc import Coroutine
from http.server import BaseHTTPRequestHandler
from typing import Any, Dict, Set, Tuple
from typing import Any

from common.state.blob_storage import BlobConfig, BlobStorageHandler
from common.state.blockchain_fetcher import BlockchainData, BlockchainDataFetcher
from common.state.blockchain_state import BlockchainState

SUPPORTED_BLOCKCHAINS = ["ethereum", "solana", "ton", "base"]
ALLOWED_PROVIDERS = {"Chainstack"}
ALLOWED_REGIONS = {"fra1"}
SUPPORTED_BLOCKCHAINS: list[str] = ["ethereum", "solana", "ton", "base"]
ALLOWED_PROVIDERS: set[str] = {
"Chainstack"
} # To reduce number of RPC calls, use only one provider here
ALLOWED_REGIONS: set[str] = {
"fra1"
} # To reduce number of RPC calls, use only one region here


class MissingEndpointsError(Exception):
"""Raised when required blockchain endpoints are not found."""

def __init__(self, missing_chains: Set[str]):
self.missing_chains = missing_chains
chains = ", ".join(missing_chains)
def __init__(self, missing_chains: set[str]) -> None:
self.missing_chains: set[str] = missing_chains
chains: str = ", ".join(missing_chains)
super().__init__(f"Missing Chainstack endpoints for: {chains}")


class StateUpdateManager:
def __init__(self):
store_id = os.getenv("STORE_ID")
token = os.getenv("VERCEL_BLOB_TOKEN")
"""Manages the collection, processing, and storage of blockchain state data.

This class orchestrates the retrieval of blockchain state data from configured endpoints,
handles fallback to previous data in case of errors, and updates the centralized blob storage.
It enforces provider and region filtering to optimize RPC calls and ensures data consistency.
"""

def __init__(self) -> None:
store_id: str | None = os.getenv("STORE_ID")
token: str | None = os.getenv("VERCEL_BLOB_TOKEN")
if not all([store_id, token]):
raise ValueError("Missing required blob storage configuration")

self.blob_config = BlobConfig(store_id=store_id, token=token) # type: ignore
self.logger = logging.getLogger(__name__)
self.logger: logging.Logger = logging.getLogger(__name__)

async def _get_chainstack_endpoints(self) -> Dict[str, str]:
async def _get_chainstack_endpoints(self) -> dict[str, str]:
"""Get Chainstack endpoints for supported blockchains."""
endpoints = json.loads(os.getenv("ENDPOINTS", "{}"))
chainstack_endpoints: Dict[str, str] = {}
missing_chains: Set[str] = set(SUPPORTED_BLOCKCHAINS)
chainstack_endpoints: dict[str, str] = {}
missing_chains: set[str] = set(SUPPORTED_BLOCKCHAINS)

for provider in endpoints.get("providers", []):
blockchain = provider["blockchain"].lower()
Expand All @@ -56,8 +68,8 @@ async def _get_chainstack_endpoints(self) -> Dict[str, str]:

return chainstack_endpoints

async def _get_previous_data(self) -> Dict[str, Any]:
"""Fetch previous blockchain state data"""
async def _get_previous_data(self) -> dict[str, Any]:
"""Fetch previous blockchain state data."""
try:
state = BlockchainState()
previous_data = {}
Expand All @@ -76,11 +88,11 @@ async def _get_previous_data(self) -> Dict[str, Any]:
return {}

async def _collect_blockchain_data(
self, providers: Dict[str, str], previous_data: Dict[str, Any]
) -> Dict[str, dict]:
self, providers: dict[str, str], previous_data: dict[str, Any]
) -> dict[str, dict]:
async def fetch_single(
blockchain: str, endpoint: str
) -> Tuple[str, Dict[str, str]]:
) -> tuple[str, dict[str, str]]:
try:
fetcher = BlockchainDataFetcher(endpoint)
data: BlockchainData = await fetcher.fetch_latest_data(blockchain)
Expand All @@ -89,7 +101,7 @@ async def fetch_single(
return blockchain, {
"block": data.block_id,
"tx": data.transaction_id,
"old_block": data.old_block_id, # Add new field
"old_block": data.old_block_id,
}

if blockchain in previous_data:
Expand All @@ -108,7 +120,7 @@ async def fetch_single(
self.logger.warning(f"Returning empty data for {blockchain}")
return blockchain, {"block": "", "tx": "", "old_block": ""}

tasks = [
tasks: list[Coroutine[Any, Any, tuple[str, dict[str, str]]]] = [
fetch_single(blockchain, endpoint)
for blockchain, endpoint in providers.items()
]
Expand All @@ -125,9 +137,11 @@ async def update(self) -> str:
return "Region not authorized for state updates"

try:
previous_data = await self._get_previous_data()
previous_data: dict[str, Any] = await self._get_previous_data()

chainstack_endpoints = await self._get_chainstack_endpoints()
chainstack_endpoints: dict[str, str] = (
await self._get_chainstack_endpoints()
)
blockchain_data = await self._collect_blockchain_data(
chainstack_endpoints, previous_data
)
Expand All @@ -136,7 +150,7 @@ async def update(self) -> str:
if not blockchain_data:
if previous_data:
self.logger.warning("Using complete previous state as fallback")
blockchain_data = previous_data
blockchain_data: dict[str, Any] = previous_data
else:
return "No blockchain data collected and no previous data available"

Expand All @@ -157,21 +171,21 @@ class handler(BaseHTTPRequestHandler):
def _check_auth(self) -> bool:
if os.getenv("SKIP_AUTH", "").lower() == "true":
return True
token = self.headers.get("Authorization", "")
token: str = self.headers.get("Authorization", "")
return token == f"Bearer {os.getenv('CRON_SECRET', '')}"

def do_GET(self):
def do_GET(self) -> None:
if not self._check_auth():
self.send_response(401)
self.end_headers()
self.wfile.write(b"Unauthorized")
return

loop = asyncio.new_event_loop()
loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
result = loop.run_until_complete(StateUpdateManager().update())
result: str = loop.run_until_complete(StateUpdateManager().update())
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
Expand Down
2 changes: 1 addition & 1 deletion api/write/solana.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
target_region = "fra1"

# Run this metric only in EU (fra1)
METRICS = (
METRICS: list[tuple[type[SolanaLandingMetric], str]] = (
[]
if os.getenv("VERCEL_REGION") != target_region # System env var, standard name
else [(SolanaLandingMetric, metric_name)]
Expand Down
30 changes: 15 additions & 15 deletions common/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union

from common.metric_config import MetricConfig, MetricLabelKey, MetricLabels

Expand All @@ -16,28 +16,28 @@ class MetricValue:
"""Container for a single metric value and its specific labels."""

value: Union[int, float]
labels: Optional[Dict[str, str]] = None
labels: Optional[dict[str, str]] = None


class BaseMetric(ABC):
"""Base class for collecting and formatting metrics in single-invocation environments."""

def __init__(
self,
handler: "MetricsHandler", # type: ignore
handler: "MetricsHandler", # type: ignore # noqa: F821
metric_name: str,
labels: MetricLabels,
config: MetricConfig,
ws_endpoint: Optional[str] = None,
http_endpoint: Optional[str] = None,
) -> None:
self.metric_id = str(uuid.uuid4())
self.metric_name = metric_name
self.labels = labels
self.config = config
self.ws_endpoint = ws_endpoint
self.http_endpoint = http_endpoint
self.values: Dict[str, MetricValue] = {}
self.metric_name: str = metric_name
self.labels: MetricLabels = labels
self.config: MetricConfig = config
self.ws_endpoint: str | None = ws_endpoint
self.http_endpoint: str | None = http_endpoint
self.values: dict[str, MetricValue] = {}
handler._instances.append(self)

@abstractmethod
Expand All @@ -48,24 +48,24 @@ async def collect_metric(self) -> None:
def process_data(self, data: Any) -> Union[int, float]:
"""Processes raw data into metric value."""

def get_influx_format(self) -> List[str]:
def get_influx_format(self) -> list[str]:
"""Returns metrics in Influx line protocol format."""
if not self.values:
raise ValueError("No metric values set")

metrics = []
base_tags = ",".join(
base_tags: str = ",".join(
[f"{label.key.value}={label.value}" for label in self.labels.labels]
)

for value_type, metric_value in self.values.items():
tags = base_tags
tags: str = base_tags
if tags:
tags = f"{base_tags},metric_type={value_type}"
else:
tags = f"metric_type={value_type}"

metric_line = f"{self.metric_name}"
metric_line: str = f"{self.metric_name}"
if tags:
metric_line += f",{tags}"
metric_line += f" value={metric_value.value}"
Expand All @@ -78,7 +78,7 @@ def update_metric_value(
self,
value: Union[int, float],
value_type: str = "response_time",
labels: Optional[Dict[str, str]] = None,
labels: Optional[dict[str, str]] = None,
) -> None:
"""Updates metric value, preserving existing labels if present."""
if value_type in self.values:
Expand All @@ -101,7 +101,7 @@ def handle_error(self, error: Exception) -> None:
if not self.values:
self.update_metric_value(0)

error_type = error.__class__.__name__
error_type: str = error.__class__.__name__
error_details = getattr(error, "error_msg", str(error))

logging.error(
Expand Down
Loading