Skip to content

Commit

Permalink
common/tlvstream: put TLV checking back in the generic function.
Browse files Browse the repository at this point in the history
Callers were supposed to call "tlv_fields_valid" after fromwire_tlv,
but few did.  Make this the default, and call the underlying function
directly where we want to be more flexible (one place).

This loses the ability to allow misordered fields, or to pass through
*any* even fields.  We restore that for special cases in the next
patch.

Signed-off-by: Rusty Russell <[email protected]>
  • Loading branch information
rustyrussell committed Mar 25, 2022
1 parent a770f51 commit 83ee68a
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 155 deletions.
8 changes: 1 addition & 7 deletions common/blindedpath.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ static struct tlv_encrypted_data_tlv *decrypt_encmsg(const tal_t *ctx,
const struct secret *ss,
const u8 *enctlv)
{
struct tlv_encrypted_data_tlv *encmsg;
const u8 *cursor = decrypt_encmsg_raw(tmpctx, blinding, ss, enctlv);
size_t maxlen = tal_bytelen(cursor);

Expand All @@ -197,12 +196,7 @@ static struct tlv_encrypted_data_tlv *decrypt_encmsg(const tal_t *ctx,
* - if the `enctlv` is not a valid TLV...
* - MUST drop the message.
*/
encmsg = fromwire_tlv_encrypted_data_tlv(ctx, &cursor, &maxlen);
if (!encmsg
|| !tlv_fields_valid(encmsg->fields, NULL, NULL))
return tal_free(encmsg);

return encmsg;
return fromwire_tlv_encrypted_data_tlv(ctx, &cursor, &maxlen);
}

bool decrypt_enctlv(const struct pubkey *blinding,
Expand Down
33 changes: 14 additions & 19 deletions common/onion.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,21 @@ struct onion_payload *onion_decode(const tal_t *ctx,
const u8 *cursor = rs->raw_payload;
size_t max = tal_bytelen(cursor), len;
struct tlv_tlv_payload *tlv;
size_t badfield;

if (!pull_payload_length(&cursor, &max, true, &len))
goto general_fail;

tlv = fromwire_tlv_tlv_payload(p, &cursor, &max);
if (!tlv) {
/* FIXME: Fill in correct thing here! */
goto general_fail;
if (!pull_payload_length(&cursor, &max, true, &len)) {
*failtlvtype = 0;
*failtlvpos = tal_bytelen(rs->raw_payload);
goto fail_no_tlv;
}

/* FIXME: This API makes it really hard to get the actual
* offset of field. */
if (!tlv_fields_valid(tlv->fields, accepted_extra_tlvs, &badfield)) {
*failtlvtype = tlv->fields[badfield].numtype;
goto field_bad;
/* We do this manually so we can accept extra types, and get
* error off and type. */
tlv = tlv_tlv_payload_new(p);
if (!fromwire_tlv(&cursor, &max, tlvs_tlv_tlv_payload,
TLVS_ARRAY_SIZE_tlv_tlv_payload,
tlv, &tlv->fields, accepted_extra_tlvs,
failtlvpos, failtlvtype)) {
goto fail;
}

/* BOLT #4:
Expand Down Expand Up @@ -336,14 +335,10 @@ struct onion_payload *onion_decode(const tal_t *ctx,
field_bad:
*failtlvpos = tlv_field_offset(rs->raw_payload, tal_bytelen(rs->raw_payload),
*failtlvtype);
goto fail;

general_fail:
*failtlvtype = 0;
*failtlvpos = tal_bytelen(rs->raw_payload);
goto fail;
fail:
tal_free(tlv);

fail_no_tlv:
tal_free(p);
return NULL;
}
7 changes: 2 additions & 5 deletions common/test/run-json.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
/* Generated stub for fromwire_tlv */
bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
void *record UNNEEDED, struct tlv_field **fields UNNEEDED)
void *record UNNEEDED, struct tlv_field **fields UNNEEDED,
const u64 *extra_types UNNEEDED, size_t *err_off UNNEEDED, u64 *err_type UNNEEDED)
{ fprintf(stderr, "fromwire_tlv called!\n"); abort(); }
/* Generated stub for tlv_fields_valid */
bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED,
size_t *err_index UNNEEDED)
{ fprintf(stderr, "tlv_fields_valid called!\n"); abort(); }
/* Generated stub for towire_tlv */
void towire_tlv(u8 **pptr UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
Expand Down
7 changes: 2 additions & 5 deletions common/test/run-param.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ struct command_result *command_fail(struct command *cmd,
/* Generated stub for fromwire_tlv */
bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
void *record UNNEEDED, struct tlv_field **fields UNNEEDED)
void *record UNNEEDED, struct tlv_field **fields UNNEEDED,
const u64 *extra_types UNNEEDED, size_t *err_off UNNEEDED, u64 *err_type UNNEEDED)
{ fprintf(stderr, "fromwire_tlv called!\n"); abort(); }
/* Generated stub for json_to_channel_id */
bool json_to_channel_id(const char *buffer UNNEEDED, const jsmntok_t *tok UNNEEDED,
Expand Down Expand Up @@ -76,10 +77,6 @@ int segwit_addr_decode(
const char* addr
)
{ fprintf(stderr, "segwit_addr_decode called!\n"); abort(); }
/* Generated stub for tlv_fields_valid */
bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED,
size_t *err_index UNNEEDED)
{ fprintf(stderr, "tlv_fields_valid called!\n"); abort(); }
/* Generated stub for towire_tlv */
void towire_tlv(u8 **pptr UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
Expand Down
7 changes: 2 additions & 5 deletions common/test/run-route-specific.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ bool fromwire_channel_id(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
/* Generated stub for fromwire_tlv */
bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
void *record UNNEEDED, struct tlv_field **fields UNNEEDED)
void *record UNNEEDED, struct tlv_field **fields UNNEEDED,
const u64 *extra_types UNNEEDED, size_t *err_off UNNEEDED, u64 *err_type UNNEEDED)
{ fprintf(stderr, "fromwire_tlv called!\n"); abort(); }
/* Generated stub for fromwire_wireaddr */
bool fromwire_wireaddr(const u8 **cursor UNNEEDED, size_t *max UNNEEDED, struct wireaddr *addr UNNEEDED)
{ fprintf(stderr, "fromwire_wireaddr called!\n"); abort(); }
/* Generated stub for tlv_fields_valid */
bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED,
size_t *err_index UNNEEDED)
{ fprintf(stderr, "tlv_fields_valid called!\n"); abort(); }
/* Generated stub for towire_bigsize */
void towire_bigsize(u8 **pptr UNNEEDED, const bigsize_t val UNNEEDED)
{ fprintf(stderr, "towire_bigsize called!\n"); abort(); }
Expand Down
7 changes: 2 additions & 5 deletions common/test/run-route.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@ bool fromwire_channel_id(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
/* Generated stub for fromwire_tlv */
bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
void *record UNNEEDED, struct tlv_field **fields UNNEEDED)
void *record UNNEEDED, struct tlv_field **fields UNNEEDED,
const u64 *extra_types UNNEEDED, size_t *err_off UNNEEDED, u64 *err_type UNNEEDED)
{ fprintf(stderr, "fromwire_tlv called!\n"); abort(); }
/* Generated stub for fromwire_wireaddr */
bool fromwire_wireaddr(const u8 **cursor UNNEEDED, size_t *max UNNEEDED, struct wireaddr *addr UNNEEDED)
{ fprintf(stderr, "fromwire_wireaddr called!\n"); abort(); }
/* Generated stub for tlv_fields_valid */
bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED,
size_t *err_index UNNEEDED)
{ fprintf(stderr, "tlv_fields_valid called!\n"); abort(); }
/* Generated stub for towire_bigsize */
void towire_bigsize(u8 **pptr UNNEEDED, const bigsize_t val UNNEEDED)
{ fprintf(stderr, "towire_bigsize called!\n"); abort(); }
Expand Down
7 changes: 2 additions & 5 deletions common/test/run-sphinx.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,15 @@ bigsize_t fromwire_bigsize(const u8 **cursor UNNEEDED, size_t *max UNNEEDED)
/* Generated stub for fromwire_tlv */
bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED,
const struct tlv_record_type *types UNNEEDED, size_t num_types UNNEEDED,
void *record UNNEEDED, struct tlv_field **fields UNNEEDED)
void *record UNNEEDED, struct tlv_field **fields UNNEEDED,
const u64 *extra_types UNNEEDED, size_t *err_off UNNEEDED, u64 *err_type UNNEEDED)
{ fprintf(stderr, "fromwire_tlv called!\n"); abort(); }
/* Generated stub for pubkey_from_node_id */
bool pubkey_from_node_id(struct pubkey *key UNNEEDED, const struct node_id *id UNNEEDED)
{ fprintf(stderr, "pubkey_from_node_id called!\n"); abort(); }
/* Generated stub for tlv_field_offset */
size_t tlv_field_offset(const u8 *tlvstream UNNEEDED, size_t tlvlen UNNEEDED, u64 fieldtype UNNEEDED)
{ fprintf(stderr, "tlv_field_offset called!\n"); abort(); }
/* Generated stub for tlv_fields_valid */
bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED,
size_t *err_index UNNEEDED)
{ fprintf(stderr, "tlv_fields_valid called!\n"); abort(); }
/* Generated stub for towire_amount_msat */
void towire_amount_msat(u8 **pptr UNNEEDED, const struct amount_msat msat UNNEEDED)
{ fprintf(stderr, "towire_amount_msat called!\n"); abort(); }
Expand Down
14 changes: 11 additions & 3 deletions plugins/keysend.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ static struct command_result *htlc_accepted_call(struct command *cmd,
struct out_req *req;
struct timeabs now = time_now();
const char *err;
u64 *allowed = tal_arr(cmd, u64, 1);
size_t err_off;
u64 err_type;

err = json_scan(tmpctx, buf, params,
"{onion:{payload:%},htlc:{payment_hash:%}}",
Expand All @@ -356,10 +359,15 @@ static struct command_result *htlc_accepted_call(struct command *cmd,
if (s != max) {
return htlc_accepted_continue(cmd, NULL);
}
payload = fromwire_tlv_tlv_payload(cmd, &rawpayload, &max);
if (!payload) {

/* We explicitly allow our type. */
allowed[0] = 5482373484;
payload = tlv_tlv_payload_new(cmd);
if (!fromwire_tlv(&rawpayload, &max, tlvs_tlv_tlv_payload, TLVS_ARRAY_SIZE_tlv_tlv_payload,
payload, &payload->fields, allowed, &err_off, &err_type)) {
plugin_log(
cmd->plugin, LOG_UNUSUAL, "Malformed TLV payload %.*s",
cmd->plugin, LOG_UNUSUAL, "Malformed TLV payload type %"PRIu64" at off %zu %.*s",
err_type, err_off,
json_tok_full_len(params),
json_tok_full(buf, params));
return htlc_accepted_continue(cmd, NULL);
Expand Down
8 changes: 1 addition & 7 deletions tools/gen/impl_template
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,10 @@ void towire_${tlv.name}(u8 **pptr, const struct ${tlv.struct_name()} *record)
struct ${tlv.name} *fromwire_${tlv.name}(const tal_t *ctx, const u8 **cursor, size_t *max)
{
struct ${tlv.name} *record = ${tlv.name}_new(ctx);
if (!fromwire_tlv(cursor, max, tlvs_${tlv.name}, ${len(tlv.messages)}, record, &record->fields))
if (!fromwire_tlv(cursor, max, tlvs_${tlv.name}, ${len(tlv.messages)}, record, &record->fields, NULL, NULL, NULL))
return tal_free(record);
return record;
}

bool ${tlv.name}_is_valid(const struct ${tlv.struct_name()} *record, size_t *err_index)
{
return tlv_fields_valid(record->fields, NULL, err_index);
}

% endfor ## END TLV's
% for msg in messages: ## START Wire Messages

Expand Down
35 changes: 15 additions & 20 deletions wire/test/run-tlvstream.c
Original file line number Diff line number Diff line change
Expand Up @@ -468,12 +468,12 @@ int main(int argc, char *argv[])
max = tal_count(orig_p);
p = orig_p;
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) || !tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
assert(strstr(invalid_streams_either[i].reason, reason));
max = tal_count(orig_p);
p = orig_p;
tlv_n2 = fromwire_tlv_n2(tmpctx, &p, &max);
assert((!tlv_n2 && !p) || !tlv_n2_is_valid(tlv_n2, NULL));
assert(!tlv_n2 && !p);
assert(strstr(invalid_streams_either[i].reason, reason));
}

Expand All @@ -485,7 +485,7 @@ int main(int argc, char *argv[])
p = stream(tmpctx, invalid_streams_n1[i].hex);
max = tal_count(p);
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) || !tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
assert(strstr(invalid_streams_n1[i].reason, reason));
}

Expand All @@ -497,7 +497,7 @@ int main(int argc, char *argv[])
p = stream(tmpctx, invalid_streams_n1_combo[i].hex);
max = tal_count(p);
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) || !tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
assert(strstr(invalid_streams_n1_combo[i].reason, reason));
}

Expand All @@ -509,8 +509,7 @@ int main(int argc, char *argv[])
p = stream(tmpctx, invalid_streams_n2_combo[i].hex);
max = tal_count(p);
tlv_n2 = fromwire_tlv_n2(tmpctx, &p, &max);
assert((!tlv_n2 && !p) ||
!tlv_n2_is_valid(tlv_n2, NULL));
assert(!tlv_n2 && !p);
assert(strstr(invalid_streams_n2_combo[i].reason, reason));
}

Expand All @@ -525,8 +524,7 @@ int main(int argc, char *argv[])
max = tal_count(orig_p);
p = orig_p;
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert(tlv_n1 &&
tlv_n1_is_valid(tlv_n1, NULL));
assert(tlv_n1);
assert(max == 0);
assert(tlv_n1_eq(tlv_n1, &valid_streams[i].expect));

Expand Down Expand Up @@ -558,13 +556,11 @@ int main(int argc, char *argv[])
max = tal_count(orig_p);
p = orig_p;
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) ||
!tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
max = tal_count(orig_p);
p = orig_p;
tlv_n2 = fromwire_tlv_n2(tmpctx, &p, &max);
assert((!tlv_n2 && !p) ||
!tlv_n2_is_valid(tlv_n2, NULL));
assert(!tlv_n2 && !p);
}
}

Expand All @@ -578,8 +574,7 @@ int main(int argc, char *argv[])
invalid_streams_n1[i].hex);
max = tal_count(p);
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) ||
!tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
}
}

Expand All @@ -593,8 +588,7 @@ int main(int argc, char *argv[])
invalid_streams_n1_combo[i].hex);
max = tal_count(p);
tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert((!tlv_n1 && !p) ||
!tlv_n1_is_valid(tlv_n1, NULL));
assert(!tlv_n1 && !p);
}
}

Expand Down Expand Up @@ -624,11 +618,12 @@ int main(int argc, char *argv[])
< pull_type(valid_streams[j].hex);

tlv_n1 = fromwire_tlv_n1(tmpctx, &p, &max);
assert(tlv_n1 &&
tlv_n1_is_valid(tlv_n1, NULL) == expect_success);

if (!expect_success)
if (!expect_success) {
assert(!tlv_n1);
continue;
}

assert(tlv_n1);

/* Re-encoding should give the same results (except
* ignored fields tests!) */
Expand Down
Loading

0 comments on commit 83ee68a

Please sign in to comment.