Skip to content

Commit 6df0a92

Browse files
cdeckerrustyrussell
authored andcommitted
pyln-testing: Add a couple of methods used in tests
These are the most used methods in tests, so we can start getting our test coverage up.
1 parent ca8c46c commit 6df0a92

File tree

1 file changed

+151
-13
lines changed
  • contrib/pyln-testing/pyln/testing

1 file changed

+151
-13
lines changed

contrib/pyln-testing/pyln/testing/grpc.py

+151-13
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""A drop-in replacement for the JSON-RPC LightningRpc
22
"""
33

4-
from pyln.testing import node_pb2_grpc as pbgrpc
5-
from pyln.testing import node_pb2 as pb
4+
import logging
5+
from binascii import unhexlify
6+
from typing import List, Optional, Tuple
7+
68
import grpc
7-
import json
8-
from google.protobuf.json_format import MessageToJson
99
from pyln.testing import grpc2py
10-
10+
from pyln.testing import node_pb2 as pb
11+
from pyln.testing import node_pb2_grpc as pbgrpc
12+
from pyln.testing import primitives_pb2 as primpb
1113

1214
DUMMY_CA_PEM = b"""-----BEGIN CERTIFICATE-----
1315
MIIBcTCCARigAwIBAgIJAJhah1bqO05cMAoGCCqGSM49BAMCMBYxFDASBgNVBAMM
@@ -46,6 +48,26 @@
4648
-----END CERTIFICATE-----"""
4749

4850

51+
def int2msat(amount: int) -> primpb.Amount:
52+
return primpb.Amount(msat=amount)
53+
54+
55+
def int2amount_or_all(amount: Tuple[int, str]) -> primpb.AmountOrAll:
56+
if amount == "all":
57+
return primpb.AmountOrAll(all=True)
58+
else:
59+
assert isinstance(amount, int)
60+
return primpb.AmountOrAll(amount=int2msat(amount))
61+
62+
63+
def int2amount_or_any(amount: Tuple[int, str]) -> primpb.AmountOrAny:
64+
if amount == "any":
65+
return primpb.AmountOrAny(any=True)
66+
else:
67+
assert isinstance(amount, int)
68+
return primpb.AmountOrAny(amount=int2msat(amount))
69+
70+
4971
class LightningGrpc(object):
5072
def __init__(
5173
self,
@@ -55,11 +77,13 @@ def __init__(
5577
private_key: bytes = DUMMY_CLIENT_KEY_PEM,
5678
certificate_chain: bytes = DUMMY_CLIENT_PEM,
5779
):
80+
self.logger = logging.getLogger("LightningGrpc")
5881
self.credentials = grpc.ssl_channel_credentials(
5982
root_certificates=root_certificates,
6083
private_key=private_key,
6184
certificate_chain=certificate_chain,
6285
)
86+
self.logger.debug(f"Connecting to grpc interface at {host}:{port}")
6387
self.channel = grpc.secure_channel(
6488
f"{host}:{port}",
6589
self.credentials,
@@ -68,17 +92,131 @@ def __init__(
6892
self.stub = pbgrpc.NodeStub(self.channel)
6993

7094
def getinfo(self):
71-
return grpc2py.getinfo2py(
72-
self.stub.Getinfo(pb.GetinfoRequest())
73-
)
95+
return grpc2py.getinfo2py(self.stub.Getinfo(pb.GetinfoRequest()))
7496

7597
def connect(self, peer_id, host=None, port=None):
7698
"""
7799
Connect to {peer_id} at {host} and {port}.
78100
"""
79-
payload = pb.ConnectRequest(
80-
id=peer_id,
81-
host=host,
82-
port=port
83-
)
101+
payload = pb.ConnectRequest(id=peer_id, host=host, port=port)
84102
return grpc2py.connect2py(self.stub.ConnectPeer(payload))
103+
104+
def listpeers(self, peerid=None, level=None):
105+
payload = pb.ListpeersRequest(
106+
id=unhexlify(peerid) if peerid is not None else None,
107+
level=level,
108+
)
109+
return grpc2py.listpeers2py(self.stub.ListPeers(payload))
110+
111+
def getpeer(self, peer_id, level=None):
112+
"""
113+
Show peer with {peer_id}, if {level} is set, include {log}s.
114+
"""
115+
res = self.listpeers(peer_id, level)
116+
return res.get("peers") and res["peers"][0] or None
117+
118+
def newaddr(self, addresstype=None):
119+
"""Get a new address of type {addresstype} of the internal wallet."""
120+
enum = {
121+
None: 0,
122+
"BECH32": 0,
123+
"P2SH_SEGWIT": 1,
124+
"P2SH-SEGWIT": 1,
125+
"ALL": 2
126+
}
127+
if addresstype is not None:
128+
addresstype = addresstype.upper()
129+
atype = enum.get(addresstype, None)
130+
if atype is None:
131+
raise ValueError(
132+
f"Unknown addresstype {addresstype}, known values are {enum.values()}"
133+
)
134+
135+
payload = pb.NewaddrRequest(addresstype=atype)
136+
res = grpc2py.newaddr2py(self.stub.NewAddr(payload))
137+
138+
# Need to remap the bloody spelling of p2sh-segwit to match
139+
# addresstype.
140+
if 'p2sh_segwit' in res:
141+
res['p2sh-segwit'] = res['p2sh_segwit']
142+
del res['p2sh_segwit']
143+
return res
144+
145+
def listfunds(self, spent=None):
146+
payload = pb.ListfundsRequest(spent=spent)
147+
return grpc2py.listfunds2py(self.stub.ListFunds(payload))
148+
149+
def fundchannel(
150+
self,
151+
node_id: str,
152+
amount: int,
153+
# TODO map the following arguments
154+
# feerate=None,
155+
announce: Optional[bool] = True,
156+
minconf: Optional[int] = None,
157+
# utxos=None,
158+
# push_msat=None,
159+
close_to: Optional[str] = None,
160+
# request_amt=None,
161+
compact_lease: Optional[str] = None,
162+
):
163+
payload = pb.FundchannelRequest(
164+
id=unhexlify(node_id),
165+
amount=int2amount_or_all(amount * 1000), # This is satoshis after all
166+
# TODO Parse and insert `feerate`
167+
announce=announce,
168+
utxos=None,
169+
minconf=minconf,
170+
close_to=close_to,
171+
compact_lease=compact_lease,
172+
)
173+
return grpc2py.fundchannel2py(self.stub.FundChannel(payload))
174+
175+
def listchannels(self, short_channel_id=None, source=None, destination=None):
176+
payload = pb.ListchannelsRequest(
177+
short_channel_id=short_channel_id,
178+
source=unhexlify(source) if source else None,
179+
destination=unhexlify(destination) if destination else None,
180+
)
181+
return grpc2py.listchannels2py(self.stub.ListChannels(payload))
182+
183+
def pay(
184+
self,
185+
bolt11: str,
186+
amount_msat: Optional[int] = None,
187+
label: Optional[str] = None,
188+
riskfactor: Optional[float] = None,
189+
maxfeepercent: Optional[float] = None,
190+
retry_for: Optional[int] = None,
191+
maxdelay: Optional[int] = None,
192+
exemptfee: Optional[int] = None,
193+
localofferid: Optional[str] = None,
194+
# TODO map the following arguments
195+
# exclude: Optional[List[str]] = None,
196+
# maxfee=None,
197+
description: Optional[str] = None,
198+
msatoshi: Optional[int] = None,
199+
):
200+
payload = pb.PayRequest(
201+
bolt11=bolt11,
202+
amount_msat=int2msat(amount_msat),
203+
label=label,
204+
riskfactor=riskfactor,
205+
maxfeepercent=maxfeepercent,
206+
retry_for=retry_for,
207+
maxdelay=maxdelay,
208+
exemptfee=exemptfee,
209+
localofferid=localofferid,
210+
# Needs conversion
211+
# exclude=exclude,
212+
# maxfee=maxfee
213+
description=description,
214+
)
215+
return grpc2py.pay2py(self.stub.Pay(payload))
216+
217+
def stop(self):
218+
payload = pb.StopRequest()
219+
try:
220+
self.stub.Stop(payload)
221+
except Exception:
222+
pass

0 commit comments

Comments
 (0)