From bb24a47a89eac6c1227fbcb2ae37a8b9ed323366 Mon Sep 17 00:00:00 2001 From: Luca Bruno Date: Wed, 30 Aug 2017 08:51:02 +0000 Subject: [PATCH] unix: drop dummy byte for oob in unixgram SendmsgN This commit relaxes SendmsgN behavior of introducing a dummy 1-byte payload when sending ancillary-only messages. The fake payload is not needed for SOCK_DGRAM type sockets, and actually breaks interoperability with other fd-passing software (journald is one known example). This introduces an additional check to avoid injecting dummy payload in such case. Backport of https://go.googlesource.com/go/+/93da0b6e66f24c4c307e0df37ceb102a33306174 Full reference at https:/golang.org/issue/6476#issue-51285243 Change-Id: I7cf00a1c7cde75fe62e00b98ccba6ac8469b0493 Reviewed-on: https://go-review.googlesource.com/60190 Reviewed-by: Ian Lance Taylor Run-TryBot: Ian Lance Taylor TryBot-Result: Gobot Gobot --- unix/creds_test.go | 185 +++++++++++++++++++++++------------------- unix/syscall_linux.go | 14 +++- 2 files changed, 112 insertions(+), 87 deletions(-) diff --git a/unix/creds_test.go b/unix/creds_test.go index eaae7c36..7ae33053 100644 --- a/unix/creds_test.go +++ b/unix/creds_test.go @@ -21,101 +21,116 @@ import ( // sockets. The SO_PASSCRED socket option is enabled on the sending // socket for this to work. func TestSCMCredentials(t *testing.T) { - fds, err := unix.Socketpair(unix.AF_LOCAL, unix.SOCK_STREAM, 0) - if err != nil { - t.Fatalf("Socketpair: %v", err) - } - defer unix.Close(fds[0]) - defer unix.Close(fds[1]) - - err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1) - if err != nil { - t.Fatalf("SetsockoptInt: %v", err) + socketTypeTests := []struct { + socketType int + dataLen int + }{ + { + unix.SOCK_STREAM, + 1, + }, { + unix.SOCK_DGRAM, + 0, + }, } - srvFile := os.NewFile(uintptr(fds[0]), "server") - defer srvFile.Close() - srv, err := net.FileConn(srvFile) - if err != nil { - t.Errorf("FileConn: %v", err) - return - } - defer srv.Close() + for _, tt := range socketTypeTests { + fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0) + if err != nil { + t.Fatalf("Socketpair: %v", err) + } + defer unix.Close(fds[0]) + defer unix.Close(fds[1]) - cliFile := os.NewFile(uintptr(fds[1]), "client") - defer cliFile.Close() - cli, err := net.FileConn(cliFile) - if err != nil { - t.Errorf("FileConn: %v", err) - return - } - defer cli.Close() + err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1) + if err != nil { + t.Fatalf("SetsockoptInt: %v", err) + } + + srvFile := os.NewFile(uintptr(fds[0]), "server") + defer srvFile.Close() + srv, err := net.FileConn(srvFile) + if err != nil { + t.Errorf("FileConn: %v", err) + return + } + defer srv.Close() + + cliFile := os.NewFile(uintptr(fds[1]), "client") + defer cliFile.Close() + cli, err := net.FileConn(cliFile) + if err != nil { + t.Errorf("FileConn: %v", err) + return + } + defer cli.Close() + + var ucred unix.Ucred + if os.Getuid() != 0 { + ucred.Pid = int32(os.Getpid()) + ucred.Uid = 0 + ucred.Gid = 0 + oob := unix.UnixCredentials(&ucred) + _, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) + if op, ok := err.(*net.OpError); ok { + err = op.Err + } + if sys, ok := err.(*os.SyscallError); ok { + err = sys.Err + } + if err != syscall.EPERM { + t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err) + } + } - var ucred unix.Ucred - if os.Getuid() != 0 { ucred.Pid = int32(os.Getpid()) - ucred.Uid = 0 - ucred.Gid = 0 + ucred.Uid = uint32(os.Getuid()) + ucred.Gid = uint32(os.Getgid()) oob := unix.UnixCredentials(&ucred) - _, _, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) - if op, ok := err.(*net.OpError); ok { - err = op.Err + + // On SOCK_STREAM, this is internally going to send a dummy byte + n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) + if err != nil { + t.Fatalf("WriteMsgUnix: %v", err) } - if sys, ok := err.(*os.SyscallError); ok { - err = sys.Err + if n != 0 { + t.Fatalf("WriteMsgUnix n = %d, want 0", n) } - if err != syscall.EPERM { - t.Fatalf("WriteMsgUnix failed with %v, want EPERM", err) + if oobn != len(oob) { + t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob)) } - } - ucred.Pid = int32(os.Getpid()) - ucred.Uid = uint32(os.Getuid()) - ucred.Gid = uint32(os.Getgid()) - oob := unix.UnixCredentials(&ucred) + oob2 := make([]byte, 10*len(oob)) + n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2) + if err != nil { + t.Fatalf("ReadMsgUnix: %v", err) + } + if flags != 0 { + t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags) + } + if n != tt.dataLen { + t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen) + } + if oobn2 != oobn { + // without SO_PASSCRED set on the socket, ReadMsgUnix will + // return zero oob bytes + t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn) + } + oob2 = oob2[:oobn2] + if !bytes.Equal(oob, oob2) { + t.Fatal("ReadMsgUnix oob bytes don't match") + } - // this is going to send a dummy byte - n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) - if err != nil { - t.Fatalf("WriteMsgUnix: %v", err) - } - if n != 0 { - t.Fatalf("WriteMsgUnix n = %d, want 0", n) - } - if oobn != len(oob) { - t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob)) - } - - oob2 := make([]byte, 10*len(oob)) - n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2) - if err != nil { - t.Fatalf("ReadMsgUnix: %v", err) - } - if flags != 0 { - t.Fatalf("ReadMsgUnix flags = 0x%x, want 0", flags) - } - if n != 1 { - t.Fatalf("ReadMsgUnix n = %d, want 1 (dummy byte)", n) - } - if oobn2 != oobn { - // without SO_PASSCRED set on the socket, ReadMsgUnix will - // return zero oob bytes - t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn) - } - oob2 = oob2[:oobn2] - if !bytes.Equal(oob, oob2) { - t.Fatal("ReadMsgUnix oob bytes don't match") - } - - scm, err := unix.ParseSocketControlMessage(oob2) - if err != nil { - t.Fatalf("ParseSocketControlMessage: %v", err) - } - newUcred, err := unix.ParseUnixCredentials(&scm[0]) - if err != nil { - t.Fatalf("ParseUnixCredentials: %v", err) - } - if *newUcred != ucred { - t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred) + scm, err := unix.ParseSocketControlMessage(oob2) + if err != nil { + t.Fatalf("ParseSocketControlMessage: %v", err) + } + newUcred, err := unix.ParseUnixCredentials(&scm[0]) + if err != nil { + t.Fatalf("ParseUnixCredentials: %v", err) + } + if *newUcred != ucred { + t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred) + } } } diff --git a/unix/syscall_linux.go b/unix/syscall_linux.go index 2afe62bf..1b7d59d8 100644 --- a/unix/syscall_linux.go +++ b/unix/syscall_linux.go @@ -931,8 +931,13 @@ func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from } var dummy byte if len(oob) > 0 { + var sockType int + sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) + if err != nil { + return + } // receive at least one normal byte - if len(p) == 0 { + if sockType != SOCK_DGRAM && len(p) == 0 { iov.Base = &dummy iov.SetLen(1) } @@ -978,8 +983,13 @@ func SendmsgN(fd int, p, oob []byte, to Sockaddr, flags int) (n int, err error) } var dummy byte if len(oob) > 0 { + var sockType int + sockType, err = GetsockoptInt(fd, SOL_SOCKET, SO_TYPE) + if err != nil { + return 0, err + } // send at least one normal byte - if len(p) == 0 { + if sockType != SOCK_DGRAM && len(p) == 0 { iov.Base = &dummy iov.SetLen(1) }