From 71da6904945ac440253cb5c132d64712f80ca497 Mon Sep 17 00:00:00 2001 From: Davis Goodin Date: Mon, 23 Jan 2023 14:48:25 -0800 Subject: [PATCH] windows/mkwinsyscall: support "." and "-" in DLL name This change adds "." and "-" support for DLL filenames in "//sys". Supporting "." requires a change in how mkwinsyscall handles the "= ." syntax. Instead of assuming that only one "." can appear in this string, now mkwinsyscall assumes that any additional "." belongs to the filename. Supporting "." also requires changing how Go identifiers are created for each DLL. This change also allows mkwinsyscall to support "-". When creating a Go identifier, "." and "-" in the DLL filename are replaced with "_". Otherwise, mkwinsyscall would produce invalid Go code, causing "format.Source" to fail. Includes a test for the new behavior. There aren't yet any cases where this code is executed while generating the x/sys/windows syscalls. The syscalls "SetSocketMediaStreamingMode" from "windows.networking.dll" and "WslRegisterDistribution" from "api-ms-win-wsl-api-l1-1-0.dll" can be successfully called using this change, but these syscalls have no known use in Go so they are not included in this change. Fixes golang/go#57913 Change-Id: If64deeb8c7738d61520e7392fd2d81ef8920f08d Reviewed-on: https://go-review.googlesource.com/c/sys/+/463215 TryBot-Result: Gopher Robot Reviewed-by: Alex Brainman Reviewed-by: Michael Knyszek Run-TryBot: Alex Brainman Reviewed-by: Quim Muntal Reviewed-by: Bryan Mills --- windows/mkwinsyscall/mkwinsyscall.go | 62 ++++++++++++++++------- windows/mkwinsyscall/mkwinsyscall_test.go | 50 ++++++++++++++++++ 2 files changed, 93 insertions(+), 19 deletions(-) create mode 100644 windows/mkwinsyscall/mkwinsyscall_test.go diff --git a/windows/mkwinsyscall/mkwinsyscall.go b/windows/mkwinsyscall/mkwinsyscall.go index b080c539..7fe4efa9 100644 --- a/windows/mkwinsyscall/mkwinsyscall.go +++ b/windows/mkwinsyscall/mkwinsyscall.go @@ -480,15 +480,14 @@ func newFn(s string) (*Fn, error) { return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") } s = trim(s[1:]) - a := strings.Split(s, ".") - switch len(a) { - case 1: - f.dllfuncname = a[0] - case 2: - f.dllname = a[0] - f.dllfuncname = a[1] - default: - return nil, errors.New("Could not extract dll name from \"" + f.src + "\"") + if i := strings.LastIndex(s, "."); i >= 0 { + f.dllname = s[:i] + f.dllfuncname = s[i+1:] + } else { + f.dllfuncname = s + } + if f.dllfuncname == "" { + return nil, fmt.Errorf("function name is not specified in %q", s) } if n := f.dllfuncname; strings.HasSuffix(n, "?") { f.dllfuncname = n[:len(n)-1] @@ -505,7 +504,23 @@ func (f *Fn) DLLName() string { return f.dllname } -// DLLName returns DLL function name for function f. +// DLLVar returns a valid Go identifier that represents DLLName. +func (f *Fn) DLLVar() string { + id := strings.Map(func(r rune) rune { + switch r { + case '.', '-': + return '_' + default: + return r + } + }, f.DLLName()) + if !token.IsIdentifier(id) { + panic(fmt.Errorf("could not create Go identifier for DLLName %q", f.DLLName())) + } + return id +} + +// DLLFuncName returns DLL function name for function f. func (f *Fn) DLLFuncName() string { if f.dllfuncname == "" { return f.Name @@ -650,6 +665,13 @@ func (f *Fn) HelperName() string { return "_" + f.Name } +// DLL is a DLL's filename and a string that is valid in a Go identifier that should be used when +// naming a variable that refers to the DLL. +type DLL struct { + Name string + Var string +} + // Source files and functions. type Source struct { Funcs []*Fn @@ -699,17 +721,19 @@ func ParseFiles(fs []string) (*Source, error) { } // DLLs return dll names for a source set src. -func (src *Source) DLLs() []string { +func (src *Source) DLLs() []DLL { uniq := make(map[string]bool) - r := make([]string, 0) + r := make([]DLL, 0) for _, f := range src.Funcs { - name := f.DLLName() - if _, found := uniq[name]; !found { - uniq[name] = true - r = append(r, name) + id := f.DLLVar() + if _, found := uniq[id]; !found { + uniq[id] = true + r = append(r, DLL{f.DLLName(), id}) } } - sort.Strings(r) + sort.Slice(r, func(i, j int) bool { + return r[i].Var < r[j].Var + }) return r } @@ -936,10 +960,10 @@ var ( {{/* help functions */}} -{{define "dlls"}}{{range .DLLs}} mod{{.}} = {{newlazydll .}} +{{define "dlls"}}{{range .DLLs}} mod{{.Var}} = {{newlazydll .Name}} {{end}}{{end}} -{{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLName}}.NewProc("{{.DLLFuncName}}") +{{define "funcnames"}}{{range .DLLFuncNames}} proc{{.DLLFuncName}} = mod{{.DLLVar}}.NewProc("{{.DLLFuncName}}") {{end}}{{end}} {{define "helperbody"}} diff --git a/windows/mkwinsyscall/mkwinsyscall_test.go b/windows/mkwinsyscall/mkwinsyscall_test.go new file mode 100644 index 00000000..cabbf403 --- /dev/null +++ b/windows/mkwinsyscall/mkwinsyscall_test.go @@ -0,0 +1,50 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "go/format" + "os" + "path/filepath" + "testing" +) + +func TestDLLFilenameEscaping(t *testing.T) { + tests := []struct { + name string + filename string + }{ + {"no escaping necessary", "kernel32"}, + {"escape period", "windows.networking"}, + {"escape dash", "api-ms-win-wsl-api-l1-1-0"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Write a made-up syscall into a temp file for testing. + const prefix = "package windows\n//sys Example() = " + const suffix = ".Example" + name := filepath.Join(t.TempDir(), "syscall.go") + if err := os.WriteFile(name, []byte(prefix+tt.filename+suffix), 0666); err != nil { + t.Fatal(err) + } + + // Ensure parsing, generating, and formatting run without errors. + // This is good enough to show that escaping is working. + src, err := ParseFiles([]string{name}) + if err != nil { + t.Fatal(err) + } + var buf bytes.Buffer + if err := src.Generate(&buf); err != nil { + t.Fatal(err) + } + if _, err := format.Source(buf.Bytes()); err != nil { + t.Log(buf.String()) + t.Fatal(err) + } + }) + } +}