Skip to content

Commit

Permalink
refactor default value logic
Browse files Browse the repository at this point in the history
  • Loading branch information
starwing committed Jul 25, 2022
1 parent 739bf67 commit 30a6f35
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 46 deletions.
72 changes: 31 additions & 41 deletions pb.c
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,12 @@ LUALIB_API int luaopen_pb_slice(lua_State *L) {

/* high level typeinfo/encode/decode routines */

static int lpb_pushdeffield(lua_State *L, lpb_State *LS, const pb_Field *f, int is_proto3);
typedef enum {USE_FIELD = 1, USE_REPEAT = 2, USE_MESSAGE = 4} lpb_DefFlags;

static void lpb_pushtypetable(lua_State *L, lpb_State *LS, const pb_Type *t);

static void lpb_newmsgtable(lua_State *L, const pb_Type *t)
{ lua_createtable(L, 0, t->field_count - t->oneof_field + t->oneof_count*2); }

LUALIB_API const pb_Type *lpb_type(lpb_State *LS, pb_Slice s) {
const pb_Type *t;
Expand Down Expand Up @@ -1318,35 +1323,11 @@ static int Lpb_enum(lua_State *L) {
return 1;
}

static void lpb_newtypetable(lua_State *L, const pb_Type *t, int with_repeat) {
const pb_Field *f = NULL;
if (t == NULL) { lua_newtable(L); return; }
lua_createtable(L, 0, t->field_count - t->oneof_field + t->oneof_count*2);
if (!with_repeat) return;
while (pb_nextfield(t, &f)) {
if (f->repeated) {
lua_newtable(L);
lua_setfield(L, -2, (const char*)f->name);
}
}
}

static int lpb_pushdefmsg(lua_State *L, lpb_State *LS, const pb_Type *t) {
const pb_Field *f = NULL;
if (t == NULL) return 0;
lpb_newtypetable(L, t, 0);
while (pb_nextfield(t, &f))
if (!f->oneof_idx && lpb_pushdeffield(L, LS, f, t->is_proto3))
lua_setfield(L, -2, (const char*)f->name);
return 1;
}

static int lpb_pushdeffield(lua_State *L, lpb_State *LS, const pb_Field *f, int is_proto3) {
int ret = 0;
const pb_Type *type;
char *end;
if (f == NULL) return 0;
if (f->repeated) return is_proto3 ? (lua_newtable(L), 1) : 0;
switch (f->type_id) {
case PB_Tbytes: case PB_Tstring:
if (f->default_value)
Expand All @@ -1369,9 +1350,8 @@ static int lpb_pushdeffield(lua_State *L, lpb_State *LS, const pb_Field *f, int
}
break;
case PB_Tmessage:
if (LS->decode_default_message)
return lpb_pushdefmsg(L, LS, f->type);
return 0;
ret = (lpb_pushtypetable(L, LS, f->type), 1);
break;
case PB_Tbool:
if (f->default_value) {
if (f->default_value == lpb_name(LS, pb_slice("true")))
Expand All @@ -1398,18 +1378,26 @@ static int lpb_pushdeffield(lua_State *L, lpb_State *LS, const pb_Field *f, int
return ret;
}

static void lpb_setdeffields(lua_State *L, lpb_State *LS, const pb_Type *t, lpb_DefFlags flags) {
const pb_Field *f = NULL;
while (pb_nextfield(t, &f)) {
int has_field = f->repeated ?
(flags & USE_REPEAT) && (t->is_proto3 || LS->decode_default_array)
&& (lua_newtable(L), 1) :
!f->oneof_idx && (f->type_id != PB_Tmessage ?
(flags & USE_FIELD) :
(flags & USE_MESSAGE) && LS->decode_default_message)
&& lpb_pushdeffield(L, LS, f, t->is_proto3);
if (has_field) lua_setfield(L, -2, (const char*)f->name);
}
}

static void lpb_pushdefmeta(lua_State *L, lpb_State *LS, const pb_Type *t) {
lpb_pushdeftable(L, LS);
if (lua53_rawgetp(L, -1, t) != LUA_TTABLE) {
const pb_Field *f = NULL;
lua_pop(L, 1);
lpb_newtypetable(L, t, 0);
while (pb_nextfield(t, &f))
if (!f->oneof_idx /* not oneof */
&& f->type_id != PB_Tmessage /* not message */
&& !f->repeated /* not repeated */
&& lpb_pushdeffield(L, LS, f, t->is_proto3))
lua_setfield(L, -2, (const char*)f->name);
lpb_newmsgtable(L, t);
lpb_setdeffields(L, LS, t, USE_FIELD);
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
lua_pushvalue(L, -1);
Expand Down Expand Up @@ -1640,7 +1628,7 @@ static void lpbE_repeated(lpb_Env *e, const pb_Field *f) {
lpbE_field(e, f, NULL);
lua_pop(L, 1);
}
if (i == 1)
if (i == 1 && !e->LS->encode_default_values)
pb_bufflen(b) = bufflen;
else
lpb_addlength(L, b, len);
Expand Down Expand Up @@ -1727,17 +1715,19 @@ static void lpb_usedechooks(lua_State *L, lpb_State *LS, const pb_Type *t) {

static void lpb_pushtypetable(lua_State *L, lpb_State *LS, const pb_Type *t) {
int mode = LS->default_mode;
lpb_newmsgtable(L, t);
switch (t->is_proto3 && mode == LPB_DEFDEF ? LPB_COPYDEF : mode) {
case LPB_COPYDEF:
lpb_pushdefmsg(L, LS, t);
lpb_setdeffields(L, LS, t, USE_FIELD|USE_REPEAT|USE_MESSAGE);
break;
case LPB_METADEF:
lpb_newtypetable(L, t, t->is_proto3);
lpb_setdeffields(L, LS, t, USE_REPEAT|USE_MESSAGE);
lpb_pushdefmeta(L, LS, t);
lua_setmetatable(L, -2);
break;
default: /* no default value */
lpb_newtypetable(L, t, LS->decode_default_array);
default:
if (LS->decode_default_array || LS->decode_default_message)
lpb_setdeffields(L, LS, t, USE_REPEAT|USE_MESSAGE);
break;
}
}
Expand Down
12 changes: 7 additions & 5 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,7 @@ function _G.test_default()
pb.option "decode_default_array"
local dt = pb.decode("TestDefault", "")
eq(getmetatable(dt), nil)
table_eq(dt,{
array = {},
})
table_eq(dt,{ array = {} })
local chunk2, _ = pb.encode("TestDefault", {defaulted_int = 0,defaulted_bool = true})
local dt = pb.decode("TestDefault", chunk2)
eq(dt.defaulted_int, 0)
Expand Down Expand Up @@ -674,7 +672,11 @@ function _G.test_map()
local chunk = pb.encode("TestNum", {f = 123})
pb.decode("TestMap", chunk)
end)
--eq(pb.decode("TestMap", "\10\4\3\10\1\1"), { map = {} })
eq(pb.decode("TestMap", "\10\4\3\10\1\1"), {
map = {["\1"] = 0},
packed_map = {},
msg_map = {},
})
eq(pb.decode("TestMap", "\10\0"), {
map = { [""] = 0 },
packed_map = {},
Expand All @@ -683,7 +685,7 @@ function _G.test_map()
eq(pb.decode("TestMap", "\26\0"), {
map = {},
packed_map = {},
msg_map = {}
msg_map = {[""] = {}}
})

check_load [[
Expand Down

0 comments on commit 30a6f35

Please sign in to comment.