Skip to content

Commit

Permalink
utils: make tal_arr_expand safer.
Browse files Browse the repository at this point in the history
Christian and I both unwittingly used it in form:

	*tal_arr_expand(&x) = tal(x, ...)

Since '=' isn't a sequence point, the compiler can (and does!) cache
the value of x, handing it to tal *after* tal_arr_expand() moves it
due to tal_resize().

The new version is somewhat less convenient to use, but doesn't have
this problem, since the assignment is always evaluated after the
resize.

Signed-off-by: Rusty Russell <[email protected]>
  • Loading branch information
rustyrussell authored and cdecker committed Jan 15, 2019
1 parent a5ed98a commit 26dda57
Show file tree
Hide file tree
Showing 26 changed files with 173 additions and 154 deletions.
47 changes: 26 additions & 21 deletions channeld/channeld.c
Original file line number Diff line number Diff line change
Expand Up @@ -1224,28 +1224,31 @@ static u8 *got_commitsig_msg(const tal_t *ctx,
for (size_t i = 0; i < tal_count(changed_htlcs); i++) {
const struct htlc *htlc = changed_htlcs[i];
if (htlc->state == RCVD_ADD_COMMIT) {
struct added_htlc *a = tal_arr_expand(&added);
struct secret *s = tal_arr_expand(&shared_secret);
a->id = htlc->id;
a->amount_msat = htlc->msatoshi;
a->payment_hash = htlc->rhash;
a->cltv_expiry = abs_locktime_to_blocks(&htlc->expiry);
memcpy(a->onion_routing_packet,
struct added_htlc a;
struct secret s;

a.id = htlc->id;
a.amount_msat = htlc->msatoshi;
a.payment_hash = htlc->rhash;
a.cltv_expiry = abs_locktime_to_blocks(&htlc->expiry);
memcpy(a.onion_routing_packet,
htlc->routing,
sizeof(a->onion_routing_packet));
sizeof(a.onion_routing_packet));
/* Invalid shared secret gets set to all-zero: our
* code generator can't make arrays of optional values */
if (!htlc->shared_secret)
memset(s, 0, sizeof(*s));
memset(&s, 0, sizeof(s));
else
*s = *htlc->shared_secret;
s = *htlc->shared_secret;
tal_arr_expand(&added, a);
tal_arr_expand(&shared_secret, s);
} else if (htlc->state == RCVD_REMOVE_COMMIT) {
if (htlc->r) {
struct fulfilled_htlc *f;
struct fulfilled_htlc f;
assert(!htlc->fail && !htlc->failcode);
f = tal_arr_expand(&fulfilled);
f->id = htlc->id;
f->payment_preimage = *htlc->r;
f.id = htlc->id;
f.payment_preimage = *htlc->r;
tal_arr_expand(&fulfilled, f);
} else {
struct failed_htlc *f;
assert(htlc->fail || htlc->failcode);
Expand All @@ -1255,15 +1258,16 @@ static u8 *got_commitsig_msg(const tal_t *ctx,
f->failreason = cast_const(u8 *, htlc->fail);
f->scid = cast_const(struct short_channel_id *,
htlc->failed_scid);
*tal_arr_expand(&failed) = f;
tal_arr_expand(&failed, f);
}
} else {
struct changed_htlc *c = tal_arr_expand(&changed);
struct changed_htlc c;
assert(htlc->state == RCVD_REMOVE_ACK_COMMIT
|| htlc->state == RCVD_ADD_ACK_COMMIT);

c->id = htlc->id;
c->newstate = htlc->state;
c.id = htlc->id;
c.newstate = htlc->state;
tal_arr_expand(&changed, c);
}
}

Expand Down Expand Up @@ -1418,15 +1422,16 @@ static u8 *got_revoke_msg(const tal_t *ctx, u64 revoke_num,
struct changed_htlc *changed = tal_arr(tmpctx, struct changed_htlc, 0);

for (size_t i = 0; i < tal_count(changed_htlcs); i++) {
struct changed_htlc *c = tal_arr_expand(&changed);
struct changed_htlc c;
const struct htlc *htlc = changed_htlcs[i];

status_trace("HTLC %"PRIu64"[%s] => %s",
htlc->id, side_to_str(htlc_owner(htlc)),
htlc_state_name(htlc->state));

c->id = changed_htlcs[i]->id;
c->newstate = changed_htlcs[i]->state;
c.id = changed_htlcs[i]->id;
c.newstate = changed_htlcs[i]->state;
tal_arr_expand(&changed, c);
}

msg = towire_channel_got_revoke(ctx, revoke_num, per_commitment_secret,
Expand Down
6 changes: 3 additions & 3 deletions channeld/full_channel.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ static void htlc_arr_append(const struct htlc ***arr, const struct htlc *htlc)
{
if (!arr)
return;
*tal_arr_expand(arr) = htlc;
tal_arr_expand(arr, htlc);
}

/* What does adding the HTLC do to the balance for this side */
Expand Down Expand Up @@ -227,8 +227,8 @@ static void add_htlcs(struct bitcoin_tx ***txs,
/* Append to array. */
assert(tal_count(*txs) == tal_count(*wscripts));

*tal_arr_expand(wscripts) = wscript;
*tal_arr_expand(txs) = tx;
tal_arr_expand(wscripts, wscript);
tal_arr_expand(txs, tx);
}
}

Expand Down
6 changes: 4 additions & 2 deletions common/bolt11.c
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,15 @@ static char *decode_r(struct bolt11 *b11,
pull_bits_certain(hu5, data, data_len, r8, data_length * 5, false);

do {
if (!fromwire_route_info(&cursor, &rlen, tal_arr_expand(&r))) {
struct route_info ri;
if (!fromwire_route_info(&cursor, &rlen, &ri)) {
return tal_fmt(b11, "r: hop %zu truncated", n);
}
tal_arr_expand(&r, ri);
} while (rlen);

/* Append route */
*tal_arr_expand(&b11->routes) = tal_steal(b11, r);
tal_arr_expand(&b11->routes, tal_steal(b11, r));
return NULL;
}

Expand Down
5 changes: 3 additions & 2 deletions common/decode_short_channel_ids.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ struct short_channel_id *decode_short_ids(const tal_t *ctx, const u8 *encoded)
case SHORTIDS_UNCOMPRESSED:
scids = tal_arr(ctx, struct short_channel_id, 0);
while (max) {
fromwire_short_channel_id(&encoded, &max,
tal_arr_expand(&scids));
struct short_channel_id scid;
fromwire_short_channel_id(&encoded, &max, &scid);
tal_arr_expand(&scids, scid);
}

/* encoded is set to NULL if we ran over */
Expand Down
4 changes: 2 additions & 2 deletions common/memleak.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ void *notleak_(const void *ptr, bool plus_children)
if (!notleaks)
return cast_const(void *, ptr);

*tal_arr_expand(&notleaks) = ptr;
*tal_arr_expand(&notleak_children) = plus_children;
tal_arr_expand(&notleaks, ptr);
tal_arr_expand(&notleak_children, plus_children);

tal_add_notifier(ptr, TAL_NOTIFY_FREE|TAL_NOTIFY_MOVE, notleak_change);
return cast_const(void *, ptr);
Expand Down
2 changes: 1 addition & 1 deletion common/msg_queue.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct msg_queue *msg_queue_new(const tal_t *ctx)

static void do_enqueue(struct msg_queue *q, const u8 *add TAKES)
{
*tal_arr_expand(&q->q) = tal_dup_arr(q, u8, add, tal_count(add), 0);
tal_arr_expand(&q->q, tal_dup_arr(q, u8, add, tal_count(add), 0));

/* In case someone is waiting */
io_wake(q);
Expand Down
14 changes: 8 additions & 6 deletions common/param.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ static bool param_add(struct param **params,
if (!(name && cbx && arg))
return false;
#endif
struct param *last = tal_arr_expand(params);
struct param last;

last->is_set = false;
last->name = name;
last->required = required;
last->cbx = cbx;
last->arg = arg;
last.is_set = false;
last.name = name;
last.required = required;
last.cbx = cbx;
last.arg = arg;

tal_arr_expand(params, last);
return true;
}

Expand Down
16 changes: 7 additions & 9 deletions common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ char *tal_hex(const tal_t *ctx, const tal_t *data);
/* Allocate and fill a buffer with the data of this hex string. */
u8 *tal_hexdata(const tal_t *ctx, const void *str, size_t len);

/* Helper macro to extend tal_arr and return pointer new last element. */
#if HAVE_STATEMENT_EXPR
/* More efficient version calls tal_count() once */
#define tal_arr_expand(p) \
({ size_t n = tal_count(*p); tal_resize((p), n+1); *(p) + n; })
#else
#define tal_arr_expand(p) \
(tal_resize((p), tal_count(*(p))+1), (*p) + tal_count(*(p))-1)
#endif
/* Note: p is never a complex expression, otherwise this multi-evaluates! */
#define tal_arr_expand(p, s) \
do { \
size_t n = tal_count(*(p)); \
tal_resize((p), n+1); \
(*(p))[n] = (s); \
} while(0)

/**
* Remove an element from an array
Expand Down
19 changes: 10 additions & 9 deletions connectd/connectd.c
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,10 @@ static void add_listen_fd(struct daemon *daemon, int fd, bool mayfail)
/*~ utils.h contains a convenience macro tal_arr_expand which
* reallocates a tal_arr to make it one longer, then returns a pointer
* to the (new) last element. */
struct listen_fd *l = tal_arr_expand(&daemon->listen_fds);
l->fd = fd;
l->mayfail = mayfail;
struct listen_fd l;
l.fd = fd;
l.mayfail = mayfail;
tal_arr_expand(&daemon->listen_fds, l);
}

/*~ Helper routine to create and bind a socket of a given type; like many
Expand Down Expand Up @@ -876,13 +877,13 @@ static bool public_address(struct daemon *daemon, struct wireaddr *wireaddr)
static void add_announcable(struct wireaddr **announcable,
const struct wireaddr *addr)
{
*tal_arr_expand(announcable) = *addr;
tal_arr_expand(announcable, *addr);
}

static void add_binding(struct wireaddr_internal **binding,
const struct wireaddr_internal *addr)
{
*tal_arr_expand(binding) = *addr;
tal_arr_expand(binding, *addr);
}

/*~ ccan/asort provides a type-safe sorting function; it requires a comparison
Expand Down Expand Up @@ -1223,7 +1224,7 @@ static void add_seed_addrs(struct wireaddr_internal **addrs,
status_trace("Resolved %s to %s", addr,
type_to_string(tmpctx, struct wireaddr,
&a.u.wireaddr));
*tal_arr_expand(addrs) = a;
tal_arr_expand(addrs, a);
}
}

Expand Down Expand Up @@ -1254,7 +1255,7 @@ static void add_gossip_addrs(struct wireaddr_internal **addrs,
struct wireaddr_internal addr;
addr.itype = ADDR_INTERNAL_WIREADDR;
addr.u.wireaddr = normal_addrs[i];
*tal_arr_expand(addrs) = addr;
tal_arr_expand(addrs, addr);
}
}

Expand Down Expand Up @@ -1284,7 +1285,7 @@ static void try_connect_peer(struct daemon *daemon,

/* They can supply an optional address for the connect RPC */
if (addrhint)
*tal_arr_expand(&addrs) = *addrhint;
tal_arr_expand(&addrs, *addrhint);

add_gossip_addrs(&addrs, id);

Expand All @@ -1297,7 +1298,7 @@ static void try_connect_peer(struct daemon *daemon,
wireaddr_from_unresolved(&unresolved,
seedname(tmpctx, id),
DEFAULT_PORT);
*tal_arr_expand(&addrs) = unresolved;
tal_arr_expand(&addrs, unresolved);
} else if (daemon->use_dns) {
add_seed_addrs(&addrs, id,
daemon->broken_resolver_response);
Expand Down
68 changes: 34 additions & 34 deletions gossipd/gossipd.c
Original file line number Diff line number Diff line change
Expand Up @@ -1022,8 +1022,8 @@ static void maybe_create_next_scid_reply(struct peer *peer)
queue_peer_msg(peer, chan->half[1].channel_update);

/* Record node ids for later transmission of node_announcement */
*tal_arr_expand(&peer->scid_query_nodes) = chan->nodes[0]->id;
*tal_arr_expand(&peer->scid_query_nodes) = chan->nodes[1]->id;
tal_arr_expand(&peer->scid_query_nodes, chan->nodes[0]->id);
tal_arr_expand(&peer->scid_query_nodes, chan->nodes[1]->id);
sent = true;
}

Expand Down Expand Up @@ -1919,14 +1919,12 @@ static void append_half_channel(struct gossip_getchannels_entry **entries,
int idx)
{
const struct half_chan *c = &chan->half[idx];
struct gossip_getchannels_entry *e;
struct gossip_getchannels_entry e;

/* If we've never seen a channel_update for this direction... */
if (!is_halfchan_defined(c))
return;

e = tal_arr_expand(entries);

/* Our 'struct chan' contains two nodes: they are in pubkey_cmp order
* (ie. chan->nodes[0] is the lesser pubkey) and this is the same as
* the direction bit in `channel_update`s `channel_flags`.
Expand All @@ -1936,18 +1934,20 @@ static void append_half_channel(struct gossip_getchannels_entry **entries,
* pubkeys to DER and back: that proves quite expensive, and we assume
* we're on the same architecture as lightningd, so we just send them
* raw in this case. */
raw_pubkey(e->source, &chan->nodes[idx]->id);
raw_pubkey(e->destination, &chan->nodes[!idx]->id);
e->satoshis = chan->satoshis;
e->channel_flags = c->channel_flags;
e->message_flags = c->message_flags;
e->local_disabled = chan->local_disabled;
e->public = is_chan_public(chan);
e->short_channel_id = chan->scid;
e->last_update_timestamp = c->last_timestamp;
e->base_fee_msat = c->base_fee;
e->fee_per_millionth = c->proportional_fee;
e->delay = c->delay;
raw_pubkey(e.source, &chan->nodes[idx]->id);
raw_pubkey(e.destination, &chan->nodes[!idx]->id);
e.satoshis = chan->satoshis;
e.channel_flags = c->channel_flags;
e.message_flags = c->message_flags;
e.local_disabled = chan->local_disabled;
e.public = is_chan_public(chan);
e.short_channel_id = chan->scid;
e.last_update_timestamp = c->last_timestamp;
e.base_fee_msat = c->base_fee;
e.fee_per_millionth = c->proportional_fee;
e.delay = c->delay;

tal_arr_expand(entries, e);
}

/*~ Marshal (possibly) both channel directions into entries */
Expand Down Expand Up @@ -2002,21 +2002,21 @@ static void append_node(const struct gossip_getnodes_entry ***entries,
{
struct gossip_getnodes_entry *e;

*tal_arr_expand(entries) = e
= tal(*entries, struct gossip_getnodes_entry);
e = tal(*entries, struct gossip_getnodes_entry);
raw_pubkey(e->nodeid, &n->id);
e->last_timestamp = n->last_timestamp;
/* Timestamp on wire is an unsigned 32 bit: we use a 64-bit signed, so
* -1 means "we never received a channel_update". */
if (e->last_timestamp < 0)
return;
if (e->last_timestamp >= 0) {
e->globalfeatures = n->globalfeatures;
e->addresses = n->addresses;
BUILD_ASSERT(ARRAY_SIZE(e->alias) == ARRAY_SIZE(n->alias));
BUILD_ASSERT(ARRAY_SIZE(e->color) == ARRAY_SIZE(n->rgb_color));
memcpy(e->alias, n->alias, ARRAY_SIZE(e->alias));
memcpy(e->color, n->rgb_color, ARRAY_SIZE(e->color));
}

e->globalfeatures = n->globalfeatures;
e->addresses = n->addresses;
BUILD_ASSERT(ARRAY_SIZE(e->alias) == ARRAY_SIZE(n->alias));
BUILD_ASSERT(ARRAY_SIZE(e->color) == ARRAY_SIZE(n->rgb_color));
memcpy(e->alias, n->alias, ARRAY_SIZE(e->alias));
memcpy(e->color, n->rgb_color, ARRAY_SIZE(e->color));
tal_arr_expand(entries, e);
}

/* Simply routine when they ask for `listnodes` */
Expand Down Expand Up @@ -2125,7 +2125,7 @@ static struct io_plan *get_incoming_channels(struct io_conn *conn,
for (size_t i = 0; i < tal_count(node->chans); i++) {
const struct chan *c = node->chans[i];
const struct half_chan *hc;
struct route_info *ri;
struct route_info ri;

/* Don't leak private channels. */
if (!is_chan_public(c))
Expand All @@ -2136,12 +2136,12 @@ static struct io_plan *get_incoming_channels(struct io_conn *conn,
if (!is_halfchan_enabled(hc))
continue;

ri = tal_arr_expand(&r);
ri->pubkey = other_node(node, c)->id;
ri->short_channel_id = c->scid;
ri->fee_base_msat = hc->base_fee;
ri->fee_proportional_millionths = hc->proportional_fee;
ri->cltv_expiry_delta = hc->delay;
ri.pubkey = other_node(node, c)->id;
ri.short_channel_id = c->scid;
ri.fee_base_msat = hc->base_fee;
ri.fee_proportional_millionths = hc->proportional_fee;
ri.cltv_expiry_delta = hc->delay;
tal_arr_expand(&r, ri);
}
}

Expand Down
Loading

0 comments on commit 26dda57

Please sign in to comment.