Skip to content

Commit c57c9d3

Browse files
andrrosscwperks
andauthored
Limit stack walking to frames before AccessController.doPrivileged (#18029)
Signed-off-by: Craig Perkins <cwperx@amazon.com> Signed-off-by: Andrew Ross <andrross@amazon.com> Co-authored-by: Craig Perkins <cwperx@amazon.com>
1 parent d18982c commit c57c9d3

File tree

3 files changed

+124
-4
lines changed

3 files changed

+124
-4
lines changed

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/SocketChannelInterceptor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import org.opensearch.javaagent.bootstrap.AgentPolicy;
1212

13-
import java.lang.StackWalker.Option;
1413
import java.lang.reflect.Method;
1514
import java.net.InetSocketAddress;
1615
import java.net.NetPermission;
@@ -46,7 +45,7 @@ public static void intercept(@Advice.AllArguments Object[] args, @Origin Method
4645
return; /* noop */
4746
}
4847

49-
final StackWalker walker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE);
48+
final StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
5049
final Collection<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);
5150

5251
if (args[0] instanceof InetSocketAddress address) {

libs/agent-sm/agent/src/main/java/org/opensearch/javaagent/StackCallerProtectionDomainChainExtractor.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ private StackCallerProtectionDomainChainExtractor() {}
3535
*/
3636
@Override
3737
public Collection<ProtectionDomain> apply(Stream<StackFrame> frames) {
38-
return frames.map(StackFrame::getDeclaringClass)
38+
return frames.takeWhile(
39+
frame -> !(frame.getClassName().equals("java.security.AccessController") && frame.getMethodName().equals("doPrivileged"))
40+
)
41+
.map(StackFrame::getDeclaringClass)
3942
.map(Class::getProtectionDomain)
40-
.filter(pd -> pd.getCodeSource() != null) /* JDK */
43+
.filter(pd -> pd.getCodeSource() != null) // Filter out JDK classes
4144
.collect(Collectors.toSet());
4245
}
4346
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.javaagent;
10+
11+
import org.junit.Test;
12+
13+
import java.net.URI;
14+
import java.net.URISyntaxException;
15+
import java.nio.file.Path;
16+
import java.nio.file.Paths;
17+
import java.security.AccessController;
18+
import java.security.PrivilegedAction;
19+
import java.security.ProtectionDomain;
20+
import java.util.List;
21+
import java.util.Set;
22+
import java.util.stream.Collectors;
23+
24+
import static org.hamcrest.MatcherAssert.assertThat;
25+
import static org.hamcrest.Matchers.containsInAnyOrder;
26+
import static org.hamcrest.Matchers.hasItem;
27+
import static org.junit.Assert.assertEquals;
28+
29+
public class StackCallerProtectionDomainExtractorTests {
30+
31+
private static List<StackWalker.StackFrame> indirectlyCaptureStackFrames() {
32+
return captureStackFrames();
33+
}
34+
35+
private static List<StackWalker.StackFrame> captureStackFrames() {
36+
// OPTION.RETAIN_CLASS_REFERENCE lets you do f.getDeclaringClass() if you need it
37+
StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
38+
return walker.walk(frames -> frames.collect(Collectors.toList()));
39+
}
40+
41+
@Test
42+
public void testSimpleProtectionDomainExtraction() throws Exception {
43+
StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE;
44+
Set<ProtectionDomain> protectionDomains = (Set<ProtectionDomain>) extractor.apply(captureStackFrames().stream());
45+
assertEquals(7, protectionDomains.size());
46+
List<String> simpleNames = protectionDomains.stream().map(pd -> {
47+
try {
48+
return pd.getCodeSource().getLocation().toURI();
49+
} catch (URISyntaxException e) {
50+
throw new RuntimeException(e);
51+
}
52+
})
53+
.map(URI::getPath)
54+
.map(Paths::get)
55+
.map(Path::getFileName)
56+
.map(Path::toString)
57+
// strip trailing “-VERSION.jar” if present
58+
.map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", ""))
59+
// otherwise strip “.jar”
60+
.map(name -> name.replaceFirst("\\.jar$", ""))
61+
.toList();
62+
assertThat(
63+
simpleNames,
64+
containsInAnyOrder(
65+
"gradle-worker",
66+
"gradle-worker-main",
67+
"gradle-messaging",
68+
"gradle-testing-base-infrastructure",
69+
"test", // from the build/classes/java/test directory
70+
"junit",
71+
"gradle-testing-jvm-infrastructure"
72+
)
73+
);
74+
}
75+
76+
@Test
77+
public void testIndirectlyCaptureStackFramesInListOfFrames() throws Exception {
78+
List<StackWalker.StackFrame> stackFrames = indirectlyCaptureStackFrames();
79+
List<String> methodNames = stackFrames.stream().map(StackWalker.StackFrame::getMethodName).toList();
80+
assertThat(methodNames, hasItem("indirectlyCaptureStackFrames"));
81+
}
82+
83+
@Test
84+
@SuppressWarnings("removal")
85+
public void testStackTruncationWithAccessController() throws Exception {
86+
AccessController.doPrivileged(new PrivilegedAction<Void>() {
87+
@Override
88+
public Void run() {
89+
StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE;
90+
Set<ProtectionDomain> protectionDomains = (Set<ProtectionDomain>) extractor.apply(captureStackFrames().stream());
91+
assertEquals(1, protectionDomains.size());
92+
List<String> simpleNames = protectionDomains.stream().map(pd -> {
93+
try {
94+
return pd.getCodeSource().getLocation().toURI();
95+
} catch (URISyntaxException e) {
96+
throw new RuntimeException(e);
97+
}
98+
})
99+
.map(URI::getPath)
100+
.map(Paths::get)
101+
.map(Path::getFileName)
102+
.map(Path::toString)
103+
// strip trailing “-VERSION.jar” if present
104+
.map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", ""))
105+
// otherwise strip “.jar”
106+
.map(name -> name.replaceFirst("\\.jar$", ""))
107+
.toList();
108+
assertThat(
109+
simpleNames,
110+
containsInAnyOrder(
111+
"test" // from the build/classes/java/test directory
112+
)
113+
);
114+
return null;
115+
}
116+
});
117+
}
118+
}

0 commit comments

Comments
 (0)