Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common_utility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .sessionProvider import *
from .fileDownloader import *
from .configLoader import *
from .rateLimiter import *
46 changes: 46 additions & 0 deletions common_utility/rateLimiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-FileCopyrightText: 2024 Ferenc Nandor Janky <ferenj@effective-range.com>
# SPDX-FileCopyrightText: 2024 Attila Gombos <attila.gombos@effective-range.com>
# SPDX-License-Identifier: MIT

import time

from context_logger import get_logger

log = get_logger('RateLimiter')


class IRateLimiter(object):

def acquire(self, size: int = 1) -> bool:
raise NotImplementedError()


class TokenBucketLimiter(IRateLimiter):

def __init__(self, fill_rate: int, time_base: int = 1, burst_factor: int = 3) -> None:
self._scale_factor = time_base
self._fill_rate = fill_rate * time_base
self._capacity = self._fill_rate * burst_factor
self._tokens = self._capacity
self._last_time = self._get_time()

def acquire(self, size: int = 1) -> bool:
current_time = self._get_time()
elapsed_time = current_time - self._last_time
self._last_time = current_time

self._tokens += round(elapsed_time * self._fill_rate)
self._tokens = min(self._tokens, self._capacity)

required_size = size * self._scale_factor

if required_size > self._tokens:
log.debug("Rate limit exceeded", tokens=self._tokens, required=required_size)
return False
else:
log.debug("Acquired", tokens=self._tokens, required=required_size)
self._tokens -= required_size
return True

def _get_time(self) -> float:
return time.monotonic() / self._scale_factor
77 changes: 77 additions & 0 deletions tests/rateLimiterTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import unittest
from unittest import TestCase
from unittest.mock import patch

from context_logger import setup_logging

from common_utility import TokenBucketLimiter


class RateLimiterTest(TestCase):

@classmethod
def setUpClass(cls):
setup_logging('python-common-utility', 'DEBUG', warn_on_overwrite=False)

def setUp(self):
print()

@patch('time.monotonic', side_effect=[0.0, 100.0, 100.9, 101.2])
def test_acquire_when_not_exceeds_limit(self, mock_monotonic):
# Given
rate_limiter = TokenBucketLimiter(10, 1, 2)

rate_limiter.acquire(10)
rate_limiter.acquire(10)

# When
result = rate_limiter.acquire(10)

# Then
self.assertTrue(result)

@patch('time.monotonic', side_effect=[0.0, 100.0, 100.9, 101.2])
def test_acquire_when_exceeds_limit(self, mock_monotonic):
# Given
rate_limiter = TokenBucketLimiter(10, 1, 2)

rate_limiter.acquire(12)
rate_limiter.acquire(12)

# When
result = rate_limiter.acquire(12)

# Then
self.assertFalse(result)

@patch('time.monotonic', side_effect=[0.0, 100.0, 130.0, 160.0])
def test_acquire_when_not_exceeds_limit_and_minute_based(self, mock_monotonic):
# Given
rate_limiter = TokenBucketLimiter(10, 60, 2)

rate_limiter.acquire(10)
rate_limiter.acquire(10)

# When
result = rate_limiter.acquire(10)

# Then
self.assertTrue(result)

@patch('time.monotonic', side_effect=[0.0, 100.0, 130.0, 160.0])
def test_acquire_when_exceeds_limit_and_minute_based(self, mock_monotonic):
# Given
rate_limiter = TokenBucketLimiter(10, 60, 2)

rate_limiter.acquire(12)
rate_limiter.acquire(12)

# When
result = rate_limiter.acquire(12)

# Then
self.assertFalse(result)


if __name__ == '__main__':
unittest.main()
Loading