luci-app-mosdns/mosdns/patches/103-use-go-bytes-pool.patch
2023-09-18 14:59:04 +08:00

476 lines
13 KiB
Diff

From 9bb813bf50ba0eed3a35acc8f41764613c11687d Mon Sep 17 00:00:00 2001
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
Date: Sun, 17 Sep 2023 08:29:31 +0800
Subject: [PATCH 04/10] use go-bytes-pool
---
go.mod | 1 +
go.sum | 2 +
pkg/dnsutils/net_io.go | 45 +++++----
pkg/pool/allocator.go | 78 +---------------
pkg/pool/allocator_test.go | 119 ------------------------
pkg/pool/msg_buf.go | 4 +-
pkg/pool/msg_buf_test.go | 4 +-
pkg/server/udp.go | 6 +-
pkg/upstream/transport/dns_conn_test.go | 4 +-
plugin/executable/cache/cache.go | 8 +-
10 files changed, 43 insertions(+), 228 deletions(-)
delete mode 100644 pkg/pool/allocator_test.go
diff --git a/go.mod b/go.mod
index b56078b..2c359e3 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.21
toolchain go1.21.1
require (
+ github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686
github.com/go-chi/chi/v5 v5.0.10
github.com/google/nftables v0.1.0
github.com/kardianos/service v1.2.2
diff --git a/go.sum b/go.sum
index 633f864..d4bbaa2 100644
--- a/go.sum
+++ b/go.sum
@@ -38,6 +38,8 @@ cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3f
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686 h1:5R32cCep3VUDTKf3aurFKfgbvg+RScuBmZsw/DyyXco=
+github.com/IrineSistiana/go-bytes-pool v0.0.0-20230419012903-2f1f26674686/go.mod h1:pQ/FSsWSNYmNdgIKmulKlmVC/R2PEpq2vIEi3J9IijI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a h1:GQdh/h0q0ni3L//CXusyk+7QdhBL289vdNaes1WKkHI=
github.com/IrineSistiana/ipset v0.5.1-0.20220703061533-6e0fc3b04c0a/go.mod h1:rYF5DQLRGGoQ8ZSWeK+6eX5amAuPqwFkWjhQlEITGJQ=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
diff --git a/pkg/dnsutils/net_io.go b/pkg/dnsutils/net_io.go
index 1b8a7b1..8fc769b 100644
--- a/pkg/dnsutils/net_io.go
+++ b/pkg/dnsutils/net_io.go
@@ -23,9 +23,10 @@ import (
"encoding/binary"
"errors"
"fmt"
+ "io"
+
"github.com/IrineSistiana/mosdns/v5/pkg/pool"
"github.com/miekg/dns"
- "io"
)
var (
@@ -35,45 +36,43 @@ var (
// ReadRawMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
-// The returned the []byte should be released by pool.ReleaseBuf.
-func ReadRawMsgFromTCP(c io.Reader) ([]byte, int, error) {
- n := 0
+// The returned the *[]byte should be released by pool.ReleaseBuf.
+func ReadRawMsgFromTCP(c io.Reader) (*[]byte, error) {
h := pool.GetBuf(2)
defer pool.ReleaseBuf(h)
- nh, err := io.ReadFull(c, h)
- n += nh
+ _, err := io.ReadFull(c, *h)
+
if err != nil {
- return nil, n, err
+ return nil, err
}
// dns length
- length := binary.BigEndian.Uint16(h)
+ length := binary.BigEndian.Uint16(*h)
if length == 0 {
- return nil, 0, errZeroLenMsg
+ return nil, errZeroLenMsg
}
- buf := pool.GetBuf(int(length))
- nm, err := io.ReadFull(c, buf)
- n += nm
+ b := pool.GetBuf(int(length))
+ _, err = io.ReadFull(c, *b)
if err != nil {
- pool.ReleaseBuf(buf)
- return nil, n, err
+ pool.ReleaseBuf(b)
+ return nil, err
}
- return buf, n, nil
+ return b, nil
}
// ReadMsgFromTCP reads msg from c in RFC 1035 format (msg is prefixed
// with a two byte length field).
// n represents how many bytes are read from c.
func ReadMsgFromTCP(c io.Reader) (*dns.Msg, int, error) {
- b, n, err := ReadRawMsgFromTCP(c)
+ b, err := ReadRawMsgFromTCP(c)
if err != nil {
return nil, 0, err
}
defer pool.ReleaseBuf(b)
- m, err := unpackMsgWithDetailedErr(b)
- return m, n, err
+ m, err := unpackMsgWithDetailedErr(*b)
+ return m, len(*b) + 2, err
}
// WriteMsgToTCP packs and writes m to c in RFC 1035 format.
@@ -96,9 +95,9 @@ func WriteRawMsgToTCP(c io.Writer, b []byte) (n int, err error) {
buf := pool.GetBuf(len(b) + 2)
defer pool.ReleaseBuf(buf)
- binary.BigEndian.PutUint16(buf[:2], uint16(len(b)))
- copy(buf[2:], b)
- return c.Write(buf)
+ binary.BigEndian.PutUint16((*buf)[:2], uint16(len(b)))
+ copy((*buf)[2:], b)
+ return c.Write((*buf))
}
func WriteMsgToUDP(c io.Writer, m *dns.Msg) (int, error) {
@@ -118,12 +117,12 @@ func ReadMsgFromUDP(c io.Reader, bufSize int) (*dns.Msg, int, error) {
b := pool.GetBuf(bufSize)
defer pool.ReleaseBuf(b)
- n, err := c.Read(b)
+ n, err := c.Read(*b)
if err != nil {
return nil, n, err
}
- m, err := unpackMsgWithDetailedErr(b[:n])
+ m, err := unpackMsgWithDetailedErr((*b)[:n])
return m, n, err
}
diff --git a/pkg/pool/allocator.go b/pkg/pool/allocator.go
index 0bfc0bc..d8ac4ef 100644
--- a/pkg/pool/allocator.go
+++ b/pkg/pool/allocator.go
@@ -20,79 +20,11 @@
package pool
import (
- "fmt"
- "math"
- "math/bits"
- "sync"
+ bytesPool "github.com/IrineSistiana/go-bytes-pool"
)
// defaultBufPool is an Allocator that has a maximum capacity.
-var defaultBufPool = NewAllocator()
-
-// GetBuf returns a []byte from pool with most appropriate cap.
-// It panics if size < 0.
-func GetBuf(size int) []byte {
- return defaultBufPool.Get(size)
-}
-
-// ReleaseBuf puts the buf to the pool.
-func ReleaseBuf(b []byte) {
- defaultBufPool.Release(b)
-}
-
-type Allocator struct {
- buffers []sync.Pool
-}
-
-// NewAllocator initiates a []byte Allocator.
-// The waste(memory fragmentation) of space allocation is guaranteed to be
-// no more than 50%.
-func NewAllocator() *Allocator {
- alloc := &Allocator{
- buffers: make([]sync.Pool, bits.UintSize+1),
- }
-
- for i := range alloc.buffers {
- var bufSize uint
- if i == bits.UintSize {
- bufSize = math.MaxUint
- } else {
- bufSize = 1 << i
- }
- alloc.buffers[i].New = func() any {
- b := make([]byte, bufSize)
- return &b
- }
- }
- return alloc
-}
-
-// Get returns a []byte from pool with most appropriate cap
-func (alloc *Allocator) Get(size int) []byte {
- if size < 0 {
- panic(fmt.Sprintf("invalid slice size %d", size))
- }
-
- i := shard(size)
- v := alloc.buffers[i].Get()
- buf := v.(*[]byte)
- return (*buf)[0:size]
-}
-
-// Release releases the buf to the allocatorL.
-func (alloc *Allocator) Release(buf []byte) {
- c := cap(buf)
- i := shard(c)
- if c == 0 || c != 1<<i {
- panic("unexpected cap size")
- }
- alloc.buffers[i].Put(&buf)
-}
-
-// shard returns the shard index that is suitable for the size.
-func shard(size int) int {
- if size <= 1 {
- return 0
- }
- return bits.Len64(uint64(size - 1))
-}
+var (
+ GetBuf = bytesPool.Get
+ ReleaseBuf = bytesPool.Release
+)
diff --git a/pkg/pool/allocator_test.go b/pkg/pool/allocator_test.go
deleted file mode 100644
index 3f6d2a9..0000000
--- a/pkg/pool/allocator_test.go
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Copyright (C) 2020-2022, IrineSistiana
- *
- * This file is part of mosdns.
- *
- * mosdns is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * mosdns is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program. If not, see <https://www.gnu.org/licenses/>.
- */
-
-package pool
-
-import (
- "fmt"
- "strconv"
- "testing"
-)
-
-func TestAllocator_Get(t *testing.T) {
- alloc := NewAllocator()
- tests := []struct {
- size int
- wantCap int
- wantPanic bool
- }{
- {-1, 0, true}, // invalid
- {0, 1, false},
- {1, 1, false},
- {2, 2, false},
- {12, 16, false},
- {256, 256, false},
- {257, 512, false},
- }
- for _, tt := range tests {
- t.Run(strconv.Itoa(tt.size), func(t *testing.T) {
- if tt.wantPanic {
- defer func() {
- msg := recover()
- if msg == nil {
- t.Error("no panic")
- }
- }()
- }
-
- for i := 0; i < 5; i++ {
- b := alloc.Get(tt.size)
- if len(b) != tt.size {
- t.Fatalf("buffer size, want %d, got %d", tt.size, len(b))
- }
- if cap(b) != tt.wantCap {
- t.Fatalf("buffer cap, want %d, got %d", tt.wantCap, cap(b))
- }
- alloc.Release(b)
- }
- })
- }
-}
-
-func Test_shard(t *testing.T) {
- tests := []struct {
- size int
- want int
- }{
- {-1, 0},
- {0, 0},
- {1, 0},
- {2, 1},
- {3, 2},
- {4, 2},
- {5, 3},
- {8, 3},
- {1023, 10},
- {1024, 10},
- {1025, 11},
- }
- for _, tt := range tests {
- t.Run(strconv.Itoa(tt.size), func(t *testing.T) {
- if got := shard(tt.size); got != tt.want {
- t.Errorf("shard() = %v, want %v", got, tt.want)
- }
- })
- }
-}
-
-func Benchmark_Allocator(b *testing.B) {
- allocator := NewAllocator()
-
- for l := 0; l <= 16; l += 4 {
- bufLen := 1 << l
- b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) {
- b.ReportAllocs()
- for i := 0; i < b.N; i++ {
- buf := allocator.Get(bufLen)
- allocator.Release(buf)
- }
- })
- }
-}
-
-func Benchmark_MakeByteSlice(b *testing.B) {
- for l := 0; l <= 8; l++ {
- bufLen := 1 << l
- b.Run(fmt.Sprintf("length %d", bufLen), func(b *testing.B) {
- b.ReportAllocs()
- for i := 0; i < b.N; i++ {
- _ = make([]byte, bufLen)
- }
- })
- }
-}
diff --git a/pkg/pool/msg_buf.go b/pkg/pool/msg_buf.go
index 83eea80..980a08b 100644
--- a/pkg/pool/msg_buf.go
+++ b/pkg/pool/msg_buf.go
@@ -31,9 +31,9 @@ const packBufSize = 4096
// PackBuffer packs the dns msg m to wire format.
// Callers should release the buf by calling ReleaseBuf after they have done
// with the wire []byte.
-func PackBuffer(m *dns.Msg) (wire, buf []byte, err error) {
+func PackBuffer(m *dns.Msg) (wire []byte, buf *[]byte, err error) {
buf = GetBuf(packBufSize)
- wire, err = m.PackBuffer(buf)
+ wire, err = m.PackBuffer(*buf)
if err != nil {
ReleaseBuf(buf)
return nil, nil, err
diff --git a/pkg/pool/msg_buf_test.go b/pkg/pool/msg_buf_test.go
index 0685864..02d7348 100644
--- a/pkg/pool/msg_buf_test.go
+++ b/pkg/pool/msg_buf_test.go
@@ -32,7 +32,7 @@ func TestPackBuffer_No_Allocation(t *testing.T) {
t.Fatal(err)
}
- if cap(wire) != cap(buf) {
- t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(buf))
+ if cap(wire) != cap(*buf) {
+ t.Fatalf("wire and buf have different cap, wire %d, buf %d", cap(wire), cap(*buf))
}
}
diff --git a/pkg/server/udp.go b/pkg/server/udp.go
index 8980a08..c1c9aa9 100644
--- a/pkg/server/udp.go
+++ b/pkg/server/udp.go
@@ -73,15 +73,15 @@ func (s *UDPServer) ServeUDP(c *net.UDPConn) error {
}
for {
- n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(rb, ob)
+ n, oobn, _, remoteAddr, err := c.ReadMsgUDPAddrPort(*rb, ob)
if err != nil {
return fmt.Errorf("unexpected read err: %w", err)
}
clientAddr := remoteAddr.Addr()
q := new(dns.Msg)
- if err := q.Unpack(rb[:n]); err != nil {
- s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", rb[:n]), zap.Stringer("from", remoteAddr))
+ if err := q.Unpack((*rb)[:n]); err != nil {
+ s.opts.Logger.Warn("invalid msg", zap.Error(err), zap.Binary("msg", (*rb)[:n]), zap.Stringer("from", remoteAddr))
continue
}
diff --git a/pkg/upstream/transport/dns_conn_test.go b/pkg/upstream/transport/dns_conn_test.go
index 8c49131..c677797 100644
--- a/pkg/upstream/transport/dns_conn_test.go
+++ b/pkg/upstream/transport/dns_conn_test.go
@@ -40,13 +40,13 @@ var (
c1, c2 := net.Pipe()
go func() {
for {
- m, _, readErr := dnsutils.ReadRawMsgFromTCP(c2)
+ m, readErr := dnsutils.ReadRawMsgFromTCP(c2)
if m != nil {
go func() {
defer pool.ReleaseBuf(m)
latency := time.Millisecond * time.Duration(rand.Intn(20))
time.Sleep(latency)
- _, _ = dnsutils.WriteRawMsgToTCP(c2, m)
+ _, _ = dnsutils.WriteRawMsgToTCP(c2, *m)
}()
}
if readErr != nil {
diff --git a/plugin/executable/cache/cache.go b/plugin/executable/cache/cache.go
index f67d740..4091b50 100644
--- a/plugin/executable/cache/cache.go
+++ b/plugin/executable/cache/cache.go
@@ -420,27 +420,27 @@ func (c *Cache) readDump(r io.Reader) (int, error) {
readBlock := func() error {
h := pool.GetBuf(8)
defer pool.ReleaseBuf(h)
- _, err := io.ReadFull(gr, h)
+ _, err := io.ReadFull(gr, *h)
if err != nil {
if errors.Is(err, io.EOF) {
return errReadHeaderEOF
}
return fmt.Errorf("failed to read block header, %w", err)
}
- u := binary.BigEndian.Uint64(h)
+ u := binary.BigEndian.Uint64(*h)
if u > dumpMaximumBlockLength {
return fmt.Errorf("invalid header, block length is big, %d", u)
}
b := pool.GetBuf(int(u))
defer pool.ReleaseBuf(b)
- _, err = io.ReadFull(gr, b)
+ _, err = io.ReadFull(gr, *b)
if err != nil {
return fmt.Errorf("failed to read block data, %w", err)
}
block := new(CacheDumpBlock)
- if err := proto.Unmarshal(b, block); err != nil {
+ if err := proto.Unmarshal(*b, block); err != nil {
return fmt.Errorf("failed to decode block data, %w", err)
}
--
2.34.8