298 lines
8.3 KiB
Diff
298 lines
8.3 KiB
Diff
From 11436dd9cde412f83d1bfbd06b4163445c52bb12 Mon Sep 17 00:00:00 2001
|
|
From: Irine Sistiana <49315432+IrineSistiana@users.noreply.github.com>
|
|
Date: Fri, 22 Sep 2023 20:55:49 +0800
|
|
Subject: [PATCH 7/9] add plugin rate_limiter
|
|
|
|
---
|
|
go.mod | 1 +
|
|
go.sum | 2 +
|
|
pkg/rate_limiter/rate_limiter.go | 145 ++++++++++++++++++
|
|
plugin/enabled_plugins.go | 1 +
|
|
.../executable/rate_limiter/rate_limiter.go | 85 ++++++++++
|
|
5 files changed, 234 insertions(+)
|
|
create mode 100644 pkg/rate_limiter/rate_limiter.go
|
|
create mode 100644 plugin/executable/rate_limiter/rate_limiter.go
|
|
|
|
diff --git a/go.mod b/go.mod
|
|
index 7c2b96a..aea0c99 100644
|
|
--- a/go.mod
|
|
+++ b/go.mod
|
|
@@ -63,6 +63,7 @@ require (
|
|
golang.org/x/crypto v0.13.0 // indirect
|
|
golang.org/x/mod v0.12.0 // indirect
|
|
golang.org/x/text v0.13.0 // indirect
|
|
+ golang.org/x/time v0.3.0 // indirect
|
|
golang.org/x/tools v0.13.0 // indirect
|
|
gopkg.in/ini.v1 v1.67.0 // indirect
|
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
|
diff --git a/go.sum b/go.sum
|
|
index dd20043..d2b393f 100644
|
|
--- a/go.sum
|
|
+++ b/go.sum
|
|
@@ -413,6 +413,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
|
+golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
|
+golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
|
diff --git a/pkg/rate_limiter/rate_limiter.go b/pkg/rate_limiter/rate_limiter.go
|
|
new file mode 100644
|
|
index 0000000..30fa516
|
|
--- /dev/null
|
|
+++ b/pkg/rate_limiter/rate_limiter.go
|
|
@@ -0,0 +1,145 @@
|
|
+package rate_limiter
|
|
+
|
|
+import (
|
|
+ "io"
|
|
+ "net/netip"
|
|
+ "sync"
|
|
+ "time"
|
|
+
|
|
+ "golang.org/x/time/rate"
|
|
+)
|
|
+
|
|
+type RateLimiter interface {
|
|
+ Allow(addr netip.Addr) bool
|
|
+ io.Closer
|
|
+}
|
|
+
|
|
+type limiter struct {
|
|
+ limit rate.Limit
|
|
+ burst int
|
|
+ mask4 int
|
|
+ mask6 int
|
|
+
|
|
+ closeOnce sync.Once
|
|
+ closeNotify chan struct{}
|
|
+ m sync.Mutex
|
|
+ tables map[netip.Addr]*limiterEntry
|
|
+}
|
|
+
|
|
+type limiterEntry struct {
|
|
+ l *rate.Limiter
|
|
+ lastSeen time.Time
|
|
+ sync.Once
|
|
+}
|
|
+
|
|
+// limit and burst should be greater than zero.
|
|
+// If gcInterval is <= 0, it will be automatically chosen between 2~10s.
|
|
+// In this case, if the token refill time (burst/limit) is greater than 10s,
|
|
+// the actual average qps limit may be higher than expected.
|
|
+// If mask is zero or greater than 32/128. The default is 32/48.
|
|
+// If mask is negative, the masks will be 0.
|
|
+func NewRateLimiter(limit rate.Limit, burst int, gcInterval time.Duration, mask4, mask6 int) RateLimiter {
|
|
+ if mask4 > 32 || mask4 == 0 {
|
|
+ mask4 = 32
|
|
+ }
|
|
+ if mask4 < 0 {
|
|
+ mask4 = 0
|
|
+ }
|
|
+
|
|
+ if mask6 > 128 || mask6 == 0 {
|
|
+ mask6 = 48
|
|
+ }
|
|
+ if mask6 < 0 {
|
|
+ mask6 = 0
|
|
+ }
|
|
+
|
|
+ if gcInterval <= 0 {
|
|
+ if limit <= 0 || burst <= 0 {
|
|
+ gcInterval = time.Second * 2
|
|
+ } else {
|
|
+ refillSec := float64(burst) / float64(limit)
|
|
+ if refillSec < 2 {
|
|
+ refillSec = 2
|
|
+ }
|
|
+ if refillSec > 10 {
|
|
+ refillSec = 10
|
|
+ }
|
|
+ gcInterval = time.Duration(refillSec) * time.Second
|
|
+ }
|
|
+ }
|
|
+
|
|
+ l := &limiter{
|
|
+ limit: limit,
|
|
+ burst: burst,
|
|
+ mask4: mask4,
|
|
+ mask6: mask6,
|
|
+ closeNotify: make(chan struct{}),
|
|
+ tables: make(map[netip.Addr]*limiterEntry),
|
|
+ }
|
|
+ go l.gcLoop(gcInterval)
|
|
+ return l
|
|
+}
|
|
+
|
|
+func (l *limiter) Allow(a netip.Addr) bool {
|
|
+ a = l.applyMask(a)
|
|
+ now := time.Now()
|
|
+ l.m.Lock()
|
|
+ e, ok := l.tables[a]
|
|
+ if !ok {
|
|
+ e = &limiterEntry{
|
|
+ l: rate.NewLimiter(l.limit, l.burst),
|
|
+ lastSeen: now,
|
|
+ }
|
|
+ l.tables[a] = e
|
|
+ }
|
|
+ e.lastSeen = now
|
|
+ clientLimiter := e.l
|
|
+ l.m.Unlock()
|
|
+ return clientLimiter.AllowN(now, 1)
|
|
+}
|
|
+
|
|
+func (l *limiter) Close() error {
|
|
+ l.closeOnce.Do(func() {
|
|
+ close(l.closeNotify)
|
|
+ })
|
|
+ return nil
|
|
+}
|
|
+
|
|
+func (l *limiter) gcLoop(gcInterval time.Duration) {
|
|
+ ticker := time.NewTicker(gcInterval)
|
|
+ defer ticker.Stop()
|
|
+
|
|
+ for {
|
|
+ select {
|
|
+ case <-l.closeNotify:
|
|
+ return
|
|
+ case now := <-ticker.C:
|
|
+ l.doGc(now, gcInterval)
|
|
+ }
|
|
+ }
|
|
+}
|
|
+
|
|
+func (l *limiter) doGc(now time.Time, gcInterval time.Duration) {
|
|
+ l.m.Lock()
|
|
+ defer l.m.Unlock()
|
|
+
|
|
+ for a, e := range l.tables {
|
|
+ if now.Sub(e.lastSeen) > gcInterval {
|
|
+ delete(l.tables, a)
|
|
+ }
|
|
+ }
|
|
+}
|
|
+
|
|
+func (l *limiter) applyMask(a netip.Addr) netip.Addr {
|
|
+ switch {
|
|
+ case a.Is4():
|
|
+ m, _ := a.Prefix(l.mask4)
|
|
+ return m.Addr()
|
|
+ case a.Is4In6():
|
|
+ m, _ := netip.AddrFrom4(a.As4()).Prefix(l.mask4)
|
|
+ return m.Addr()
|
|
+ default:
|
|
+ m, _ := a.Prefix(l.mask6)
|
|
+ return m.Addr()
|
|
+ }
|
|
+}
|
|
diff --git a/plugin/enabled_plugins.go b/plugin/enabled_plugins.go
|
|
index dfb311b..d72ed07 100644
|
|
--- a/plugin/enabled_plugins.go
|
|
+++ b/plugin/enabled_plugins.go
|
|
@@ -54,6 +54,7 @@ import (
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/metrics_collector"
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/nftset"
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/query_summary"
|
|
+ _ "github.com/IrineSistiana/mosdns/v5/plugin/executable/rate_limiter"
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/redirect"
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/reverse_lookup"
|
|
_ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
|
|
diff --git a/plugin/executable/rate_limiter/rate_limiter.go b/plugin/executable/rate_limiter/rate_limiter.go
|
|
new file mode 100644
|
|
index 0000000..241f947
|
|
--- /dev/null
|
|
+++ b/plugin/executable/rate_limiter/rate_limiter.go
|
|
@@ -0,0 +1,85 @@
|
|
+/*
|
|
+ * Copyright (C) 2020-2022, IrineSistiana
|
|
+ *
|
|
+ * This file is part of mosdns.
|
|
+ *
|
|
+ * mosdns is free software: you can redistribute it and/or modify
|
|
+ * it under the terms of the GNU General Public License as published by
|
|
+ * the Free Software Foundation, either version 3 of the License, or
|
|
+ * (at your option) any later version.
|
|
+ *
|
|
+ * mosdns is distributed in the hope that it will be useful,
|
|
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
+ * GNU General Public License for more details.
|
|
+ *
|
|
+ * You should have received a copy of the GNU General Public License
|
|
+ * along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
+ */
|
|
+
|
|
+package rate_limiter
|
|
+
|
|
+import (
|
|
+ "context"
|
|
+
|
|
+ "github.com/IrineSistiana/mosdns/v5/coremain"
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/query_context"
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/rate_limiter"
|
|
+ "github.com/IrineSistiana/mosdns/v5/pkg/utils"
|
|
+ "github.com/IrineSistiana/mosdns/v5/plugin/executable/sequence"
|
|
+ "github.com/miekg/dns"
|
|
+ "golang.org/x/time/rate"
|
|
+)
|
|
+
|
|
+const PluginType = "rate_limiter"
|
|
+
|
|
+func init() {
|
|
+ coremain.RegNewPluginFunc(PluginType, Init, func() any { return new(Args) })
|
|
+}
|
|
+
|
|
+type Args struct {
|
|
+ Qps float64 `yaml:"qps"`
|
|
+ Burst int `yaml:"burst"`
|
|
+ Mask4 int `yaml:"mask4"`
|
|
+ Mask6 int `yaml:"mask6"`
|
|
+}
|
|
+
|
|
+func (args *Args) init() {
|
|
+ utils.SetDefaultUnsignNum(&args.Qps, 20)
|
|
+ utils.SetDefaultUnsignNum(&args.Burst, 40)
|
|
+ utils.SetDefaultUnsignNum(&args.Mask4, 32)
|
|
+ utils.SetDefaultUnsignNum(&args.Mask4, 48)
|
|
+}
|
|
+
|
|
+var _ sequence.Executable = (*RateLimiter)(nil)
|
|
+
|
|
+type RateLimiter struct {
|
|
+ l rate_limiter.RateLimiter
|
|
+}
|
|
+
|
|
+func Init(_ *coremain.BP, args any) (any, error) {
|
|
+ return New(*(args.(*Args))), nil
|
|
+}
|
|
+
|
|
+func New(args Args) *RateLimiter {
|
|
+ args.init()
|
|
+ l := rate_limiter.NewRateLimiter(rate.Limit(args.Qps), args.Burst, 0, args.Mask4, args.Mask6)
|
|
+ return &RateLimiter{l: l}
|
|
+}
|
|
+
|
|
+func (s *RateLimiter) Exec(ctx context.Context, qCtx *query_context.Context) error {
|
|
+ clientAddr := qCtx.QueryMeta().ClientAddr
|
|
+ if clientAddr.IsValid() {
|
|
+ if !s.l.Allow(clientAddr) {
|
|
+ qCtx.SetResponse(refuse(qCtx.Q()))
|
|
+ }
|
|
+ }
|
|
+ return nil
|
|
+}
|
|
+
|
|
+func refuse(q *dns.Msg) *dns.Msg {
|
|
+ r := new(dns.Msg)
|
|
+ r.SetReply(q)
|
|
+ r.Rcode = dns.RcodeRefused
|
|
+ return r
|
|
+}
|
|
--
|
|
2.34.8
|
|
|