Skip to content

Commit

Permalink
wp reduce serialization and db calls (Chia-Network#1569)
Browse files Browse the repository at this point in the history
* reduce db calls, split wp serialisation, add get_block_records_at

* return prev proof is same tip

* change log to debug

* brake
  • Loading branch information
almogdepaz authored Apr 2, 2021
1 parent 6001198 commit 489acc4
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 36 deletions.
3 changes: 3 additions & 0 deletions src/consensus/blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,9 @@ async def get_header_block_by_height(self, height: int, header_hash: bytes32) ->
return None
return header_dict[header_hash]

async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]:
return await self.block_store.get_block_records_at(heights)

async def get_block_record_from_db(self, header_hash: bytes32) -> Optional[BlockRecord]:
if header_hash in self.__block_records:
return self.__block_records[header_hash]
Expand Down
3 changes: 3 additions & 0 deletions src/consensus/blockchain_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ async def get_header_blocks_in_range(self, start: int, stop: int) -> Dict[bytes3
async def get_header_block_by_height(self, height: int, header_hash: bytes32) -> Optional[HeaderBlock]:
pass

async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]:
pass

def try_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:
if self.contains_block(header_hash):
return self.block_record(header_hash)
Expand Down
12 changes: 12 additions & 0 deletions src/full_node/block_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ async def get_full_blocks_at(self, heights: List[uint32]) -> List[FullBlock]:
await cursor.close()
return [FullBlock.from_bytes(row[0]) for row in rows]

async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]:
if len(heights) == 0:
return []
heights_db = tuple(heights)
formatted_str = (
f'SELECT block from block_records WHERE height in ({"?," * (len(heights_db) - 1)}?) ORDER BY height ASC;'
)
cursor = await self.db.execute(formatted_str, heights_db)
rows = await cursor.fetchall()
await cursor.close()
return [BlockRecord.from_bytes(row[0]) for row in rows]

async def get_blocks_by_hash(self, header_hashes: List[bytes32]) -> List[FullBlock]:
"""
Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in.
Expand Down
106 changes: 70 additions & 36 deletions src/full_node/weight_proof.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
from src.types.blockchain_format.vdf import VDFInfo
from src.types.end_of_slot_bundle import EndOfSubSlotBundle
from src.types.header_block import HeaderBlock
from src.types.weight_proof import SubEpochChallengeSegment, SubEpochData, SubSlotData, WeightProof
from src.types.weight_proof import (
SubEpochChallengeSegment,
SubEpochData,
SubSlotData,
WeightProof,
SubEpochSegments,
RecentChainData,
)
from src.util.block_cache import BlockCache
from src.util.hash import std_hash
from src.util.ints import uint8, uint32, uint64, uint128
Expand Down Expand Up @@ -64,19 +71,31 @@ async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]:
return None

async with self.lock:
if self.proof is not None:
if self.proof.recent_chain_data[-1].header_hash == tip:
return self.proof
wp = await self._create_proof_of_weight(tip)
if wp is None:
return None
self.proof = wp
self.tip = tip
return wp

def get_sub_epoch_data(self, tip_height: uint32, summary_heights: List[uint32]) -> List[SubEpochData]:
sub_epoch_data: List[SubEpochData] = []
for sub_epoch_n, ses_height in enumerate(summary_heights):
if ses_height > tip_height:
break
ses = self.blockchain.get_ses(ses_height)
log.debug(f"handle sub epoch summary {sub_epoch_n} at height: {ses_height} ses {ses}")
sub_epoch_data.append(_create_sub_epoch_data(ses))
return sub_epoch_data

async def _create_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]:
"""
Creates a weight proof object
"""
assert self.blockchain is not None
sub_epoch_data: List[SubEpochData] = []
sub_epoch_segments: List[SubEpochChallengeSegment] = []
tip_rec = self.blockchain.try_block_record(tip)
if tip_rec is None:
Expand All @@ -88,34 +107,34 @@ async def _create_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]:
return None

summary_heights = self.blockchain.get_ses_heights()
prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0)))
if prev_ses_block is None:
return None
sub_epoch_data = self.get_sub_epoch_data(tip_rec.height, summary_heights)
# use second to last ses as seed
seed = self.get_seed_for_proof(summary_heights, tip_rec.height)
rng = random.Random(seed)
weight_to_check = _get_weights_for_sampling(rng, tip_rec.weight, recent_chain)

prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0)))
if prev_ses_block is None:
sample_n = 0
ses_blocks = await self.blockchain.get_block_records_at(summary_heights)
if ses_blocks is None:
return None

sample_n = 0
summary_heights = self.blockchain.get_ses_heights()
for sub_epoch_n, ses_height in enumerate(summary_heights):
if ses_height > tip_rec.height:
break
# next sub block
ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(ses_height))
if ses_block is None or ses_block.sub_epoch_summary_included is None:
log.error("error while building proof")
return None

log.debug(f"handle sub epoch summary {sub_epoch_n} at height: {ses_height} weight: {ses_block.weight}")
sub_epoch_data.append(_create_sub_epoch_data(ses_block.sub_epoch_summary_included))

# if we have enough sub_epoch samples, dont sample
if sample_n >= self.MAX_SAMPLES:
log.debug("reached sampled sub epoch cap")
continue
break
# sample sub epoch
# next sub block
ses_block = ses_blocks[sub_epoch_n]
if ses_block is None or ses_block.sub_epoch_summary_included is None:
log.error("error while building proof")
return None

if _sample_sub_epoch(prev_ses_block.weight, ses_block.weight, weight_to_check): # type: ignore
sample_n += 1
segments = await self.blockchain.get_sub_epoch_challenge_segments(ses_block.height)
Expand Down Expand Up @@ -148,8 +167,17 @@ def get_seed_for_proof(self, summary_heights: List[uint32], tip_height) -> bytes

async def _get_recent_chain(self, tip_height: uint32) -> Optional[List[HeaderBlock]]:
recent_chain: List[HeaderBlock] = []
min_height = max(0, tip_height - self.constants.WEIGHT_PROOF_RECENT_BLOCKS * 2)
headers: Dict[bytes32, HeaderBlock] = await self.blockchain.get_header_blocks_in_range(min_height, tip_height)
ses_heights = self.blockchain.get_ses_heights()
min_height = 0
count_ses = 0
for ses_height in reversed(ses_heights):
if ses_height <= tip_height:
count_ses += 1
if count_ses == 2:
min_height = ses_height - 1
break
log.debug(f"start {min_height} end {tip_height}")
headers = await self.blockchain.get_header_blocks_in_range(min_height, tip_height)
blocks = await self.blockchain.get_block_records_in_range(min_height, tip_height)
ses_count = 0
curr_height = tip_height
Expand Down Expand Up @@ -464,18 +492,20 @@ def validate_weight_proof_single_proc(self, weight_proof: WeightProof) -> Tuple[
if summaries is None:
log.warning("weight proof failed sub epoch data validation")
return False, uint32(0)
constants, summary_bytes, wp_bytes = vars_to_bytes(self.constants, summaries, weight_proof)
constants, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes(
self.constants, summaries, weight_proof
)
log.info("validate sub epoch challenge segments")
seed = summaries[-2].get_hash()
rng = random.Random(seed)
if not validate_sub_epoch_sampling(rng, sub_epoch_weight_list, weight_proof):
log.error("failed weight proof sub epoch sample validation")
return False, uint32(0)

if not _validate_sub_epoch_segments(constants, rng, wp_bytes, summary_bytes):
if not _validate_sub_epoch_segments(constants, rng, wp_segment_bytes, summary_bytes):
return False, uint32(0)
log.info("validate weight proof recent blocks")
if not _validate_recent_blocks(constants, wp_bytes, summary_bytes):
if not _validate_recent_blocks(constants, wp_recent_chain_bytes, summary_bytes):
return False, uint32(0)
return True, self.get_fork_point(summaries)

Expand Down Expand Up @@ -512,13 +542,15 @@ async def validate_weight_proof(self, weight_proof: WeightProof) -> Tuple[bool,
return False, uint32(0)

executor = ProcessPoolExecutor(1)
constants, summary_bytes, wp_bytes = vars_to_bytes(self.constants, summaries, weight_proof)
constants, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes(
self.constants, summaries, weight_proof
)
segment_validation_task = asyncio.get_running_loop().run_in_executor(
executor, _validate_sub_epoch_segments, constants, rng, wp_bytes, summary_bytes
executor, _validate_sub_epoch_segments, constants, rng, wp_segment_bytes, summary_bytes
)

recent_blocks_validation_task = asyncio.get_running_loop().run_in_executor(
executor, _validate_recent_blocks, constants, wp_bytes, summary_bytes
executor, _validate_recent_blocks, constants, wp_recent_chain_bytes, summary_bytes
)

valid_segment_task = segment_validation_task
Expand Down Expand Up @@ -839,13 +871,14 @@ def _validate_sub_epoch_segments(
weight_proof_bytes: bytes,
summaries_bytes: List[bytes],
):
constants, summaries, weight_proof = bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes)
constants, summaries = bytes_to_vars(constants_dict, summaries_bytes)
sub_epoch_segments: SubEpochSegments = SubEpochSegments.from_bytes(weight_proof_bytes)
rc_sub_slot_hash = constants.GENESIS_CHALLENGE
total_blocks, total_ip_iters = 0, 0
total_slot_iters, total_slots = 0, 0
total_ip_iters = 0
prev_ses: Optional[SubEpochSummary] = None
segments_by_sub_epoch = map_segments_by_sub_epoch(weight_proof.sub_epoch_segments)
segments_by_sub_epoch = map_segments_by_sub_epoch(sub_epoch_segments.challenge_segments)
curr_ssi = constants.SUB_SLOT_ITERS_STARTING
for sub_epoch_n, segments in segments_by_sub_epoch.items():
prev_ssi = curr_ssi
Expand Down Expand Up @@ -1089,10 +1122,11 @@ def sub_slot_data_vdf_input(
return cc_input


def _validate_recent_blocks(constants_dict: Dict, weight_proof_bytes: bytes, summaries_bytes: List[bytes]) -> bool:
constants, summaries, weight_proof = bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes)
def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, summaries_bytes: List[bytes]) -> bool:
constants, summaries = bytes_to_vars(constants_dict, summaries_bytes)
recent_chain: RecentChainData = RecentChainData.from_bytes(recent_chain_bytes)
sub_blocks = BlockCache({})
first_ses_idx = _get_ses_idx(weight_proof.recent_chain_data)
first_ses_idx = _get_ses_idx(recent_chain.recent_chain_data)
ses_idx = len(summaries) - len(first_ses_idx)
ssi: uint64 = constants.SUB_SLOT_ITERS_STARTING
diff: Optional[uint64] = constants.DIFFICULTY_STARTING
Expand All @@ -1105,10 +1139,10 @@ def _validate_recent_blocks(constants_dict: Dict, weight_proof_bytes: bytes, sum

ses_blocks, sub_slots, transaction_blocks = 0, 0, 0
challenge, prev_challenge = None, None
tip_height = weight_proof.recent_chain_data[-1].height
tip_height = recent_chain.recent_chain_data[-1].height
prev_block_record = None
deficit = uint8(0)
for idx, block in enumerate(weight_proof.recent_chain_data):
for idx, block in enumerate(recent_chain.recent_chain_data):
required_iters = uint64(0)
overflow = False
ses = False
Expand Down Expand Up @@ -1360,20 +1394,20 @@ def _get_curr_diff_ssi(constants: ConsensusConstants, idx, summaries):

def vars_to_bytes(constants, summaries, weight_proof):
constants_dict = recurse_jsonify(dataclasses.asdict(constants))
wp_bytes = bytes(weight_proof)
wp_recent_chain_bytes = bytes(RecentChainData(weight_proof.recent_chain_data))
wp_segment_bytes = bytes(SubEpochSegments(weight_proof.sub_epoch_segments))
summary_bytes = []
for summary in summaries:
summary_bytes.append(bytes(summary))
return constants_dict, summary_bytes, wp_bytes
return constants_dict, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes


def bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes):
def bytes_to_vars(constants_dict, summaries_bytes):
summaries = []
for summary in summaries_bytes:
summaries.append(SubEpochSummary.from_bytes(summary))
constants: ConsensusConstants = dataclass_from_dict(ConsensusConstants, constants_dict)
weight_proof = WeightProof.from_bytes(weight_proof_bytes)
return constants, summaries, weight_proof
return constants, summaries


def _get_last_ses_hash(
Expand Down
7 changes: 7 additions & 0 deletions src/types/weight_proof.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ class SubEpochSegments(Streamable):
challenge_segments: List[SubEpochChallengeSegment]


@dataclass(frozen=True)
@streamable
# this is used only for serialization to database
class RecentChainData(Streamable):
recent_chain_data: List[HeaderBlock]


@dataclass(frozen=True)
@streamable
class ProofBlockHeader(Streamable):
Expand Down
6 changes: 6 additions & 0 deletions src/util/block_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def contains_height(self, height: uint32) -> bool:
async def get_block_records_in_range(self, start: int, stop: int) -> Dict[bytes32, BlockRecord]:
return self._block_records

async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]:
block_records: List[BlockRecord] = []
for height in heights:
block_records.append(self.height_to_block_record(height))
return block_records

async def get_block_record_from_db(self, header_hash: bytes32) -> Optional[BlockRecord]:
return self._block_records[header_hash]

Expand Down

0 comments on commit 489acc4

Please sign in to comment.