Skip to content

Commit

Permalink
tlv: use var_ints for size of messages
Browse files Browse the repository at this point in the history
TLV's use var_int's for messages sizes, both internally and
in the top level (you should really stack a var_int inside a var_int!!)

this updates our automagick generator code to understand 'var_ints'
  • Loading branch information
niftynei authored and rustyrussell committed Apr 3, 2019
1 parent 74ae9f0 commit bad0ac6
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions tools/generate-wire.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
'u32': 4,
'u16': 2,
'u8': 1,
'bool': 1
'bool': 1,
'var_int': 8,
}

# These struct array helpers require a context to allocate from.
Expand All @@ -46,8 +47,11 @@ class FieldType(object):
def __init__(self, name):
self.name = name

def is_var_int(self):
return self.name == 'var_int'

def is_assignable(self):
return self.name in ['u8', 'u16', 'u32', 'u64', 'bool', 'struct amount_msat', 'struct amount_sat'] or self.name.startswith('enum ')
return self.name in ['u8', 'u16', 'u32', 'u64', 'bool', 'struct amount_msat', 'struct amount_sat', 'var_int'] or self.name.startswith('enum ')

# We only accelerate the u8 case: it's common and trivial.
def has_array_helper(self):
Expand All @@ -59,6 +63,8 @@ def base(self):
basetype = basetype[7:]
elif basetype.startswith('enum '):
basetype = basetype[5:]
elif self.name == 'var_int':
return 'u64'
return basetype

# Returns base size
Expand Down Expand Up @@ -160,8 +166,12 @@ def __init__(self, message, name, size, comments, prevname, includes):

# Bolts use just a number: Guess type based on size.
if options.bolt:
base_size = int(size)
self.fieldtype = Field._guess_type(message, self.name, base_size)
if size == 'var_int':
base_size = 8
self.fieldtype = FieldType(size)
else:
base_size = int(size)
self.fieldtype = Field._guess_type(message, self.name, base_size)
# There are some arrays which we have to guess, based on sizes.
tsize = FieldType._typesize(self.fieldtype.name)
if base_size % tsize != 0:
Expand Down Expand Up @@ -352,8 +362,8 @@ def checkLenField(self, field):
return
for f in self.fields:
if f.name == field.lenvar:
if f.fieldtype.name != 'u16' and options.bolt:
raise ValueError('Field {} has non-u16 length variable {} (type {})'
if not (f.fieldtype.name == 'u16' or f.fieldtype.name == 'var_int') and options.bolt:
raise ValueError('Field {} has non-u16 and non-var_int length variable {} (type {})'
.format(field.name, field.lenvar, f.fieldtype.name))

if f.is_array() or f.needs_ptr_to_ptr():
Expand Down Expand Up @@ -399,7 +409,7 @@ def print_tlv_fromwire(self, tlv_name):
to populate, instead of fields, as well as a length to read in
"""
ctx_arg = 'const tal_t *ctx, ' if self.has_variable_fields else ''
args = 'const u8 **cursor, size_t *plen, const u16 len, struct tlv_msg_{name} *{name}'.format(name=self.name)
args = 'const u8 **cursor, size_t *plen, const u64 len, struct tlv_msg_{name} *{name}'.format(name=self.name)
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
subcalls = CCode()
for f in self.fields:
Expand Down Expand Up @@ -482,11 +492,13 @@ def print_fromwire(self, is_header, tlv_name):
args.append(', {} {}{}'.format(f.fieldtype.name, ptrs, f.name))

template = fromwire_header_templ if is_header else fromwire_impl_templ
fields = ['\t{} {};\n'.format(f.fieldtype.name, f.name) for f in self.fields if f.is_len_var]
fields = ['\t{} {};\n'.format(f.fieldtype.base(), f.name) for f in self.fields if f.is_len_var]

subcalls = CCode()
for f in self.fields:
basetype = f.fieldtype.base()
if f.fieldtype.is_var_int():
basetype = 'var_int'

for c in f.comments:
subcalls.append('/*{} */'.format(c))
Expand Down Expand Up @@ -636,7 +648,7 @@ def print_towire(self, is_header, tlv_name):
if f.is_len_var:
if f.lenvar_for.is_tlv:
# used below...
field_decls.append('\t{0} {1};'.format(f.fieldtype.name, f.name))
field_decls.append('\t{0} {1};'.format(f.fieldtype.base(), f.name))
else:
field_decls.append('\t{0} {1} = tal_count({2});'.format(
f.fieldtype.name, f.name, f.lenvar_for.name
Expand Down Expand Up @@ -833,20 +845,21 @@ def print_struct(self):
\t\ttowire_u8(p, {enum});
\t\ttowire_{tlv_name}_{name}(&tlv_msg, {tlv_name}->{name});
\t\tmsg_len = tal_count(tlv_msg);
\t\ttowire_u8(p, msg_len);
\t\ttowire_var_int(p, msg_len);
\t\ttowire_u8_array(p, tlv_msg, msg_len);
\t\ttal_free(tlv_msg);
\t}}
"""

tlv__type_impl_towire_template = """static void towire__{tlv_name}(const tal_t *ctx, u8 **p, const struct {tlv_name} *{tlv_name}) {{
\tu8 msg_len;
\tu64 msg_len;
\tu8 *tlv_msg;
{fields}}}
"""

tlv__type_impl_fromwire_template = """static struct {tlv_name} *fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u16 *len) {{
\tu8 msg_type, msg_len;
tlv__type_impl_fromwire_template = """static struct {tlv_name} *fromwire__{tlv_name}(const tal_t *ctx, const u8 **p, size_t *plen, const u64 *len) {{
\tu8 msg_type;
\tu64 msg_len;
\tsize_t start_len = *plen;
\tif (*plen < *len)
\t\treturn NULL;
Expand All @@ -855,7 +868,7 @@ def print_struct(self):
\twhile (*plen) {{
\t\tmsg_type = fromwire_u8(p, plen);
\t\tmsg_len = fromwire_u8(p, plen);
\t\tmsg_len = fromwire_var_int(p, plen);
\t\tif (*plen < msg_len) {{
\t\t\tfromwire_fail(p, plen);
\t\t\tbreak;
Expand Down

0 comments on commit bad0ac6

Please sign in to comment.