Skip to content

Commit

Permalink
move address resolution into lua
Browse files Browse the repository at this point in the history
  • Loading branch information
wg committed Feb 21, 2015
1 parent 93348e2 commit 6f0aa32
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 55 deletions.
1 change: 0 additions & 1 deletion src/main.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <fcntl.h>
#include <getopt.h>
#include <math.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <stdarg.h>
Expand Down
132 changes: 127 additions & 5 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,28 @@
#include <string.h>
#include "script.h"
#include "http_parser.h"
#include "zmalloc.h"

typedef struct {
char *name;
int type;
void *value;
} table_field;

static int script_addr_tostring(lua_State *);
static int script_addr_gc(lua_State *);
static int script_stats_len(lua_State *);
static int script_stats_get(lua_State *);
static int script_wrk_lookup(lua_State *);
static int script_wrk_connect(lua_State *);
static void set_fields(lua_State *, int index, const table_field *);

static const struct luaL_reg addrlib[] = {
{ "__tostring", script_addr_tostring },
{ "__gc" , script_addr_gc },
{ NULL, NULL }
};

static const struct luaL_reg statslib[] = {
{ "__index", script_stats_get },
{ "__len", script_stats_len },
Expand All @@ -26,16 +37,22 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
luaL_openlibs(L);
luaL_dostring(L, "wrk = require \"wrk\"");

luaL_newmetatable(L, "wrk.addr");
luaL_register(L, NULL, addrlib);
lua_pop(L, 1);

luaL_newmetatable(L, "wrk.stats");
luaL_register(L, NULL, statslib);
lua_pop(L, 1);

const table_field fields[] = {
{ "scheme", LUA_TSTRING, scheme },
{ "host", LUA_TSTRING, host },
{ "port", LUA_TSTRING, port },
{ "path", LUA_TSTRING, path },
{ NULL, 0, NULL },
{ "scheme", LUA_TSTRING, scheme },
{ "host", LUA_TSTRING, host },
{ "port", LUA_TSTRING, port },
{ "path", LUA_TSTRING, path },
{ "lookup", LUA_TFUNCTION, script_wrk_lookup },
{ "connect", LUA_TFUNCTION, script_wrk_connect },
{ NULL, 0, NULL },
};

lua_getglobal(L, "wrk");
Expand All @@ -45,6 +62,36 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
return L;
}

void script_prepare_setup(lua_State *L, char *script) {
if (script && luaL_dofile(L, script)) {
const char *cause = lua_tostring(L, -1);
fprintf(stderr, "%s: %s\n", script, cause);
}
}

bool script_resolve(lua_State *L, char *host, char *service) {
lua_getglobal(L, "wrk");

lua_getfield(L, -1, "resolve");
lua_pushstring(L, host);
lua_pushstring(L, service);
lua_call(L, 2, 0);

lua_getfield(L, -1, "addrs");
size_t count = lua_objlen(L, -1);
lua_pop(L, 2);
return count > 0;
}

struct addrinfo *script_peek_addr(lua_State *L) {
lua_getglobal(L, "wrk");
lua_getfield(L, -1, "addrs");
lua_rawgeti(L, -1, 1);
struct addrinfo *addr = lua_touserdata(L, -1);
lua_pop(L, 3);
return addr;
}

void script_headers(lua_State *L, char **headers) {
lua_getglobal(L, "wrk");
lua_getfield(L, 1, "headers");
Expand Down Expand Up @@ -225,6 +272,34 @@ size_t script_verify_request(lua_State *L) {
return count;
}

static struct addrinfo *checkaddr(lua_State *L) {
struct addrinfo *addr = luaL_checkudata(L, -1, "wrk.addr");
luaL_argcheck(L, addr != NULL, 1, "`addr' expected");
return addr;
}

static int script_addr_tostring(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
char host[NI_MAXHOST];
char service[NI_MAXSERV];

int flags = NI_NUMERICHOST | NI_NUMERICSERV;
int rc = getnameinfo(addr->ai_addr, addr->ai_addrlen, host, NI_MAXHOST, service, NI_MAXSERV, flags);
if (rc != 0) {
const char *msg = gai_strerror(rc);
return luaL_error(L, "addr tostring failed %s", msg);
}

lua_pushfstring(L, "%s:%s", host, service);
return 1;
}

static int script_addr_gc(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
zfree(addr->ai_addr);
return 0;
}

static stats *checkstats(lua_State *L) {
stats **s = luaL_checkudata(L, 1, "wrk.stats");
luaL_argcheck(L, s != NULL, 1, "`stats' expected");
Expand Down Expand Up @@ -262,10 +337,57 @@ static int script_stats_len(lua_State *L) {
return 1;
}

static int script_wrk_lookup(lua_State *L) {
struct addrinfo *addrs;
struct addrinfo hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM
};
int rc, index = 1;

const char *host = lua_tostring(L, -2);
const char *service = lua_tostring(L, -1);

if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
const char *msg = gai_strerror(rc);
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
exit(1);
}

lua_newtable(L);
for (struct addrinfo *addr = addrs; addr != NULL; addr = addr->ai_next) {
struct addrinfo *udata = lua_newuserdata(L, sizeof(*udata));
luaL_getmetatable(L, "wrk.addr");
lua_setmetatable(L, -2);

*udata = *addr;
udata->ai_addr = zmalloc(addr->ai_addrlen);
memcpy(udata->ai_addr, addr->ai_addr, addr->ai_addrlen);
lua_rawseti(L, -2, index++);
}

freeaddrinfo(addrs);
return 1;
}

static int script_wrk_connect(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
int fd, connected = 0;
if ((fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)) != -1) {
connected = connect(fd, addr->ai_addr, addr->ai_addrlen) == 0;
close(fd);
}
lua_pushboolean(L, connected);
return 1;
}

static void set_fields(lua_State *L, int index, const table_field *fields) {
for (int i = 0; fields[i].name; i++) {
table_field f = fields[i];
switch (f.value == NULL ? LUA_TNIL : f.type) {
case LUA_TFUNCTION:
lua_pushcfunction(L, (lua_CFunction) f.value);
break;
case LUA_TNUMBER:
lua_pushinteger(L, *((lua_Integer *) f.value));
break;
Expand Down
6 changes: 6 additions & 0 deletions src/script.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <sys/types.h>
#include <netdb.h>
#include <unistd.h>
#include "stats.h"

typedef struct {
Expand All @@ -14,6 +17,9 @@ typedef struct {
} buffer;

lua_State *script_create(char *, char *, char *, char *);
void script_prepare_setup(lua_State *, char *);
bool script_resolve(lua_State *, char *, char *);
struct addrinfo *script_peek_addr(lua_State *);
void script_headers(lua_State *, char **);
size_t script_verify_request(lua_State *L);

Expand Down
35 changes: 7 additions & 28 deletions src/wrk.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "main.h"

static struct config {
struct addrinfo addr;
uint64_t threads;
uint64_t connections;
uint64_t duration;
Expand Down Expand Up @@ -58,10 +57,8 @@ static void usage() {
}

int main(int argc, char **argv) {
struct addrinfo *addrs, *addr;
struct http_parser_url parser_url;
char *url, **headers;
int rc;

headers = zmalloc((argc / 2) * sizeof(char *));

Expand All @@ -85,26 +82,9 @@ int main(int argc, char **argv) {
path = &url[parser_url.field_data[UF_PATH].off];
}

struct addrinfo hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM
};

if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
const char *msg = gai_strerror(rc);
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
exit(1);
}

for (addr = addrs; addr != NULL; addr = addr->ai_next) {
int fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (fd == -1) continue;
rc = connect(fd, addr->ai_addr, addr->ai_addrlen);
close(fd);
if (rc == 0) break;
}

if (addr == NULL) {
lua_State *L = script_create(schema, host, port, path);
script_prepare_setup(L, cfg.script);
if (!script_resolve(L, host, service)) {
char *msg = strerror(errno);
fprintf(stderr, "unable to connect to %s:%s %s\n", host, service, msg);
exit(1);
Expand All @@ -125,7 +105,6 @@ int main(int argc, char **argv) {

signal(SIGPIPE, SIG_IGN);
signal(SIGINT, SIG_IGN);
cfg.addr = *addr;

pthread_mutex_init(&statistics.mutex, NULL);
statistics.latency = stats_alloc(SAMPLES);
Expand All @@ -138,6 +117,7 @@ int main(int argc, char **argv) {
for (uint64_t i = 0; i < cfg.threads; i++) {
thread *t = &threads[i];
t->loop = aeCreateEventLoop(10 + cfg.connections * 3);
t->addr = script_peek_addr(L);
t->connections = connections;
t->stop_at = stop_at;

Expand Down Expand Up @@ -217,7 +197,6 @@ int main(int argc, char **argv) {
printf("Requests/sec: %9.2Lf\n", req_per_s);
printf("Transfer/sec: %10sB\n", format_binary(bytes_per_s));

lua_State *L = threads[0].L;
if (script_has_done(L)) {
script_summary(L, runtime_us, complete, bytes);
script_errors(L, &errors);
Expand Down Expand Up @@ -274,16 +253,16 @@ void *thread_main(void *arg) {
}

static int connect_socket(thread *thread, connection *c) {
struct addrinfo addr = cfg.addr;
struct addrinfo *addr = thread->addr;
struct aeEventLoop *loop = thread->loop;
int fd, flags;

fd = socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);

flags = fcntl(fd, F_GETFL, 0);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);

if (connect(fd, addr.ai_addr, addr.ai_addrlen) == -1) {
if (connect(fd, addr->ai_addr, addr->ai_addrlen) == -1) {
if (errno != EINPROGRESS) goto error;
}

Expand Down
2 changes: 2 additions & 0 deletions src/wrk.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <pthread.h>
#include <inttypes.h>
#include <sys/types.h>
#include <netdb.h>

#include <openssl/ssl.h>
#include <openssl/err.h>
Expand All @@ -25,6 +26,7 @@
typedef struct {
pthread_t thread;
aeEventLoop *loop;
struct addrinfo *addr;
uint64_t connections;
int interval;
uint64_t stop_at;
Expand Down
52 changes: 31 additions & 21 deletions src/wrk.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,14 @@ local wrk = {
body = nil
}

function wrk.format(method, path, headers, body)
local method = method or wrk.method
local path = path or wrk.path
local headers = headers or wrk.headers
local body = body or wrk.body
local s = {}

if not headers["Host"] then
headers["Host"] = wrk.headers["Host"]
function wrk.resolve(host, service)
local addrs = wrk.lookup(host, service)
for i = #addrs, 1, -1 do
if not wrk.connect(addrs[i]) then
table.remove(addrs, i)
end
end

headers["Content-Length"] = body and string.len(body)

s[1] = string.format("%s %s HTTP/1.1", method, path)
for name, value in pairs(headers) do
s[#s+1] = string.format("%s: %s", name, value)
end

s[#s+1] = ""
s[#s+1] = body or ""

return table.concat(s, "\r\n")
wrk.addrs = addrs
end

function wrk.init(args)
Expand All @@ -53,4 +39,28 @@ function wrk.init(args)
end
end

function wrk.format(method, path, headers, body)
local method = method or wrk.method
local path = path or wrk.path
local headers = headers or wrk.headers
local body = body or wrk.body
local s = {}

if not headers["Host"] then
headers["Host"] = wrk.headers["Host"]
end

headers["Content-Length"] = body and string.len(body)

s[1] = string.format("%s %s HTTP/1.1", method, path)
for name, value in pairs(headers) do
s[#s+1] = string.format("%s: %s", name, value)
end

s[#s+1] = ""
s[#s+1] = body or ""

return table.concat(s, "\r\n")
end

return wrk

0 comments on commit 6f0aa32

Please sign in to comment.