Skip to content

Commit

Permalink
pyln: Add code to unwrap an encrypted onion at the intended node
Browse files Browse the repository at this point in the history
Changelog-Added: pyln-proto: Added pure python implementation of the sphinx onion creation and processing functionality.
  • Loading branch information
cdecker authored and rustyrussell committed Sep 24, 2020
1 parent e8dcd59 commit 04462f6
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 10 deletions.
111 changes: 102 additions & 9 deletions contrib/pyln-proto/pyln/proto/onion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
from hashlib import sha256
from io import BytesIO, SEEK_CUR
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
import coincurve
import io
import os
import struct

Expand Down Expand Up @@ -44,7 +45,7 @@ def from_hex(cls, s):
s = s.encode('ASCII')
return cls.from_bytes(bytes(unhexlify(s)))

def to_bytes(self):
def to_bytes(self, include_prefix):
raise ValueError("OnionPayload is an abstract class, use "
"LegacyOnionPayload or TlvPayload instead")

Expand Down Expand Up @@ -92,20 +93,20 @@ def from_bytes(cls, b):
padding = b.read(12)
return LegacyOnionPayload(a, o, s, padding)

def to_bytes(self, include_realm=True):
def to_bytes(self, include_prefix=True):
b = b''
if include_realm:
if include_prefix:
b += b'\x00'

b += struct.pack("!Q", self.short_channel_id)
b += struct.pack("!Q", self.amt_to_forward)
b += struct.pack("!L", self.outgoing_cltv_value)
b += self.padding
assert(len(b) == 32 + include_realm)
assert(len(b) == 32 + include_prefix)
return b

def to_hex(self, include_realm=True):
return hexlify(self.to_bytes(include_realm)).decode('ASCII')
def to_hex(self, include_prefix=True):
return hexlify(self.to_bytes(include_prefix)).decode('ASCII')

def __str__(self):
return ("LegacyOnionPayload[scid={self.short_channel_id}, "
Expand Down Expand Up @@ -143,6 +144,12 @@ def from_bytes(cls, b, skip_length=False):
raise ValueError(
"Unable to read length at position {}".format(b.tell())
)

elif length > start + payload_length - b.tell():
b.seek(start + payload_length)
raise ValueError("Failed to parse TLV payload: value length "
"is longer than available bytes.")

val = b.read(length)

# Get the subclass that is the correct interpretation of this
Expand All @@ -167,10 +174,11 @@ def get(self, key, default=None):
return f
return default

def to_bytes(self):
def to_bytes(self, include_prefix=True) -> bytes:
ser = [f.to_bytes() for f in self.fields]
b = BytesIO()
varint_encode(sum([len(b) for b in ser]), b)
if include_prefix:
varint_encode(sum([len(b) for b in ser]), b)
for f in ser:
b.write(f)
return b.getvalue()
Expand All @@ -179,6 +187,40 @@ def __str__(self):
return "TlvPayload[" + ', '.join([str(f) for f in self.fields]) + "]"


class RawPayload(OnionPayload):
"""A payload that doesn't deserialize correctly as TLV stream.
Mainly used if TLV parsing fails, but we still want access to the raw
payload.
"""

def __init__(self):
self.content: Optional[bytes] = None

@classmethod
def from_bytes(cls, b):
if isinstance(b, str):
b = b.encode('ASCII')
if isinstance(b, bytes):
b = BytesIO(b)

self = cls()
payload_length = varint_decode(b)
self.content = b.read(payload_length)
return self

def to_bytes(self, include_prefix=True) -> bytes:
b = BytesIO()
if self.content is None:
raise ValueError("Cannot serialize empty TLV payload")

if include_prefix:
varint_encode(len(self.content), b)
b.write(self.content)
return b.getvalue()


class TlvField(object):

def __init__(self, typenum, value=None, description=None):
Expand Down Expand Up @@ -319,6 +361,57 @@ def to_bin(self) -> bytes:
def to_hex(self):
return hexlify(self.to_bin())

def unwrap(self, privkey: PrivateKey, assocdata: Optional[bytes]) \
-> Tuple[OnionPayload, Optional['RoutingOnion']]:
shared_secret = ecdh(privkey, self.ephemeralkey)
keys = generate_keyset(shared_secret)

h = hmac.HMAC(keys.mu, hashes.SHA256(),
backend=default_backend())
h.update(self.payloads)
if assocdata is not None:
h.update(assocdata)
hh = h.finalize()

if hh != self.hmac:
raise ValueError("HMAC does not match, onion might have been "
"tampered with: {hh} != {hmac}".format(
hh=hexlify(hh).decode('ascii'),
hmac=hexlify(self.hmac).decode('ascii'),
))

# Create the scratch twice as large as the original packet, since we
# need to left-shift a single payload off, which may itself be up to
# ROUTING_INFO_SIZE in length.
payloads = bytearray(2 * ROUTING_INFO_SIZE)
payloads[:ROUTING_INFO_SIZE] = self.payloads
chacha20_stream(keys.rho, payloads)

r = io.BytesIO(payloads)
start = r.tell()

try:
payload = OnionPayload.from_bytes(r)
except ValueError:
r.seek(start)
payload = RawPayload.from_bytes(r)

next_hmac = r.read(32)
shift_size = r.tell()

if next_hmac == bytes(32):
return payload, None
else:
b = blind(self.ephemeralkey, shared_secret)
ek = blind_group_element(self.ephemeralkey, b)
payloads = payloads[shift_size:shift_size + ROUTING_INFO_SIZE]
return payload, RoutingOnion(
version=self.version,
ephemeralkey=ek,
payloads=payloads,
hmac=next_hmac,
)


KeySet = namedtuple('KeySet', ['rho', 'mu', 'um', 'pad', 'gamma', 'pi'])

Expand Down
20 changes: 19 additions & 1 deletion contrib/pyln-proto/tests/test_onion.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_legacy_payload():
b'00000067000001000100000000000003e800000075000000000000000000000000'
)
payload = onion.OnionPayload.from_bytes(legacy)
assert(payload.to_bytes(include_realm=True) == legacy)
assert(payload.to_bytes(include_prefix=True) == legacy)


def test_tlv_payload():
Expand Down Expand Up @@ -325,3 +325,21 @@ def test_sphinx_path_compile():
o = sp.compile()

assert(o.to_bin() == unhexlify(v['onion']))


def test_unwrap():
f = 'tests/vectors/onion-test-multi-frame.json'
sp, v = sphinx_path_from_test_vector(f)
o = onion.RoutingOnion.from_hex(v['onion'])
assocdata = unhexlify(v['generate']['associated_data'])
privkeys = [onion.PrivateKey(unhexlify(h)) for h in v['decode']]

for pk, h in zip(privkeys, v['generate']['hops']):
pl, o = o.unwrap(pk, assocdata=assocdata)

b = hexlify(pl.to_bytes(include_prefix=False))
if h['type'] == 'legacy':
assert(b == h['payload'].encode('ascii') + b'00' * 12)
else:
assert(b == h['payload'].encode('ascii'))
assert(o is None)

0 comments on commit 04462f6

Please sign in to comment.