Skip to content

Commit

Permalink
Revolutionize E2BIG
Browse files Browse the repository at this point in the history
  • Loading branch information
tbodt committed Feb 8, 2019
1 parent 2627dec commit c3217ca
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 76 deletions.
2 changes: 1 addition & 1 deletion kernel/calls.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int must_check user_write_string(addr_t addr, const char *buf);
dword_t sys_clone(dword_t flags, addr_t stack, addr_t ptid, addr_t tls, addr_t ctid);
dword_t sys_fork(void);
dword_t sys_vfork(void);
int sys_execve(const char *file, char *const argv[], char *const envp[]);
int sys_execve(const char *file, const char *argv, const char *envp);
dword_t _sys_execve(addr_t file, addr_t argv, addr_t envp);
dword_t sys_exit(dword_t status);
noreturn void do_exit(int status);
Expand Down
160 changes: 94 additions & 66 deletions kernel/exec.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ static inline dword_t align_stack(dword_t sp);
static inline ssize_t user_strlen(dword_t p);
static inline int user_memset(addr_t start, byte_t val, dword_t len);
static inline dword_t copy_string(dword_t sp, const char *string);
static inline dword_t copy_strings(dword_t sp, char *const strings[]);
static unsigned count_args(char *const args[]);
static inline dword_t copy_strings(dword_t sp, const char *strings);
static size_t strings_size(const char *args);
static unsigned strings_count(const char *args);

static int read_header(struct fd *fd, struct elf_header *header) {
int err;
Expand Down Expand Up @@ -122,7 +123,7 @@ static addr_t find_hole_for_elf(struct elf_header *header, struct prg_header *ph
return pt_find_hole(current->mem, size) << PAGE_BITS;
}

static int elf_exec(struct fd *fd, const char *file, char *const argv[], char *const envp[]) {
static int elf_exec(struct fd *fd, const char *file, const char *argv, const char *envp) {
int err = 0;

// read the headers
Expand Down Expand Up @@ -317,8 +318,8 @@ static int elf_exec(struct fd *fd, const char *file, char *const argv[], char *c
{AX_PLATFORM, platform_addr},
{0, 0}
};
dword_t argc = count_args(argv);
dword_t envc = count_args(envp);
dword_t argc = strings_count(argv);
dword_t envc = strings_count(envp);
sp -= ((argc + 1) + (envc + 1) + 1) * sizeof(dword_t);
sp -= sizeof(aux);
sp &=~ 0xf;
Expand Down Expand Up @@ -377,11 +378,24 @@ static int elf_exec(struct fd *fd, const char *file, char *const argv[], char *c
goto out_free_interp;
}

static unsigned count_args(char *const args[]) {
unsigned i;
for (i = 0; args[i] != NULL; i++)
;
return i;
// Returns the total memory needed to store the arguments, including all the terminating nulls
static size_t strings_size(const char *args) {
const char *args_end = args;
size_t arg_len;
do {
arg_len = strlen(args_end);
args_end += arg_len + 1;
} while (arg_len != 0);
return args_end - args;
}

static unsigned strings_count(const char *args) {
unsigned n = 0;
while (*args != '\0') {
args += strlen(args) + 1;
n++;
}
return n;
}

static inline dword_t align_stack(addr_t sp) {
Expand All @@ -395,12 +409,11 @@ static inline dword_t copy_string(addr_t sp, const char *string) {
return sp;
}

static inline dword_t copy_strings(addr_t sp, char *const strings[]) {
for (unsigned i = count_args(strings); i > 0; i--) {
sp = copy_string(sp, strings[i - 1]);
if (sp == 0)
return 0;
}
static inline dword_t copy_strings(addr_t sp, const char *strings) {
size_t size = strings_size(strings);
sp -= size;
if (user_write(sp, strings, size))
return 0;
return sp;
}

Expand All @@ -422,15 +435,15 @@ static inline int user_memset(addr_t start, byte_t val, dword_t len) {
return 0;
}

static int format_exec(struct fd *fd, const char *file, char *const argv[], char *const envp[]) {
static int format_exec(struct fd *fd, const char *file, const char *argv, const char *envp) {
int err = elf_exec(fd, file, argv, envp);
if (err != _ENOEXEC)
return err;
// other formats would go here
return _ENOEXEC;
}

static int shebang_exec(struct fd *fd, const char *file, char *const argv[], char *const envp[]) {
static int shebang_exec(struct fd *fd, const char *file, const char *argv, const char *envp) {
// read the first 128 bytes to get the shebang line out of
if (fd->ops->lseek(fd, 0, SEEK_SET))
return _EIO;
Expand Down Expand Up @@ -473,15 +486,22 @@ static int shebang_exec(struct fd *fd, const char *file, char *const argv[], cha
if (*argument == '\0')
argument = NULL;

int args_extra = 2;
size_t args_size = strings_size(argv);
size_t extra_args_size = strlen(interpreter) + 1 + strlen(file) + 1;
if (argument)
args_extra++;
char *real_argv[count_args(argv) + args_extra];
real_argv[0] = interpreter;
if (argument)
real_argv[1] = argument;
real_argv[args_extra - 1] = (char *) file; // maybe you'll have better luck getting rid of this cast
memcpy(real_argv + args_extra, argv + 1, (count_args(argv)) * sizeof(argv[0]));
extra_args_size += strlen(argument) + 1;
if (args_size + extra_args_size >= 4096)
return _E2BIG;

char real_argv[4096];
size_t n = 0;
strcpy(real_argv, interpreter);
n += strlen(interpreter) + 1;
if (argument) {
strcpy(real_argv + n, argument);
n += strlen(argument) + 1;
}
strcpy(real_argv + n, file);

struct fd *interpreter_fd = generic_open(interpreter, O_RDONLY_, 0);
if (IS_ERR(interpreter_fd))
Expand All @@ -491,7 +511,7 @@ static int shebang_exec(struct fd *fd, const char *file, char *const argv[], cha
return err;
}

int sys_execve(const char *file, char *const argv[], char *const envp[]) {
int sys_execve(const char *file, const char *argv, const char *envp) {
struct fd *fd = generic_open(file, O_RDONLY, 0);
if (IS_ERR(fd))
return PTR_ERR(fd);
Expand Down Expand Up @@ -565,52 +585,60 @@ int sys_execve(const char *file, char *const argv[], char *const envp[]) {
return 0;
}

static int user_read_string_array(addr_t addr, char *buf, size_t max) {
size_t i = 0;
size_t p = 0;
for (;;) {
addr_t str_addr;
if (user_get(addr + i * sizeof(addr_t), str_addr))
return _EFAULT;
if (str_addr == 0)
break;
size_t str_p = 0;
for (;;) {
if (p >= max)
return _E2BIG;
if (user_get(str_addr + str_p, buf[p]))
return _EFAULT;
str_p++;
p++;
if (buf[p - 1] == '\0')
break;
}
i++;
}
if (p >= max)
return _E2BIG;
buf[p] = '\0';
return 0;
}

#define MAX_ARGS 256 // for now
dword_t _sys_execve(addr_t filename_addr, addr_t argv_addr, addr_t envp_addr) {
// TODO this code is shit, fix it
char filename[MAX_PATH];
if (user_read_string(filename_addr, filename, sizeof(filename)))
return _EFAULT;
char *argv[MAX_ARGS + 1];
int i;
addr_t arg;
char argv[4096];
int err = user_read_string_array(argv_addr, argv, sizeof(argv));
if (err < 0)
return err;
char envp[4096];
err = user_read_string_array(envp_addr, envp, sizeof(envp));
if (err < 0)
return err;

STRACE("execve(\"%s\", {", filename);
for (i = 0; ; i++) {
if (user_get(argv_addr + i * 4, arg))
return _EFAULT;
if (arg == 0)
break;
if (i >= MAX_ARGS)
return _E2BIG;
argv[i] = malloc(MAX_PATH);
if (user_read_string(arg, argv[i], MAX_PATH))
return _EFAULT;
if (i < 16)
STRACE("\"%.100s\", ", argv[i]);
else if (i == 16)
STRACE("...");
const char *args = argv;
while (*args != '\0') {
STRACE("\"%s\", ", args);
args += strlen(args) + 1;
}
argv[i] = NULL;
char *envp[MAX_ARGS + 1];
STRACE("}, {");
for (i = 0; ; i++) {
if (user_get(envp_addr + i * 4, arg))
return _EFAULT;
if (arg == 0)
break;
if (i >= MAX_ARGS)
return _E2BIG;
envp[i] = malloc(MAX_PATH);
if (user_read_string(arg, envp[i], MAX_PATH))
return _EFAULT;
STRACE("\"%.100s\", ", envp[i]);
args = envp;
while (*args != '\0') {
STRACE("\"%s\", ", args);
args += strlen(args) + 1;
}
envp[i] = NULL;
STRACE("})");
int res = sys_execve(filename, argv, envp);
for (i = 0; argv[i] != NULL; i++)
free(argv[i]);
for (i = 0; envp[i] != NULL; i++)
free(envp[i]);
return res;

return sys_execve(filename, argv, envp);
}
2 changes: 1 addition & 1 deletion kernel/log.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ static void output_line(const char *line) {
void vprintk(const char *msg, va_list args) {
// format the message
// I'm trusting you to not pass an absurdly long message
static __thread char buf[4096] = "";
static __thread char buf[16384] = "";
static __thread size_t buf_size = 0;
buf_size += vsprintf(buf + buf_size, msg, args);

Expand Down
4 changes: 2 additions & 2 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include "xX_main_Xx.h"

int main(int argc, char *const argv[]) {
char *const *envp = NULL;
char envp[100] = {0};
if (getenv("TERM"))
envp = (char *[]) {getenv("TERM") - strlen("TERM") - 1, NULL};
strcpy(envp, getenv("TERM") - strlen("TERM") - 1);
int err = xX_main_Xx(argc, argv, envp);
if (err < 0) {
fprintf(stderr, "%s\n", strerror(-err));
Expand Down
4 changes: 2 additions & 2 deletions tools/ptraceomatic.c
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,9 @@ static void prepare_tracee(int pid) {
}

int main(int argc, char *const argv[]) {
char *const *envp = NULL;
char envp[100] = {0};
if (getenv("TERM"))
envp = (char *[]) {getenv("TERM") - strlen("TERM") - 1, NULL};
strcpy(envp, getenv("TERM") - strlen("TERM") - 1);
int err = xX_main_Xx(argc, argv, envp);
if (err < 0) {
fprintf(stderr, "%s\n", strerror(-err));
Expand Down
15 changes: 11 additions & 4 deletions xX_main_Xx.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static void exit_handler(int code) {
// this function parses command line arguments and initializes global
// data structures. thanks programming discussions discord server for the name.
// https://discord.gg/9zT7NHP
static inline int xX_main_Xx(int argc, char *const argv[], char *const *envp) {
static inline int xX_main_Xx(int argc, char *const argv[], const char *envp) {
// parse cli options
int opt;
const char *root = "";
Expand Down Expand Up @@ -51,9 +51,16 @@ static inline int xX_main_Xx(int argc, char *const argv[], char *const *envp) {
fs_chdir(current->fs, pwd);
}

if (envp == NULL)
envp = (char *[]) {NULL};
err = sys_execve(argv[optind], argv + optind, envp);
char argv_copy[4096];
int i = optind;
size_t p = 0;
while (i < argc) {
strcpy(&argv_copy[p], argv[i]);
p += strlen(argv[i]) + 1;
i++;
}
argv_copy[p] = '\0';
err = sys_execve(argv[optind], argv_copy, envp == NULL ? "\0" : envp);
if (err < 0)
return err;
err = create_stdio(&real_tty_driver);
Expand Down

0 comments on commit c3217ca

Please sign in to comment.