Skip to content

Commit

Permalink
Change WASM direct heap access to use helper functions (dotnet#61355)
Browse files Browse the repository at this point in the history
Direct heap writes via Module.HEAPxx[y] = func(...) are incorrect because the left-hand side (according to spec) is evaluated before the right, so if evaluating func(...) causes the heap to grow, the assignment target becomes a detached buffer and the write goes nowhere, breaking your application. This PR introduces a new set of helper functions for memory reads and writes.
  • Loading branch information
kg authored Nov 10, 2021
1 parent e5eafc9 commit 3dce93b
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 38 deletions.
17 changes: 9 additions & 8 deletions src/mono/wasm/runtime/cs-to-js.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { get_js_owned_object_by_gc_handle, js_owned_gc_handle_symbol, mono_wasm_
import { mono_method_get_call_signature, call_method, wrap_error } from "./method-calls";
import { _js_to_mono_obj } from "./js-to-cs";
import { _are_promises_supported, _create_cancelable_promise } from "./cancelable-promise";
import { getU32, getI32, getF32, getF64 } from "./memory";

// see src/mono/wasm/driver.c MARSHAL_TYPE_xxx and Runtime.cs MarshalType
export enum MarshalType {
Expand Down Expand Up @@ -132,7 +133,7 @@ export function _unbox_mono_obj_root_with_known_nonprimitive_type(root: WasmRoot

let typePtr = MonoTypeNull;
if ((type === MarshalType.VT) || (type == MarshalType.OBJECT)) {
typePtr = <MonoType><any>Module.HEAPU32[<any>unbox_buffer >>> 2];
typePtr = <MonoType><any>getU32(unbox_buffer);
if (<number><any>typePtr < 1024)
throw new Error(`Got invalid MonoType ${typePtr} for object at address ${root.value} (root located at ${root.get_address()})`);
}
Expand All @@ -148,20 +149,20 @@ export function _unbox_mono_obj_root(root: WasmRoot<any>): any {
const type = cwraps.mono_wasm_try_unbox_primitive_and_get_type(root.value, unbox_buffer, runtimeHelpers._unbox_buffer_size);
switch (type) {
case MarshalType.INT:
return Module.HEAP32[<any>unbox_buffer >>> 2];
return getI32(unbox_buffer);
case MarshalType.UINT32:
return Module.HEAPU32[<any>unbox_buffer >>> 2];
return getU32(unbox_buffer);
case MarshalType.POINTER:
// FIXME: Is this right?
return Module.HEAPU32[<any>unbox_buffer >>> 2];
return getU32(unbox_buffer);
case MarshalType.FP32:
return Module.HEAPF32[<any>unbox_buffer >>> 2];
return getF32(unbox_buffer);
case MarshalType.FP64:
return Module.HEAPF64[<any>unbox_buffer >>> 3];
return getF64(unbox_buffer);
case MarshalType.BOOL:
return (Module.HEAP32[<any>unbox_buffer >>> 2]) !== 0;
return (getI32(unbox_buffer)) !== 0;
case MarshalType.CHAR:
return String.fromCharCode(Module.HEAP32[<any>unbox_buffer >>> 2]);
return String.fromCharCode(getI32(unbox_buffer));
case MarshalType.NULL:
return null;
default:
Expand Down
26 changes: 26 additions & 0 deletions src/mono/wasm/runtime/exports.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ import { mono_wasm_release_cs_owned_object } from "./gc-handles";
import { mono_wasm_web_socket_open, mono_wasm_web_socket_send, mono_wasm_web_socket_receive, mono_wasm_web_socket_close, mono_wasm_web_socket_abort } from "./web-socket";
import cwraps from "./cwraps";
import { ArgsMarshalString } from "./method-binding";
import {
setI8, setI16, setI32, setI64,
setU8, setU16, setU32, setF32, setF64,
getI8, getI16, getI32, getI64,
getU8, getU16, getU32, getF32, getF64,
} from "./memory";

export const MONO: MONO = <any>{
// current "public" MONO API
Expand Down Expand Up @@ -251,6 +257,26 @@ export const INTERNAL: any = {
mono_wasm_detach_debugger,
mono_wasm_raise_debug_event,
mono_wasm_runtime_is_ready: runtimeHelpers.mono_wasm_runtime_is_ready,

// memory accessors
setI8,
setI16,
setI32,
setI64,
setU8,
setU16,
setU32,
setF32,
setF64,
getI8,
getI16,
getI32,
getI64,
getU8,
getU16,
getU32,
getF32,
getF64,
};

// this represents visibility in the javascript
Expand Down
9 changes: 5 additions & 4 deletions src/mono/wasm/runtime/js-to-cs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { js_string_to_mono_string, js_string_to_mono_string_interned } from "./s
import { isThenable } from "./cancelable-promise";
import { has_backing_array_buffer } from "./buffers";
import { Int32Ptr, JSHandle, MonoArray, MonoMethod, MonoObject, MonoObjectNull, MonoString, wasm_type_symbol } from "./types";
import { setI32, setU32, setF64 } from "./memory";

// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types
export function _js_to_mono_uri(should_add_in_flight: boolean, js_obj: any): MonoObject {
Expand Down Expand Up @@ -109,22 +110,22 @@ function _extract_mono_obj(should_add_in_flight: boolean, js_obj: any): MonoObje
}

function _box_js_int(js_obj: number) {
Module.HEAP32[<any>runtimeHelpers._box_buffer >>> 2] = js_obj;
setI32(runtimeHelpers._box_buffer, js_obj);
return cwraps.mono_wasm_box_primitive(runtimeHelpers._class_int32, runtimeHelpers._box_buffer, 4);
}

function _box_js_uint(js_obj: number) {
Module.HEAPU32[<any>runtimeHelpers._box_buffer >>> 2] = js_obj;
setU32(runtimeHelpers._box_buffer, js_obj);
return cwraps.mono_wasm_box_primitive(runtimeHelpers._class_uint32, runtimeHelpers._box_buffer, 4);
}

function _box_js_double(js_obj: number) {
Module.HEAPF64[<any>runtimeHelpers._box_buffer >>> 3] = js_obj;
setF64(runtimeHelpers._box_buffer, js_obj);
return cwraps.mono_wasm_box_primitive(runtimeHelpers._class_double, runtimeHelpers._box_buffer, 8);
}

export function _box_js_bool(js_obj: boolean): MonoObject {
Module.HEAP32[<any>runtimeHelpers._box_buffer >>> 2] = js_obj ? 1 : 0;
setI32(runtimeHelpers._box_buffer, js_obj ? 1 : 0);
return cwraps.mono_wasm_box_primitive(runtimeHelpers._class_boolean, runtimeHelpers._box_buffer, 4);
}

Expand Down
77 changes: 77 additions & 0 deletions src/mono/wasm/runtime/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,80 @@ export function _release_temp_frame(): void {
for (let i = 0, l = frame.length; i < l; i++)
Module._free(frame[i]);
}

type _MemOffset = number | VoidPtr | NativePointer;

export function setU8 (offset: _MemOffset, value: number) : void {
Module.HEAPU8[<any>offset] = value;
}

export function setU16 (offset: _MemOffset, value: number) : void {
Module.HEAPU16[<any>offset >>> 1] = value;
}

export function setU32 (offset: _MemOffset, value: number) : void {
Module.HEAPU32[<any>offset >>> 2] = value;
}

export function setI8 (offset: _MemOffset, value: number) : void {
Module.HEAP8[<any>offset] = value;
}

export function setI16 (offset: _MemOffset, value: number) : void {
Module.HEAP16[<any>offset >>> 1] = value;
}

export function setI32 (offset: _MemOffset, value: number) : void {
Module.HEAP32[<any>offset >>> 2] = value;
}

// NOTE: Accepts a number, not a BigInt, so values over Number.MAX_SAFE_INTEGER will be corrupted
export function setI64 (offset: _MemOffset, value: number) : void {
Module.setValue(<VoidPtr><any>offset, value, "i64");
}

export function setF32 (offset: _MemOffset, value: number) : void {
Module.HEAPF32[<any>offset >>> 2] = value;
}

export function setF64 (offset: _MemOffset, value: number) : void {
Module.HEAPF64[<any>offset >>> 3] = value;
}


export function getU8 (offset: _MemOffset) : number {
return Module.HEAPU8[<any>offset];
}

export function getU16 (offset: _MemOffset) : number {
return Module.HEAPU16[<any>offset >>> 1];
}

export function getU32 (offset: _MemOffset) : number {
return Module.HEAPU32[<any>offset >>> 2];
}

export function getI8 (offset: _MemOffset) : number {
return Module.HEAP8[<any>offset];
}

export function getI16 (offset: _MemOffset) : number {
return Module.HEAP16[<any>offset >>> 1];
}

export function getI32 (offset: _MemOffset) : number {
return Module.HEAP32[<any>offset >>> 2];
}

// NOTE: Returns a number, not a BigInt. This means values over Number.MAX_SAFE_INTEGER will be corrupted
export function getI64 (offset: _MemOffset) : number {
return Module.getValue(<number><any>offset, "i64");
}

export function getF32 (offset: _MemOffset) : number {
return Module.HEAPF32[<any>offset >>> 2];
}

export function getF64 (offset: _MemOffset) : number {
return Module.HEAPF64[<any>offset >>> 3];
}
51 changes: 30 additions & 21 deletions src/mono/wasm/runtime/method-binding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import { BINDING, runtimeHelpers } from "./modules";
import { js_to_mono_enum, _js_to_mono_obj, _js_to_mono_uri } from "./js-to-cs";
import { js_string_to_mono_string, js_string_to_mono_string_interned } from "./strings";
import { MarshalType, _unbox_mono_obj_root_with_known_nonprimitive_type } from "./cs-to-js";
import { _create_temp_frame } from "./memory";
import {
_create_temp_frame,
getI32, getU32, getF32, getF64,
setI32, setU32, setF32, setF64, setI64,
} from "./memory";
import {
_get_args_root_buffer_for_method_call, _get_buffer_for_method_call,
_handle_exception_for_call, _teardown_after_call
Expand Down Expand Up @@ -213,15 +217,18 @@ export function _compile_converter_for_marshal_string(args_marshal: ArgsMarshalS
Module,
_malloc: Module._malloc,
mono_wasm_unbox_rooted: cwraps.mono_wasm_unbox_rooted,
setI32,
setU32,
setF32,
setF64,
setI64
};
let indirectLocalOffset = 0;

body.push(
"if (!method) throw new Error('no method provided');",
`if (!buffer) buffer = _malloc (${bufferSizeBytes});`,
`let indirectStart = buffer + ${indirectBaseOffset};`,
"let indirect32 = indirectStart >>> 2, indirect64 = indirectStart >>> 3;",
"let buffer32 = buffer >>> 2;",
""
);

Expand Down Expand Up @@ -253,37 +260,35 @@ export function _compile_converter_for_marshal_string(args_marshal: ArgsMarshalS
body.push(`${valueKey} = mono_wasm_unbox_rooted (${valueKey});`);

if (step.indirect) {
let heapArrayName = null;
const offsetText = `(indirectStart + ${indirectLocalOffset})`;

switch (step.indirect) {
case "u32":
heapArrayName = "HEAPU32";
body.push(`setU32(${offsetText}, ${valueKey});`);
break;
case "i32":
heapArrayName = "HEAP32";
body.push(`setI32(${offsetText}, ${valueKey});`);
break;
case "float":
heapArrayName = "HEAPF32";
body.push(`setF32(${offsetText}, ${valueKey});`);
break;
case "double":
body.push(`Module.HEAPF64[indirect64 + ${(indirectLocalOffset >>> 3)}] = ${valueKey};`);
body.push(`setF64(${offsetText}, ${valueKey});`);
break;
case "i64":
body.push(`Module.setValue (indirectStart + ${indirectLocalOffset}, ${valueKey}, 'i64');`);
body.push(`setI64(${offsetText}, ${valueKey});`);
break;
default:
throw new Error("Unimplemented indirect type: " + step.indirect);
}

if (heapArrayName)
body.push(`Module.${heapArrayName}[indirect32 + ${(indirectLocalOffset >>> 2)}] = ${valueKey};`);

body.push(`Module.HEAP32[buffer32 + ${i}] = indirectStart + ${indirectLocalOffset};`, "");
body.push(`setU32(buffer + (${i} * 4), ${offsetText});`);
indirectLocalOffset += step.size!;
} else {
body.push(`Module.HEAP32[buffer32 + ${i}] = ${valueKey};`, "");
body.push(`setI32(buffer + (${i} * 4), ${valueKey});`);
indirectLocalOffset += 4;
}
body.push("");
}

body.push("return buffer;");
Expand Down Expand Up @@ -404,7 +409,11 @@ export function mono_bind_method(method: MonoMethod, this_arg: MonoObject | null
this_arg,
token,
unbox_buffer,
unbox_buffer_size
unbox_buffer_size,
getI32,
getU32,
getF32,
getF64
};

const converterKey = converter ? "converter_" + converter.name : "";
Expand Down Expand Up @@ -493,18 +502,18 @@ export function mono_bind_method(method: MonoMethod, this_arg: MonoObject | null
" let resultType = mono_wasm_try_unbox_primitive_and_get_type (resultPtr, unbox_buffer, unbox_buffer_size);",
" switch (resultType) {",
` case ${MarshalType.INT}:`,
" result = Module.HEAP32[unbox_buffer >>> 2]; break;",
" result = getI32(unbox_buffer); break;",
` case ${MarshalType.POINTER}:`, // FIXME: Is this right?
` case ${MarshalType.UINT32}:`,
" result = Module.HEAPU32[unbox_buffer >>> 2]; break;",
" result = getU32(unbox_buffer); break;",
` case ${MarshalType.FP32}:`,
" result = Module.HEAPF32[unbox_buffer >>> 2]; break;",
" result = getF32(unbox_buffer); break;",
` case ${MarshalType.FP64}:`,
" result = Module.HEAPF64[unbox_buffer >>> 3]; break;",
" result = getF64(unbox_buffer); break;",
` case ${MarshalType.BOOL}:`,
" result = (Module.HEAP32[unbox_buffer >>> 2]) !== 0; break;",
" result = getI32(unbox_buffer) !== 0; break;",
` case ${MarshalType.CHAR}:`,
" result = String.fromCharCode(Module.HEAP32[unbox_buffer >>> 2]); break;",
" result = String.fromCharCode(getI32(unbox_buffer)); break;",
" default:",
" result = _unbox_mono_obj_root_with_known_nonprimitive_type (resultRoot, resultType, unbox_buffer); break;",
" }",
Expand Down
9 changes: 7 additions & 2 deletions src/mono/wasm/runtime/roots.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,18 @@ export class WasmRootBuffer {
return this.__offset32 + index;
}

// NOTE: These functions do not use the helpers from memory.ts because WasmRoot.get and WasmRoot.set
// are hot-spots when you profile any application that uses the bindings extensively.

get(index: number): ManagedPointer {
this._check_in_range(index);
return <any>Module.HEAP32[this.get_address_32(index)];
const offset = this.get_address_32(index);
return <any>Module.HEAP32[offset];
}

set(index: number, value: ManagedPointer): ManagedPointer {
Module.HEAP32[this.get_address_32(index)] = <any>value;
const offset = this.get_address_32(index);
Module.HEAP32[offset] = <any>value;
return value;
}

Expand Down
7 changes: 4 additions & 3 deletions src/mono/wasm/runtime/strings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { CharPtr, MonoString, MonoStringNull, NativePointer } from "./types";
import { Module } from "./modules";
import cwraps from "./cwraps";
import { mono_wasm_new_root } from "./roots";
import { getI32 } from "./memory";

export class StringDecoder {

Expand All @@ -31,9 +32,9 @@ export class StringDecoder {
cwraps.mono_wasm_string_get_data(mono_string, <any>ppChars, <any>pLengthBytes, <any>pIsInterned);

let result = mono_wasm_empty_string;
const lengthBytes = Module.HEAP32[pLengthBytes >>> 2],
pChars = Module.HEAP32[ppChars >>> 2],
isInterned = Module.HEAP32[pIsInterned >>> 2];
const lengthBytes = getI32(pLengthBytes),
pChars = getI32(ppChars),
isInterned = getI32(pIsInterned);

if (pLengthBytes && pChars) {
if (
Expand Down

0 comments on commit 3dce93b

Please sign in to comment.