diff --git a/windows/svc/security.go b/windows/svc/security.go index 8cc67784..ef719c17 100644 --- a/windows/svc/security.go +++ b/windows/svc/security.go @@ -7,8 +7,8 @@ package svc import ( - "errors" - "syscall" + "path/filepath" + "strings" "unsafe" "golang.org/x/sys/windows" @@ -64,101 +64,45 @@ func IsAnInteractiveSession() (bool, error) { return false, nil } -var ( - ntdll = windows.NewLazySystemDLL("ntdll.dll") - _NtQueryInformationProcess = ntdll.NewProc("NtQueryInformationProcess") - - kernel32 = windows.NewLazySystemDLL("kernel32.dll") - _QueryFullProcessImageNameA = kernel32.NewProc("QueryFullProcessImageNameA") -) - // IsWindowsService reports whether the process is currently executing // as a Windows service. func IsWindowsService() (bool, error) { - // This code was copied from runtime.isWindowsService function. - // The below technique looks a bit hairy, but it's actually // exactly what the .NET framework does for the similarly named function: // https://github.com/dotnet/extensions/blob/f4066026ca06984b07e90e61a6390ac38152ba93/src/Hosting/WindowsServices/src/WindowsServiceHelpers.cs#L26-L31 // Specifically, it looks up whether the parent process has session ID zero // and is called "services". - const _CURRENT_PROCESS = ^uintptr(0) - // pbi is a PROCESS_BASIC_INFORMATION struct, where we just care about - // the 6th pointer inside of it, which contains the pid of the process - // parent: - // https://github.com/wine-mirror/wine/blob/42cb7d2ad1caba08de235e6319b9967296b5d554/include/winternl.h#L1294 - var pbi [6]uintptr - var pbiLen uint32 - r0, _, _ := syscall.Syscall6(_NtQueryInformationProcess.Addr(), 5, _CURRENT_PROCESS, 0, uintptr(unsafe.Pointer(&pbi[0])), uintptr(unsafe.Sizeof(pbi)), uintptr(unsafe.Pointer(&pbiLen)), 0) - if r0 != 0 { - return false, errors.New("NtQueryInformationProcess failed: error=" + itoa(int(r0))) - } - var psid uint32 - err := windows.ProcessIdToSessionId(uint32(pbi[5]), &psid) + + var pbi windows.PROCESS_BASIC_INFORMATION + pbiLen := uint32(unsafe.Sizeof(pbi)) + err := windows.NtQueryInformationProcess(windows.CurrentProcess(), windows.ProcessBasicInformation, unsafe.Pointer(&pbi), pbiLen, &pbiLen) if err != nil { - if err == windows.ERROR_INVALID_PARAMETER { - // This error happens when Windows cannot find process parent. - // Perhaps process parent exited. - // Assume we are not running in a service, because service - // parent process (services.exe) cannot exit. - return false, nil - } return false, err } - if psid != 0 { - // parent session id should be 0 for service process + var psid uint32 + err = windows.ProcessIdToSessionId(uint32(pbi.InheritedFromUniqueProcessId), &psid) + if err != nil || psid != 0 { return false, nil } - - pproc, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pbi[5])) + pproc, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pbi.InheritedFromUniqueProcessId)) if err != nil { return false, err } defer windows.CloseHandle(pproc) - - // exeName gets the path to the executable image of the parent process - var exeName [261]byte - exeNameLen := uint32(len(exeName) - 1) - r0, _, e0 := syscall.Syscall6(_QueryFullProcessImageNameA.Addr(), 4, uintptr(pproc), 0, uintptr(unsafe.Pointer(&exeName[0])), uintptr(unsafe.Pointer(&exeNameLen)), 0, 0) - if r0 == 0 { - if e0 != 0 { - return false, e0 - } else { - return false, syscall.EINVAL - } + var exeNameBuf [261]uint16 + exeNameLen := uint32(len(exeNameBuf) - 1) + err = windows.QueryFullProcessImageName(pproc, 0, &exeNameBuf[0], &exeNameLen) + if err != nil { + return false, err } - const ( - servicesLower = "services.exe" - servicesUpper = "SERVICES.EXE" - ) - i := int(exeNameLen) - 1 - j := len(servicesLower) - 1 - if i < j { + exeName := windows.UTF16ToString(exeNameBuf[:exeNameLen]) + if !strings.EqualFold(filepath.Base(exeName), "services.exe") { return false, nil } - for { - if j == -1 { - return i == -1 || exeName[i] == '\\', nil - } - if exeName[i] != servicesLower[j] && exeName[i] != servicesUpper[j] { - return false, nil - } - i-- - j-- + system32, err := windows.GetSystemDirectory() + if err != nil { + return false, err } -} - -func itoa(val int) string { // do it here rather than with fmt to avoid dependency - if val < 0 { - return "-" + itoa(-val) - } - var buf [32]byte // big enough for int64 - i := len(buf) - 1 - for val >= 10 { - buf[i] = byte(val%10 + '0') - i-- - val /= 10 - } - buf[i] = byte(val + '0') - return string(buf[i:]) + targetExeName := filepath.Join(system32, "services.exe") + return strings.EqualFold(exeName, targetExeName), nil }