Skip to content

Commit 7c2a978

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents 34d235b + 13b1261 commit 7c2a978

File tree

10 files changed

+166
-57
lines changed

10 files changed

+166
-57
lines changed

src/cmd/compile/internal/devirtualize/pgo.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ func findHotConcreteCallee(p *pgoir.Profile, caller *ir.Func, call *ir.CallExpr,
741741
hottest = e
742742
}
743743

744-
if hottest == nil {
744+
if hottest == nil || hottest.Weight == 0 {
745745
if base.Debug.PGODebug >= 2 {
746746
fmt.Printf("%v: call %s:%d: no hot callee\n", ir.Line(call), callerName, callOffset)
747747
}

src/cmd/compile/internal/inline/inl.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ func inlineCallCheck(callerfn *ir.Func, call *ir.CallExpr) (bool, bool) {
786786
if call.Op() != ir.OCALLFUNC {
787787
return false, false
788788
}
789-
if call.GoDefer || call.NoInline {
789+
if call.GoDefer {
790790
return false, false
791791
}
792792

src/cmd/compile/internal/ir/expr.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ type CallExpr struct {
191191
KeepAlive []*Name // vars to be kept alive until call returns
192192
IsDDD bool
193193
GoDefer bool // whether this call is part of a go or defer statement
194-
NoInline bool // whether this call must not be inlined
195194
}
196195

197196
func NewCallExpr(pos src.XPos, op Op, fun Node, args []Node) *CallExpr {

src/cmd/compile/internal/ir/scc.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ package ir
2222
// Second, each function becomes two virtual nodes in the graph,
2323
// with numbers n and n+1. We record the function's node number as n
2424
// but search from node n+1. If the search tells us that the component
25-
// number (min) is n+1, we know that this is a trivial component: one function
25+
// number (minVisitGen) is n+1, we know that this is a trivial component: one function
2626
// plus its closures. If the search tells us that the component number is
2727
// n, then there was a path from node n+1 back to node n, meaning that
2828
// the function set is mutually recursive. The escape analysis can be
@@ -70,13 +70,13 @@ func (v *bottomUpVisitor) visit(n *Func) uint32 {
7070
id := v.visitgen
7171
v.nodeID[n] = id
7272
v.visitgen++
73-
min := v.visitgen
73+
minVisitGen := v.visitgen
7474
v.stack = append(v.stack, n)
7575

7676
do := func(defn Node) {
7777
if defn != nil {
78-
if m := v.visit(defn.(*Func)); m < min {
79-
min = m
78+
if m := v.visit(defn.(*Func)); m < minVisitGen {
79+
minVisitGen = m
8080
}
8181
}
8282
}
@@ -97,13 +97,13 @@ func (v *bottomUpVisitor) visit(n *Func) uint32 {
9797
}
9898
})
9999

100-
if (min == id || min == id+1) && !n.IsClosure() {
100+
if (minVisitGen == id || minVisitGen == id+1) && !n.IsClosure() {
101101
// This node is the root of a strongly connected component.
102102

103-
// The original min was id+1. If the bottomUpVisitor found its way
103+
// The original minVisitGen was id+1. If the bottomUpVisitor found its way
104104
// back to id, then this block is a set of mutually recursive functions.
105105
// Otherwise, it's just a lone function that does not recurse.
106-
recursive := min == id
106+
recursive := minVisitGen == id
107107

108108
// Remove connected component from stack and mark v.nodeID so that future
109109
// visits return a large number, which will not affect the caller's min.
@@ -121,5 +121,5 @@ func (v *bottomUpVisitor) visit(n *Func) uint32 {
121121
v.analyze(block, recursive)
122122
}
123123

124-
return min
124+
return minVisitGen
125125
}

src/cmd/compile/internal/test/pgo_devirtualize_test.go

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ const profFileName = "devirt.pprof"
2323
const preProfFileName = "devirt.pprof.node_map"
2424

2525
// testPGODevirtualize tests that specific PGO devirtualize rewrites are performed.
26-
func testPGODevirtualize(t *testing.T, dir string, want []devirtualization, pgoProfileName string) {
26+
func testPGODevirtualize(t *testing.T, dir string, want, nowant []devirtualization, pgoProfileName string) {
2727
testenv.MustHaveGoRun(t)
2828
t.Parallel()
2929

@@ -69,24 +69,32 @@ go 1.21
6969
}
7070

7171
got := make(map[devirtualization]struct{})
72+
gotNoHot := make(map[devirtualization]struct{})
7273

7374
devirtualizedLine := regexp.MustCompile(`(.*): PGO devirtualizing \w+ call .* to (.*)`)
75+
noHotLine := regexp.MustCompile(`(.*): call .*: no hot callee`)
7476

7577
scanner := bufio.NewScanner(pr)
7678
for scanner.Scan() {
7779
line := scanner.Text()
7880
t.Logf("child: %s", line)
7981

8082
m := devirtualizedLine.FindStringSubmatch(line)
81-
if m == nil {
83+
if m != nil {
84+
d := devirtualization{
85+
pos: m[1],
86+
callee: m[2],
87+
}
88+
got[d] = struct{}{}
8289
continue
8390
}
84-
85-
d := devirtualization{
86-
pos: m[1],
87-
callee: m[2],
91+
m = noHotLine.FindStringSubmatch(line)
92+
if m != nil {
93+
d := devirtualization{
94+
pos: m[1],
95+
}
96+
gotNoHot[d] = struct{}{}
8897
}
89-
got[d] = struct{}{}
9098
}
9199
if err := cmd.Wait(); err != nil {
92100
t.Fatalf("error running go test: %v", err)
@@ -104,6 +112,11 @@ go 1.21
104112
}
105113
t.Errorf("devirtualization %v missing; got %v", w, got)
106114
}
115+
for _, nw := range nowant {
116+
if _, ok := gotNoHot[nw]; !ok {
117+
t.Errorf("unwanted devirtualization %v; got %v", nw, got)
118+
}
119+
}
107120

108121
// Run test with PGO to ensure the assertions are still true.
109122
cmd = testenv.CleanCmdEnv(testenv.Command(t, out))
@@ -174,8 +187,18 @@ func TestPGODevirtualize(t *testing.T) {
174187
// callee: "mult.MultClosure.func1",
175188
//},
176189
}
190+
nowant := []devirtualization{
191+
// ExerciseIfaceZeroWeight
192+
{
193+
pos: "./devirt.go:256:29",
194+
},
195+
// ExerciseIndirCallZeroWeight
196+
{
197+
pos: "./devirt.go:282:37",
198+
},
199+
}
177200

178-
testPGODevirtualize(t, dir, want, profFileName)
201+
testPGODevirtualize(t, dir, want, nowant, profFileName)
179202
}
180203

181204
// TestPGOPreprocessDevirtualize tests that specific functions are devirtualized when PGO
@@ -237,8 +260,18 @@ func TestPGOPreprocessDevirtualize(t *testing.T) {
237260
// callee: "mult.MultClosure.func1",
238261
//},
239262
}
263+
nowant := []devirtualization{
264+
// ExerciseIfaceZeroWeight
265+
{
266+
pos: "./devirt.go:256:29",
267+
},
268+
// ExerciseIndirCallZeroWeight
269+
{
270+
pos: "./devirt.go:282:37",
271+
},
272+
}
240273

241-
testPGODevirtualize(t, dir, want, preProfFileName)
274+
testPGODevirtualize(t, dir, want, nowant, preProfFileName)
242275
}
243276

244277
// Regression test for https://go.dev/issue/65615. If a target function changes
@@ -303,8 +336,18 @@ func TestLookupFuncGeneric(t *testing.T) {
303336
// callee: "mult.MultClosure.func1",
304337
//},
305338
}
339+
nowant := []devirtualization{
340+
// ExerciseIfaceZeroWeight
341+
{
342+
pos: "./devirt.go:256:29",
343+
},
344+
// ExerciseIndirCallZeroWeight
345+
{
346+
pos: "./devirt.go:282:37",
347+
},
348+
}
306349

307-
testPGODevirtualize(t, dir, want, profFileName)
350+
testPGODevirtualize(t, dir, want, nowant, profFileName)
308351
}
309352

310353
var multFnRe = regexp.MustCompile(`func MultFn\(a, b int64\) int64`)

src/cmd/compile/internal/test/testdata/pgo/devirtualize/devirt.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,3 +250,45 @@ func ExerciseFuncClosure(iter int, a1, a2 AddFunc, m1, m2 mult.MultFunc) int {
250250
}
251251
return val
252252
}
253+
254+
//go:noinline
255+
func IfaceZeroWeight(a *Add, b Adder) bool {
256+
return a.Add(1, 2) == b.Add(3, 4) // unwanted devirtualization
257+
}
258+
259+
// ExerciseIfaceZeroWeight never calls IfaceZeroWeight, so the callee
260+
// is not expected to appear in the profile.
261+
//
262+
//go:noinline
263+
func ExerciseIfaceZeroWeight() {
264+
if false {
265+
a := &Add{}
266+
b := &Sub{}
267+
// Unreachable call
268+
IfaceZeroWeight(a, b)
269+
}
270+
}
271+
272+
func DirectCall() bool {
273+
return true
274+
}
275+
276+
func IndirectCall() bool {
277+
return false
278+
}
279+
280+
//go:noinline
281+
func IndirCallZeroWeight(indirectCall func() bool) bool {
282+
return DirectCall() && indirectCall() // unwanted devirtualization
283+
}
284+
285+
// ExerciseIndirCallZeroWeight never calls IndirCallZeroWeight, so the
286+
// callee is not expected to appear in the profile.
287+
//
288+
//go:noinline
289+
func ExerciseIndirCallZeroWeight() {
290+
if false {
291+
// Unreachable call
292+
IndirCallZeroWeight(IndirectCall)
293+
}
294+
}

src/cmd/compile/internal/test/testdata/pgo/devirtualize/devirt_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,19 @@ func TestDevirtFuncClosure(t *testing.T) {
7171
t.Errorf("ExerciseFuncClosure(10) got %d want 1176", v)
7272
}
7373
}
74+
75+
func BenchmarkDevirtIfaceZeroWeight(t *testing.B) {
76+
ExerciseIfaceZeroWeight()
77+
}
78+
79+
func TestDevirtIfaceZeroWeight(t *testing.T) {
80+
ExerciseIfaceZeroWeight()
81+
}
82+
83+
func BenchmarkDevirtIndirCallZeroWeight(t *testing.B) {
84+
ExerciseIndirCallZeroWeight()
85+
}
86+
87+
func TestDevirtIndirCallZeroWeight(t *testing.T) {
88+
ExerciseIndirCallZeroWeight()
89+
}

src/crypto/tls/handshake_client_tls13.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"bytes"
99
"context"
1010
"crypto"
11+
"crypto/hkdf"
1112
"crypto/hmac"
12-
"crypto/internal/fips140/hkdf"
1313
"crypto/internal/fips140/mlkem"
1414
"crypto/internal/fips140/tls13"
1515
"crypto/rsa"
@@ -90,12 +90,13 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
9090
confTranscript.Write(hs.serverHello.original[:30])
9191
confTranscript.Write(make([]byte, 8))
9292
confTranscript.Write(hs.serverHello.original[38:])
93-
acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
94-
hkdf.Extract(hs.suite.hash.New, hs.echContext.innerHello.random, nil),
95-
"ech accept confirmation",
96-
confTranscript.Sum(nil),
97-
8,
98-
)
93+
h := hs.suite.hash.New
94+
prk, err := hkdf.Extract(h, hs.echContext.innerHello.random, nil)
95+
if err != nil {
96+
c.sendAlert(alertInternalError)
97+
return err
98+
}
99+
acceptConfirmation := tls13.ExpandLabel(h, prk, "ech accept confirmation", confTranscript.Sum(nil), 8)
99100
if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.random[len(hs.serverHello.random)-8:]) == 1 {
100101
hs.hello = hs.echContext.innerHello
101102
c.serverName = c.config.ServerName
@@ -264,12 +265,13 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
264265
copy(hrrHello, hs.serverHello.original)
265266
hrrHello = bytes.Replace(hrrHello, hs.serverHello.encryptedClientHello, make([]byte, 8), 1)
266267
confTranscript.Write(hrrHello)
267-
acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
268-
hkdf.Extract(hs.suite.hash.New, hs.echContext.innerHello.random, nil),
269-
"hrr ech accept confirmation",
270-
confTranscript.Sum(nil),
271-
8,
272-
)
268+
h := hs.suite.hash.New
269+
prk, err := hkdf.Extract(h, hs.echContext.innerHello.random, nil)
270+
if err != nil {
271+
c.sendAlert(alertInternalError)
272+
return err
273+
}
274+
acceptConfirmation := tls13.ExpandLabel(h, prk, "hrr ech accept confirmation", confTranscript.Sum(nil), 8)
273275
if subtle.ConstantTimeCompare(acceptConfirmation, hs.serverHello.encryptedClientHello) == 1 {
274276
hello = hs.echContext.innerHello
275277
c.serverName = c.config.ServerName

src/crypto/tls/handshake_server_tls13.go

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ import (
88
"bytes"
99
"context"
1010
"crypto"
11+
"crypto/hkdf"
1112
"crypto/hmac"
12-
"crypto/internal/fips140/hkdf"
1313
"crypto/internal/fips140/mlkem"
1414
"crypto/internal/fips140/tls13"
1515
"crypto/internal/hpke"
@@ -572,12 +572,13 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
572572
if err := transcriptMsg(helloRetryRequest, confTranscript); err != nil {
573573
return nil, err
574574
}
575-
acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
576-
hkdf.Extract(hs.suite.hash.New, hs.clientHello.random, nil),
577-
"hrr ech accept confirmation",
578-
confTranscript.Sum(nil),
579-
8,
580-
)
575+
h := hs.suite.hash.New
576+
prf, err := hkdf.Extract(h, hs.clientHello.random, nil)
577+
if err != nil {
578+
c.sendAlert(alertInternalError)
579+
return nil, err
580+
}
581+
acceptConfirmation := tls13.ExpandLabel(h, prf, "hrr ech accept confirmation", confTranscript.Sum(nil), 8)
581582
helloRetryRequest.encryptedClientHello = acceptConfirmation
582583
}
583584

@@ -735,12 +736,13 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
735736
return err
736737
}
737738
// compute the acceptance message
738-
acceptConfirmation := tls13.ExpandLabel(hs.suite.hash.New,
739-
hkdf.Extract(hs.suite.hash.New, hs.clientHello.random, nil),
740-
"ech accept confirmation",
741-
echTranscript.Sum(nil),
742-
8,
743-
)
739+
h := hs.suite.hash.New
740+
prk, err := hkdf.Extract(h, hs.clientHello.random, nil)
741+
if err != nil {
742+
c.sendAlert(alertInternalError)
743+
return err
744+
}
745+
acceptConfirmation := tls13.ExpandLabel(h, prk, "ech accept confirmation", echTranscript.Sum(nil), 8)
744746
copy(hs.hello.random[32-8:], acceptConfirmation)
745747
}
746748

0 commit comments

Comments
 (0)