Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ The constructor performs validation on all parameters and returns descriptive er
- `CheckTokens(id []byte, n uint8) bool`: Checks if n tokens would be available without consuming them
- `TakeToken(id []byte) bool`: Attempts to take a single token, returns true if successful
- `TakeTokens(id []byte, n uint8) bool`: Attempts to take n tokens atomically, returns true if all n tokens were taken
- `SetRefillRate(refillRate float64) error`: Updates the refill rate in-place while preserving existing bucket state
- `RefillRate() float64`: Returns the current refill rate
- `RotationInterval() time.Duration`: Returns the automatically calculated rotation interval

#### Collision-Resistant Algorithm Explained
Expand Down
24 changes: 16 additions & 8 deletions bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,8 @@ func NewTokenBucketLimiter(
refillRate float64,
refillRateUnit time.Duration,
) (*TokenBucketLimiter, error) {
if math.IsNaN(refillRate) || math.IsInf(refillRate, 0) || refillRate <= 0 {
return nil, fmt.Errorf("refillRate must be a positive, finite number")
}

if rate := float64(refillRateUnit.Nanoseconds()); rate <= 0 {
return nil, fmt.Errorf("refillRateUnit must represent a positive duration")
} else if rate > math.MaxFloat64/refillRate {
return nil, fmt.Errorf("refillRate per duration is too large")
if err := validateRefillRate(refillRate, refillRateUnit); err != nil {
return nil, err
}

n := ceilPow2(uint64(numBuckets))
Expand All @@ -92,6 +86,20 @@ func NewTokenBucketLimiter(
}, nil
}

func validateRefillRate(refillRate float64, refillRateUnit time.Duration) error {
if math.IsNaN(refillRate) || math.IsInf(refillRate, 0) || refillRate <= 0 {
return fmt.Errorf("refillRate must be a positive, finite number")
}

if rate := float64(refillRateUnit.Nanoseconds()); rate <= 0 {
return fmt.Errorf("refillRateUnit must represent a positive duration")
} else if rate > math.MaxFloat64/refillRate {
return fmt.Errorf("refillRate per duration is too large")
}

return nil
}

// CheckToken returns whether a token would be available for the given
// ID without actually taking it. This is useful for preemptively
// checking if an operation would be rate limited before attempting
Expand Down
92 changes: 69 additions & 23 deletions rotating.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ type rotatingPair struct {
rotated time56.Time
}

type refillState struct {
refillRate float64
nanosPerToken int64
nanosPerRotation int64
}

// RotatingTokenBucketLimiter implements a collision-resistant token
// bucket rate limiter. It maintains two TokenBucketLimiters with
// different hash seeds and rotates between them periodically. This
Expand All @@ -40,8 +46,10 @@ type rotatingPair struct {
// last for the duration of the rotation period, providing better
// fairness and accuracy compared to a single TokenBucketLimiter.
type RotatingTokenBucketLimiter struct {
pair atomic.Pointer[rotatingPair] // Current limiter pair
nanosPerRotation int64 // Rotation interval in nanoseconds
pair atomic.Pointer[rotatingPair] // Current limiter pair
burstCapacity uint8
refillRateUnit time.Duration
state atomic.Pointer[refillState]
}

// Compile-time assertion that RotatingTokenBucketLimiter implements Limiter
Expand Down Expand Up @@ -97,18 +105,18 @@ func NewRotatingTokenBucketLimiter(
refillRate float64,
refillRateUnit time.Duration,
) (*RotatingTokenBucketLimiter, error) {
checked, err := NewTokenBucketLimiter(
if err := validateRefillRate(refillRate, refillRateUnit); err != nil {
return nil, err
}

// Validation passed above, and NewTokenBucketLimiter currently has no
// additional error cases beyond parameter validation.
checked, _ := NewTokenBucketLimiter(
numBuckets,
burstCapacity,
refillRate,
refillRateUnit,
)
if err != nil {
return nil, err
}

// validation passed for exact params above, continue w/o checking
// error for 100% coverage.

ignored, _ := NewTokenBucketLimiter(
numBuckets,
Expand All @@ -121,13 +129,11 @@ func NewRotatingTokenBucketLimiter(
// convergence of all token buckets to steady state before rotation
// occurs. This guarantees correctness by eliminating state
// inconsistency issues when hash mappings change during rotation.
refillTime := time.Duration(float64(burstCapacity) / refillRate * float64(refillRateUnit))
safetyFactor := 5.0
rotationRate := time.Duration(float64(refillTime) * safetyFactor)

limiter := &RotatingTokenBucketLimiter{
nanosPerRotation: rotationRate.Nanoseconds(),
burstCapacity: burstCapacity,
refillRateUnit: refillRateUnit,
}
limiter.setRefillRateState(refillRate)

pair := &rotatingPair{
checked: checked,
Expand Down Expand Up @@ -157,13 +163,13 @@ func NewRotatingTokenBucketLimiter(
// This approach ensures that hash collisions are resolved
// periodically without affecting the thread-safety or performance of
// the limiter.
func (r *RotatingTokenBucketLimiter) load(nowNS int64) *rotatingPair {
func (r *RotatingTokenBucketLimiter) load(nowNS int64, state *refillState) *rotatingPair {
now := time56.Unix(nowNS)

for {
pair := r.pair.Load()

if now.Since(pair.rotated) < r.nanosPerRotation {
if now.Since(pair.rotated) < state.nanosPerRotation {
return pair
}

Expand Down Expand Up @@ -220,9 +226,11 @@ func (r *RotatingTokenBucketLimiter) CheckToken(id []byte) bool {
// multiple goroutines.
func (r *RotatingTokenBucketLimiter) CheckTokens(id []byte, n uint8) bool {
now := nowfn()
pair := r.load(now)
pair.ignored.checkTokensWithNow(id, n, now)
return pair.checked.checkTokensWithNow(id, n, now)
state := r.loadRefillState()
pair := r.load(now, state)
rate := state.nanosPerToken
pair.ignored.checkInner(pair.ignored.index(id), rate, now, n)
return pair.checked.checkInner(pair.checked.index(id), rate, now, n)
}

// TakeToken attempts to take a token for the given ID. It returns
Expand Down Expand Up @@ -273,9 +281,29 @@ func (r *RotatingTokenBucketLimiter) TakeToken(id []byte) bool {
// multiple goroutines.
func (r *RotatingTokenBucketLimiter) TakeTokens(id []byte, n uint8) bool {
now := nowfn()
pair := r.load(now)
pair.ignored.takeTokensWithNow(id, n, now)
return pair.checked.takeTokensWithNow(id, n, now)
state := r.loadRefillState()
pair := r.load(now, state)
rate := state.nanosPerToken
pair.ignored.takeTokenInner(pair.ignored.index(id), rate, now, n)
return pair.checked.takeTokenInner(pair.checked.index(id), rate, now, n)
}

// SetRefillRate updates the refill rate used by the rotating limiter
// without rebuilding bucket state. Existing tokens are preserved, and
// subsequent checks, takes, and rotation timing use the new rate.
func (r *RotatingTokenBucketLimiter) SetRefillRate(refillRate float64) error {
if err := validateRefillRate(refillRate, r.refillRateUnit); err != nil {
return err
}

r.setRefillRateState(refillRate)
return nil
}

// RefillRate returns the current refill rate in tokens per
// refillRateUnit.
func (r *RotatingTokenBucketLimiter) RefillRate() float64 {
return r.loadRefillState().refillRate
}

// RotationInterval returns the automatically calculated rotation
Expand All @@ -290,5 +318,23 @@ func (r *RotatingTokenBucketLimiter) TakeTokens(id []byte, n uint8) bool {
// This method is thread-safe and can be called concurrently from
// multiple goroutines.
func (r *RotatingTokenBucketLimiter) RotationInterval() time.Duration {
return time.Duration(r.nanosPerRotation)
return time.Duration(r.loadRefillState().nanosPerRotation)
}

func (r *RotatingTokenBucketLimiter) setRefillRateState(refillRate float64) {
r.state.Store(&refillState{
refillRate: refillRate,
nanosPerToken: nanoRate(r.refillRateUnit, refillRate),
nanosPerRotation: calculateRotationInterval(r.burstCapacity, refillRate, r.refillRateUnit).Nanoseconds(),
})
}

func (r *RotatingTokenBucketLimiter) loadRefillState() *refillState {
return r.state.Load()
}

func calculateRotationInterval(burstCapacity uint8, refillRate float64, refillRateUnit time.Duration) time.Duration {
refillTime := time.Duration(float64(burstCapacity) / refillRate * float64(refillRateUnit))
safetyFactor := 5.0
return time.Duration(float64(refillTime) * safetyFactor)
}
Loading
Loading