351 lines
9.8 KiB
Diff
351 lines
9.8 KiB
Diff
From f0274e3664a690d8ef1dd05ba7dc68869df7806f Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Thu, 31 Aug 2023 18:09:09 +0800
|
|
Subject: [PATCH 03/10] server: simplify udp oob handling
|
|
|
|
---
|
|
pkg/server/udp.go | 71 +++++++---------
|
|
pkg/server/udp_linux.go | 112 +++++++++++--------------
|
|
pkg/server/udp_others.go | 4 +-
|
|
plugin/server/udp_server/udp_server.go | 5 +-
|
|
4 files changed, 86 insertions(+), 106 deletions(-)
|
|
|
|
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
|
|
index 45d689c..8980a08 100644
|
|
--- a/pkg/server/udp.go
|
|
+++ b/pkg/server/udp.go
|
|
@@ -22,14 +22,14 @@ package server
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
+ "net"
|
|
+
|
|
"github.com/IrineSistiana/mosdns/v5/mlog"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/server/dns_handler"
|
|
- "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
"github.com/miekg/dns"
|
|
"go.uber.org/zap"
|
|
- "net"
|
|
)
|
|
|
|
type UDPServer struct {
|
|
@@ -53,39 +53,31 @@ func (opts *UDPServerOpts) init() {
|
|
return
|
|
}
|
|
|
|
-// cmcUDPConn can read and write cmsg.
|
|
-type cmcUDPConn interface {
|
|
- readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error)
|
|
- writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error)
|
|
-}
|
|
-
|
|
// ServeUDP starts a server at c. It returns if c had a read error.
|
|
// It always returns a non-nil error.
|
|
-func (s *UDPServer) ServeUDP(c net.PacketConn) error {
|
|
+func (s *UDPServer) ServeUDP(c *net.UDPConn) error {
|
|
listenerCtx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
rb := pool.GetBuf(dns.MaxMsgSize)
|
|
defer pool.ReleaseBuf(rb)
|
|
|
|
- var cmc cmcUDPConn
|
|
- var err error
|
|
- uc, ok := c.(*net.UDPConn)
|
|
- if ok && uc.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
|
|
- cmc, err = newCmc(uc)
|
|
- if err != nil {
|
|
- return fmt.Errorf("failed to control socket cmsg, %w", err)
|
|
- }
|
|
- } else {
|
|
- cmc = newDummyCmc(c)
|
|
+ oobReader, oobWriter, err := initOobHandler(c)
|
|
+ if err != nil {
|
|
+ return fmt.Errorf("failed to init oob handler, %w", err)
|
|
+ }
|
|
+ var ob []byte
|
|
+ if oobReader != nil {
|
|
+ ob := pool.GetBuf(1024)
|
|
+ defer pool.ReleaseBuf(ob)
|
|
}
|
|
|
|
for {
|
|
- n, localAddr, ifIndex, remoteAddr, err := cmc.readFrom(rb)
|
|
+ n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob)
|
|
if err != nil {
|
|
return fmt.Errorf("unexpected read err: %w", err)
|
|
}
|
|
- clientAddr := utils.GetAddrFromAddr(remoteAddr)
|
|
+ clientAddr := remoteAddr.Addr()
|
|
|
|
q := new(dns.Msg)
|
|
if err := q.Unpack(rb[:n]); err != nil {
|
|
@@ -93,6 +85,15 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error {
|
|
continue
|
|
}
|
|
|
|
+ var dstIpFromCm net.IP
|
|
+ if oobReader != nil {
|
|
+ var err error
|
|
+ dstIpFromCm, err = oobReader(ob[:oobn])
|
|
+ if err != nil {
|
|
+ s.opts.Logger.Error("failed to get dst address from oob", zap.Error(err))
|
|
+ }
|
|
+ }
|
|
+
|
|
// handle query
|
|
go func() {
|
|
qCtx := query_context.NewContext(q)
|
|
@@ -110,7 +111,12 @@ func (s *UDPServer) ServeUDP(c net.PacketConn) error {
|
|
return
|
|
}
|
|
defer pool.ReleaseBuf(buf)
|
|
- if _, err := cmc.writeTo(b, localAddr, ifIndex, remoteAddr); err != nil {
|
|
+ var oob []byte
|
|
+
|
|
+ if oobWriter != nil && dstIpFromCm != nil {
|
|
+ oob = oobWriter(dstIpFromCm)
|
|
+ }
|
|
+ if _, _, err := c.WriteMsgUDPAddrPort(b, oob, remoteAddr); err != nil {
|
|
s.opts.Logger.Warn("failed to write response", zap.Stringer("client", remoteAddr), zap.Error(err))
|
|
}
|
|
}
|
|
@@ -129,22 +135,5 @@ func getUDPSize(m *dns.Msg) int {
|
|
return int(s)
|
|
}
|
|
|
|
-// newDummyCmc returns a dummyCmcWrapper.
|
|
-func newDummyCmc(c net.PacketConn) cmcUDPConn {
|
|
- return dummyCmcWrapper{c: c}
|
|
-}
|
|
-
|
|
-// dummyCmcWrapper is just a wrapper that implements cmcUDPConn but does not
|
|
-// write or read any control msg.
|
|
-type dummyCmcWrapper struct {
|
|
- c net.PacketConn
|
|
-}
|
|
-
|
|
-func (w dummyCmcWrapper) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
|
|
- n, src, err = w.c.ReadFrom(b)
|
|
- return
|
|
-}
|
|
-
|
|
-func (w dummyCmcWrapper) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
|
|
- return w.c.WriteTo(b, dst)
|
|
-}
|
|
+type getSrcAddrFromOOB func(oob []byte) (net.IP, error)
|
|
+type writeSrcAddrToOOB func(a net.IP) []byte
|
|
diff --git a/pkg/server/udp_linux.go b/pkg/server/udp_linux.go
|
|
index 4eb466d..9728a39 100644
|
|
--- a/pkg/server/udp_linux.go
|
|
+++ b/pkg/server/udp_linux.go
|
|
@@ -22,84 +22,71 @@
|
|
package server
|
|
|
|
import (
|
|
+ "errors"
|
|
"fmt"
|
|
+ "net"
|
|
+ "os"
|
|
+
|
|
"golang.org/x/net/ipv4"
|
|
"golang.org/x/net/ipv6"
|
|
"golang.org/x/sys/unix"
|
|
- "net"
|
|
- "os"
|
|
)
|
|
|
|
-type ipv4cmc struct {
|
|
- c *ipv4.PacketConn
|
|
-}
|
|
-
|
|
-func newIpv4cmc(c *ipv4.PacketConn) *ipv4cmc {
|
|
- return &ipv4cmc{c: c}
|
|
-}
|
|
+var (
|
|
+ errCmNoDstAddr = errors.New("control msg does not have dst address")
|
|
+)
|
|
|
|
-func (i *ipv4cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
|
|
- n, cm, src, err := i.c.ReadFrom(b)
|
|
- if cm != nil {
|
|
- dst, IfIndex = cm.Dst, cm.IfIndex
|
|
+func getOOBFromCM4(oob []byte) (net.IP, error) {
|
|
+ var cm ipv4.ControlMessage
|
|
+ if err := cm.Parse(oob); err != nil {
|
|
+ return nil, err
|
|
}
|
|
- return
|
|
-}
|
|
-
|
|
-func (i *ipv4cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
|
|
- cm := &ipv4.ControlMessage{
|
|
- Src: src,
|
|
- IfIndex: IfIndex,
|
|
+ if cm.Dst == nil {
|
|
+ return nil, errCmNoDstAddr
|
|
}
|
|
- return i.c.WriteTo(b, cm, dst)
|
|
+ return cm.Dst, nil
|
|
}
|
|
|
|
-type ipv6cmc struct {
|
|
- c4 *ipv4.PacketConn // ipv4 entrypoint for sending ipv4 packages.
|
|
- c6 *ipv6.PacketConn
|
|
+func getOOBFromCM6(oob []byte) (net.IP, error) {
|
|
+ var cm ipv6.ControlMessage
|
|
+ if err := cm.Parse(oob); err != nil {
|
|
+ return nil, err
|
|
+ }
|
|
+ if cm.Dst == nil {
|
|
+ return nil, errCmNoDstAddr
|
|
+ }
|
|
+ return cm.Dst, nil
|
|
}
|
|
|
|
-func newIpv6PacketConn(c4 *ipv4.PacketConn, c6 *ipv6.PacketConn) *ipv6cmc {
|
|
- return &ipv6cmc{c4: c4, c6: c6}
|
|
-}
|
|
+func srcIP2Cm(ip net.IP) []byte {
|
|
+ if ip4 := ip.To4(); ip4 != nil {
|
|
+ return (&ipv4.ControlMessage{
|
|
+ Src: ip,
|
|
+ }).Marshal()
|
|
+ }
|
|
|
|
-func (i *ipv6cmc) readFrom(b []byte) (n int, dst net.IP, IfIndex int, src net.Addr, err error) {
|
|
- n, cm, src, err := i.c6.ReadFrom(b)
|
|
- if cm != nil {
|
|
- dst, IfIndex = cm.Dst, cm.IfIndex
|
|
+ if ip6 := ip.To16(); ip6 != nil {
|
|
+ return (&ipv6.ControlMessage{
|
|
+ Src: ip,
|
|
+ }).Marshal()
|
|
}
|
|
- return
|
|
+
|
|
+ return nil
|
|
}
|
|
|
|
-func (i *ipv6cmc) writeTo(b []byte, src net.IP, IfIndex int, dst net.Addr) (n int, err error) {
|
|
- if src != nil {
|
|
- // If src is ipv4, use IP_PKTINFO instead of IPV6_PKTINFO.
|
|
- // Otherwise, sendmsg will raise "invalid argument" error.
|
|
- // No official doc found.
|
|
- if src4 := src.To4(); src4 != nil {
|
|
- cm4 := &ipv4.ControlMessage{
|
|
- Src: src4,
|
|
- IfIndex: IfIndex,
|
|
- }
|
|
- return i.c4.WriteTo(b, cm4, dst)
|
|
- }
|
|
- }
|
|
- cm6 := &ipv6.ControlMessage{
|
|
- Src: src,
|
|
- IfIndex: IfIndex,
|
|
+func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
|
|
+ if !c.LocalAddr().(*net.UDPAddr).IP.IsUnspecified() {
|
|
+ return nil, nil, nil
|
|
}
|
|
- return i.c6.WriteTo(b, cm6, dst)
|
|
-}
|
|
|
|
-func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
|
|
sc, err := c.SyscallConn()
|
|
if err != nil {
|
|
- return nil, err
|
|
+ return nil, nil, err
|
|
}
|
|
|
|
+ var getter getSrcAddrFromOOB
|
|
+ var setter writeSrcAddrToOOB
|
|
var controlErr error
|
|
- var cmc cmcUDPConn
|
|
-
|
|
if err := sc.Control(func(fd uintptr) {
|
|
v, err := unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_DOMAIN)
|
|
if err != nil {
|
|
@@ -109,27 +96,30 @@ func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
|
|
switch v {
|
|
case unix.AF_INET:
|
|
c4 := ipv4.NewPacketConn(c)
|
|
- if err := c4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil {
|
|
+ if err := c4.SetControlMessage(ipv4.FlagDst, true); err != nil {
|
|
controlErr = fmt.Errorf("failed to set ipv4 cmsg flags, %w", err)
|
|
}
|
|
- cmc = newIpv4cmc(c4)
|
|
+
|
|
+ getter = getOOBFromCM4
|
|
+ setter = srcIP2Cm
|
|
return
|
|
case unix.AF_INET6:
|
|
c6 := ipv6.NewPacketConn(c)
|
|
- if err := c6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil {
|
|
+ if err := c6.SetControlMessage(ipv6.FlagDst, true); err != nil {
|
|
controlErr = fmt.Errorf("failed to set ipv6 cmsg flags, %w", err)
|
|
}
|
|
- cmc = newIpv6PacketConn(ipv4.NewPacketConn(c), c6)
|
|
+ getter = getOOBFromCM6
|
|
+ setter = srcIP2Cm
|
|
return
|
|
default:
|
|
controlErr = fmt.Errorf("socket protocol %d is not supported", v)
|
|
}
|
|
}); err != nil {
|
|
- return nil, fmt.Errorf("control fd err, %w", controlErr)
|
|
+ return nil, nil, fmt.Errorf("control fd err, %w", controlErr)
|
|
}
|
|
|
|
if controlErr != nil {
|
|
- return nil, fmt.Errorf("failed to set up socket, %w", controlErr)
|
|
+ return nil, nil, fmt.Errorf("failed to set up socket, %w", controlErr)
|
|
}
|
|
- return cmc, nil
|
|
+ return getter, setter, nil
|
|
}
|
|
diff --git a/pkg/server/udp_others.go b/pkg/server/udp_others.go
|
|
index 8ce6280..1e42651 100644
|
|
--- a/pkg/server/udp_others.go
|
|
+++ b/pkg/server/udp_others.go
|
|
@@ -23,6 +23,6 @@ package server
|
|
|
|
import "net"
|
|
|
|
-func newCmc(c *net.UDPConn) (cmcUDPConn, error) {
|
|
- return newDummyCmc(c), nil
|
|
+func initOobHandler(c *net.UDPConn) (getSrcAddrFromOOB, writeSrcAddrToOOB, error) {
|
|
+ return nil, nil, nil
|
|
}
|
|
diff --git a/plugin/server/udp_server/udp_server.go b/plugin/server/udp_server/udp_server.go
|
|
index 9ed94f5..293f720 100644
|
|
--- a/plugin/server/udp_server/udp_server.go
|
|
+++ b/plugin/server/udp_server/udp_server.go
|
|
@@ -21,11 +21,12 @@ package udp_server
|
|
|
|
import (
|
|
"fmt"
|
|
+ "net"
|
|
+
|
|
"github.com/IrineSistiana/mosdns/v5/coremain"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/server"
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
"github.com/IrineSistiana/mosdns/v5/plugin/server/server_utils"
|
|
- "net"
|
|
)
|
|
|
|
const PluginType = "udp_server"
|
|
@@ -71,7 +72,7 @@ func StartServer(bp *coremain.BP, args *Args) (*UdpServer, error) {
|
|
}
|
|
go func() {
|
|
defer c.Close()
|
|
- err := s.ServeUDP(c)
|
|
+ err := s.ServeUDP(c.(*net.UDPConn))
|
|
bp.M().GetSafeClose().SendCloseSignal(err)
|
|
}()
|
|
return &UdpServer{
|
|
--
|
|
2.34.8
|
|
|