Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block
from minichain.validators import is_valid_receiver
from minichain.chain import MAX_BLOCKS_PER_REQUEST


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -156,6 +157,51 @@ async def handler(data):
else:
logger.warning("📥 Received Block #%s — rejected", block.index)

elif msg_type == "status":
import json as _json
peer_height = payload["height"]
my_height = chain.height

if peer_height > my_height:
writer = data.get("_writer")
if writer:
from_h = my_height + 1
to_h = min(peer_height, from_h + MAX_BLOCKS_PER_REQUEST - 1)
request = _json.dumps({
"type": "get_blocks",
"data": {"from_height": from_h, "to_height": to_h},
}) + "\n"
writer.write(request.encode())
await writer.drain()
logger.info(
"📡 Requesting blocks %d~%d from %s",
from_h, to_h, peer_addr,
)

elif msg_type == "get_blocks":
import json as _json
from_h = payload["from_height"]
to_h = payload["to_height"]
blocks = chain.get_blocks_range(from_h, to_h)

writer = data.get("_writer")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Injecting _writer into the message payload and directly calling writer.write(...) in main.py breaks the abstraction of the P2P layer. If we ever switch the transport layer (e.g., from raw sockets to WebSockets or a library), we will have to rewrite the application logic in main.py.
Instead of passing the raw StreamWriter to main.py, consider adding a send_message_to_peer(peer_addr, msg_type, payload) method directly on the P2PNetwork class, and invoke that from main.py.

if writer and blocks:
response = _json.dumps({
"type": "blocks",
"data": {"blocks": blocks}
}) + "\n"
writer.write(response.encode())
await writer.drain()
logger.info("📤 Sent %d blocks to %s", len(blocks), peer_addr)

elif msg_type == "blocks":
received = payload["blocks"]
success, count = chain.add_blocks_bulk(received)
if success:
logger.info("✅ Chain synced: added %d blocks", count)
else:
logger.warning("❌ Chain sync failed — batch rejected")

return handler


Expand Down Expand Up @@ -318,13 +364,24 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data
# When a new peer connects, send our state so they can sync
async def on_peer_connected(writer):
import json as _json
accounts_snapshot, height_snapshot = chain.snapshot_state_and_height()

sync_msg = _json.dumps({
"type": "sync",
"data": {"accounts": chain.state.accounts}
"data": {"accounts": accounts_snapshot},
}) + "\n"
status_msg = _json.dumps({
"type": "status",
"data": {"height": height_snapshot},
}) + "\n"

writer.write(sync_msg.encode())
writer.write(status_msg.encode())
await writer.drain()
logger.info("🔄 Sent state sync to new peer")
logger.info(
"🔄 Sent state sync (%d accounts) and status (height=%d) to new peer",
len(accounts_snapshot), height_snapshot,
)

network.register_on_peer_connected(on_peer_connected)

Expand Down
69 changes: 69 additions & 0 deletions minichain/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import logging
import threading

MAX_BLOCKS_PER_REQUEST = 500

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -54,6 +56,12 @@ def last_block(self):
with self._lock: # Acquire lock for thread-safe access
return self.chain[-1]

@property
def height(self) -> int:
"""Returns the current chain height (genesis = 0)"""
with self._lock:
return len(self.chain) - 1

def add_block(self, block):
"""
Validates and adds a block to the chain if all transactions succeed.
Expand Down Expand Up @@ -82,3 +90,64 @@ def add_block(self, block):
self.state = temp_state
self.chain.append(block)
return True

def get_blocks_range(self, from_height: int, to_height: int) -> list:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@0u-Y
There is no limit enforced on the number of blocks a peer can request at once. If a malicious peer connects and sends a get_blocks message with from_height: 1 and to_height: 10000000, this list comprehension will attempt to serialize millions of blocks into memory all at once. This will immediately exhaust the node's memory (OOM) and crash it. We can definitely enforce a hard limit on the batch size.

"""Return serialized blocks in [from_height, to_height], capped at MAX_BLOCKS_PER_REQUEST."""
with self._lock:
to_height = min(
to_height,
len(self.chain) - 1,
from_height + MAX_BLOCKS_PER_REQUEST - 1,
)
if from_height > to_height or from_height < 0:
return []
return [b.to_dict() for b in self.chain[from_height:to_height + 1]]

def add_blocks_bulk(self, block_dicts: list) -> tuple:
"""
Atomically add a batch of blocks: validate each block's transactions
against a temporary state, and commit chain + state only if every
block passes. Any failure leaves the local chain and state untouched.

Returns (True, count) on full success, (False, 0) on any failure.
"""
with self._lock:
temp_state = self.state.copy()
prev_block = self.chain[-1]
new_blocks = []

for block_dict in block_dicts:
try:
block = Block.from_dict(block_dict)
except (KeyError, TypeError, ValueError) as exc:
logger.warning("Bulk add rejected: malformed block dict: %s", exc)
return False, 0

try:
validate_block_link_and_hash(prev_block, block)
except ValueError as exc:
logger.warning("Bulk add rejected at block %s: %s", block.index, exc)
return False, 0

for tx in block.transactions:
if not temp_state.validate_and_apply(tx):
logger.warning(
"Bulk add rejected at block %s: transaction failed validation",
block.index,
)
return False, 0

new_blocks.append(block)
prev_block = block

self.state = temp_state
self.chain.extend(new_blocks)
return True, len(new_blocks)

def snapshot_state_and_height(self) -> tuple:
"""Capture accounts and chain height under a single lock acquisition."""
with self._lock:
accounts_copy = {
addr: dict(acc) for addr, acc in self.state.accounts.items()
}
return accounts_copy, len(self.chain) - 1
40 changes: 39 additions & 1 deletion minichain/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

from .serialization import canonical_json_hash
from .validators import is_valid_receiver
from .chain import MAX_BLOCKS_PER_REQUEST

logger = logging.getLogger(__name__)

TOPIC = "minichain-global"
SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"}
SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block", "status", "get_blocks", "blocks"}


class P2PNetwork:
Expand Down Expand Up @@ -207,6 +208,39 @@ def _validate_block_payload(self, payload):
for tx_payload in payload["transactions"]
)

def _validate_status_payload(self, payload):
if not isinstance(payload, dict):
return False
if set(payload) != {"height"}:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the new payload validators repeat the exact same type-checking logic, specifically checking if the payload is a dictionary and validating the exact keys using set(payload) != {...}. While I see this follows the pattern of older validators in the file, it introduces unnecessary boilerplate.

You can consider creating a generic helper method to reduce the duplication, or simply inline the dict check if the logic is very simple

return False
if not isinstance(payload["height"], int) or payload["height"] < 0:
return False
return True

def _validate_get_blocks_payload(self, payload):
if not isinstance(payload, dict):
return False
if set(payload) != {"from_height", "to_height"}:
return False
fh, th = payload.get("from_height"), payload.get("to_height")
if not isinstance(fh, int) or not isinstance(th, int):
return False
if fh < 0 or fh > th:
return False
return True

def _validate_blocks_payload(self, payload):
if not isinstance(payload, dict):
return False
if set(payload) != {"blocks"}:
return False
blocks = payload.get("blocks")
if not isinstance(blocks, list):
return False
if len(blocks) > MAX_BLOCKS_PER_REQUEST:
return False
return all(self._validate_block_payload(b) for b in blocks)

def _validate_message(self, message):
if not isinstance(message, dict):
return False
Expand All @@ -226,6 +260,9 @@ def _validate_message(self, message):
"sync": self._validate_sync_payload,
"tx": self._validate_transaction_payload,
"block": self._validate_block_payload,
"status": self._validate_status_payload,
"get_blocks": self._validate_get_blocks_payload,
"blocks": self._validate_blocks_payload,
}
return validators[msg_type](payload)

Expand Down Expand Up @@ -283,6 +320,7 @@ async def _listen_to_peer(
continue
self._mark_seen(msg_type, payload)
data["_peer_addr"] = addr
data["_writer"] = writer

if self._handler_callback:
try:
Expand Down
Loading