luci-app-mosdns/mosdns/patches/102-server-simplify-udp-oob-handling.patch
2023-09-18 14:59:04 +08:00

351 lines
9.8 KiB
Diff

From f0274e3664a690d8ef1dd05ba7dc68869df7806f Mon Sep 17 00:00:00 2001
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
Date: Thu, 31 Aug 2023 18:09:09 +0800
Subject: [PATCH 03/10] server: simplify udp oob handling
---
pkg/server/udp.go | 71 +++++++---------
pkg/server/udp_linux.go | 112 +++++++++++--------------
pkg/server/udp_others.go | 4 +-
plugin/server/udp_server/udp_server.go | 5 +-
4 files changed, 86 insertions(+), 106 deletions(-)
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
index 45d689c..8980a08 100644
--- a/pkg/server/udp.go
+++ b/pkg/server/udp.go
@@ -22,14 +22,14 @@ package server
import (
"context"
"fmt"
+ "net"
+
"github.com/IrineSistiana/mosdns/v5/mlog"
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
"github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler"
- "github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/miekg/dns"
"go.uber.org/zap"
- "net"
)
type UDPServer struct {
@@ -53,39 +53,31 @@ func (opts *UDPServerOpts) init() {
return
}
-// cmcUDPConn can read and write cmsg.
-type cmcUDPConn interface {
- readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error)
- writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error)
-}
-
// ServeUDP starts a server at c. It returns if c had a read error.
// It always returns a non-nil error.
-func (s *UDPServer) ServeUDP(c net.PacketConn) error {
+func (s *UDPServer) ServeUDP(c *net.UDPConn) error {
listenerCtx, cancel := context.WithCancel(context.Background())
defer cancel()
rb := pool.GetBuf(dns.MaxMsgSize)
defer pool.ReleaseBuf(rb)
- var cmc cmcUDPConn
- var err error
- uc, ok := c.(*net.UDPConn)
- if ok && uc.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
- cmc, err = newCmc(uc)
- if err != nil {
- return fmt.Errorf("failed to control socket cmsg, %w", err)
- }
- } else {
- cmc = newDummyCmc(c)
+ oobReader, oobWriter, err := initOobHandler(c)
+ if err != nil {
+ return fmt.Errorf("failed to init oob handler, %w", err)
+ }
+ var ob []byte
+ if oobReader != nil {
+ ob := pool.GetBuf(1024)
+ defer pool.ReleaseBuf(ob)
}
for {
- n, localAddr, ifIndex, remoteAddr, err := cmc.readFrom(rb)
+ n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob)
if err != nil {
return fmt.Errorf("unexpected read err: %w", err)
}
- clientAddr := utils.GetAddrFromAddr(remoteAddr)
+ clientAddr := remoteAddr.Addr()
q := new(dns.Msg)
if err := q.Unpack(rb[:n]); err != nil {
@@ -93,6 +85,15 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error {
continue
}
+ var dstIpFromCm net.IP
+ if oobReader != nil {
+ var err error
+ dstIpFromCm, err = oobReader(ob[:oobn])
+ if err != nil {
+ s.opts.Logger.Error("failed to get dst address from oob", zap.Error(err))
+ }
+ }
+
// handle query
go func() {
qCtx := query_context.NewContext(q)
@@ -110,7 +111,12 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error {
return
}
defer pool.ReleaseBuf(buf)
- if _, err := cmc.writeTo(b, localAddr, ifIndex, remoteAddr); err != nil {
+ var oob []byte
+
+ if oobWriter != nil && dstIpFromCm != nil {
+ oob = oobWriter(dstIpFromCm)
+ }
+ if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil {
s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
}
}
@@ -129,22 +135,5 @@ func getUDPSize(m *dns.Msg) int {
return int(s)
}
-// newDummyCmc returns a dummyCmcWrapper.
-func newDummyCmc(c net.PacketConn) cmcUDPConn {
- return dummyCmcWrapper{c: c}
-}
-
-// dummyCmcWrapper is just a wrapper that implements cmcUDPConn but does not
-// write or read any control msg.
-type dummyCmcWrapper struct {
- c net.PacketConn
-}
-
-func (w dummyCmcWrapper) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
- n, src, err = w.c.ReadFrom(b)
- return
-}
-
-func (w dummyCmcWrapper) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
- return w.c.WriteTo(b, dst)
-}
+type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
+type writeSrcAddrToOOB func(a net.IP) []byte
diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go
index 4eb466d..9728a39 100644
--- a/pkg/server/udp_linux.go
+++ b/pkg/server/udp_linux.go
@@ -22,84 +22,71 @@
package server
import (
+ "errors"
"fmt"
+ "net"
+ "os"
+
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
- "net"
- "os"
)
-type ipv4cmc struct {
- c *ipv4.PacketConn
-}
-
-func newIpv4cmc(c *ipv4.PacketConn) *ipv4cmc {
- return &ipv4cmc{c: c}
-}
+var (
+ errCmNoDstAddr = errors.New("control msg does not have dst address")
+)
-func (i *ipv4cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
- n, cm, src, err := i.c.ReadFrom(b)
- if cm != nil {
- dst, IfIndex = cm.Dst, cm.IfIndex
+func getOOBFromCM4(oob []byte) (net.IP, error) {
+ var cm ipv4.ControlMessage
+ if err := cm.Parse(oob); err != nil {
+ return nil, err
}
- return
-}
-
-func (i *ipv4cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
- cm := &ipv4.ControlMessage{
- Src: src,
- IfIndex: IfIndex,
+ if cm.Dst == nil {
+ return nil, errCmNoDstAddr
}
- return i.c.WriteTo(b, cm, dst)
+ return cm.Dst, nil
}
-type ipv6cmc struct {
- c4 *ipv4.PacketConn // ipv4 entrypoint for sending ipv4 packages.
- c6 *ipv6.PacketConn
+func getOOBFromCM6(oob []byte) (net.IP, error) {
+ var cm ipv6.ControlMessage
+ if err := cm.Parse(oob); err != nil {
+ return nil, err
+ }
+ if cm.Dst == nil {
+ return nil, errCmNoDstAddr
+ }
+ return cm.Dst, nil
}
-func newIpv6PacketConn(c4 *ipv4.PacketConn, c6 *ipv6.PacketConn) *ipv6cmc {
- return &ipv6cmc{c4: c4, c6: c6}
-}
+func srcIP2Cm(ip net.IP) []byte {
+ if ip4 := ip.To4(); ip4 != nil {
+ return (&ipv4.ControlMessage{
+ Src: ip,
+ }).Marshal()
+ }
-func (i *ipv6cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
- n, cm, src, err := i.c6.ReadFrom(b)
- if cm != nil {
- dst, IfIndex = cm.Dst, cm.IfIndex
+ if ip6 := ip.To16(); ip6 != nil {
+ return (&ipv6.ControlMessage{
+ Src: ip,
+ }).Marshal()
}
- return
+
+ return nil
}
-func (i *ipv6cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
- if src != nil {
- // If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO.
- // Otherwise, sendmsg will raise "invalid argument" error.
- // No official doc found.
- if src4 := src.To4(); src4 != nil {
- cm4 := &ipv4.ControlMessage{
- Src: src4,
- IfIndex: IfIndex,
- }
- return i.c4.WriteTo(b, cm4, dst)
- }
- }
- cm6 := &ipv6.ControlMessage{
- Src: src,
- IfIndex: IfIndex,
+func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
+ if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
+ return nil, nil, nil
}
- return i.c6.WriteTo(b, cm6, dst)
-}
-func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
sc, err := c.SyscallConn()
if err != nil {
- return nil, err
+ return nil, nil, err
}
+ var getter getSrcAddrFromOOB
+ var setter writeSrcAddrToOOB
var controlErr error
- var cmc cmcUDPConn
-
if err := sc.Control(func(fd uintptr) {
v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
if err != nil {
@@ -109,27 +96,30 @@ func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
switch v {
case unix.AF_INET:
c4 := ipv4.NewPacketConn(c)
- if err := c4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
+ if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
}
- cmc = newIpv4cmc(c4)
+
+ getter = getOOBFromCM4
+ setter = srcIP2Cm
return
case unix.AF_INET6:
c6 := ipv6.NewPacketConn(c)
- if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil {
+ if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil {
controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
}
- cmc = newIpv6PacketConn(ipv4.NewPacketConn(c), c6)
+ getter = getOOBFromCM6
+ setter = srcIP2Cm
return
default:
controlErr = fmt.Errorf("socket protocol %d is not supported", v)
}
}); err != nil {
- return nil, fmt.Errorf("control fd err, %w", controlErr)
+ return nil, nil, fmt.Errorf("control fd err, %w", controlErr)
}
if controlErr != nil {
- return nil, fmt.Errorf("failed to set up socket, %w", controlErr)
+ return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr)
}
- return cmc, nil
+ return getter, setter, nil
}
diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go
index 8ce6280..1e42651 100644
--- a/pkg/server/udp_others.go
+++ b/pkg/server/udp_others.go
@@ -23,6 +23,6 @@ package server
import "net"
-func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
- return newDummyCmc(c), nil
+func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
+ return nil, nil, nil
}
diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go
index 9ed94f5..293f720 100644
--- a/plugin/server/udp_server/udp_server.go
+++ b/plugin/server/udp_server/udp_server.go
@@ -21,11 +21,12 @@ package udp_server
import (
"fmt"
+ "net"
+
"github.com/IrineSistiana/mosdns/v5/coremain"
"github.com/IrineSistiana/mosdns/v5/pkg/server"
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
"github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils"
- "net"
)
const PluginType = "udp_server"
@@ -71,7 +72,7 @@ func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) {
}
go func() {
defer c.Close()
- err := s.ServeUDP(c)
+ err := s.ServeUDP(c.(*net.UDPConn))
bp.M().GetSafeClose().SendCloseSignal(err)
}()
return &UdpServer{
--
2.34.8