luci-app-mosdns/mosdns/patches/123-add-plugin-rate_limiter.patch
2023-09-23 15:17:07 +08:00

298 lines
8.3 KiB
Diff

From 11436dd9cde412f83d1bfbd06b4163445c52bb12 Mon Sep 17 00:00:00 2001
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
Date: Fri, 22 Sep 2023 20:55:49 +0800
Subject: [PATCH 7/9] add plugin rate_limiter
---
go.mod | 1 +
go.sum | 2 +
pkg/rate_limiter/rate_limiter.go | 145 ++++++++++++++++++
plugin/enabled_plugins.go | 1 +
.../executable/rate_limiter/rate_limiter.go | 85 ++++++++++
5 files changed, 234 insertions(+)
create mode 100644 pkg/rate_limiter/rate_limiter.go
create mode 100644 plugin/executable/rate_limiter/rate_limiter.go
diff --git a/go.mod b/go.mod
index 7c2b96a..aea0c99 100644
--- a/go.mod
+++ b/go.mod
@@ -63,6 +63,7 @@ require (
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
+ golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.13.0 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
diff --git a/go.sum b/go.sum
index dd20043..d2b393f 100644
--- a/go.sum
+++ b/go.sum
@@ -413,6 +413,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
diff --git a/pkg/rate_limiter/rate_limiter.go b/pkg/rate_limiter/rate_limiter.go
new file mode 100644
index 0000000..30fa516
--- /dev/null
+++ b/pkg/rate_limiter/rate_limiter.go
@@ -0,0 +1,145 @@
+package rate_limiter
+
+import (
+ "io"
+ "net/netip"
+ "sync"
+ "time"
+
+ "golang.org/x/time/rate"
+)
+
+type RateLimiter interface {
+ Allow(addr netip.Addr) bool
+ io.Closer
+}
+
+type limiter struct {
+ limit rate.Limit
+ burst int
+ mask4 int
+ mask6 int
+
+ closeOnce sync.Once
+ closeNotify chan struct{}
+ m sync.Mutex
+ tables map[netip.Addr]*limiterEntry
+}
+
+type limiterEntry struct {
+ l *rate.Limiter
+ lastSeen time.Time
+ sync.Once
+}
+
+// limit and burst should be greater than zero.
+// If gcInterval is <= 0, it will be automatically chosen between 2~10s.
+// In this case, if the token refill time (burst/limit) is greater than 10s,
+// the actual average qps limit may be higher than expected.
+// If mask is zero or greater than 32/128. The default is 32/48.
+// If mask is negative, the masks will be 0.
+func NewRateLimiter(limit rate.Limit, burst int, gcInterval time.Duration, mask4, mask6 int) RateLimiter {
+ if mask4 > 32 || mask4 == 0 {
+ mask4 = 32
+ }
+ if mask4 < 0 {
+ mask4 = 0
+ }
+
+ if mask6 > 128 || mask6 == 0 {
+ mask6 = 48
+ }
+ if mask6 < 0 {
+ mask6 = 0
+ }
+
+ if gcInterval <= 0 {
+ if limit <= 0 || burst <= 0 {
+ gcInterval = time.Second * 2
+ } else {
+ refillSec := float64(burst) / float64(limit)
+ if refillSec < 2 {
+ refillSec = 2
+ }
+ if refillSec > 10 {
+ refillSec = 10
+ }
+ gcInterval = time.Duration(refillSec) * time.Second
+ }
+ }
+
+ l := &limiter{
+ limit: limit,
+ burst: burst,
+ mask4: mask4,
+ mask6: mask6,
+ closeNotify: make(chan struct{}),
+ tables: make(map[netip.Addr]*limiterEntry),
+ }
+ go l.gcLoop(gcInterval)
+ return l
+}
+
+func (l *limiter) Allow(a netip.Addr) bool {
+ a = l.applyMask(a)
+ now := time.Now()
+ l.m.Lock()
+ e, ok := l.tables[a]
+ if !ok {
+ e = &limiterEntry{
+ l: rate.NewLimiter(l.limit, l.burst),
+ lastSeen: now,
+ }
+ l.tables[a] = e
+ }
+ e.lastSeen = now
+ clientLimiter := e.l
+ l.m.Unlock()
+ return clientLimiter.AllowN(now, 1)
+}
+
+func (l *limiter) Close() error {
+ l.closeOnce.Do(func() {
+ close(l.closeNotify)
+ })
+ return nil
+}
+
+func (l *limiter) gcLoop(gcInterval time.Duration) {
+ ticker := time.NewTicker(gcInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-l.closeNotify:
+ return
+ case now := <-ticker.C:
+ l.doGc(now, gcInterval)
+ }
+ }
+}
+
+func (l *limiter) doGc(now time.Time, gcInterval time.Duration) {
+ l.m.Lock()
+ defer l.m.Unlock()
+
+ for a, e := range l.tables {
+ if now.Sub(e.lastSeen) > gcInterval {
+ delete(l.tables, a)
+ }
+ }
+}
+
+func (l *limiter) applyMask(a netip.Addr) netip.Addr {
+ switch {
+ case a.Is4():
+ m, _ := a.Prefix(l.mask4)
+ return m.Addr()
+ case a.Is4In6():
+ m, _ := netip.AddrFrom4(a.As4()).Prefix(l.mask4)
+ return m.Addr()
+ default:
+ m, _ := a.Prefix(l.mask6)
+ return m.Addr()
+ }
+}
diff --git a/plugin/enabled_plugins.go b/plugin/enabled_plugins.go
index dfb311b..d72ed07 100644
--- a/plugin/enabled_plugins.go
+++ b/plugin/enabled_plugins.go
@@ -54,6 +54,7 @@ import (
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/metrics_collector"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/nftset"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/query_summary"
+ _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/rate_limiter"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/redirect"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/reverse_lookup"
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
diff --git a/plugin/executable/rate_limiter/rate_limiter.go b/plugin/executable/rate_limiter/rate_limiter.go
new file mode 100644
index 0000000..241f947
--- /dev/null
+++ b/plugin/executable/rate_limiter/rate_limiter.go
@@ -0,0 +1,85 @@
+/*
+ * Copyright (C) 2020-2022, IrineSistiana
+ *
+ * This file is part of mosdns.
+ *
+ * mosdns is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * mosdns is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
+ */
+
+package rate_limiter
+
+import (
+ "context"
+
+ "github.com/IrineSistiana/mosdns/v5/coremain"
+ "github.com/IrineSistiana/mosdns/v5/pkg/query_context"
+ "github.com/IrineSistiana/mosdns/v5/pkg/rate_limiter"
+ "github.com/IrineSistiana/mosdns/v5/pkg/utils"
+ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
+ "github.com/miekg/dns"
+ "golang.org/x/time/rate"
+)
+
+const PluginType = "rate_limiter"
+
+func init() {
+ coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
+}
+
+type Args struct {
+ Qps float64 `yaml:"qps"`
+ Burst int `yaml:"burst"`
+ Mask4 int `yaml:"mask4"`
+ Mask6 int `yaml:"mask6"`
+}
+
+func (args *Args) init() {
+ utils.SetDefaultUnsignNum(&args.Qps, 20)
+ utils.SetDefaultUnsignNum(&args.Burst, 40)
+ utils.SetDefaultUnsignNum(&args.Mask4, 32)
+ utils.SetDefaultUnsignNum(&args.Mask4, 48)
+}
+
+var _ sequence.Executable = (*RateLimiter)(nil)
+
+type RateLimiter struct {
+ l rate_limiter.RateLimiter
+}
+
+func Init(_ *coremain.BP, args any) (any, error) {
+ return New(*(args.(*Args))), nil
+}
+
+func New(args Args) *RateLimiter {
+ args.init()
+ l := rate_limiter.NewRateLimiter(rate.Limit(args.Qps), args.Burst, 0, args.Mask4, args.Mask6)
+ return &RateLimiter{l: l}
+}
+
+func (s *RateLimiter) Exec(ctx context.Context, qCtx *query_context.Context) error {
+ clientAddr := qCtx.QueryMeta().ClientAddr
+ if clientAddr.IsValid() {
+ if !s.l.Allow(clientAddr) {
+ qCtx.SetResponse(refuse(qCtx.Q()))
+ }
+ }
+ return nil
+}
+
+func refuse(q *dns.Msg) *dns.Msg {
+ r := new(dns.Msg)
+ r.SetReply(q)
+ r.Rcode = dns.RcodeRefused
+ return r
+}
--
2.34.8