diff --git a/unix/creds_test.go b/unix/creds_test.go index 1d4c0911..9ab57ecb 100644 --- a/unix/creds_test.go +++ b/unix/creds_test.go @@ -195,3 +195,102 @@ func TestPktInfo(t *testing.T) { }) } } + +func TestParseOrigDstAddr(t *testing.T) { + testcases := []struct { + network string + address *net.UDPAddr + }{ + {"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}}, + {"udp6", &net.UDPAddr{IP: net.IPv6loopback}}, + } + + for _, test := range testcases { + t.Run(test.network, func(t *testing.T) { + conn, err := net.ListenUDP(test.network, test.address) + if errors.Is(err, unix.EADDRNOTAVAIL) { + t.Skipf("%v is not available", test.address) + } + if err != nil { + t.Fatal("Listen:", err) + } + defer conn.Close() + + raw, err := conn.SyscallConn() + if err != nil { + t.Fatal("SyscallConn:", err) + } + + var opErr error + err = raw.Control(func(fd uintptr) { + switch test.network { + case "udp4": + opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1) + case "udp6": + opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) + } + }) + if err != nil { + t.Fatal("Control:", err) + } + if opErr != nil { + t.Fatal("Can't enable RECVORIGDSTADDR:", err) + } + + msg := []byte{1} + addr := conn.LocalAddr().(*net.UDPAddr) + _, err = conn.WriteToUDP(msg, addr) + if err != nil { + t.Fatal("WriteToUDP:", err) + } + + conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6)) + _, oobn, _, _, err := conn.ReadMsgUDP(msg, oob) + if err != nil { + t.Fatal("ReadMsgUDP:", err) + } + + scms, err := unix.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + t.Fatal("ParseSocketControlMessage:", err) + } + + sa, err := unix.ParseOrigDstAddr(&scms[0]) + if err != nil { + t.Fatal("ParseOrigDstAddr:", err) + } + + switch test.network { + case "udp4": + sa4, ok := sa.(*unix.SockaddrInet4) + if !ok { + t.Fatalf("Got %T not *SockaddrInet4", sa) + } + + lo := net.IPv4(127, 0, 0, 1) + if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) { + t.Errorf("Got address %v, want %v", addr, lo) + } + + if sa4.Port != addr.Port { + t.Errorf("Got port %d, want %d", sa4.Port, addr.Port) + } + + case "udp6": + sa6, ok := sa.(*unix.SockaddrInet6) + if !ok { + t.Fatalf("Got %T, want *SockaddrInet6", sa) + } + + if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) { + t.Errorf("Got address %v, want %v", addr, net.IPv6loopback) + } + + if sa6.Port != addr.Port { + t.Errorf("Got port %d, want %d", sa6.Port, addr.Port) + } + } + }) + } +} diff --git a/unix/sockcmsg_linux.go b/unix/sockcmsg_linux.go index e86d543b..326fb04a 100644 --- a/unix/sockcmsg_linux.go +++ b/unix/sockcmsg_linux.go @@ -56,3 +56,34 @@ func PktInfo6(info *Inet6Pktinfo) []byte { *(*Inet6Pktinfo)(h.data(0)) = *info return b } + +// ParseOrigDstAddr decodes a socket control message containing the original +// destination address. To receive such a message the IP_RECVORIGDSTADDR or +// IPV6_RECVORIGDSTADDR option must be enabled on the socket. +func ParseOrigDstAddr(m *SocketControlMessage) (Sockaddr, error) { + switch { + case m.Header.Level == SOL_IP && m.Header.Type == IP_ORIGDSTADDR: + pp := (*RawSockaddrInet4)(unsafe.Pointer(&m.Data[0])) + sa := new(SockaddrInet4) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + for i := 0; i < len(sa.Addr); i++ { + sa.Addr[i] = pp.Addr[i] + } + return sa, nil + + case m.Header.Level == SOL_IPV6 && m.Header.Type == IPV6_ORIGDSTADDR: + pp := (*RawSockaddrInet6)(unsafe.Pointer(&m.Data[0])) + sa := new(SockaddrInet6) + p := (*[2]byte)(unsafe.Pointer(&pp.Port)) + sa.Port = int(p[0])<<8 + int(p[1]) + sa.ZoneId = pp.Scope_id + for i := 0; i < len(sa.Addr); i++ { + sa.Addr[i] = pp.Addr[i] + } + return sa, nil + + default: + return nil, EINVAL + } +}