Skip to content

Commit

Permalink
pylightning: provide a class for Lightning JSONDecoder.
Browse files Browse the repository at this point in the history
Some JSON functions want a *class*, not just a hook, so provide one.
To make it clear that we want an encoding *class* and a decoding *object*,
rename the UnixDomainSocketRpc encode parameter to encode_cls.

Signed-off-by: Rusty Russell <[email protected]>
  • Loading branch information
rustyrussell committed Feb 25, 2019
1 parent 4648588 commit 5a7d038
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions contrib/pylightning/lightning/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __mod__(self, other):


class UnixDomainSocketRpc(object):
def __init__(self, socket_path, executor=None, logger=logging, encoder=json.JSONEncoder, decoder=json.JSONDecoder):
def __init__(self, socket_path, executor=None, logger=logging, encoder_cls=json.JSONEncoder, decoder=json.JSONDecoder()):
self.socket_path = socket_path
self.encoder = encoder
self.encoder_cls = encoder_cls
self.decoder = decoder
self.executor = executor
self.logger = logger
Expand All @@ -133,7 +133,7 @@ def __init__(self, socket_path, executor=None, logger=logging, encoder=json.JSON
self.next_id = 0

def _writeobj(self, sock, obj):
s = json.dumps(obj, cls=self.encoder)
s = json.dumps(obj, cls=self.encoder_cls)
sock.sendall(bytearray(s, 'UTF-8'))

def _readobj_compat(self, sock, buff=b''):
Expand Down Expand Up @@ -245,32 +245,39 @@ def default(self, o):
pass
return json.JSONEncoder.default(self, o)

@staticmethod
def lightning_json_hook(json_object):
return json_object

@staticmethod
def replace_amounts(obj):
"""
Recursively replace _msat fields with appropriate values with Millisatoshi.
"""
if isinstance(obj, dict):
for k, v in obj.items():
if k.endswith('msat'):
if isinstance(v, str) and v.endswith('msat'):
obj[k] = Millisatoshi(v)
# Special case for array of msat values
elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v):
obj[k] = [Millisatoshi(e) for e in v]
else:
obj[k] = LightningRpc.replace_amounts(v)
elif isinstance(obj, list):
obj = [LightningRpc.replace_amounts(e) for e in obj]

return obj
class LightningJSONDecoder(json.JSONDecoder):
def __init__(self, *, object_hook=None, parse_float=None, parse_int=None, parse_constant=None, strict=True, object_pairs_hook=None):
self.object_hook_next = object_hook
super().__init__(object_hook=self.millisatoshi_hook, parse_float=parse_float, parse_int=parse_int, parse_constant=parse_constant, strict=strict, object_pairs_hook=object_pairs_hook)

@staticmethod
def replace_amounts(obj):
"""
Recursively replace _msat fields with appropriate values with Millisatoshi.
"""
if isinstance(obj, dict):
for k, v in obj.items():
if k.endswith('msat'):
if isinstance(v, str) and v.endswith('msat'):
obj[k] = Millisatoshi(v)
# Special case for array of msat values
elif isinstance(v, list) and all(isinstance(e, str) and e.endswith('msat') for e in v):
obj[k] = [Millisatoshi(e) for e in v]
else:
obj[k] = LightningRpc.LightningJSONDecoder.replace_amounts(v)
elif isinstance(obj, list):
obj = [LightningRpc.LightningJSONDecoder.replace_amounts(e) for e in obj]

return obj

def millisatoshi_hook(self, obj):
obj = LightningRpc.LightningJSONDecoder.replace_amounts(obj)
if self.object_hook_next:
obj = self.object_hook_next(obj)
return obj

def __init__(self, socket_path, executor=None, logger=logging):
super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, json.JSONDecoder(object_hook=self.replace_amounts))
super().__init__(socket_path, executor, logging, self.LightningJSONEncoder, self.LightningJSONDecoder())

def getpeer(self, peer_id, level=None):
"""
Expand Down

0 comments on commit 5a7d038

Please sign in to comment.