Skip to content

Commit

Permalink
[GR-19378] Added unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
tzezula committed Nov 6, 2019
1 parent e7f079c commit 0a2d1cc
Showing 1 changed file with 188 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,37 @@
package com.oracle.truffle.api.test.polyglot;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URISyntaxException;
import java.net.JarURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLConnection;
import java.net.URLStreamHandler;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.CodeSource;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.jar.JarFile;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;

import org.junit.Assert;
import org.junit.Assume;
import org.junit.Test;

import com.oracle.truffle.api.TruffleLanguage;
Expand All @@ -66,22 +80,52 @@ public class LanguageCacheTest {

@Test
public void testDuplicateLanguageIds() throws Throwable {
ClassLoader testClassLoader = new TestClassLoader();
CodeSource codeSource = LanguageCacheTest.class.getProtectionDomain().getCodeSource();
Assume.assumeNotNull(codeSource);
Path location = Paths.get(codeSource.getLocation().toURI());
Function<String, List<URL>> loader = new Function<String, List<URL>>() {
@Override
public List<URL> apply(String binaryName) {
try {
if (Files.isRegularFile(location)) {
return Collections.singletonList(new URL("jar:" + location.toUri().toString() + "!/" + binaryName));
} else {
return Collections.singletonList(new URL(location.toUri().toString() + binaryName));
}
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
};
ClassLoader testClassLoader = new TestClassLoader(loader);
try {
invokeLanguageCacheCreateLanguages(testClassLoader);
invokeLanguageCacheCreateLanguages(LanguageCacheTest.class.getClassLoader(), testClassLoader);
Assert.fail("Expected IllegalStateException");
} catch (IllegalStateException ise) {
// Expected exception
}
}

@Test
public void testNestedArchives() throws Throwable {
CodeSource codeSource = LanguageCacheTest.class.getProtectionDomain().getCodeSource();
Assume.assumeNotNull(codeSource);
URL location = codeSource.getLocation();
Path source = Paths.get(location.toURI());
Assume.assumeTrue(Files.isRegularFile(source));
try (NestedJarLoader loader = new NestedJarLoader(source, location + "!/inner.jar!/")) {
ClassLoader testClassLoader = new TestClassLoader(loader);
invokeLanguageCacheCreateLanguages(testClassLoader);
}
}

@SuppressWarnings("unchecked")
private static Map<String, Object> invokeLanguageCacheCreateLanguages(ClassLoader loader) throws Throwable {
private static Map<String, Object> invokeLanguageCacheCreateLanguages(ClassLoader... loaders) throws Throwable {
try {
final Class<?> langCacheClz = Class.forName("com.oracle.truffle.polyglot.LanguageCache", true, LanguageCacheTest.class.getClassLoader());
final Method createLanguages = langCacheClz.getDeclaredMethod("createLanguages", List.class);
createLanguages.setAccessible(true);
return (Map<String, Object>) createLanguages.invoke(null, Arrays.asList(LanguageCacheTest.class.getClassLoader(), loader));
return (Map<String, Object>) createLanguages.invoke(null, Arrays.asList(loaders));
} catch (InvocationTargetException ite) {
throw ite.getCause();
} catch (ReflectiveOperationException re) {
Expand Down Expand Up @@ -111,20 +155,23 @@ protected boolean isObjectOfLanguage(Object object) {
*/
private static final class TestClassLoader extends ClassLoader {

private static final Set<String> IMPORTANT_CLASSES;
private static final Set<String> IMPORTANT_RESOURCES;
static {
IMPORTANT_CLASSES = new HashSet<>();
IMPORTANT_CLASSES.add(DuplicateIdLanguage.class.getName());
IMPORTANT_CLASSES.add(LanguageCacheTestDuplicateIdLanguageProvider.class.getName());
IMPORTANT_RESOURCES = new HashSet<>();
IMPORTANT_RESOURCES.add(binaryName(DuplicateIdLanguage.class.getName()) + ".class");
IMPORTANT_RESOURCES.add(binaryName(LanguageCacheTestDuplicateIdLanguageProvider.class.getName()) + ".class");
}

TestClassLoader() {
private final Function<String, List<URL>> loader;

TestClassLoader(Function<String, List<URL>> loader) {
super(TestClassLoader.class.getClassLoader());
this.loader = loader;
}

@Override
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
if (!IMPORTANT_CLASSES.contains(name)) {
if (!IMPORTANT_RESOURCES.contains(binaryName(name) + ".class")) {
return super.loadClass(name, resolve);
} else {
synchronized (getClassLoadingLock(name)) {
Expand All @@ -142,28 +189,71 @@ protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundE

@Override
protected Class<?> findClass(String name) throws ClassNotFoundException {
if (!IMPORTANT_CLASSES.contains(name)) {
throw new IllegalArgumentException("Only " + String.join(", ", IMPORTANT_CLASSES) + " can be loaded.");
String filePath = binaryName(name) + ".class";
if (!IMPORTANT_RESOURCES.contains(filePath)) {
throw new IllegalArgumentException("Only " + String.join(", ", IMPORTANT_RESOURCES) + " can be loaded.");
}
try {
URL location = DuplicateIdLanguage.class.getProtectionDomain().getCodeSource().getLocation();
Path path = Paths.get(location.toURI());
if (Files.isRegularFile(path)) {
location = new URL("jar:" + location.toExternalForm() + "!/" + binaryName(name) + ".class");
} else {
location = new URL(location.toExternalForm() + binaryName(name) + ".class");
URL location = findResource(filePath);
if (location == null) {
throw new ClassNotFoundException("Cannot load class: " + name);
}
try (InputStream in = location.openStream(); ByteArrayOutputStream out = new ByteArrayOutputStream()) {
copy(in, out);
byte[] content = out.toByteArray();
definePackage(name);
return defineClass(name, content, 0, content.length);
}
} catch (URISyntaxException | IOException e) {
} catch (IOException e) {
throw new ClassNotFoundException("Cannot load class: " + name, e);
}
}

@Override
public URL getResource(String name) {
if (!IMPORTANT_RESOURCES.contains(name)) {
return super.getResource(name);
} else {
URL url = findResource(name);
return url != null ? url : getParent().getResource(name);
}
}

@Override
public Enumeration<URL> getResources(String name) throws IOException {
if (!IMPORTANT_RESOURCES.contains(name)) {
return super.getResources(name);
} else {
Enumeration<URL> e1 = findResources(name);
Enumeration<URL> e2 = getParent().getResources(name);
List<URL> result = new ArrayList<>();
addAll(result, e1);
addAll(result, e2);
return Collections.enumeration(result);
}
}

@Override
protected URL findResource(String name) {
try {
Enumeration<URL> e = findResources(name);
return e.hasMoreElements() ? e.nextElement() : null;
} catch (IOException ioe) {
return null;
}
}

@Override
protected Enumeration<URL> findResources(String name) throws IOException {
return Collections.enumeration(loader.apply(name));
}

private static <T> void addAll(Collection<? super T> dest, Enumeration<? extends T> src) {
while (src.hasMoreElements()) {
dest.add(src.nextElement());
}
}

@SuppressWarnings("deprecation")
private void definePackage(String className) {
String packageName = getPackageName(className);
Expand Down Expand Up @@ -192,4 +282,82 @@ private static String binaryName(String name) {
return name.replace(".", "/");
}
}

/**
* Simulates a jar file inside a container jar (war, ear) file.
*/
private static final class NestedJarLoader implements Function<String, List<URL>>, Closeable {

private final ZipFile zipFile;
private final String relocation;

private NestedJarLoader(Path delegate, String relocation) throws IOException {
if (!relocation.endsWith("!/")) {
throw new IllegalArgumentException("Relocation must point into an archive file.");
}
this.zipFile = new ZipFile(delegate.toFile());
this.relocation = relocation;
}

@Override
public List<URL> apply(String binaryName) {
String entryName = binaryName.charAt(0) == '/' ? binaryName.substring(1) : binaryName;
ZipEntry e = zipFile.getEntry(entryName);
if (e != null) {
try {
URL url = new URL("jar", null, -1, relocation + binaryName, new NestedJarURLStreamHandler(zipFile, e));
return Collections.singletonList(url);
} catch (MalformedURLException murl) {
throw new RuntimeException(murl);
}
}
return Collections.emptyList();
}

@Override
public void close() throws IOException {
zipFile.close();
}

private static final class NestedJarURLStreamHandler extends URLStreamHandler {
private final ZipFile zipFile;
private final ZipEntry entry;

NestedJarURLStreamHandler(ZipFile zipFile, ZipEntry entry) {
this.zipFile = zipFile;
this.entry = entry;
}

@Override
protected URLConnection openConnection(URL u) throws IOException {
return new JarURLConnection(u) {

@Override
public JarFile getJarFile() throws IOException {
throw new UnsupportedOperationException("Not supported.");
}

@Override
public URL getJarFileURL() {
try {
String surl = u.toString();
int index = surl.lastIndexOf("!/");
return new URL(surl.substring(0, index));
} catch (MalformedURLException mue) {
throw new IllegalArgumentException(mue);
}
}

@Override
public InputStream getInputStream() throws IOException {
return zipFile.getInputStream(entry);
}

@Override
public void connect() throws IOException {
}
};
}
}
}
}

0 comments on commit 0a2d1cc

Please sign in to comment.