Skip to content

Commit

Permalink
feat(Java reachability): Detect classes that use dynamic code loading (
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverchang authored Jan 17, 2025
1 parent ec5f69c commit 44f1371
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 16 deletions.
109 changes: 93 additions & 16 deletions experimental/javareach/cmd/reachable/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@ import (
"slices"
"strings"

"github.com/google/osv-scalibr/extractor"
"github.com/google/osv-scalibr/extractor/filesystem/language/java/archive"
"github.com/google/osv-scanner/experimental/javareach"
)

type ReachabilityResult struct {
Classes []string
UsesDynamicCodeLoading []string
}

// Usage:
//
// go run ./cmd/reachable -classpath=<classpath> path/to/root/class
Expand All @@ -25,7 +31,7 @@ import (
// This is unlike classpaths supported by Java runtimes (which supports
// specifying multiple directories and .jar files)
//
// TODO: Support non-uber jars by downloading dependencie on demand from registries. This requires
// TODO: Support non-uber jars by downloading dependencies on demand from registries. This requires
// a reliable index of class -> Maven jar mappings for the entire Maven universe.
func main() {
classPath := flag.String("classpath", "", "A single directory containing Java class files with a directory structure that mirrors the package hierarchy.")
Expand All @@ -52,35 +58,49 @@ func main() {
os.Exit(1)
}

classes, err := EnumerateReachabilityFromClass(arg, *classPath)
result, err := EnumerateReachabilityFromClass(arg, *classPath)
if err != nil {
slog.Error("Failed to enumerate reachability for", "class", arg, "error", err)
os.Exit(1)
}

for _, class := range classes {
for _, class := range result.Classes {
slog.Info("Reachable", "class", class)
}
}
}
}

func fmtJavaInventory(i *extractor.Inventory) string {
return fmt.Sprintf("%s:%s", i.Metadata.(*archive.Metadata).GroupID, i.Name)
}

func enumerateReachabilityForJar(jarPath string) error {
jarfile, err := os.Open(jarPath)
if err != nil {
return err
}

// Extract dependencies from the .jar (from META-INF/maven/**/pom.properties)
allDeps, err := javareach.ExtractDependencies(jarfile)
if err != nil {
return err
}
slices.SortFunc(allDeps, func(i1 *extractor.Inventory, i2 *extractor.Inventory) int {
return strings.Compare(fmtJavaInventory(i1), fmtJavaInventory(i2))
})
for _, dep := range allDeps {
slog.Debug("extracted dep",
"group id", dep.Metadata.(*archive.Metadata).GroupID, "artifact id", dep.Name, "version", dep.Version)
}

// Build .class -> Maven group ID:artifact ID mappings.
classFinder, err := javareach.NewDefaultPackageFinder(allDeps)
if err != nil {
return err
}

// Unpack .jar
tmpDir, err := os.MkdirTemp("", "")
if err != nil {
return err
Expand All @@ -93,6 +113,7 @@ func enumerateReachabilityForJar(jarPath string) error {
return err
}

// Extract the main entrypoint.
manifest, err := os.Open(filepath.Join(tmpDir, "META-INF/MANIFEST.MF"))
if err != nil {
return err
Expand All @@ -103,16 +124,19 @@ func enumerateReachabilityForJar(jarPath string) error {
return err
}
slog.Info("Found", "main class", mainClass)
classes, err := EnumerateReachabilityFromClass(mainClass, tmpDir)

// Enumerate reachable classes.
result, err := EnumerateReachabilityFromClass(mainClass, tmpDir)
if err != nil {
return err
}

// Map reachable classes back to Maven group ID:artifact ID.
reachableDeps := map[string]struct{}{}
for _, class := range classes {
for _, class := range result.Classes {
deps, err := classFinder.Find(class)
if err != nil {
slog.Error("Failed to find", "class", class, "error", err)
slog.Error("Failed to find dep mapping", "class", class, "error", err)
continue
}

Expand All @@ -121,20 +145,39 @@ func enumerateReachabilityForJar(jarPath string) error {
}
}

for dep := range reachableDeps {
slog.Info("Reachable", "dep", dep)
// Find Maven deps that use dynamic code loading.
// TODO: consider all declared dependencies of the Maven dependency to be
// reachable. We can find this within uber jars via the META-INF/maven
// directory by parsing pom.xml files, or by querying deps.dev / Maven.
dynamicLoadingDeps := map[string]struct{}{}
slices.Sort(result.UsesDynamicCodeLoading)
for _, class := range result.UsesDynamicCodeLoading {
slog.Info("Found use of dynamic code loading", "class", class)
deps, err := classFinder.Find(class)
if err != nil {
slog.Error("Failed to find dep mapping", "class", class, "error", err)
continue
}
for _, dep := range deps {
dynamicLoadingDeps[dep] = struct{}{}
}
}

for _, dep := range slices.Sorted(maps.Keys(reachableDeps)) {
_, dynamicLoading := dynamicLoadingDeps[dep]
slog.Info("Reachable", "dep", dep, "dynamic code", dynamicLoading)
}

for _, dep := range allDeps {
name := fmt.Sprintf("%s:%s", dep.Metadata.(*archive.Metadata).GroupID, dep.Name)
name := fmtJavaInventory(dep)
if _, ok := reachableDeps[name]; !ok {
slog.Info("Not reachable", "dep", name)
}
}
return nil
}

func EnumerateReachabilityFromClass(mainClass string, classPath string) ([]string, error) {
func EnumerateReachabilityFromClass(mainClass string, classPath string) (*ReachabilityResult, error) {
cf, err := findClass(classPath, mainClass)
if err != nil {
return nil, err
Expand All @@ -160,21 +203,38 @@ func findClass(classPath string, className string) (*javareach.ClassFile, error)
}

// TODO:
// - Detect uses of reflection and dynamic class loading -> Consider all dependencies used.
// - See if we should do a finer grained analysis to only consider referenced
// classes where a method is called/referenced.
func EnumerateReachability(roots []*javareach.ClassFile, classPath string) ([]string, error) {
func EnumerateReachability(roots []*javareach.ClassFile, classPath string) (*ReachabilityResult, error) {
seen := map[string]struct{}{}
codeLoading := map[string]struct{}{}
for _, root := range roots {
if err := enumerateReachability(root, classPath, seen); err != nil {
if err := enumerateReachability(root, classPath, seen, codeLoading); err != nil {
return nil, err
}
}

return slices.Collect(maps.Keys(seen)), nil
return &ReachabilityResult{
Classes: slices.Collect(maps.Keys(seen)),
UsesDynamicCodeLoading: slices.Collect(maps.Keys(codeLoading)),
}, nil
}

func enumerateReachability(cf *javareach.ClassFile, classPath string, seen map[string]struct{}) error {
func isDynamicCodeLoading(method string, descriptor string) bool {
// https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/ClassLoader.html#loadClass(java.lang.String)
if strings.Contains(method, "loadClass") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") {
return true
}

// https://docs.oracle.com/en/java/javase/23/docs/api/java.base/java/lang/Class.html#forName(java.lang.String)
if strings.Contains(method, "forName") && strings.HasSuffix(descriptor, "Ljava/lang/Class;") {
return true
}

return false
}

func enumerateReachability(cf *javareach.ClassFile, classPath string, seen map[string]struct{}, codeLoading map[string]struct{}) error {
thisClass, err := cf.ConstantPoolClass(int(cf.ThisClass))
if err != nil {
return err
Expand All @@ -186,6 +246,23 @@ func enumerateReachability(cf *javareach.ClassFile, classPath string, seen map[s
slog.Debug("Analyzing", "class", thisClass)
seen[thisClass] = struct{}{}

for i, cp := range cf.ConstantPool {
if cp.Type() != javareach.ConstantKindMethodref {
continue
}

_, method, descriptor, err := cf.ConstantPoolMethodref(i)
if err != nil {
return err
}

if isDynamicCodeLoading(method, descriptor) {
slog.Debug("found dynamic class loading", "thisClass", thisClass, "method", method, "descriptor", descriptor)
codeLoading[thisClass] = struct{}{}
break
}
}

for i, cp := range cf.ConstantPool {
if int(cf.ThisClass) == i {
// Don't consider this class itself.
Expand Down Expand Up @@ -233,7 +310,7 @@ func enumerateReachability(cf *javareach.ClassFile, classPath string, seen map[s
continue
}

if err := enumerateReachability(depcf, classPath, seen); err != nil {
if err := enumerateReachability(depcf, classPath, seen, codeLoading); err != nil {
return err
}
}
Expand Down
35 changes: 35 additions & 0 deletions experimental/javareach/javaclass.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,41 @@ func (cf *ClassFile) checkIndex(idx int) error {
return nil
}

func (cf *ClassFile) ConstantPoolMethodref(idx int) (class string, method string, descriptor string, err error) {
err = cf.checkIndex(idx)
if err != nil {
return
}

if cf.ConstantPool[idx].Type() != ConstantKindMethodref {
err = errors.New("constant pool idx does not point to a method ref")
return
}

methodRef := cf.ConstantPool[idx].(*ConstantMethodref)
class, err = cf.ConstantPoolClass(int(methodRef.ClassIndex))
if err != nil {
return
}

err = cf.checkIndex(int(methodRef.NameAndTypeIndex))
if err != nil {
return
}

nameAndType, ok := cf.ConstantPool[methodRef.NameAndTypeIndex].(*ConstantNameAndType)
if !ok {
err = errors.New("invalid constant name and type")
return
}
method, err = cf.ConstantPoolUtf8(int(nameAndType.NameIndex))
if err != nil {
return
}
descriptor, err = cf.ConstantPoolUtf8(int(nameAndType.DescriptorIndex))
return
}

func (cf *ClassFile) ConstantPoolClass(idx int) (string, error) {
if err := cf.checkIndex(idx); err != nil {
return "", err
Expand Down

0 comments on commit 44f1371

Please sign in to comment.