Skip to content

Commit

Permalink
Merge pull request #215 from Dash-Industry-Forum/feat/whitelist-ips
Browse files Browse the repository at this point in the history
feat: whitelistblocks option
  • Loading branch information
tobbee authored Oct 1, 2024
2 parents 70363a7 + e6b3b5a commit a1daf85
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Short HEVC + AC-3 test content
- Generation of CMAF ingest streams with a REST-based API
- New program `cmaf-ingest-receiver` that can receive one or more CMAF ingest streams
- New option `--whitelistblocks` for unlimited number of requests to some CIDR blocks

### Fixed

Expand Down
12 changes: 8 additions & 4 deletions cmd/livesim2/app/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ type ServerConfig struct {
LiveWindowS int `json:"livewindowS"`
TimeoutS int `json:"timeoutS"`
MaxRequests int `json:"maxrequests"`
VodRoot string `json:"vodroot"`
// WhiteListBlocks is a comma-separated list of CIDR blocks that are not rate limited
WhiteListBlocks string `json:"whitelistblocks"`
VodRoot string `json:"vodroot"`
// RepDataRoot is the root directory for representation metadata
RepDataRoot string `json:"repdataroot"`
// WriteRepData is true if representation metadata should be written (will override existing metadata)
Expand Down Expand Up @@ -71,9 +73,10 @@ var DefaultConfig = ServerConfig{
ReqLimitInt: defaultReqIntervalS,
VodRoot: "./vod",
// MetaRoot + means follow VodRoot, _ means no metadata
RepDataRoot: "+",
WriteRepData: false,
PlayURL: defaultPlayURL,
RepDataRoot: "+",
WriteRepData: false,
PlayURL: defaultPlayURL,
WhiteListBlocks: "",
}

type Config struct {
Expand Down Expand Up @@ -112,6 +115,7 @@ func LoadConfig(args []string, cwd string) (*ServerConfig, error) {
f.String("vodroot", k.String("vodroot"), "VoD root directory")
f.String("repdataroot", k.String("repdataroot"), `Representation metadata root directory. "+" copies vodroot value. "-" disables usage.`)
f.Bool("writerepdata", k.Bool("writerepdata"), "Write representation metadata if not present")
f.String("whitelistblocks", k.String("whitelistblocks"), "comma-separated list of CIDR blocks that are not rate limited")
f.Int("timeout", k.Int("timeoutS"), "timeout for all requests (seconds)")
f.Int("maxrequests", k.Int("maxrequests"), "max nr of request per IP address per 24 hours")
f.String("reqlimitlog", k.String("reqlimitlog"), "path to request limit log file (only written if maxrequests > 0)")
Expand Down
74 changes: 53 additions & 21 deletions cmd/livesim2/app/ipreqlimit.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,50 @@ import (
"net"
"net/http"
"os"
"strings"
"sync"
"time"
)

// IPRequestLimiter limits the number of requests per interval
type IPRequestLimiter struct {
MaxNrRequests int `json:"maxNrRequests"`
Interval time.Duration `json:"interval"`
ResetTime time.Time `json:"resetTime"`
Counters map[string]int `json:"counters"`
logFile string `json:"-"`
mux sync.Mutex `json:"-"`
MaxNrRequests int `json:"maxNrRequests"`
Interval time.Duration `json:"interval"`
ResetTime time.Time `json:"resetTime"`
Counters map[string]int `json:"counters"`
WhiteListBlocks string `json:"whiteListBlocks"`
logFile string `json:"-"`
mux sync.Mutex `json:"-"`
cidrBlocks []*net.IPNet `json:"-"`
}

// NewIPRequestLimiter returns a new IPRequestLimiter with maxNrRequests per interval starting now.
// If logFile is not empty, the IPRequestLimiter is dumped to the logFile at the end of each interval.
func NewIPRequestLimiter(maxNrRequests int, interval time.Duration, start time.Time, logFile string) *IPRequestLimiter {
return &IPRequestLimiter{
MaxNrRequests: maxNrRequests,
Interval: interval,
ResetTime: start,
Counters: make(map[string]int),
logFile: logFile,
mux: sync.Mutex{},
func NewIPRequestLimiter(maxNrRequests int, interval time.Duration, start time.Time,
whiteListBlocks string, logFile string) (*IPRequestLimiter, error) {
var cidrBlocks []*net.IPNet
if whiteListBlocks != "" {
blocks := strings.Split(whiteListBlocks, ",")
cidrBlocks = make([]*net.IPNet, 0, len(blocks))
for _, cidrBlock := range blocks {
_, ciBlock, err := net.ParseCIDR(cidrBlock)
if err != nil {
return nil, fmt.Errorf("invalid CIDR block %s: %w", cidrBlock, err)
}
cidrBlocks = append(cidrBlocks, ciBlock)
}
}

return &IPRequestLimiter{
MaxNrRequests: maxNrRequests,
Interval: interval,
ResetTime: start,
Counters: make(map[string]int),
WhiteListBlocks: whiteListBlocks,
logFile: logFile,
mux: sync.Mutex{},
cidrBlocks: cidrBlocks,
}, nil
}

// NewLimiterMiddleware returns a middleware that limits the number of requests per IP address per interval
Expand All @@ -51,25 +70,26 @@ func NewLimiterMiddleware(hdrName string, reqLimiter *IPRequestLimiter) func(nex
return
}
now := time.Now()
count, ok := reqLimiter.Inc(now, ip)
count, maxNr, ok := reqLimiter.Inc(now, ip)
if !ok {
if hdrName != "" {
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, reqLimiter.MaxNrRequests))
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, maxNr))
}
w.WriteHeader(http.StatusTooManyRequests)
return
}
if hdrName != "" {
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, reqLimiter.MaxNrRequests))
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, maxNr))
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}

// Inc increments the number of requests and returns number and ok value
func (il *IPRequestLimiter) Inc(now time.Time, ip string) (int, bool) {
// Inc increments the number of requests and returns number and ok value.
// If the address is in a white list block, maxNr is set to -1.
func (il *IPRequestLimiter) Inc(now time.Time, ip string) (nr, maxNr int, ok bool) {
il.mux.Lock()
defer il.mux.Unlock()
if now.Sub(il.ResetTime) > il.Interval {
Expand All @@ -80,8 +100,20 @@ func (il *IPRequestLimiter) Inc(now time.Time, ip string) (int, bool) {
il.ResetTime = now
}
il.Counters[ip]++
val := il.Counters[ip]
return val, val <= il.MaxNrRequests
nr = il.Counters[ip]
maxNr = il.MaxNrRequests
ok = nr <= maxNr
if len(il.cidrBlocks) > 0 {
parsedIP := net.ParseIP(ip)
for _, cidrBlock := range il.cidrBlocks {
if cidrBlock.Contains(parsedIP) {
ok = true
maxNr = -1
break
}
}
}
return nr, maxNr, ok
}

// Count returns the counter value for an IP address
Expand Down
35 changes: 32 additions & 3 deletions cmd/livesim2/app/ipreqlimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestRequestLimiter(t *testing.T) {
Expand All @@ -19,8 +21,9 @@ func TestRequestLimiter(t *testing.T) {

maxNrRequests := 5
maxTime := 100 * time.Millisecond
ltr := NewIPRequestLimiter(maxNrRequests, maxTime, time.Now(), "")
l := NewLimiterMiddleware("limiter", ltr)
ltr, err := NewIPRequestLimiter(maxNrRequests, maxTime, time.Now(), "192.168.5.0/24", "")
require.NoError(t, err)
lmw := NewLimiterMiddleware("limiter", ltr)

handler := func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&endpointCalledCount, 1)
Expand All @@ -29,7 +32,7 @@ func TestRequestLimiter(t *testing.T) {
mux := http.NewServeMux()

finalHandler := http.HandlerFunc(handler)
mux.Handle("/", l(finalHandler))
mux.Handle("/", lmw(finalHandler))

ts := httptest.NewServer(mux)
defer ts.Close()
Expand All @@ -46,6 +49,32 @@ func TestRequestLimiter(t *testing.T) {
}
}

func TestWhiteList(t *testing.T) {
endpointCalledCount := int64(0)

maxNrRequests := 3
maxTime := 100 * time.Millisecond
ltr, err := NewIPRequestLimiter(maxNrRequests, maxTime, time.Now(), "127.0.0.3/24", "")
require.NoError(t, err)
lmw := NewLimiterMiddleware("limiter", ltr)

handler := func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&endpointCalledCount, 1)
}

mux := http.NewServeMux()

finalHandler := http.HandlerFunc(handler)
mux.Handle("/", lmw(finalHandler))

ts := httptest.NewServer(mux)
defer ts.Close()

for i := 0; i < maxNrRequests+2; i++ {
doRequestAndCheckResponse(t, ts, i+1, -1, http.StatusOK)
}
}

func doRequestAndCheckResponse(t *testing.T, ts *httptest.Server, reqNr, maxNrRequests int, wantedStatus int) {
t.Helper()
res, err := http.Get(ts.URL)
Expand Down
6 changes: 5 additions & 1 deletion cmd/livesim2/app/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ func SetupServer(ctx context.Context, cfg *ServerConfig) (*Server, error) {
l := chi.NewRouter()
v := chi.NewRouter()
if cfg.MaxRequests > 0 {
reqLimiter = NewIPRequestLimiter(cfg.MaxRequests, time.Duration(cfg.ReqLimitInt)*time.Second, time.Now(), cfg.ReqLimitLog)
reqLimiter, err = NewIPRequestLimiter(cfg.MaxRequests, time.Duration(cfg.ReqLimitInt)*time.Second,
time.Now(), cfg.WhiteListBlocks, cfg.ReqLimitLog)
if err != nil {
return nil, fmt.Errorf("newIPLimiter: %w", err)
}
ltrMw := NewLimiterMiddleware("Livesim2-Requests", reqLimiter)
l.Use(ltrMw)
v.Use(ltrMw)
Expand Down

0 comments on commit a1daf85

Please sign in to comment.