diff --git a/unix/syscall_internal_bsd_test.go b/unix/syscall_internal_bsd_test.go index 5e3e0529..6a796c6f 100644 --- a/unix/syscall_internal_bsd_test.go +++ b/unix/syscall_internal_bsd_test.go @@ -13,20 +13,12 @@ import ( "unsafe" ) -// as per socket(2) -type SocketSpec struct { - domain int - typ int - protocol int -} - func Test_anyToSockaddr(t *testing.T) { tests := []struct { name string rsa *RawSockaddrAny sa Sockaddr err error - skt SocketSpec }{ { name: "AF_UNIX zero length", @@ -80,19 +72,6 @@ func Test_anyToSockaddr(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fd := int(0) - var err error - if tt.skt.domain != 0 { - fd, err = Socket(tt.skt.domain, tt.skt.typ, tt.skt.protocol) - // Some sockaddr types need specific kernel modules running: if these - // are not present we'll get EPROTONOSUPPORT back when trying to create - // the socket. Skip the test in this situation. - if err == EPROTONOSUPPORT { - t.Skip("socket family/protocol not supported by kernel") - } else if err != nil { - t.Fatalf("socket(%v): %v", tt.skt, err) - } - defer Close(fd) - } sa, err := anyToSockaddr(fd, tt.rsa) if err != tt.err { t.Fatalf("unexpected error: %v, want: %v", err, tt.err) diff --git a/unix/syscall_internal_darwin_test.go b/unix/syscall_internal_darwin_test.go index 529e9778..e4431447 100644 --- a/unix/syscall_internal_darwin_test.go +++ b/unix/syscall_internal_darwin_test.go @@ -55,7 +55,6 @@ func Test_anyToSockaddr_darwin(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fd := int(0) - var err error sa, err := anyToSockaddr(fd, tt.rsa) if err != tt.err { t.Fatalf("unexpected error: %v, want: %v", err, tt.err) diff --git a/unix/syscall_internal_linux_test.go b/unix/syscall_internal_linux_test.go index 5a4f4f6d..4c64f54f 100644 --- a/unix/syscall_internal_linux_test.go +++ b/unix/syscall_internal_linux_test.go @@ -13,20 +13,17 @@ import ( "unsafe" ) -// as per socket(2) -type SocketSpec struct { - domain int - typ int - protocol int +func makeProto(proto int) *int { + return &proto } func Test_anyToSockaddr(t *testing.T) { tests := []struct { - name string - rsa *RawSockaddrAny - sa Sockaddr - err error - skt SocketSpec + name string + rsa *RawSockaddrAny + sa Sockaddr + err error + proto *int }{ { name: "AF_TIPC bad addrtype", @@ -109,7 +106,7 @@ func Test_anyToSockaddr(t *testing.T) { Addr: [4]byte{0xef, 0x10, 0x5b, 0xa2}, ConnId: 0x1234abcd, }, - skt: SocketSpec{domain: AF_INET, typ: SOCK_DGRAM, protocol: IPPROTO_L2TP}, + proto: makeProto(IPPROTO_L2TP), }, { name: "AF_INET6 IPPROTO_L2TP", @@ -135,7 +132,7 @@ func Test_anyToSockaddr(t *testing.T) { ZoneId: 90210, ConnId: 0x1234abcd, }, - skt: SocketSpec{domain: AF_INET6, typ: SOCK_DGRAM, protocol: IPPROTO_L2TP}, + proto: makeProto(IPPROTO_L2TP), }, { name: "AF_UNIX unnamed/abstract", @@ -185,7 +182,7 @@ func Test_anyToSockaddr(t *testing.T) { RxID: 0xAAAAAAAA, TxID: 0xBBBBBBBB, }, - skt: SocketSpec{domain: AF_CAN, typ: SOCK_RAW, protocol: CAN_RAW}, + proto: makeProto(CAN_RAW), }, { name: "AF_CAN CAN_J1939", @@ -205,7 +202,7 @@ func Test_anyToSockaddr(t *testing.T) { PGN: 0xBBBBBBBB, Addr: 0xCC, }, - skt: SocketSpec{domain: AF_CAN, typ: SOCK_DGRAM, protocol: CAN_J1939}, + proto: makeProto(CAN_J1939), }, { name: "AF_MAX EAFNOSUPPORT", @@ -219,27 +216,15 @@ func Test_anyToSockaddr(t *testing.T) { // TODO: expand to support other families. } + realSocketProtocol := socketProtocol + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fd := int(0) - var err error - if tt.skt.domain != 0 { - fd, err = Socket(tt.skt.domain, tt.skt.typ, tt.skt.protocol) - // Some sockaddr types need specific kernel modules running: if these - // are not present we'll get EPROTONOSUPPORT/EAFNOSUPPORT back when - // trying to create the socket. Skip the test in this situation. - if err == EPROTONOSUPPORT { - t.Skip("socket family/protocol not supported by kernel") - } else if err == EAFNOSUPPORT { - t.Skip("socket address family not supported by kernel") - } else if err == EACCES { - // Some platforms might require elevated privileges to perform - // actions on sockets. Skip the test in this situation. - t.Skip("socket operation not permitted") - } else if err != nil { - t.Fatalf("socket(%v): %v", tt.skt, err) - } - defer Close(fd) + if tt.proto != nil { + socketProtocol = func(fd int) (int, error) { return *tt.proto, nil } + } else { + socketProtocol = realSocketProtocol } sa, err := anyToSockaddr(fd, tt.rsa) if err != tt.err { diff --git a/unix/syscall_linux.go b/unix/syscall_linux.go index 1fdfe48e..28be1306 100644 --- a/unix/syscall_linux.go +++ b/unix/syscall_linux.go @@ -982,6 +982,10 @@ func (sa *SockaddrIUCV) sockaddr() (unsafe.Pointer, _Socklen, error) { return unsafe.Pointer(&sa.raw), SizeofSockaddrIUCV, nil } +var socketProtocol = func(fd int) (int, error) { + return GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) +} + func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { switch rsa.Addr.Family { case AF_NETLINK: @@ -1032,7 +1036,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { return sa, nil case AF_INET: - proto, err := GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) + proto, err := socketProtocol(fd) if err != nil { return nil, err } @@ -1058,7 +1062,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { } case AF_INET6: - proto, err := GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) + proto, err := socketProtocol(fd) if err != nil { return nil, err } @@ -1093,7 +1097,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { } return sa, nil case AF_BLUETOOTH: - proto, err := GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) + proto, err := socketProtocol(fd) if err != nil { return nil, err } @@ -1180,7 +1184,7 @@ func anyToSockaddr(fd int, rsa *RawSockaddrAny) (Sockaddr, error) { return sa, nil case AF_CAN: - proto, err := GetsockoptInt(fd, SOL_SOCKET, SO_PROTOCOL) + proto, err := socketProtocol(fd) if err != nil { return nil, err }