476 lines
13 KiB
Diff
476 lines
13 KiB
Diff
From 9bb813bf50ba0eed3a35acc8f41764613c11687d Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Sun, 17 Sep 2023 08:29:31 +0800
|
|
Subject: [PATCH 04/10] use go-bytes-pool
|
|
|
|
---
|
|
go.mod | 1 +
|
|
go.sum | 2 +
|
|
pkg/dnsutils/net_io.go | 45 +++++----
|
|
pkg/pool/allocator.go | 78 +---------------
|
|
pkg/pool/allocator_test.go | 119 ------------------------
|
|
pkg/pool/msg_buf.go | 4 +-
|
|
pkg/pool/msg_buf_test.go | 4 +-
|
|
pkg/server/udp.go | 6 +-
|
|
pkg/upstream/transport/dns_conn_test.go | 4 +-
|
|
plugin/executable/cache/cache.go | 8 +-
|
|
10 files changed, 43 insertions(+), 228 deletions(-)
|
|
delete mode 100644 pkg/pool/allocator_test.go
|
|
|
|
diff --git a/go.mod b/go.mod
|
|
index b56078b..2c359e3 100644
|
|
--- a/go.mod
|
|
+++ b/go.mod
|
|
@@ -5,6 +5,7 @@ go 1.21
|
|
toolchain go1.21.1
|
|
|
|
require (
|
|
+ github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686
|
|
github.com/go-chi/chi/v5 v5.0.10
|
|
github.com/google/nftables v0.1.0
|
|
github.com/kardianos/service v1.2.2
|
|
diff --git a/go.sum b/go.sum
|
|
index 633f864..d4bbaa2 100644
|
|
--- a/go.sum
|
|
+++ b/go.sum
|
|
@@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
|
|
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
|
|
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
|
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
|
+github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686 h1:5R32cCep3VUDTKf3aurFKfgbvg+RScuBmZsw/DyyXco=
|
|
+github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI=
|
|
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI=
|
|
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
|
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
|
diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go
|
|
index 1b8a7b1..8fc769b 100644
|
|
--- a/pkg/dnsutils/net_io.go
|
|
+++ b/pkg/dnsutils/net_io.go
|
|
@@ -23,9 +23,10 @@ import (
|
|
"encoding/binary"
|
|
"errors"
|
|
"fmt"
|
|
+ "io"
|
|
+
|
|
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
|
|
"github.com/miekg/dns"
|
|
- "io"
|
|
)
|
|
|
|
var (
|
|
@@ -35,45 +36,43 @@ var (
|
|
// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
|
|
// with a two byte length field).
|
|
// n represents how many bytes are read from c.
|
|
-// The returned the []byte should be released by pool.ReleaseBuf.
|
|
-func ReadRawMsgFromTCP(c io.Reader) ([]byte, int, error) {
|
|
- n := 0
|
|
+// The returned the *[]byte should be released by pool.ReleaseBuf.
|
|
+func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) {
|
|
h := pool.GetBuf(2)
|
|
defer pool.ReleaseBuf(h)
|
|
- nh, err := io.ReadFull(c, h)
|
|
- n += nh
|
|
+ _, err := io.ReadFull(c, *h)
|
|
+
|
|
if err != nil {
|
|
- return nil, n, err
|
|
+ return nil, err
|
|
}
|
|
|
|
// dns length
|
|
- length := binary.BigEndian.Uint16(h)
|
|
+ length := binary.BigEndian.Uint16(*h)
|
|
if length == 0 {
|
|
- return nil, 0, errZeroLenMsg
|
|
+ return nil, errZeroLenMsg
|
|
}
|
|
|
|
- buf := pool.GetBuf(int(length))
|
|
- nm, err := io.ReadFull(c, buf)
|
|
- n += nm
|
|
+ b := pool.GetBuf(int(length))
|
|
+ _, err = io.ReadFull(c, *b)
|
|
if err != nil {
|
|
- pool.ReleaseBuf(buf)
|
|
- return nil, n, err
|
|
+ pool.ReleaseBuf(b)
|
|
+ return nil, err
|
|
}
|
|
- return buf, n, nil
|
|
+ return b, nil
|
|
}
|
|
|
|
// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
|
|
// with a two byte length field).
|
|
// n represents how many bytes are read from c.
|
|
func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) {
|
|
- b, n, err := ReadRawMsgFromTCP(c)
|
|
+ b, err := ReadRawMsgFromTCP(c)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
defer pool.ReleaseBuf(b)
|
|
|
|
- m, err := unpackMsgWithDetailedErr(b)
|
|
- return m, n, err
|
|
+ m, err := unpackMsgWithDetailedErr(*b)
|
|
+ return m, len(*b) + 2, err
|
|
}
|
|
|
|
// WriteMsgToTCP packs and writes m to c in RFC 1035 format.
|
|
@@ -96,9 +95,9 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
|
|
buf := pool.GetBuf(len(b) + 2)
|
|
defer pool.ReleaseBuf(buf)
|
|
|
|
- binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
|
|
- copy(buf[2:], b)
|
|
- return c.Write(buf)
|
|
+ binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b)))
|
|
+ copy((*buf)[2:], b)
|
|
+ return c.Write((*buf))
|
|
}
|
|
|
|
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
|
|
@@ -118,12 +117,12 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {
|
|
|
|
b := pool.GetBuf(bufSize)
|
|
defer pool.ReleaseBuf(b)
|
|
- n, err := c.Read(b)
|
|
+ n, err := c.Read(*b)
|
|
if err != nil {
|
|
return nil, n, err
|
|
}
|
|
|
|
- m, err := unpackMsgWithDetailedErr(b[:n])
|
|
+ m, err := unpackMsgWithDetailedErr((*b)[:n])
|
|
return m, n, err
|
|
}
|
|
|
|
diff --git a/pkg/pool/allocator.go b/pkg/pool/allocator.go
|
|
index 0bfc0bc..d8ac4ef 100644
|
|
--- a/pkg/pool/allocator.go
|
|
+++ b/pkg/pool/allocator.go
|
|
@@ -20,79 +20,11 @@
|
|
package pool
|
|
|
|
import (
|
|
- "fmt"
|
|
- "math"
|
|
- "math/bits"
|
|
- "sync"
|
|
+ bytesPool "github.com/IrineSistiana/go-bytes-pool"
|
|
)
|
|
|
|
// defaultBufPool is an Allocator that has a maximum capacity.
|
|
-var defaultBufPool = NewAllocator()
|
|
-
|
|
-// GetBuf returns a []byte from pool with most appropriate cap.
|
|
-// It panics if size < 0.
|
|
-func GetBuf(size int) []byte {
|
|
- return defaultBufPool.Get(size)
|
|
-}
|
|
-
|
|
-// ReleaseBuf puts the buf to the pool.
|
|
-func ReleaseBuf(b []byte) {
|
|
- defaultBufPool.Release(b)
|
|
-}
|
|
-
|
|
-type Allocator struct {
|
|
- buffers []sync.Pool
|
|
-}
|
|
-
|
|
-// NewAllocator initiates a []byte Allocator.
|
|
-// The waste(memory fragmentation) of space allocation is guaranteed to be
|
|
-// no more than 50%.
|
|
-func NewAllocator() *Allocator {
|
|
- alloc := &Allocator{
|
|
- buffers: make([]sync.Pool, bits.UintSize+1),
|
|
- }
|
|
-
|
|
- for i := range alloc.buffers {
|
|
- var bufSize uint
|
|
- if i == bits.UintSize {
|
|
- bufSize = math.MaxUint
|
|
- } else {
|
|
- bufSize = 1 << i
|
|
- }
|
|
- alloc.buffers[i].New = func() any {
|
|
- b := make([]byte, bufSize)
|
|
- return &b
|
|
- }
|
|
- }
|
|
- return alloc
|
|
-}
|
|
-
|
|
-// Get returns a []byte from pool with most appropriate cap
|
|
-func (alloc *Allocator) Get(size int) []byte {
|
|
- if size < 0 {
|
|
- panic(fmt.Sprintf("invalid slice size %d", size))
|
|
- }
|
|
-
|
|
- i := shard(size)
|
|
- v := alloc.buffers[i].Get()
|
|
- buf := v.(*[]byte)
|
|
- return (*buf)[0:size]
|
|
-}
|
|
-
|
|
-// Release releases the buf to the allocatorL.
|
|
-func (alloc *Allocator) Release(buf []byte) {
|
|
- c := cap(buf)
|
|
- i := shard(c)
|
|
- if c == 0 || c != 1<<i {
|
|
- panic("unexpected cap size")
|
|
- }
|
|
- alloc.buffers[i].Put(&buf)
|
|
-}
|
|
-
|
|
-// shard returns the shard index that is suitable for the size.
|
|
-func shard(size int) int {
|
|
- if size <= 1 {
|
|
- return 0
|
|
- }
|
|
- return bits.Len64(uint64(size - 1))
|
|
-}
|
|
+var (
|
|
+ GetBuf = bytesPool.Get
|
|
+ ReleaseBuf = bytesPool.Release
|
|
+)
|
|
diff --git a/pkg/pool/allocator_test.go b/pkg/pool/allocator_test.go
|
|
deleted file mode 100644
|
|
index 3f6d2a9..0000000
|
|
--- a/pkg/pool/allocator_test.go
|
|
+++ /dev/null
|
|
@@ -1,119 +0,0 @@
|
|
-/*
|
|
- * Copyright (C) 2020-2022, IrineSistiana
|
|
- *
|
|
- * This file is part of mosdns.
|
|
- *
|
|
- * mosdns is free software: you can redistribute it and/or modify
|
|
- * it under the terms of the GNU General Public License as published by
|
|
- * the Free Software Foundation, either version 3 of the License, or
|
|
- * (at your option) any later version.
|
|
- *
|
|
- * mosdns is distributed in the hope that it will be useful,
|
|
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
- * GNU General Public License for more details.
|
|
- *
|
|
- * You should have received a copy of the GNU General Public License
|
|
- * along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
- */
|
|
-
|
|
-package pool
|
|
-
|
|
-import (
|
|
- "fmt"
|
|
- "strconv"
|
|
- "testing"
|
|
-)
|
|
-
|
|
-func TestAllocator_Get(t *testing.T) {
|
|
- alloc := NewAllocator()
|
|
- tests := []struct {
|
|
- size int
|
|
- wantCap int
|
|
- wantPanic bool
|
|
- }{
|
|
- {-1, 0, true}, // invalid
|
|
- {0, 1, false},
|
|
- {1, 1, false},
|
|
- {2, 2, false},
|
|
- {12, 16, false},
|
|
- {256, 256, false},
|
|
- {257, 512, false},
|
|
- }
|
|
- for _, tt := range tests {
|
|
- t.Run(strconv.Itoa(tt.size), func(t *testing.T) {
|
|
- if tt.wantPanic {
|
|
- defer func() {
|
|
- msg := recover()
|
|
- if msg == nil {
|
|
- t.Error("no panic")
|
|
- }
|
|
- }()
|
|
- }
|
|
-
|
|
- for i := 0; i < 5; i++ {
|
|
- b := alloc.Get(tt.size)
|
|
- if len(b) != tt.size {
|
|
- t.Fatalf("buffer size, want %d, got %d", tt.size, len(b))
|
|
- }
|
|
- if cap(b) != tt.wantCap {
|
|
- t.Fatalf("buffer cap, want %d, got %d", tt.wantCap, cap(b))
|
|
- }
|
|
- alloc.Release(b)
|
|
- }
|
|
- })
|
|
- }
|
|
-}
|
|
-
|
|
-func Test_shard(t *testing.T) {
|
|
- tests := []struct {
|
|
- size int
|
|
- want int
|
|
- }{
|
|
- {-1, 0},
|
|
- {0, 0},
|
|
- {1, 0},
|
|
- {2, 1},
|
|
- {3, 2},
|
|
- {4, 2},
|
|
- {5, 3},
|
|
- {8, 3},
|
|
- {1023, 10},
|
|
- {1024, 10},
|
|
- {1025, 11},
|
|
- }
|
|
- for _, tt := range tests {
|
|
- t.Run(strconv.Itoa(tt.size), func(t *testing.T) {
|
|
- if got := shard(tt.size); got != tt.want {
|
|
- t.Errorf("shard() = %v, want %v", got, tt.want)
|
|
- }
|
|
- })
|
|
- }
|
|
-}
|
|
-
|
|
-func Benchmark_Allocator(b *testing.B) {
|
|
- allocator := NewAllocator()
|
|
-
|
|
- for l := 0; l <= 16; l += 4 {
|
|
- bufLen := 1 << l
|
|
- b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) {
|
|
- b.ReportAllocs()
|
|
- for i := 0; i < b.N; i++ {
|
|
- buf := allocator.Get(bufLen)
|
|
- allocator.Release(buf)
|
|
- }
|
|
- })
|
|
- }
|
|
-}
|
|
-
|
|
-func Benchmark_MakeByteSlice(b *testing.B) {
|
|
- for l := 0; l <= 8; l++ {
|
|
- bufLen := 1 << l
|
|
- b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) {
|
|
- b.ReportAllocs()
|
|
- for i := 0; i < b.N; i++ {
|
|
- _ = make([]byte, bufLen)
|
|
- }
|
|
- })
|
|
- }
|
|
-}
|
|
diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go
|
|
index 83eea80..980a08b 100644
|
|
--- a/pkg/pool/msg_buf.go
|
|
+++ b/pkg/pool/msg_buf.go
|
|
@@ -31,9 +31,9 @@ const packBufSize = 4096
|
|
// PackBuffer packs the dns msg m to wire format.
|
|
// Callers should release the buf by calling ReleaseBuf after they have done
|
|
// with the wire []byte.
|
|
-func PackBuffer(m *dns.Msg) (wire, buf []byte, err error) {
|
|
+func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) {
|
|
buf = GetBuf(packBufSize)
|
|
- wire, err = m.PackBuffer(buf)
|
|
+ wire, err = m.PackBuffer(*buf)
|
|
if err != nil {
|
|
ReleaseBuf(buf)
|
|
return nil, nil, err
|
|
diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go
|
|
index 0685864..02d7348 100644
|
|
--- a/pkg/pool/msg_buf_test.go
|
|
+++ b/pkg/pool/msg_buf_test.go
|
|
@@ -32,7 +32,7 @@ func TestPackBuffer_No_Allocation(t *testing.T) {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
- if cap(wire) != cap(buf) {
|
|
- t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(buf))
|
|
+ if cap(wire) != cap(*buf) {
|
|
+ t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf))
|
|
}
|
|
}
|
|
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
|
|
index 8980a08..c1c9aa9 100644
|
|
--- a/pkg/server/udp.go
|
|
+++ b/pkg/server/udp.go
|
|
@@ -73,15 +73,15 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error {
|
|
}
|
|
|
|
for {
|
|
- n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob)
|
|
+ n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
|
|
if err != nil {
|
|
return fmt.Errorf("unexpected read err: %w", err)
|
|
}
|
|
clientAddr := remoteAddr.Addr()
|
|
|
|
q := new(dns.Msg)
|
|
- if err := q.Unpack(rb[:n]); err != nil {
|
|
- s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr))
|
|
+ if err := q.Unpack((*rb)[:n]); err != nil {
|
|
+ s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr))
|
|
continue
|
|
}
|
|
|
|
diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go
|
|
index 8c49131..c677797 100644
|
|
--- a/pkg/upstream/transport/dns_conn_test.go
|
|
+++ b/pkg/upstream/transport/dns_conn_test.go
|
|
@@ -40,13 +40,13 @@ var (
|
|
c1, c2 := net.Pipe()
|
|
go func() {
|
|
for {
|
|
- m, _, readErr := dnsutils.ReadRawMsgFromTCP(c2)
|
|
+ m, readErr := dnsutils.ReadRawMsgFromTCP(c2)
|
|
if m != nil {
|
|
go func() {
|
|
defer pool.ReleaseBuf(m)
|
|
latency := time.Millisecond * time.Duration(rand.Intn(20))
|
|
time.Sleep(latency)
|
|
- _, _ = dnsutils.WriteRawMsgToTCP(c2, m)
|
|
+ _, _ = dnsutils.WriteRawMsgToTCP(c2, *m)
|
|
}()
|
|
}
|
|
if readErr != nil {
|
|
diff --git a/plugin/executable/cache/cache.go b/plugin/executable/cache/cache.go
|
|
index f67d740..4091b50 100644
|
|
--- a/plugin/executable/cache/cache.go
|
|
+++ b/plugin/executable/cache/cache.go
|
|
@@ -420,27 +420,27 @@ func (c *Cache) readDump(r io.Reader) (int, error) {
|
|
readBlock := func() error {
|
|
h := pool.GetBuf(8)
|
|
defer pool.ReleaseBuf(h)
|
|
- _, err := io.ReadFull(gr, h)
|
|
+ _, err := io.ReadFull(gr, *h)
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return errReadHeaderEOF
|
|
}
|
|
return fmt.Errorf("failed to read block header, %w", err)
|
|
}
|
|
- u := binary.BigEndian.Uint64(h)
|
|
+ u := binary.BigEndian.Uint64(*h)
|
|
if u > dumpMaximumBlockLength {
|
|
return fmt.Errorf("invalid header, block length is big, %d", u)
|
|
}
|
|
|
|
b := pool.GetBuf(int(u))
|
|
defer pool.ReleaseBuf(b)
|
|
- _, err = io.ReadFull(gr, b)
|
|
+ _, err = io.ReadFull(gr, *b)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read block data, %w", err)
|
|
}
|
|
|
|
block := new(CacheDumpBlock)
|
|
- if err := proto.Unmarshal(b, block); err != nil {
|
|
+ if err := proto.Unmarshal(*b, block); err != nil {
|
|
return fmt.Errorf("failed to decode block data, %w", err)
|
|
}
|
|
|
|
--
|
|
2.34.8
|
|
|