cmd/compile: reorg equality functions a bit

Use signature for closure name instead of type.
Use signature instead of type to decide to use a runtime builtin comparator.
Remove trailing skips from signatures.

Change-Id: I73b2dcd3c6e2f1b2857985e14c24b290941b3ca3
Reviewed-on: https://go-review.googlesource.com/c/go/+/725604
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Cuong Manh Le <cuong.manhle.vn@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Keith Randall <khr@google.com>
This commit is contained in:
Keith Randall
2025-12-01 15:58:57 -08:00
parent 30dff416e4
commit 6eec9bcdb2

View File

@@ -299,61 +299,83 @@ func geneq(t *types.Type) *obj.LSym {
// The runtime will panic if it tries to compare
// a type with a nil equality function.
return nil
case types.AMEM0:
}
return geneqSig(eqSignature(t))
}
// geneqSig returns a symbol which is the closure used to compute
// equality for two objects with equality signature sig.
func geneqSig(sig string) *obj.LSym {
align := int64(types.PtrSize)
if len(sig) > 0 && sig[0] == sigAlign {
align, sig = parseNum(sig[1:])
}
if base.Ctxt.Arch.CanMergeLoads {
align = 8
}
switch sig {
case "":
return sysClosure("memequal0")
case types.AMEM8:
case string(sigMemory) + "1":
return sysClosure("memequal8")
case types.AMEM16:
return sysClosure("memequal16")
case types.AMEM32:
return sysClosure("memequal32")
case types.AMEM64:
return sysClosure("memequal64")
case types.AMEM128:
return sysClosure("memequal128")
case types.ASTRING:
case string(sigMemory) + "2":
if align >= 2 {
return sysClosure("memequal16")
}
case string(sigMemory) + "4":
if align >= 4 {
return sysClosure("memequal32")
}
case string(sigMemory) + "8":
if align >= 8 {
return sysClosure("memequal64")
}
case string(sigMemory) + "16":
if align >= 8 {
return sysClosure("memequal128")
}
case string(sigString):
return sysClosure("strequal")
case types.AINTER:
case string(sigIface):
return sysClosure("interequal")
case types.ANILINTER:
case string(sigEface):
return sysClosure("nilinterequal")
case types.AFLOAT32:
case string(sigFloat32):
return sysClosure("f32equal")
case types.AFLOAT64:
case string(sigFloat64):
return sysClosure("f64equal")
case types.ACPLX64:
case string(sigFloat32) + string(sigFloat32):
return sysClosure("c64equal")
case types.ACPLX128:
case string(sigFloat64) + string(sigFloat64):
return sysClosure("c128equal")
case types.AMEM:
// make equality closure. The size of the type
// is encoded in the closure.
closure := TypeLinksymLookup(fmt.Sprintf(".eqfunc%d", t.Size()))
if len(closure.P) != 0 {
return closure
}
if memequalvarlen == nil {
memequalvarlen = typecheck.LookupRuntimeFunc("memequal_varlen")
}
ot := 0
ot = objw.SymPtr(closure, ot, memequalvarlen, 0)
ot = objw.Uintptr(closure, ot, uint64(t.Size()))
objw.Global(closure, int32(ot), obj.DUPOK|obj.RODATA)
return closure
case types.ASPECIAL:
break
}
closure := TypeLinksymPrefix(".eqfunc", t)
closure := TypeLinksymLookup(".eqfunc." + sig)
if len(closure.P) > 0 { // already generated
return closure
}
if base.Flag.LowerR != 0 {
fmt.Printf("geneq %v\n", t)
if sig[0] == sigMemory {
n, rest := parseNum(sig[1:])
if rest == "" {
// Just M%d. We can make a memequal_varlen closure.
// The size of the memory region to compare is encoded in the closure.
if memequalvarlen == nil {
memequalvarlen = typecheck.LookupRuntimeFunc("memequal_varlen")
}
ot := 0
ot = objw.SymPtr(closure, ot, memequalvarlen, 0)
ot = objw.Uintptr(closure, ot, uint64(n))
objw.Global(closure, int32(ot), obj.DUPOK|obj.RODATA)
return closure
}
}
fn := eqFunc(eqSignature(t))
if base.Flag.LowerR != 0 {
fmt.Printf("geneqSig %s\n", sig)
}
fn := eqFunc(sig)
// Generate a closure which points at the function we just generated.
objw.SymPtr(closure, 0, fn.Linksym(), 0)
@@ -572,7 +594,7 @@ func eqFunc(sig string) *ir.Func {
// for i := off; i < off + N*elemSize; i += elemSize {
// if !eqfn(p+i, q+i) { goto neq }
// }
elemFn := eqFunc(elemSig).Nname
elemFn := eqFunc(sigTrimSkip(elemSig)).Nname
idx := typecheck.TempAt(pos, ir.CurFunc, types.Types[types.TUINTPTR])
init := ir.NewAssignStmt(pos, idx, ir.NewInt(pos, off))
cond := ir.NewBinaryExpr(pos, ir.OLT, idx, ir.NewInt(pos, off+n*elemSize))
@@ -702,6 +724,7 @@ func hashmem(t *types.Type) ir.Node {
// An alignment directive is only needed on platforms that can't do
// unaligned loads.
// If an alignment directive is present, it must be first.
// Signatures can end early; a K%d is not required at the end.
func eqSignature(t *types.Type) string {
var e eqSigBuilder
if !base.Ctxt.Arch.CanMergeLoads { // alignment only matters if we can't use unaligned loads
@@ -710,7 +733,7 @@ func eqSignature(t *types.Type) string {
}
}
e.build(t)
e.flush()
e.flush(true)
return e.r.String()
}
@@ -733,46 +756,48 @@ type eqSigBuilder struct {
skipMem int64 // queued up region of memory to skip
}
func (e *eqSigBuilder) flush() {
func (e *eqSigBuilder) flush(atEnd bool) {
if e.regMem > 0 {
e.r.WriteString(fmt.Sprintf("%c%d", sigMemory, e.regMem))
e.regMem = 0
}
if e.skipMem > 0 {
e.r.WriteString(fmt.Sprintf("%c%d", sigSkip, e.skipMem))
if !atEnd {
e.r.WriteString(fmt.Sprintf("%c%d", sigSkip, e.skipMem))
}
e.skipMem = 0
}
}
func (e *eqSigBuilder) regular(n int64) {
if e.regMem == 0 {
e.flush()
e.flush(false)
}
e.regMem += n
}
func (e *eqSigBuilder) skip(n int64) {
if e.skipMem == 0 {
e.flush()
e.flush(false)
}
e.skipMem += n
}
func (e *eqSigBuilder) float32() {
e.flush()
e.flush(false)
e.r.WriteByte(sigFloat32)
}
func (e *eqSigBuilder) float64() {
e.flush()
e.flush(false)
e.r.WriteByte(sigFloat64)
}
func (e *eqSigBuilder) string() {
e.flush()
e.flush(false)
e.r.WriteByte(sigString)
}
func (e *eqSigBuilder) eface() {
e.flush()
e.flush(false)
e.r.WriteByte(sigEface)
}
func (e *eqSigBuilder) iface() {
e.flush()
e.flush(false)
e.r.WriteByte(sigIface)
}
@@ -865,12 +890,12 @@ func (e *eqSigBuilder) build(t *types.Type) {
}
break
}
e.flush()
e.flush(false)
e.r.WriteString(fmt.Sprintf("%c%d", sigArrayStart, n/unroll))
for range unroll {
e.build(et)
}
e.flush()
e.flush(false)
e.r.WriteByte(sigArrayEnd)
}
default:
@@ -937,3 +962,16 @@ func sigSize(sig string) int64 {
}
return size
}
func sigTrimSkip(s string) string {
i := strings.LastIndexByte(s, sigSkip)
if i < 0 {
return s
}
for j := i + 1; j < len(s); j++ {
if s[j] < '0' || s[j] > '9' {
return s
}
}
return s[:i]
}