diff --git a/src/consensus/blockchain.py b/src/consensus/blockchain.py index 92587ca97f34..a6a5522912e7 100644 --- a/src/consensus/blockchain.py +++ b/src/consensus/blockchain.py @@ -584,6 +584,9 @@ async def get_header_block_by_height(self, height: int, header_hash: bytes32) -> return None return header_dict[header_hash] + async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]: + return await self.block_store.get_block_records_at(heights) + async def get_block_record_from_db(self, header_hash: bytes32) -> Optional[BlockRecord]: if header_hash in self.__block_records: return self.__block_records[header_hash] diff --git a/src/consensus/blockchain_interface.py b/src/consensus/blockchain_interface.py index 934e95485af2..1133767265da 100644 --- a/src/consensus/blockchain_interface.py +++ b/src/consensus/blockchain_interface.py @@ -55,6 +55,9 @@ async def get_header_blocks_in_range(self, start: int, stop: int) -> Dict[bytes3 async def get_header_block_by_height(self, height: int, header_hash: bytes32) -> Optional[HeaderBlock]: pass + async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]: + pass + def try_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]: if self.contains_block(header_hash): return self.block_record(header_hash) diff --git a/src/full_node/block_store.py b/src/full_node/block_store.py index fb142dbd39cf..317021d117b7 100644 --- a/src/full_node/block_store.py +++ b/src/full_node/block_store.py @@ -154,6 +154,18 @@ async def get_full_blocks_at(self, heights: List[uint32]) -> List[FullBlock]: await cursor.close() return [FullBlock.from_bytes(row[0]) for row in rows] + async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]: + if len(heights) == 0: + return [] + heights_db = tuple(heights) + formatted_str = ( + f'SELECT block from block_records WHERE height in ({"?," * (len(heights_db) - 1)}?) ORDER BY height ASC;' + ) + cursor = await self.db.execute(formatted_str, heights_db) + rows = await cursor.fetchall() + await cursor.close() + return [BlockRecord.from_bytes(row[0]) for row in rows] + async def get_blocks_by_hash(self, header_hashes: List[bytes32]) -> List[FullBlock]: """ Returns a list of Full Blocks blocks, ordered by the same order in which header_hashes are passed in. diff --git a/src/full_node/weight_proof.py b/src/full_node/weight_proof.py index cd3def72a126..abfa521a7623 100644 --- a/src/full_node/weight_proof.py +++ b/src/full_node/weight_proof.py @@ -26,7 +26,14 @@ from src.types.blockchain_format.vdf import VDFInfo from src.types.end_of_slot_bundle import EndOfSubSlotBundle from src.types.header_block import HeaderBlock -from src.types.weight_proof import SubEpochChallengeSegment, SubEpochData, SubSlotData, WeightProof +from src.types.weight_proof import ( + SubEpochChallengeSegment, + SubEpochData, + SubSlotData, + WeightProof, + SubEpochSegments, + RecentChainData, +) from src.util.block_cache import BlockCache from src.util.hash import std_hash from src.util.ints import uint8, uint32, uint64, uint128 @@ -64,6 +71,9 @@ async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]: return None async with self.lock: + if self.proof is not None: + if self.proof.recent_chain_data[-1].header_hash == tip: + return self.proof wp = await self._create_proof_of_weight(tip) if wp is None: return None @@ -71,12 +81,21 @@ async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]: self.tip = tip return wp + def get_sub_epoch_data(self, tip_height: uint32, summary_heights: List[uint32]) -> List[SubEpochData]: + sub_epoch_data: List[SubEpochData] = [] + for sub_epoch_n, ses_height in enumerate(summary_heights): + if ses_height > tip_height: + break + ses = self.blockchain.get_ses(ses_height) + log.debug(f"handle sub epoch summary {sub_epoch_n} at height: {ses_height} ses {ses}") + sub_epoch_data.append(_create_sub_epoch_data(ses)) + return sub_epoch_data + async def _create_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]: """ Creates a weight proof object """ assert self.blockchain is not None - sub_epoch_data: List[SubEpochData] = [] sub_epoch_segments: List[SubEpochChallengeSegment] = [] tip_rec = self.blockchain.try_block_record(tip) if tip_rec is None: @@ -88,34 +107,34 @@ async def _create_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]: return None summary_heights = self.blockchain.get_ses_heights() + prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0))) + if prev_ses_block is None: + return None + sub_epoch_data = self.get_sub_epoch_data(tip_rec.height, summary_heights) # use second to last ses as seed seed = self.get_seed_for_proof(summary_heights, tip_rec.height) rng = random.Random(seed) weight_to_check = _get_weights_for_sampling(rng, tip_rec.weight, recent_chain) - - prev_ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(uint32(0))) - if prev_ses_block is None: + sample_n = 0 + ses_blocks = await self.blockchain.get_block_records_at(summary_heights) + if ses_blocks is None: return None - sample_n = 0 - summary_heights = self.blockchain.get_ses_heights() for sub_epoch_n, ses_height in enumerate(summary_heights): if ses_height > tip_rec.height: break - # next sub block - ses_block = await self.blockchain.get_block_record_from_db(self.blockchain.height_to_hash(ses_height)) - if ses_block is None or ses_block.sub_epoch_summary_included is None: - log.error("error while building proof") - return None - - log.debug(f"handle sub epoch summary {sub_epoch_n} at height: {ses_height} weight: {ses_block.weight}") - sub_epoch_data.append(_create_sub_epoch_data(ses_block.sub_epoch_summary_included)) # if we have enough sub_epoch samples, dont sample if sample_n >= self.MAX_SAMPLES: log.debug("reached sampled sub epoch cap") - continue + break # sample sub epoch + # next sub block + ses_block = ses_blocks[sub_epoch_n] + if ses_block is None or ses_block.sub_epoch_summary_included is None: + log.error("error while building proof") + return None + if _sample_sub_epoch(prev_ses_block.weight, ses_block.weight, weight_to_check): # type: ignore sample_n += 1 segments = await self.blockchain.get_sub_epoch_challenge_segments(ses_block.height) @@ -148,8 +167,17 @@ def get_seed_for_proof(self, summary_heights: List[uint32], tip_height) -> bytes async def _get_recent_chain(self, tip_height: uint32) -> Optional[List[HeaderBlock]]: recent_chain: List[HeaderBlock] = [] - min_height = max(0, tip_height - self.constants.WEIGHT_PROOF_RECENT_BLOCKS * 2) - headers: Dict[bytes32, HeaderBlock] = await self.blockchain.get_header_blocks_in_range(min_height, tip_height) + ses_heights = self.blockchain.get_ses_heights() + min_height = 0 + count_ses = 0 + for ses_height in reversed(ses_heights): + if ses_height <= tip_height: + count_ses += 1 + if count_ses == 2: + min_height = ses_height - 1 + break + log.debug(f"start {min_height} end {tip_height}") + headers = await self.blockchain.get_header_blocks_in_range(min_height, tip_height) blocks = await self.blockchain.get_block_records_in_range(min_height, tip_height) ses_count = 0 curr_height = tip_height @@ -464,7 +492,9 @@ def validate_weight_proof_single_proc(self, weight_proof: WeightProof) -> Tuple[ if summaries is None: log.warning("weight proof failed sub epoch data validation") return False, uint32(0) - constants, summary_bytes, wp_bytes = vars_to_bytes(self.constants, summaries, weight_proof) + constants, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes( + self.constants, summaries, weight_proof + ) log.info("validate sub epoch challenge segments") seed = summaries[-2].get_hash() rng = random.Random(seed) @@ -472,10 +502,10 @@ def validate_weight_proof_single_proc(self, weight_proof: WeightProof) -> Tuple[ log.error("failed weight proof sub epoch sample validation") return False, uint32(0) - if not _validate_sub_epoch_segments(constants, rng, wp_bytes, summary_bytes): + if not _validate_sub_epoch_segments(constants, rng, wp_segment_bytes, summary_bytes): return False, uint32(0) log.info("validate weight proof recent blocks") - if not _validate_recent_blocks(constants, wp_bytes, summary_bytes): + if not _validate_recent_blocks(constants, wp_recent_chain_bytes, summary_bytes): return False, uint32(0) return True, self.get_fork_point(summaries) @@ -512,13 +542,15 @@ async def validate_weight_proof(self, weight_proof: WeightProof) -> Tuple[bool, return False, uint32(0) executor = ProcessPoolExecutor(1) - constants, summary_bytes, wp_bytes = vars_to_bytes(self.constants, summaries, weight_proof) + constants, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes = vars_to_bytes( + self.constants, summaries, weight_proof + ) segment_validation_task = asyncio.get_running_loop().run_in_executor( - executor, _validate_sub_epoch_segments, constants, rng, wp_bytes, summary_bytes + executor, _validate_sub_epoch_segments, constants, rng, wp_segment_bytes, summary_bytes ) recent_blocks_validation_task = asyncio.get_running_loop().run_in_executor( - executor, _validate_recent_blocks, constants, wp_bytes, summary_bytes + executor, _validate_recent_blocks, constants, wp_recent_chain_bytes, summary_bytes ) valid_segment_task = segment_validation_task @@ -839,13 +871,14 @@ def _validate_sub_epoch_segments( weight_proof_bytes: bytes, summaries_bytes: List[bytes], ): - constants, summaries, weight_proof = bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes) + constants, summaries = bytes_to_vars(constants_dict, summaries_bytes) + sub_epoch_segments: SubEpochSegments = SubEpochSegments.from_bytes(weight_proof_bytes) rc_sub_slot_hash = constants.GENESIS_CHALLENGE total_blocks, total_ip_iters = 0, 0 total_slot_iters, total_slots = 0, 0 total_ip_iters = 0 prev_ses: Optional[SubEpochSummary] = None - segments_by_sub_epoch = map_segments_by_sub_epoch(weight_proof.sub_epoch_segments) + segments_by_sub_epoch = map_segments_by_sub_epoch(sub_epoch_segments.challenge_segments) curr_ssi = constants.SUB_SLOT_ITERS_STARTING for sub_epoch_n, segments in segments_by_sub_epoch.items(): prev_ssi = curr_ssi @@ -1089,10 +1122,11 @@ def sub_slot_data_vdf_input( return cc_input -def _validate_recent_blocks(constants_dict: Dict, weight_proof_bytes: bytes, summaries_bytes: List[bytes]) -> bool: - constants, summaries, weight_proof = bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes) +def _validate_recent_blocks(constants_dict: Dict, recent_chain_bytes: bytes, summaries_bytes: List[bytes]) -> bool: + constants, summaries = bytes_to_vars(constants_dict, summaries_bytes) + recent_chain: RecentChainData = RecentChainData.from_bytes(recent_chain_bytes) sub_blocks = BlockCache({}) - first_ses_idx = _get_ses_idx(weight_proof.recent_chain_data) + first_ses_idx = _get_ses_idx(recent_chain.recent_chain_data) ses_idx = len(summaries) - len(first_ses_idx) ssi: uint64 = constants.SUB_SLOT_ITERS_STARTING diff: Optional[uint64] = constants.DIFFICULTY_STARTING @@ -1105,10 +1139,10 @@ def _validate_recent_blocks(constants_dict: Dict, weight_proof_bytes: bytes, sum ses_blocks, sub_slots, transaction_blocks = 0, 0, 0 challenge, prev_challenge = None, None - tip_height = weight_proof.recent_chain_data[-1].height + tip_height = recent_chain.recent_chain_data[-1].height prev_block_record = None deficit = uint8(0) - for idx, block in enumerate(weight_proof.recent_chain_data): + for idx, block in enumerate(recent_chain.recent_chain_data): required_iters = uint64(0) overflow = False ses = False @@ -1360,20 +1394,20 @@ def _get_curr_diff_ssi(constants: ConsensusConstants, idx, summaries): def vars_to_bytes(constants, summaries, weight_proof): constants_dict = recurse_jsonify(dataclasses.asdict(constants)) - wp_bytes = bytes(weight_proof) + wp_recent_chain_bytes = bytes(RecentChainData(weight_proof.recent_chain_data)) + wp_segment_bytes = bytes(SubEpochSegments(weight_proof.sub_epoch_segments)) summary_bytes = [] for summary in summaries: summary_bytes.append(bytes(summary)) - return constants_dict, summary_bytes, wp_bytes + return constants_dict, summary_bytes, wp_segment_bytes, wp_recent_chain_bytes -def bytes_to_vars(constants_dict, summaries_bytes, weight_proof_bytes): +def bytes_to_vars(constants_dict, summaries_bytes): summaries = [] for summary in summaries_bytes: summaries.append(SubEpochSummary.from_bytes(summary)) constants: ConsensusConstants = dataclass_from_dict(ConsensusConstants, constants_dict) - weight_proof = WeightProof.from_bytes(weight_proof_bytes) - return constants, summaries, weight_proof + return constants, summaries def _get_last_ses_hash( diff --git a/src/types/weight_proof.py b/src/types/weight_proof.py index 6d457c53be0d..e465039391be 100644 --- a/src/types/weight_proof.py +++ b/src/types/weight_proof.py @@ -80,6 +80,13 @@ class SubEpochSegments(Streamable): challenge_segments: List[SubEpochChallengeSegment] +@dataclass(frozen=True) +@streamable +# this is used only for serialization to database +class RecentChainData(Streamable): + recent_chain_data: List[HeaderBlock] + + @dataclass(frozen=True) @streamable class ProofBlockHeader(Streamable): diff --git a/src/util/block_cache.py b/src/util/block_cache.py index b59c2428e30d..63dc44ea651c 100644 --- a/src/util/block_cache.py +++ b/src/util/block_cache.py @@ -59,6 +59,12 @@ def contains_height(self, height: uint32) -> bool: async def get_block_records_in_range(self, start: int, stop: int) -> Dict[bytes32, BlockRecord]: return self._block_records + async def get_block_records_at(self, heights: List[uint32]) -> List[BlockRecord]: + block_records: List[BlockRecord] = [] + for height in heights: + block_records.append(self.height_to_block_record(height)) + return block_records + async def get_block_record_from_db(self, header_hash: bytes32) -> Optional[BlockRecord]: return self._block_records[header_hash]