Skip to content

Commit

Permalink
Enable topological linking of functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
axel22 committed Dec 12, 2019
1 parent 7416e51 commit 4d41bf4
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ private void readImportSection() {
switch (importType) {
case ImportIdentifier.FUNCTION: {
int typeIndex = readTypeIndex();
module.symbolTable().importFunction(moduleName, memberName, typeIndex);
module.symbolTable().importFunction(context, moduleName, memberName, typeIndex);
moduleFunctionIndex++;
break;
}
Expand Down Expand Up @@ -1065,7 +1065,7 @@ private void readExportSection(WasmContext context) {
switch (exportType) {
case ExportIdentifier.FUNCTION: {
int functionIndex = readFunctionIndex();
module.symbolTable().exportFunction(exportName, functionIndex);
module.symbolTable().exportFunction(context, functionIndex, exportName);
break;
}
case ExportIdentifier.TABLE: {
Expand Down
180 changes: 135 additions & 45 deletions wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@

import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.interop.UnknownIdentifierException;
import org.graalvm.wasm.Linker.ResolutionDag.DataDecl;
import org.graalvm.wasm.Linker.ResolutionDag.Decl;
import org.graalvm.wasm.Linker.ResolutionDag.ExportMemoryDecl;
import org.graalvm.wasm.Linker.ResolutionDag.ImportMemoryDecl;
import org.graalvm.wasm.Linker.ResolutionDag.DataSym;
import org.graalvm.wasm.Linker.ResolutionDag.Sym;
import org.graalvm.wasm.Linker.ResolutionDag.ExportMemorySym;
import org.graalvm.wasm.Linker.ResolutionDag.ImportMemorySym;
import org.graalvm.wasm.Linker.ResolutionDag.Resolver;
import org.graalvm.wasm.constants.GlobalModifier;
import org.graalvm.wasm.constants.GlobalResolution;
Expand All @@ -58,6 +58,8 @@
import java.util.Map;
import java.util.function.Consumer;

import static org.graalvm.wasm.Linker.ResolutionDag.*;

public class Linker {
private enum LinkState {
notLinked,
Expand Down Expand Up @@ -100,13 +102,12 @@ private void tryLinkOutsidePartialEvaluation() {
if (linkState == LinkState.notLinked) {
linkState = LinkState.inProgress;
Map<String, WasmModule> modules = WasmContext.getCurrent().modules();
for (WasmModule module : modules.values()) {
linkFunctions(module);
module.setLinked();
}
// TODO: Once topological linking starts handling all the import kinds,
// remove the previous loop.
linkTopologically();
for (WasmModule module : modules.values()) {
module.setLinked();
}
for (WasmModule module : modules.values()) {
final WasmFunction start = module.symbolTable().startFunction();
if (start != null) {
Expand Down Expand Up @@ -248,6 +249,37 @@ int importTable(WasmContext context, WasmModule module, String importedModuleNam
}
}

void resolveFunctionImport(WasmContext context, WasmModule module, WasmFunction function) {
final Runnable resolveAction = () -> {
final WasmModule importedModule = context.modules().get(function.importedModuleName());
if (importedModule == null) {
throw new WasmLinkerException("The module '" + function.importedModuleName() + "', referenced by the import '" + function.importedFunctionName() + "' in the module '" + module.name() +
"', does not exist.");
}
WasmFunction importedFunction;
try {
importedFunction = (WasmFunction) importedModule.readMember(function.importedFunctionName());
} catch (UnknownIdentifierException e) {
importedFunction = null;
}
if (importedFunction == null) {
throw new WasmLinkerException("The imported function '" + function.importedFunctionName() + "', referenced in the module '" + module.name() +
"', does not exist in the imported module '" + function.importedModuleName() + "'.");
}
function.setCallTarget(importedFunction.resolveCallTarget());
};
Sym[] dependencies = new Sym[]{new ExportFunctionSym(function.importDescriptor().moduleName, function.importDescriptor().memberName)};
resolutionDag.resolveLater(new ImportFunctionSym(module.name(), function.importDescriptor()), dependencies, resolveAction);
}

void resolveFunctionExport(WasmModule module, int functionIndex, String exportedFunctionName) {
final Runnable resolveAction = () -> {
};
final ImportDescriptor importDescriptor = module.symbolTable().function(functionIndex).importDescriptor();
final Sym[] dependencies = importDescriptor != null ? new Sym[]{new ImportFunctionSym(module.name(), importDescriptor)} : new Sym[0];
resolutionDag.resolveLater(new ExportFunctionSym(module.name(), exportedFunctionName), dependencies, resolveAction);
}

void resolveMemoryImport(WasmContext context, WasmModule module, ImportDescriptor importDescriptor, int initSize, int maxSize, Consumer<WasmMemory> setMemory) {
String importedModuleName = importDescriptor.moduleName;
String importedMemoryName = importDescriptor.memberName;
Expand Down Expand Up @@ -277,15 +309,15 @@ void resolveMemoryImport(WasmContext context, WasmModule module, ImportDescripto
}
setMemory.accept(memory);
};
resolutionDag.resolveLater(new ImportMemoryDecl(module.name(), importDescriptor), new Decl[]{new ExportMemoryDecl(importedModuleName, importedMemoryName)}, resolveAction);
resolutionDag.resolveLater(new ImportMemorySym(module.name(), importDescriptor), new Sym[]{new ExportMemorySym(importedModuleName, importedMemoryName)}, resolveAction);
}

void resolveMemoryExport(WasmModule module, String exportedMemoryName) {
final Runnable resolveAction = () -> {
};
final ImportDescriptor importDescriptor = module.symbolTable().importedMemory();
final Decl[] dependencies = importDescriptor != null ? new Decl[]{new ImportMemoryDecl(module.name(), importDescriptor)} : new Decl[0];
resolutionDag.resolveLater(new ExportMemoryDecl(module.name(), exportedMemoryName), dependencies, resolveAction);
final Sym[] dependencies = importDescriptor != null ? new Sym[]{new ImportMemorySym(module.name(), importDescriptor)} : new Sym[0];
resolutionDag.resolveLater(new ExportMemorySym(module.name(), exportedMemoryName), dependencies, resolveAction);
}

void resolveDataSection(WasmModule module, int dataSectionId, long baseAddress, int byteLength, byte[] data, boolean priorDataSectionsResolved) {
Expand All @@ -299,20 +331,78 @@ void resolveDataSection(WasmModule module, int dataSectionId, long baseAddress,
memory.store_i32_8(baseAddress + writeOffset, b);
}
};
final ImportMemoryDecl importMemoryDecl = new ImportMemoryDecl(module.name(), module.symbolTable().importedMemory());
final Decl[] dependencies = priorDataSectionsResolved ? new Decl[]{importMemoryDecl} : new Decl[]{importMemoryDecl, new DataDecl(module.name(), dataSectionId - 1)};
resolutionDag.resolveLater(new DataDecl(module.name(), dataSectionId), dependencies, resolveAction);
final ImportMemorySym importMemoryDecl = new ImportMemorySym(module.name(), module.symbolTable().importedMemory());
final Sym[] dependencies = priorDataSectionsResolved ? new Sym[]{importMemoryDecl} : new Sym[]{importMemoryDecl, new DataSym(module.name(), dataSectionId - 1)};
resolutionDag.resolveLater(new DataSym(module.name(), dataSectionId), dependencies, resolveAction);
}

static class ResolutionDag {
abstract static class Decl {
abstract static class Sym {
}

static class ImportFunctionSym extends Sym {
final String moduleName;
final ImportDescriptor importDescriptor;

ImportFunctionSym(String moduleName, ImportDescriptor importDescriptor) {
this.moduleName = moduleName;
this.importDescriptor = importDescriptor;
}

@Override
public String toString() {
return String.format("(import %s from %s into %s)", importDescriptor.memberName, importDescriptor.moduleName, moduleName);
}

@Override
public int hashCode() {
return moduleName.hashCode() ^ importDescriptor.hashCode();
}

@Override
public boolean equals(Object object) {
if (!(object instanceof ImportFunctionSym)) {
return false;
}
final ImportFunctionSym that = (ImportFunctionSym) object;
return this.moduleName.equals(that.moduleName) && this.importDescriptor.equals(that.importDescriptor);
}
}

static class ExportFunctionSym extends Sym {
final String moduleName;
final String memoryName;

ExportFunctionSym(String moduleName, String memoryName) {
this.moduleName = moduleName;
this.memoryName = memoryName;
}

@Override
public String toString() {
return String.format("(export %s from %s)", memoryName, moduleName);
}

@Override
public int hashCode() {
return moduleName.hashCode() ^ memoryName.hashCode();
}

@Override
public boolean equals(Object object) {
if (!(object instanceof ExportFunctionSym)) {
return false;
}
final ExportFunctionSym that = (ExportFunctionSym) object;
return this.moduleName.equals(that.moduleName) && this.memoryName.equals(that.memoryName);
}
}

static class ImportMemoryDecl extends Decl {
static class ImportMemorySym extends Sym {
final String moduleName;
final ImportDescriptor importDescriptor;

ImportMemoryDecl(String moduleName, ImportDescriptor importDescriptor) {
ImportMemorySym(String moduleName, ImportDescriptor importDescriptor) {
this.moduleName = moduleName;
this.importDescriptor = importDescriptor;
}
Expand All @@ -329,19 +419,19 @@ public int hashCode() {

@Override
public boolean equals(Object object) {
if (!(object instanceof ImportMemoryDecl)) {
if (!(object instanceof ImportMemorySym)) {
return false;
}
final ImportMemoryDecl that = (ImportMemoryDecl) object;
final ImportMemorySym that = (ImportMemorySym) object;
return this.moduleName.equals(that.moduleName) && this.importDescriptor.equals(that.importDescriptor);
}
}

static class ExportMemoryDecl extends Decl {
static class ExportMemorySym extends Sym {
final String moduleName;
final String memoryName;

ExportMemoryDecl(String moduleName, String memoryName) {
ExportMemorySym(String moduleName, String memoryName) {
this.moduleName = moduleName;
this.memoryName = memoryName;
}
Expand All @@ -358,19 +448,19 @@ public int hashCode() {

@Override
public boolean equals(Object object) {
if (!(object instanceof ExportMemoryDecl)) {
if (!(object instanceof ExportMemorySym)) {
return false;
}
final ExportMemoryDecl that = (ExportMemoryDecl) object;
final ExportMemorySym that = (ExportMemorySym) object;
return this.moduleName.equals(that.moduleName) && this.memoryName.equals(that.memoryName);
}
}

static class DataDecl extends Decl {
static class DataSym extends Sym {
final String moduleName;
final int dataSectionId;

DataDecl(String moduleName, int dataSectionId) {
DataSym(String moduleName, int dataSectionId) {
this.moduleName = moduleName;
this.dataSectionId = dataSectionId;
}
Expand All @@ -387,20 +477,20 @@ public int hashCode() {

@Override
public boolean equals(Object object) {
if (!(object instanceof DataDecl)) {
if (!(object instanceof DataSym)) {
return false;
}
final DataDecl that = (DataDecl) object;
final DataSym that = (DataSym) object;
return this.dataSectionId == that.dataSectionId && this.moduleName.equals(that.moduleName);
}
}

static class Resolver {
final Decl element;
final Decl[] dependencies;
final Sym element;
final Sym[] dependencies;
final Runnable action;

Resolver(Decl element, Decl[] dependencies, Runnable action) {
Resolver(Sym element, Sym[] dependencies, Runnable action) {
this.element = element;
this.dependencies = dependencies;
this.action = action;
Expand All @@ -412,34 +502,34 @@ public String toString() {
}
}

private final Map<Decl, Resolver> resolutions;
private final Map<Sym, Resolver> resolutions;

ResolutionDag() {
this.resolutions = new HashMap<>();
}

void resolveLater(Decl element, Decl[] dependencies, Runnable action) {
void resolveLater(Sym element, Sym[] dependencies, Runnable action) {
resolutions.put(element, new Resolver(element, dependencies, action));
}

void clear() {
resolutions.clear();
}

private static String renderCycle(List<Decl> stack) {
private static String renderCycle(List<Sym> stack) {
StringBuilder result = new StringBuilder();
String arrow = "";
for (Decl decl : stack) {
result.append(arrow).append(decl.toString());
for (Sym sym : stack) {
result.append(arrow).append(sym.toString());
arrow = " -> ";
}
return result.toString();
}

private void toposort(Decl decl, Map<Decl, Boolean> marks, ArrayList<Resolver> sorted, List<Decl> stack) {
final Resolver resolver = resolutions.get(decl);
private void toposort(Sym sym, Map<Sym, Boolean> marks, ArrayList<Resolver> sorted, List<Sym> stack) {
final Resolver resolver = resolutions.get(sym);
if (resolver != null) {
final Boolean mark = marks.get(decl);
final Boolean mark = marks.get(sym);
if (Boolean.TRUE.equals(mark)) {
// This node was already sorted.
return;
Expand All @@ -449,22 +539,22 @@ private void toposort(Decl decl, Map<Decl, Boolean> marks, ArrayList<Resolver> s
throw new WasmLinkerException(String.format("Detected a cycle in the import dependencies: %s",
renderCycle(stack)));
}
marks.put(decl, Boolean.FALSE);
stack.add(decl);
for (Decl dependency : resolver.dependencies) {
marks.put(sym, Boolean.FALSE);
stack.add(sym);
for (Sym dependency : resolver.dependencies) {
toposort(dependency, marks, sorted, stack);
}
marks.put(decl, Boolean.TRUE);
marks.put(sym, Boolean.TRUE);
stack.remove(stack.size() - 1);
sorted.add(resolver);
}
}

Resolver[] toposort() {
Map<Decl, Boolean> marks = new HashMap<>();
Map<Sym, Boolean> marks = new HashMap<>();
ArrayList<Resolver> sorted = new ArrayList<>();
for (Decl decl : resolutions.keySet()) {
toposort(decl, marks, sorted, new ArrayList<>());
for (Sym sym : resolutions.keySet()) {
toposort(sym, marks, sorted, new ArrayList<>());
}
return sorted.toArray(new Resolver[sorted.size()]);
}
Expand Down
15 changes: 9 additions & 6 deletions wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,10 @@ WasmFunction declareFunction(int typeIndex) {
return function;
}

public WasmFunction declareExportedFunction(int typeIndex, String exportedName) {
public WasmFunction declareExportedFunction(WasmContext context, int typeIndex, String exportedName) {
checkNotLinked();
final WasmFunction function = declareFunction(typeIndex);
exportFunction(exportedName, function.index());
exportFunction(context, function.index(), exportedName);
return function;
}

Expand Down Expand Up @@ -421,20 +421,23 @@ ByteArrayList functionTypeArgumentTypes(int typeIndex) {
return types;
}

void exportFunction(String exportName, int functionIndex) {
void exportFunction(WasmContext context, int functionIndex, String exportName) {
checkNotLinked();
exportedFunctions.put(exportName, functions[functionIndex]);
exportedFunctionsByIndex.put(functionIndex, exportName);
context.linker().resolveFunctionExport(module, functionIndex, exportName);
}

Map<String, WasmFunction> exportedFunctions() {
return exportedFunctions;
}

WasmFunction importFunction(String moduleName, String functionName, int typeIndex) {
WasmFunction importFunction(WasmContext context, String moduleName, String functionName, int typeIndex) {
checkNotLinked();
WasmFunction function = allocateFunction(typeIndex, new ImportDescriptor(moduleName, functionName));
final ImportDescriptor importDescriptor = new ImportDescriptor(moduleName, functionName);
WasmFunction function = allocateFunction(typeIndex, importDescriptor);
importedFunctions.add(function);
context.linker().resolveFunctionImport(context, module, function);
return function;
}

Expand Down Expand Up @@ -681,8 +684,8 @@ public void exportMemory(WasmContext context, String name) {
if (!memoryExists()) {
throw new WasmException("No memory has been declared or imported, so memory cannot be exported.");
}
context.linker().resolveMemoryExport(module, name);
exportedMemory = name;
context.linker().resolveMemoryExport(module, name);
}

public WasmMemory memory() {
Expand Down
Loading

0 comments on commit 4d41bf4

Please sign in to comment.