Skip to content

Commit

Permalink
checkpoint: into main from release/2.1.4 @ 55c064a (Chia-Network#17141)
Browse files Browse the repository at this point in the history
Source hash: 55c064a
Remaining commits: 0
  • Loading branch information
cmmarslender authored Dec 22, 2023
2 parents 04f0a74 + ae193db commit 739b0d8
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 19 deletions.
4 changes: 2 additions & 2 deletions chia/consensus/default_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
"3d8765d3a597ec1d99663f6c9816d915b9f68613ac94009884c4addaefcce6af"
),
MAX_VDF_WITNESS_SIZE=64,
# Size of mempool = 50x the size of block
MEMPOOL_BLOCK_BUFFER=50,
# Size of mempool = 10x the size of block
MEMPOOL_BLOCK_BUFFER=10,
# Max coin amount, fits into 64 bits
MAX_COIN_AMOUNT=uint64((1 << 64) - 1),
# Max block cost in clvm cost units
Expand Down
40 changes: 26 additions & 14 deletions chia/full_node/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,16 @@ class Mempool:
_block_height: uint32
_timestamp: uint64

_total_fee: int
_total_cost: int

def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterface):
self._db_conn = sqlite3.connect(":memory:")
self._items = {}
self._block_height = uint32(0)
self._timestamp = uint64(0)
self._total_fee = 0
self._total_cost = 0

with self._db_conn:
# name means SpendBundle hash
Expand All @@ -75,8 +80,6 @@ def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterfa
"""
)
self._db_conn.execute("CREATE INDEX name_idx ON tx(name)")
self._db_conn.execute("CREATE INDEX fee_sum ON tx(fee)")
self._db_conn.execute("CREATE INDEX cost_sum ON tx(cost)")
self._db_conn.execute("CREATE INDEX feerate ON tx(fee_per_cost)")
self._db_conn.execute(
"CREATE INDEX assert_before ON tx(assert_before_height, assert_before_seconds) "
Expand Down Expand Up @@ -121,16 +124,10 @@ def _row_to_item(self, row: sqlite3.Row) -> MempoolItem:
)

def total_mempool_fees(self) -> int:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(fee) FROM tx")
val = cursor.fetchone()[0]
return uint64(0) if val is None else uint64(val)
return self._total_fee

def total_mempool_cost(self) -> CLVMCost:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(cost) FROM tx")
val = cursor.fetchone()[0]
return CLVMCost(uint64(0) if val is None else uint64(val))
return CLVMCost(uint64(self._total_cost))

def all_items(self) -> Iterator[MempoolItem]:
with self._db_conn:
Expand Down Expand Up @@ -193,7 +190,7 @@ def get_min_fee_rate(self, cost: int) -> Optional[float]:
return 0

# TODO: make MempoolItem.cost be CLVMCost
current_cost = int(self.total_mempool_cost())
current_cost = self._total_cost

# Iterates through all spends in increasing fee per cost
with self._db_conn:
Expand Down Expand Up @@ -256,9 +253,19 @@ def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) ->
for batch in to_batches(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(batch.entries))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT SUM(cost), SUM(fee) FROM tx WHERE name in ({args})", batch.entries
)
cost_to_remove, fee_to_remove = cursor.fetchone()

self._db_conn.execute(f"DELETE FROM tx WHERE name in ({args})", batch.entries)
self._db_conn.execute(f"DELETE FROM spends WHERE tx in ({args})", batch.entries)

self._total_cost -= cost_to_remove
self._total_fee -= fee_to_remove
assert self._total_cost >= 0
assert self._total_fee >= 0

if reason != MempoolRemoveReason.BLOCK_INCLUSION:
info = FeeMempoolInfo(
self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now()
Expand Down Expand Up @@ -311,11 +318,12 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
if fee_per_cost > item.fee_per_cost:
return Err.INVALID_FEE_LOW_FEE
to_remove.append(name)

self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED)

# if we don't find any entries, it's OK to add this entry

total_cost = int(self.total_mempool_cost())
if total_cost + item.cost > self.mempool_info.max_size_in_cost:
if self._total_cost + item.cost > self.mempool_info.max_size_in_cost:
# pick the items with the lowest fee per cost to remove
cursor = self._db_conn.execute(
"""SELECT name FROM tx
Expand All @@ -329,6 +337,7 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
(self.mempool_info.max_size_in_cost - item.cost,),
)
to_remove = [bytes32(row[0]) for row in cursor]

self.remove_from_pool(to_remove, MempoolRemoveReason.POOL_FULL)

# TODO: In the future, for the "fee_per_cost" field, opt for
Expand All @@ -355,6 +364,9 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]:
item.spend_bundle, item.npc_result, item.height_added_to_mempool, item.bundle_coin_spends
)

self._total_cost += item.cost
self._total_fee += item.fee

info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now())
self.fee_estimator.add_mempool_item(info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
return None
Expand All @@ -364,7 +376,7 @@ def at_full_capacity(self, cost: int) -> bool:
Checks whether the mempool is at full capacity and cannot accept a transaction with size cost.
"""

return self.total_mempool_cost() + cost > self.mempool_info.max_size_in_cost
return self._total_cost + cost > self.mempool_info.max_size_in_cost

def create_bundle_from_mempool_items(
self, item_inclusion_filter: Callable[[bytes32], bool]
Expand Down
6 changes: 6 additions & 0 deletions chia/full_node/mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,12 @@ async def new_peak(
spendbundle_ids_to_remove.add(item.name)
self.mempool.remove_from_pool(list(spendbundle_ids_to_remove), MempoolRemoveReason.BLOCK_INCLUSION)
else:
log.warning(
"updating the mempool using the slow-path. "
f"peak: {self.peak.header_hash} "
f"new-peak-prev: {new_peak.prev_transaction_block_hash} "
f"coins: {'not set' if spent_coins is None else 'set'}"
)
old_pool = self.mempool
self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.seen_bundle_hashes = {}
Expand Down
28 changes: 26 additions & 2 deletions tests/core/mempool/test_mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
spend_bundle_from_conditions,
)
from tests.core.node_height import node_height_at_least
from tests.util.misc import BenchmarkRunner
from tests.util.misc import BenchmarkRunner, invariant_check_mempool
from tests.util.time_out_assert import time_out_assert

BURN_PUZZLE_HASH = bytes32(b"0" * 32)
Expand Down Expand Up @@ -335,7 +335,9 @@ async def respond_transaction(
self.full_node.full_node_store.pending_tx_request.pop(spend_name)
if spend_name in self.full_node.full_node_store.peers_with_tx:
self.full_node.full_node_store.peers_with_tx.pop(spend_name)
return await self.full_node.add_transaction(tx.transaction, spend_name, peer, test)
ret = await self.full_node.add_transaction(tx.transaction, spend_name, peer, test)
invariant_check_mempool(self.full_node.mempool_manager.mempool)
return ret


async def next_block(full_node_1, wallet_a, bt) -> Coin:
Expand Down Expand Up @@ -579,6 +581,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
)
peer = await connect_and_get_peer(server_1, server_2, self_hostname)

invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)
for block in blocks:
await full_node_1.full_node.add_block(block)
await time_out_assert(60, node_height_at_least, True, full_node_1, start_height + 3)
Expand All @@ -594,12 +597,14 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# Fee increase is insufficient, the old spendbundle must stay
self.assert_sb_in_pool(full_node_1, sb1_1)
self.assert_sb_not_in_pool(full_node_1, sb1_2)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb1_3 = await self.gen_and_send_sb(full_node_1, peer, wallet_a, coin1, fee=MEMPOOL_MIN_FEE_INCREASE)

# Fee increase is sufficiently high, sb1_1 gets replaced with sb1_3
self.assert_sb_not_in_pool(full_node_1, sb1_1)
self.assert_sb_in_pool(full_node_1, sb1_3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb2 = generate_test_spend_bundle(wallet_a, coin2, fee=MEMPOOL_MIN_FEE_INCREASE)
sb12 = SpendBundle.aggregate((sb2, sb1_3))
Expand All @@ -609,6 +614,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# of coins spent in sb1_3
self.assert_sb_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb1_3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb3 = generate_test_spend_bundle(wallet_a, coin3, fee=uint64(MEMPOOL_MIN_FEE_INCREASE * 2))
sb23 = SpendBundle.aggregate((sb2, sb3))
Expand All @@ -618,16 +624,19 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
# coins that are spent in the latter (specifically, coin1)
self.assert_sb_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb23)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

await self.send_sb(full_node_1, sb3)
# Adding non-conflicting sb3 should succeed
self.assert_sb_in_pool(full_node_1, sb3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb4_1 = generate_test_spend_bundle(wallet_a, coin4, fee=MEMPOOL_MIN_FEE_INCREASE)
sb1234_1 = SpendBundle.aggregate((sb12, sb3, sb4_1))
await self.send_sb(full_node_1, sb1234_1)
# sb1234_1 should not be in pool as it decreases total fees per cost
self.assert_sb_not_in_pool(full_node_1, sb1234_1)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

sb4_2 = generate_test_spend_bundle(wallet_a, coin4, fee=uint64(MEMPOOL_MIN_FEE_INCREASE * 2))
sb1234_2 = SpendBundle.aggregate((sb12, sb3, sb4_2))
Expand All @@ -637,6 +646,7 @@ async def test_double_spend_with_higher_fee(self, two_nodes_one_block, wallet_a,
self.assert_sb_in_pool(full_node_1, sb1234_2)
self.assert_sb_not_in_pool(full_node_1, sb12)
self.assert_sb_not_in_pool(full_node_1, sb3)
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

@pytest.mark.anyio
async def test_invalid_signature(self, one_node_one_block, wallet_a):
Expand Down Expand Up @@ -668,6 +678,7 @@ async def test_invalid_signature(self, one_node_one_block, wallet_a):
ack: TransactionAck = TransactionAck.from_bytes(res.data)
assert ack.status == MempoolInclusionStatus.FAILED.value
assert ack.error == Err.BAD_AGGREGATE_SIGNATURE.name
invariant_check_mempool(full_node_1.full_node.mempool_manager.mempool)

async def condition_tester(
self,
Expand Down Expand Up @@ -2763,13 +2774,16 @@ def test_full_mempool(items: List[int], add: int, expected: List[int]) -> None:
CLVMCost(uint64(100)),
)
mempool = Mempool(mempool_info, fee_estimator)
invariant_check_mempool(mempool)
fee_rate: float = 3.0
for i in items:
mempool.add_to_pool(item_cost(i, fee_rate))
fee_rate -= 0.1
invariant_check_mempool(mempool)

# now, add the item we're testing
mempool.add_to_pool(item_cost(add, 3.1))
invariant_check_mempool(mempool)

ordered_items = list(mempool.items_by_feerate())

Expand Down Expand Up @@ -2808,12 +2822,14 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
)
mempool = Mempool(mempool_info, fee_estimator)
mempool.new_tx_block(uint32(10), uint64(100000))
invariant_check_mempool(mempool)

# fill the mempool with regular transactions (without expiration)
fee_rate: float = 3.0
for i in range(1, 20):
mempool.add_to_pool(item_cost(i, fee_rate))
fee_rate -= 0.1
invariant_check_mempool(mempool)

# now add the expiring transactions from the test case
fee_rate = 2.7
Expand All @@ -2825,6 +2841,7 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
ret = mempool.add_to_pool(mk_item([coin], cost=cost, fee=int(cost * fee_rate), assert_before_height=15))
else:
ret = mempool.add_to_pool(mk_item([coin], cost=cost, fee=int(cost * fee_rate), assert_before_seconds=10400))
invariant_check_mempool(mempool)
if increase_fee:
fee_rate += 0.1
assert ret is None
Expand All @@ -2848,6 +2865,7 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
print(f"- cost: {item.cost} TTL: {ttl}")

assert mempool.total_mempool_cost() > 90
invariant_check_mempool(mempool)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -2883,6 +2901,7 @@ def test_get_items_by_coin_ids(items: List[MempoolItem], coin_ids: List[bytes32]
mempool = Mempool(mempool_info, fee_estimator)
for i in items:
mempool.add_to_pool(i)
invariant_check_mempool(mempool)
result = mempool.get_items_by_coin_ids(coin_ids)
assert set(result) == set(expected)

Expand All @@ -2905,6 +2924,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
sb = SpendBundle.aggregate(spend_bundles)
mi = mempool_item_from_spendbundle(sb)
mempool.add_to_pool(mi)
invariant_check_mempool(mempool)
saved_cost = run_for_cost(
sb.coin_spends[0].puzzle_reveal, sb.coin_spends[0].solution, len(mi.additions), mi.cost
)
Expand All @@ -2925,9 +2945,11 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
highest_fee = 58282830
sb_high_rate = make_test_spendbundle(coins[1], fee=highest_fee)
agg_and_add_sb_returning_cost_info(mempool, [sb_A, sb_high_rate])
invariant_check_mempool(mempool)
# Create a ~2 FPC item that spends the eligible coin using the same solution A
sb_low_rate = make_test_spendbundle(coins[2], fee=highest_fee // 5)
saved_cost_on_solution_A = agg_and_add_sb_returning_cost_info(mempool, [sb_A, sb_low_rate])
invariant_check_mempool(mempool)
result = mempool.create_bundle_from_mempool_items(always)
assert result is not None
agg, _ = result
Expand All @@ -2941,6 +2963,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
# (which has ~10 FPC) but before sb_A2 (which has ~2 FPC)
sb_mid_rate = make_test_spendbundle(coins[i], fee=38004852 - i)
saved_cost_on_solution_B = agg_and_add_sb_returning_cost_info(mempool, [sb_B, sb_mid_rate])
invariant_check_mempool(mempool)
# We'd save more cost if we went with solution B instead of A
assert saved_cost_on_solution_B > saved_cost_on_solution_A
# If we process everything now, the 3 x ~3 FPC items get skipped because
Expand All @@ -2953,6 +2976,7 @@ def agg_and_add_sb_returning_cost_info(mempool: Mempool, spend_bundles: List[Spe
# We ran with solution A and missed bigger savings on solution B
assert mempool.size() == 5
assert [c.coin for c in agg.coin_spends] == [coins[0], coins[1], coins[2]]
invariant_check_mempool(mempool)


def test_get_puzzle_and_solution_for_coin_failure():
Expand Down
8 changes: 7 additions & 1 deletion tests/core/mempool/test_mempool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from chia.wallet.wallet import Wallet
from chia.wallet.wallet_coin_record import WalletCoinRecord
from chia.wallet.wallet_node import WalletNode
from tests.util.misc import invariant_check_mempool
from tests.util.setup_nodes import OldSimulatorsAndWallets

IDENTITY_PUZZLE = SerializedProgram.to(1)
Expand Down Expand Up @@ -121,6 +122,7 @@ async def instantiate_mempool_manager(
mempool_manager = MempoolManager(get_coin_record, constants)
test_block_record = create_test_block_record(height=block_height, timestamp=block_timestamp)
await mempool_manager.new_peak(test_block_record, None)
invariant_check_mempool(mempool_manager.mempool)
return mempool_manager


Expand Down Expand Up @@ -348,7 +350,9 @@ async def add_spendbundle(
mempool_manager: MempoolManager, sb: SpendBundle, sb_name: bytes32
) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]:
npc_result = await mempool_manager.pre_validate_spendbundle(sb, None, sb_name)
return await mempool_manager.add_spend_bundle(sb, npc_result, sb_name, TEST_HEIGHT)
ret = await mempool_manager.add_spend_bundle(sb, npc_result, sb_name, TEST_HEIGHT)
invariant_check_mempool(mempool_manager.mempool)
return ret


async def generate_and_add_spendbundle(
Expand Down Expand Up @@ -1033,6 +1037,7 @@ async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:

block_record = create_test_block_record(height=uint32(11), timestamp=uint64(10019))
await mempool_manager.new_peak(block_record, None)
invariant_check_mempool(mempool_manager.mempool)

still_in_pool = mempool_manager.get_spendbundle(bundle_name) == bundle
assert still_in_pool != expect_eviction
Expand Down Expand Up @@ -1385,6 +1390,7 @@ async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
test_coin_records = {coin_id: CoinRecord(coin, uint32(0), TEST_HEIGHT, False, uint64(0))}
block_record = create_test_block_record(height=new_height)
await mempool_manager.new_peak(block_record, [coin_id])
invariant_check_mempool(mempool_manager.mempool)
# As the coin was a spend in all the mempool items we had, nothing should be left now
assert len(mempool_manager.mempool.get_items_by_coin_id(coin_id)) == 0
assert mempool_manager.mempool.size() == 0
Expand Down
17 changes: 17 additions & 0 deletions tests/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing_extensions import Protocol, final

import chia
from chia.full_node.mempool import Mempool
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.condition_opcodes import ConditionOpcode
from chia.util.hash import std_hash
Expand Down Expand Up @@ -407,3 +408,19 @@ def create_logger(file: TextIO = sys.stdout) -> logging.Logger:
logger.addHandler(hdlr=stream_handler)

return logger


def invariant_check_mempool(mempool: Mempool) -> None:
with mempool._db_conn:
cursor = mempool._db_conn.execute("SELECT SUM(cost) FROM tx")
val = cursor.fetchone()[0]
if val is None:
val = 0
assert mempool._total_cost == val

with mempool._db_conn:
cursor = mempool._db_conn.execute("SELECT SUM(fee) FROM tx")
val = cursor.fetchone()[0]
if val is None:
val = 0
assert mempool._total_fee == val

0 comments on commit 739b0d8

Please sign in to comment.