Skip to content

Commit

Permalink
chia|tests: Add node_type parameter to get_connections RPC (Chia-…
Browse files Browse the repository at this point in the history
…Network#7492)

* rpc|server: Add Optional `node_type` paramter to `get_connections`

* farmer: Query `HARVESTER` connections only

* tests: Basic test for `node_type` parameter
  • Loading branch information
xdustinface authored Jul 15, 2021
1 parent 1ea65f0 commit 1b196e6
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
8 changes: 2 additions & 6 deletions chia/farmer/farmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,9 +579,7 @@ async def update_cached_harvesters(self):
for key in remove_hosts:
del self.harvester_cache[key]
# Now query each harvester and update caches
for connection in self.server.get_connections():
if connection.connection_type != NodeType.HARVESTER:
continue
for connection in self.server.get_connections(NodeType.HARVESTER):
cache_entry = await self.get_cached_harvesters(connection)
if cache_entry.needs_update():
self.log.debug(f"update_cached_harvesters update harvester: {connection.peer_node_id}")
Expand Down Expand Up @@ -615,9 +613,7 @@ async def get_cached_harvesters(self, connection: WSChiaConnection) -> Harvester

async def get_harvesters(self) -> Dict:
harvesters: List = []
for connection in self.server.get_connections():
if connection.connection_type != NodeType.HARVESTER:
continue
for connection in self.server.get_connections(NodeType.HARVESTER):
self.log.debug(f"get_harvesters host: {connection.peer_host}, node_id: {connection.peer_node_id}")
cache_entry = await self.get_cached_harvesters(connection)
if cache_entry.data is not None:
Expand Down
9 changes: 6 additions & 3 deletions chia/rpc/rpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import aiohttp

from chia.server.server import ssl_context_for_client
from chia.server.server import NodeType, ssl_context_for_client
from chia.server.ssl_context import private_ssl_ca_paths
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
Expand Down Expand Up @@ -45,8 +45,11 @@ async def fetch(self, path, request_json) -> Any:
raise ValueError(res_json)
return res_json

async def get_connections(self) -> List[Dict]:
response = await self.fetch("get_connections", {})
async def get_connections(self, node_type: Optional[NodeType] = None) -> List[Dict]:
request = {}
if node_type is not None:
request["node_type"] = node_type.value
response = await self.fetch("get_connections", request)
for connection in response["connections"]:
connection["node_id"] = hexstr_to_bytes(connection["node_id"])
return response["connections"]
Expand Down
7 changes: 5 additions & 2 deletions chia/rpc/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,14 @@ async def inner(request) -> aiohttp.web.Response:
return inner

async def get_connections(self, request: Dict) -> Dict:
request_node_type: Optional[NodeType] = None
if "node_type" in request:
request_node_type = NodeType(request["node_type"])
if self.rpc_api.service.server is None:
raise ValueError("Global connections is not set")
if self.rpc_api.service.server._local_type is NodeType.FULL_NODE:
# TODO add peaks for peers
connections = self.rpc_api.service.server.get_connections()
connections = self.rpc_api.service.server.get_connections(request_node_type)
con_info = []
if self.rpc_api.service.sync_store is not None:
peak_store = self.rpc_api.service.sync_store.peer_to_peak
Expand Down Expand Up @@ -130,7 +133,7 @@ async def get_connections(self, request: Dict) -> Dict:
}
con_info.append(con_dict)
else:
connections = self.rpc_api.service.server.get_connections()
connections = self.rpc_api.service.server.get_connections(request_node_type)
con_info = [
{
"type": con.connection_type,
Expand Down
5 changes: 3 additions & 2 deletions chia/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,11 @@ def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
def get_full_node_connections(self) -> List[WSChiaConnection]:
return list(self.connection_by_type[NodeType.FULL_NODE].values())

def get_connections(self) -> List[WSChiaConnection]:
def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]:
result = []
for _, connection in self.all_connections.items():
result.append(connection)
if node_type is None or connection.connection_type == node_type:
result.append(connection)
return result

async def close_all_connections(self) -> None:
Expand Down
6 changes: 4 additions & 2 deletions tests/core/test_full_node_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from chia.protocols import full_node_protocol
from chia.rpc.full_node_rpc_api import FullNodeRpcApi
from chia.rpc.full_node_rpc_client import FullNodeRpcClient
from chia.rpc.rpc_server import start_rpc_server
from chia.rpc.rpc_server import NodeType, start_rpc_server
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.spend_bundle import SpendBundle
from chia.types.unfinished_block import UnfinishedBlock
Expand Down Expand Up @@ -192,7 +192,9 @@ async def num_connections():

await time_out_assert(10, num_connections, 1)
connections = await client.get_connections()

assert NodeType(connections[0]["type"]) == NodeType.FULL_NODE.value
assert len(await client.get_connections(NodeType.FULL_NODE)) == 1
assert len(await client.get_connections(NodeType.FARMER)) == 0
await client.close_connection(connections[0]["node_id"])
await time_out_assert(10, num_connections, 0)
finally:
Expand Down

0 comments on commit 1b196e6

Please sign in to comment.