From f0274e3664a690d8ef1dd05ba7dc68869df7806f Mon Sep 17 00:00:00 2001 From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com> Date: Thu, 31 Aug 2023 18:09:09 +0800 Subject: [PATCH 03/10] server: simplify udp oob handling --- pkg/server/udp.go | 71 +++++++--------- pkg/server/udp_linux.go | 112 +++++++++++-------------- pkg/server/udp_others.go | 4 +- plugin/server/udp_server/udp_server.go | 5 +- 4 files changed, 86 insertions(+), 106 deletions(-) diff --git a/pkg/server/udp.go b/pkg/server/udp.go index 45d689c..8980a08 100644 --- a/pkg/server/udp.go +++ b/pkg/server/udp.go @@ -22,14 +22,14 @@ package server import ( "context" "fmt" + "net" + "github.com/IrineSistiana/mosdns/v5/mlog" "github.com/IrineSistiana/mosdns/v5/pkg/pool" "github.com/IrineSistiana/mosdns/v5/pkg/query_context" "github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler" - "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/miekg/dns" "go.uber.org/zap" - "net" ) type UDPServer struct { @@ -53,39 +53,31 @@ func (opts *UDPServerOpts) init() { return } -// cmcUDPConn can read and write cmsg. -type cmcUDPConn interface { - readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) - writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) -} - // ServeUDP starts a server at c. It returns if c had a read error. // It always returns a non-nil error. -func (s *UDPServer) ServeUDP(c net.PacketConn) error { +func (s *UDPServer) ServeUDP(c *net.UDPConn) error { listenerCtx, cancel := context.WithCancel(context.Background()) defer cancel() rb := pool.GetBuf(dns.MaxMsgSize) defer pool.ReleaseBuf(rb) - var cmc cmcUDPConn - var err error - uc, ok := c.(*net.UDPConn) - if ok && uc.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() { - cmc, err = newCmc(uc) - if err != nil { - return fmt.Errorf("failed to control socket cmsg, %w", err) - } - } else { - cmc = newDummyCmc(c) + oobReader, oobWriter, err := initOobHandler(c) + if err != nil { + return fmt.Errorf("failed to init oob handler, %w", err) + } + var ob []byte + if oobReader != nil { + ob := pool.GetBuf(1024) + defer pool.ReleaseBuf(ob) } for { - n, localAddr, ifIndex, remoteAddr, err := cmc.readFrom(rb) + n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob) if err != nil { return fmt.Errorf("unexpected read err: %w", err) } - clientAddr := utils.GetAddrFromAddr(remoteAddr) + clientAddr := remoteAddr.Addr() q := new(dns.Msg) if err := q.Unpack(rb[:n]); err != nil { @@ -93,6 +85,15 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error { continue } + var dstIpFromCm net.IP + if oobReader != nil { + var err error + dstIpFromCm, err = oobReader(ob[:oobn]) + if err != nil { + s.opts.Logger.Error("failed to get dst address from oob", zap.Error(err)) + } + } + // handle query go func() { qCtx := query_context.NewContext(q) @@ -110,7 +111,12 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error { return } defer pool.ReleaseBuf(buf) - if _, err := cmc.writeTo(b, localAddr, ifIndex, remoteAddr); err != nil { + var oob []byte + + if oobWriter != nil && dstIpFromCm != nil { + oob = oobWriter(dstIpFromCm) + } + if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil { s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err)) } } @@ -129,22 +135,5 @@ func getUDPSize(m *dns.Msg) int { return int(s) } -// newDummyCmc returns a dummyCmcWrapper. -func newDummyCmc(c net.PacketConn) cmcUDPConn { - return dummyCmcWrapper{c: c} -} - -// dummyCmcWrapper is just a wrapper that implements cmcUDPConn but does not -// write or read any control msg. -type dummyCmcWrapper struct { - c net.PacketConn -} - -func (w dummyCmcWrapper) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { - n, src, err = w.c.ReadFrom(b) - return -} - -func (w dummyCmcWrapper) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { - return w.c.WriteTo(b, dst) -} +type getSrcAddrFromOOB func(oob []byte) (net.IP, error) +type writeSrcAddrToOOB func(a net.IP) []byte diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go index 4eb466d..9728a39 100644 --- a/pkg/server/udp_linux.go +++ b/pkg/server/udp_linux.go @@ -22,84 +22,71 @@ package server import ( + "errors" "fmt" + "net" + "os" + "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "golang.org/x/sys/unix" - "net" - "os" ) -type ipv4cmc struct { - c *ipv4.PacketConn -} - -func newIpv4cmc(c *ipv4.PacketConn) *ipv4cmc { - return &ipv4cmc{c: c} -} +var ( + errCmNoDstAddr = errors.New("control msg does not have dst address") +) -func (i *ipv4cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { - n, cm, src, err := i.c.ReadFrom(b) - if cm != nil { - dst, IfIndex = cm.Dst, cm.IfIndex +func getOOBFromCM4(oob []byte) (net.IP, error) { + var cm ipv4.ControlMessage + if err := cm.Parse(oob); err != nil { + return nil, err } - return -} - -func (i *ipv4cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { - cm := &ipv4.ControlMessage{ - Src: src, - IfIndex: IfIndex, + if cm.Dst == nil { + return nil, errCmNoDstAddr } - return i.c.WriteTo(b, cm, dst) + return cm.Dst, nil } -type ipv6cmc struct { - c4 *ipv4.PacketConn // ipv4 entrypoint for sending ipv4 packages. - c6 *ipv6.PacketConn +func getOOBFromCM6(oob []byte) (net.IP, error) { + var cm ipv6.ControlMessage + if err := cm.Parse(oob); err != nil { + return nil, err + } + if cm.Dst == nil { + return nil, errCmNoDstAddr + } + return cm.Dst, nil } -func newIpv6PacketConn(c4 *ipv4.PacketConn, c6 *ipv6.PacketConn) *ipv6cmc { - return &ipv6cmc{c4: c4, c6: c6} -} +func srcIP2Cm(ip net.IP) []byte { + if ip4 := ip.To4(); ip4 != nil { + return (&ipv4.ControlMessage{ + Src: ip, + }).Marshal() + } -func (i *ipv6cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) { - n, cm, src, err := i.c6.ReadFrom(b) - if cm != nil { - dst, IfIndex = cm.Dst, cm.IfIndex + if ip6 := ip.To16(); ip6 != nil { + return (&ipv6.ControlMessage{ + Src: ip, + }).Marshal() } - return + + return nil } -func (i *ipv6cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) { - if src != nil { - // If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO. - // Otherwise, sendmsg will raise "invalid argument" error. - // No official doc found. - if src4 := src.To4(); src4 != nil { - cm4 := &ipv4.ControlMessage{ - Src: src4, - IfIndex: IfIndex, - } - return i.c4.WriteTo(b, cm4, dst) - } - } - cm6 := &ipv6.ControlMessage{ - Src: src, - IfIndex: IfIndex, +func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) { + if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() { + return nil, nil, nil } - return i.c6.WriteTo(b, cm6, dst) -} -func newCmc(c *net.UDPConn) (cmcUDPConn, error) { sc, err := c.SyscallConn() if err != nil { - return nil, err + return nil, nil, err } + var getter getSrcAddrFromOOB + var setter writeSrcAddrToOOB var controlErr error - var cmc cmcUDPConn - if err := sc.Control(func(fd uintptr) { v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN) if err != nil { @@ -109,27 +96,30 @@ func newCmc(c *net.UDPConn) (cmcUDPConn, error) { switch v { case unix.AF_INET: c4 := ipv4.NewPacketConn(c) - if err := c4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { + if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil { controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err) } - cmc = newIpv4cmc(c4) + + getter = getOOBFromCM4 + setter = srcIP2Cm return case unix.AF_INET6: c6 := ipv6.NewPacketConn(c) - if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { + if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil { controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err) } - cmc = newIpv6PacketConn(ipv4.NewPacketConn(c), c6) + getter = getOOBFromCM6 + setter = srcIP2Cm return default: controlErr = fmt.Errorf("socket protocol %d is not supported", v) } }); err != nil { - return nil, fmt.Errorf("control fd err, %w", controlErr) + return nil, nil, fmt.Errorf("control fd err, %w", controlErr) } if controlErr != nil { - return nil, fmt.Errorf("failed to set up socket, %w", controlErr) + return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr) } - return cmc, nil + return getter, setter, nil } diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go index 8ce6280..1e42651 100644 --- a/pkg/server/udp_others.go +++ b/pkg/server/udp_others.go @@ -23,6 +23,6 @@ package server import "net" -func newCmc(c *net.UDPConn) (cmcUDPConn, error) { - return newDummyCmc(c), nil +func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) { + return nil, nil, nil } diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go index 9ed94f5..293f720 100644 --- a/plugin/server/udp_server/udp_server.go +++ b/plugin/server/udp_server/udp_server.go @@ -21,11 +21,12 @@ package udp_server import ( "fmt" + "net" + "github.com/IrineSistiana/mosdns/v5/coremain" "github.com/IrineSistiana/mosdns/v5/pkg/server" "github.com/IrineSistiana/mosdns/v5/pkg/utils" "github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils" - "net" ) const PluginType = "udp_server" @@ -71,7 +72,7 @@ func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) { } go func() { defer c.Close() - err := s.ServeUDP(c) + err := s.ServeUDP(c.(*net.UDPConn)) bp.M().GetSafeClose().SendCloseSignal(err) }() return &UdpServer{ -- 2.34.8