Skip to content

Commit

Permalink
pyln: Add type annotations to lightning.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cdecker authored and rustyrussell committed Sep 23, 2020
1 parent 49ec800 commit 8ecb157
Showing 1 changed file with 64 additions and 47 deletions.
111 changes: 64 additions & 47 deletions contrib/pyln-client/pyln/client/lightning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from decimal import Decimal
from math import floor, log10
from typing import Optional, Union
import json
import logging
import os
Expand All @@ -23,9 +24,12 @@ def monkey_patch_json(patch=True):


class RpcError(ValueError):
def __init__(self, method, payload, error):
super(ValueError, self).__init__("RPC call failed: method: {}, payload: {}, error: {}"
.format(method, payload, error))
def __init__(self, method: str, payload: dict, error: str):
super(ValueError, self).__init__(
"RPC call failed: method: {}, payload: {}, error: {}".format(
method, payload, error
)
)

self.method = method
self.payload = payload
Expand All @@ -36,54 +40,61 @@ class Millisatoshi:
"""
A subtype to represent thousandths of a satoshi.
Many JSON API fields are expressed in millisatoshis: these automatically get
turned into Millisatoshi types. Converts to and from int.
Many JSON API fields are expressed in millisatoshis: these automatically
get turned into Millisatoshi types. Converts to and from int.
"""
def __init__(self, v):
def __init__(self, v: Union[int, str, Decimal]):
"""
Takes either a string ending in 'msat', 'sat', 'btc' or an integer.
"""
if isinstance(v, str):
if v.endswith("msat"):
self.millisatoshis = int(v[0:-4])
elif v.endswith("sat"):
self.millisatoshis = Decimal(v[0:-3]) * 1000
self.millisatoshis = int(v[0:-3]) * 1000
elif v.endswith("btc"):
self.millisatoshis = Decimal(v[0:-3]) * 1000 * 10**8
self.millisatoshis = int(v[0:-3]) * 1000 * 10**8
else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int")
raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or"
" int"
)
if self.millisatoshis != int(self.millisatoshis):
raise ValueError("Millisatoshi must be a whole number")
self.millisatoshis = int(self.millisatoshis)

elif isinstance(v, Millisatoshi):
self.millisatoshis = v.millisatoshis

elif int(v) == v:
self.millisatoshis = int(v)
else:
raise TypeError("Millisatoshi must be string with msat/sat/btc suffix or int")
raise TypeError(
"Millisatoshi must be string with msat/sat/btc suffix or int"
)

if self.millisatoshis < 0:
raise ValueError("Millisatoshi must be >= 0")

def __repr__(self):
def __repr__(self) -> str:
"""
Appends the 'msat' as expected for this type.
"""
return str(self.millisatoshis) + "msat"

def to_satoshi(self):
def to_satoshi(self) -> Decimal:
"""
Return a Decimal representing the number of satoshis.
"""
return Decimal(self.millisatoshis) / 1000

def to_btc(self):
def to_btc(self) -> Decimal:
"""
Return a Decimal representing the number of bitcoin.
"""
return Decimal(self.millisatoshis) / 1000 / 10**8

def to_satoshi_str(self):
def to_satoshi_str(self) -> str:
"""
Return a string of form 1234sat or 1234.567sat.
"""
Expand All @@ -92,7 +103,7 @@ def to_satoshi_str(self):
else:
return '{:.0f}sat'.format(self.to_satoshi())

def to_btc_str(self):
def to_btc_str(self) -> str:
"""
Return a string of form 12.34567890btc or 12.34567890123btc.
"""
Expand All @@ -101,13 +112,14 @@ def to_btc_str(self):
else:
return '{:.8f}btc'.format(self.to_btc())

def to_approx_str(self, digits: int = 3):
def to_approx_str(self, digits: int = 3) -> str:
"""Returns the shortmost string using common units representation.
Rounds to significant `digits`. Default: 3
"""
round_to_n = lambda x, n: round(x, -int(floor(log10(x))) + (n - 1))
result = None
def round_to_n(x: int, n: int) -> float:
return round(x, -int(floor(log10(x))) + (n - 1))
result = self.to_satoshi_str()

# we try to increase digits to check if we did loose out on precision
# without gaining a shorter string, since this is a rarely used UI
Expand All @@ -132,46 +144,51 @@ def to_approx_str(self, digits: int = 3):
else:
return result

def to_json(self):
def to_json(self) -> str:
return self.__repr__()

def __int__(self):
def __int__(self) -> int:
return self.millisatoshis

def __lt__(self, other):
def __lt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis < other.millisatoshis

def __le__(self, other):
def __le__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis <= other.millisatoshis

def __eq__(self, other):
return self.millisatoshis == other.millisatoshis
def __eq__(self, other: object) -> bool:
if isinstance(other, Millisatoshi):
return self.millisatoshis == other.millisatoshis
elif isinstance(other, int):
return self.millisatoshis == other
else:
return False

def __gt__(self, other):
def __gt__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis > other.millisatoshis

def __ge__(self, other):
def __ge__(self, other: 'Millisatoshi') -> bool:
return self.millisatoshis >= other.millisatoshis

def __add__(self, other):
def __add__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other))

def __sub__(self, other):
def __sub__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) - int(other))

def __mul__(self, other):
return Millisatoshi(int(int(self) * other))
def __mul__(self, other: int) -> 'Millisatoshi':
return Millisatoshi(self.millisatoshis * other)

def __truediv__(self, other):
return Millisatoshi(int(int(self) / other))
def __truediv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis / other))

def __floordiv__(self, other):
return Millisatoshi(int(self) // other)
def __floordiv__(self, other: Union[int, float]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis // float(other)))

def __mod__(self, other):
return Millisatoshi(int(self) % other)
def __mod__(self, other: Union[float, int]) -> 'Millisatoshi':
return Millisatoshi(int(self.millisatoshis % other))

def __radd__(self, other):
def __radd__(self, other: 'Millisatoshi') -> 'Millisatoshi':
return Millisatoshi(int(self) + int(other))


Expand All @@ -188,17 +205,17 @@ class UnixSocket(object):
"""

def __init__(self, path):
def __init__(self, path: str):
self.path = path
self.sock = None
self.sock: Optional[socket.SocketType] = None
self.connect()

def connect(self):
def connect(self) -> None:
try:
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(self.path)
self.sock.connect(self.path)
except OSError as e:
self.sock.close()
self.close()

if (e.args[0] == "AF_UNIX path too long" and os.uname()[0] == "Linux"):
# If this is a Linux system we may be able to work around this
Expand All @@ -216,29 +233,29 @@ def connect(self):
dirfd = os.open(dirname, os.O_DIRECTORY | os.O_RDONLY)
short_path = "/proc/self/fd/%d/%s" % (dirfd, basename)
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
return self.sock.connect(short_path)
self.sock.connect(short_path)
else:
# There is no good way to recover from this.
raise

def close(self):
def close(self) -> None:
if self.sock is not None:
self.sock.close()
self.sock = None

def sendall(self, b):
def sendall(self, b: bytes) -> None:
if self.sock is None:
raise socket.error("not connected")

self.sock.sendall(b)

def recv(self, length):
def recv(self, length: int) -> bytes:
if self.sock is None:
raise socket.error("not connected")

return self.sock.recv(length)

def __del__(self):
def __del__(self) -> None:
self.close()


Expand Down

0 comments on commit 8ecb157

Please sign in to comment.