From 87e55d71481061dc6dcfb9a4953c896af893c130 Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Wed, 15 Jun 2022 17:34:44 -0700 Subject: [PATCH] unix: add RecvmsgBuffers and SendmsgBuffers Fixes golang/go#52885 Change-Id: I04b5be1ac9543a3791ebc4cd59b9e35e958e0ba2 Reviewed-on: https://go-review.googlesource.com/c/sys/+/412497 Auto-Submit: Tobias Klauser Reviewed-by: Tobias Klauser Reviewed-by: Ian Lance Taylor Reviewed-by: Carlos Amedee Auto-Submit: Ian Lance Taylor TryBot-Result: Gopher Robot Run-TryBot: Ian Lance Taylor --- unix/syscall_aix.go | 4 +-- unix/syscall_bsd.go | 46 ++++++++++++------------ unix/syscall_linux.go | 45 ++++++++++++------------ unix/syscall_solaris.go | 46 ++++++++++++------------ unix/syscall_unix.go | 74 +++++++++++++++++++++++++++++++++++++-- unix/syscall_unix_test.go | 64 +++++++++++++++++++++++++++++++++ 6 files changed, 206 insertions(+), 73 deletions(-) diff --git a/unix/syscall_aix.go b/unix/syscall_aix.go index ad22c33d..ac579c60 100644 --- a/unix/syscall_aix.go +++ b/unix/syscall_aix.go @@ -217,12 +217,12 @@ func Accept(fd int) (nfd int, sa Sockaddr, err error) { return } -func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { // Recvmsg not implemented on AIX return -1, -1, -1, ENOSYS } -func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { // SendmsgN not implemented on AIX return -1, ENOSYS } diff --git a/unix/syscall_bsd.go b/unix/syscall_bsd.go index 9c87c5f0..c437fc5d 100644 --- a/unix/syscall_bsd.go +++ b/unix/syscall_bsd.go @@ -325,27 +325,26 @@ func GetsockoptString(fd, level, opt int) (string, error) { //sys sendto(s int, buf []byte, flags int, to unsafe.Pointer, addrlen _Socklen) (err error) //sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error) -func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { var msg Msghdr msg.Name = (*byte)(unsafe.Pointer(rsa)) msg.Namelen = uint32(SizeofSockaddrAny) - var iov Iovec - if len(p) > 0 { - iov.Base = (*byte)(unsafe.Pointer(&p[0])) - iov.SetLen(len(p)) - } var dummy byte if len(oob) > 0 { // receive at least one normal byte - if len(p) == 0 { - iov.Base = &dummy - iov.SetLen(1) + if emptyIovecs(iov) { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] } msg.Control = (*byte)(unsafe.Pointer(&oob[0])) msg.SetControllen(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = recvmsg(fd, &msg, flags); err != nil { return } @@ -356,31 +355,32 @@ func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn //sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error) -func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { var msg Msghdr msg.Name = (*byte)(unsafe.Pointer(ptr)) msg.Namelen = uint32(salen) - var iov Iovec - if len(p) > 0 { - iov.Base = (*byte)(unsafe.Pointer(&p[0])) - iov.SetLen(len(p)) - } var dummy byte + var empty bool if len(oob) > 0 { // send at least one normal byte - if len(p) == 0 { - iov.Base = &dummy - iov.SetLen(1) + empty := emptyIovecs(iov) + if empty { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] } msg.Control = (*byte)(unsafe.Pointer(&oob[0])) msg.SetControllen(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = sendmsg(fd, &msg, flags); err != nil { return 0, err } - if len(oob) > 0 && len(p) == 0 { + if len(oob) > 0 && empty { n = 0 } return n, nil diff --git a/unix/syscall_linux.go b/unix/syscall_linux.go index c8d20321..5e4a94f7 100644 --- a/unix/syscall_linux.go +++ b/unix/syscall_linux.go @@ -1499,18 +1499,13 @@ func KeyctlRestrictKeyring(ringid int, keyType string, restriction string) error //sys keyctlRestrictKeyringByType(cmd int, arg2 int, keyType string, restriction string) (err error) = SYS_KEYCTL //sys keyctlRestrictKeyring(cmd int, arg2 int) (err error) = SYS_KEYCTL -func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { var msg Msghdr msg.Name = (*byte)(unsafe.Pointer(rsa)) msg.Namelen = uint32(SizeofSockaddrAny) - var iov Iovec - if len(p) > 0 { - iov.Base = &p[0] - iov.SetLen(len(p)) - } var dummy byte if len(oob) > 0 { - if len(p) == 0 { + if emptyIovecs(iov) { var sockType int sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) if err != nil { @@ -1518,15 +1513,19 @@ func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn } // receive at least one normal byte if sockType != SOCK_DGRAM { - iov.Base = &dummy - iov.SetLen(1) + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] } } msg.Control = &oob[0] msg.SetControllen(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = recvmsg(fd, &msg, flags); err != nil { return } @@ -1535,18 +1534,15 @@ func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn return } -func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { var msg Msghdr msg.Name = (*byte)(ptr) msg.Namelen = uint32(salen) - var iov Iovec - if len(p) > 0 { - iov.Base = &p[0] - iov.SetLen(len(p)) - } var dummy byte + var empty bool if len(oob) > 0 { - if len(p) == 0 { + empty := emptyIovecs(iov) + if empty { var sockType int sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) if err != nil { @@ -1554,19 +1550,22 @@ func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags i } // send at least one normal byte if sockType != SOCK_DGRAM { - iov.Base = &dummy - iov.SetLen(1) + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) } } msg.Control = &oob[0] msg.SetControllen(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = sendmsg(fd, &msg, flags); err != nil { return 0, err } - if len(oob) > 0 && len(p) == 0 { + if len(oob) > 0 && empty { n = 0 } return n, nil diff --git a/unix/syscall_solaris.go b/unix/syscall_solaris.go index cd492d7a..b5ec457c 100644 --- a/unix/syscall_solaris.go +++ b/unix/syscall_solaris.go @@ -451,26 +451,25 @@ func Accept(fd int) (nfd int, sa Sockaddr, err error) { //sys recvmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_recvmsg -func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { +func recvmsgRaw(fd int, iov []Iovec, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn int, recvflags int, err error) { var msg Msghdr msg.Name = (*byte)(unsafe.Pointer(rsa)) msg.Namelen = uint32(SizeofSockaddrAny) - var iov Iovec - if len(p) > 0 { - iov.Base = &p[0] - iov.SetLen(len(p)) - } var dummy byte if len(oob) > 0 { // receive at least one normal byte - if len(p) == 0 { - iov.Base = &dummy - iov.SetLen(1) + if emptyIovecs(iov) { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] } msg.Accrightslen = int32(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = recvmsg(fd, &msg, flags); n == -1 { return } @@ -480,30 +479,31 @@ func recvmsgRaw(fd int, p, oob []byte, flags int, rsa *RawSockaddrAny) (n, oobn //sys sendmsg(s int, msg *Msghdr, flags int) (n int, err error) = libsocket.__xnet_sendmsg -func sendmsgN(fd int, p, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { +func sendmsgN(fd int, iov []Iovec, oob []byte, ptr unsafe.Pointer, salen _Socklen, flags int) (n int, err error) { var msg Msghdr msg.Name = (*byte)(unsafe.Pointer(ptr)) msg.Namelen = uint32(salen) - var iov Iovec - if len(p) > 0 { - iov.Base = &p[0] - iov.SetLen(len(p)) - } var dummy byte + var empty bool if len(oob) > 0 { // send at least one normal byte - if len(p) == 0 { - iov.Base = &dummy - iov.SetLen(1) + empty = emptyIovecs(iov) + if empty { + var iova [1]Iovec + iova[0].Base = &dummy + iova[0].SetLen(1) + iov = iova[:] } msg.Accrightslen = int32(len(oob)) } - msg.Iov = &iov - msg.Iovlen = 1 + if len(iov) > 0 { + msg.Iov = &iov[0] + msg.SetIovlen(len(iov)) + } if n, err = sendmsg(fd, &msg, flags); err != nil { return 0, err } - if len(oob) > 0 && len(p) == 0 { + if len(oob) > 0 && empty { n = 0 } return n, nil diff --git a/unix/syscall_unix.go b/unix/syscall_unix.go index 70508afc..1ff5060b 100644 --- a/unix/syscall_unix.go +++ b/unix/syscall_unix.go @@ -338,8 +338,13 @@ func Recvfrom(fd int, p []byte, flags int) (n int, from Sockaddr, err error) { } func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) { + var iov [1]Iovec + if len(p) > 0 { + iov[0].Base = &p[0] + iov[0].SetLen(len(p)) + } var rsa RawSockaddrAny - n, oobn, recvflags, err = recvmsgRaw(fd, p, oob, flags, &rsa) + n, oobn, recvflags, err = recvmsgRaw(fd, iov[:], oob, flags, &rsa) // source address is only specified if the socket is unconnected if rsa.Addr.Family != AF_UNSPEC { from, err = anyToSockaddr(fd, &rsa) @@ -347,12 +352,42 @@ func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from return } +// RecvmsgBuffers receives a message from a socket using the recvmsg +// system call. The flags are passed to recvmsg. Any non-control data +// read is scattered into the buffers slices. The results are: +// - n is the number of non-control data read into bufs +// - oobn is the number of control data read into oob; this may be interpreted using [ParseSocketControlMessage] +// - recvflags is flags returned by recvmsg +// - from is the address of the sender +func RecvmsgBuffers(fd int, buffers [][]byte, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) { + iov := make([]Iovec, len(buffers)) + for i := range buffers { + if len(buffers[i]) > 0 { + iov[i].Base = &buffers[i][0] + iov[i].SetLen(len(buffers[i])) + } else { + iov[i].Base = (*byte)(unsafe.Pointer(&_zero)) + } + } + var rsa RawSockaddrAny + n, oobn, recvflags, err = recvmsgRaw(fd, iov, oob, flags, &rsa) + if err == nil && rsa.Addr.Family != AF_UNSPEC { + from, err = anyToSockaddr(fd, &rsa) + } + return +} + func Sendmsg(fd int, p, oob []byte, to Sockaddr, flags int) (err error) { _, err = SendmsgN(fd, p, oob, to, flags) return } func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) { + var iov [1]Iovec + if len(p) > 0 { + iov[0].Base = &p[0] + iov[0].SetLen(len(p)) + } var ptr unsafe.Pointer var salen _Socklen if to != nil { @@ -361,7 +396,32 @@ func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) return 0, err } } - return sendmsgN(fd, p, oob, ptr, salen, flags) + return sendmsgN(fd, iov[:], oob, ptr, salen, flags) +} + +// SendmsgBuffers sends a message on a socket to an address using the sendmsg +// system call. The flags are passed to sendmsg. Any non-control data written +// is gathered from buffers. The function returns the number of bytes written +// to the socket. +func SendmsgBuffers(fd int, buffers [][]byte, oob []byte, to Sockaddr, flags int) (n int, err error) { + iov := make([]Iovec, len(buffers)) + for i := range buffers { + if len(buffers[i]) > 0 { + iov[i].Base = &buffers[i][0] + iov[i].SetLen(len(buffers[i])) + } else { + iov[i].Base = (*byte)(unsafe.Pointer(&_zero)) + } + } + var ptr unsafe.Pointer + var salen _Socklen + if to != nil { + ptr, salen, err = to.sockaddr() + if err != nil { + return 0, err + } + } + return sendmsgN(fd, iov, oob, ptr, salen, flags) } func Send(s int, buf []byte, flags int) (err error) { @@ -484,3 +544,13 @@ func Lutimes(path string, tv []Timeval) error { } return UtimesNanoAt(AT_FDCWD, path, ts, AT_SYMLINK_NOFOLLOW) } + +// emptyIovec reports whether there are no bytes in the slice of Iovec. +func emptyIovecs(iov []Iovec) bool { + for i := range iov { + if iov[i].Len > 0 { + return false + } + } + return true +} diff --git a/unix/syscall_unix_test.go b/unix/syscall_unix_test.go index 81db934e..c20afbef 100644 --- a/unix/syscall_unix_test.go +++ b/unix/syscall_unix_test.go @@ -16,8 +16,10 @@ import ( "os" "os/exec" "path/filepath" + "reflect" "runtime" "strconv" + "sync" "syscall" "testing" "time" @@ -954,6 +956,68 @@ func TestSend(t *testing.T) { } } +func TestSendmsgBuffers(t *testing.T) { + if runtime.GOOS == "aix" { + t.Skipf("SendmsgBuffers not supported on %s", runtime.GOOS) + } + + fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0) + if err != nil { + t.Fatal(err) + } + defer unix.Close(fds[0]) + defer unix.Close(fds[1]) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + bufs := [][]byte{ + make([]byte, 5), + nil, + make([]byte, 5), + } + n, oobn, recvflags, _, err := unix.RecvmsgBuffers(fds[1], bufs, nil, 0) + if err != nil { + t.Fatal(err) + } + if n != 10 { + t.Errorf("got %d bytes, want 10", n) + } + if oobn != 0 { + t.Errorf("got %d OOB bytes, want 0", oobn) + } + if recvflags != 0 { + t.Errorf("got flags %#x, want %#x", recvflags, 0) + } + want := [][]byte{ + []byte("01234"), + nil, + []byte("56789"), + } + if !reflect.DeepEqual(bufs, want) { + t.Errorf("got data %q, want %q", bufs, want) + } + }() + + defer wg.Wait() + + bufs := [][]byte{ + []byte("012"), + []byte("34"), + nil, + []byte("5678"), + []byte("9"), + } + n, err := unix.SendmsgBuffers(fds[0], bufs, nil, nil, 0) + if err != nil { + t.Fatal(err) + } + if n != 10 { + t.Errorf("sent %d bytes, want 10", n) + } +} + // mktmpfifo creates a temporary FIFO and provides a cleanup function. func mktmpfifo(t *testing.T) (*os.File, func()) { err := unix.Mkfifo("fifo", 0666)