Skip to content

Commit

Permalink
Add custom ratelimits per org (#5004)
Browse files Browse the repository at this point in the history
# What this PR does

This PR refactors Throttling for public API and integrations API and
allows to specify organization ratelimits.


## Which issue(s) this PR closes

Related to [issue link here]

<!--
*Note*: If you want the issue to be auto-closed once the PR is merged,
change "Related to" to "Closes" in the line above.
If you have more than one GitHub issue that this PR closes, be sure to
preface
each issue link with a [closing
keyword](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue).
This ensures that the issue(s) are auto-closed once the PR has been
merged.
-->

## Checklist

- [x] Unit, integration, and e2e (if applicable) tests updated
- [ ] Documentation added (or `pr:no public docs` PR label added if not
required)
- [x] Added the relevant release notes label (see labels prefixed w/
`release:`). These labels dictate how your PR will
    show up in the autogenerated release notes.
  • Loading branch information
iskhakov authored Sep 17, 2024
1 parent dd6d2ab commit c718863
Show file tree
Hide file tree
Showing 15 changed files with 268 additions and 144 deletions.
89 changes: 49 additions & 40 deletions engine/apps/api/tests/test_user.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from unittest import mock
from unittest.mock import Mock, patch
from unittest.mock import Mock, PropertyMock, patch

import pytest
from django.core.cache import cache
Expand Down Expand Up @@ -1775,17 +1775,10 @@ def test_invalid_working_hours(

@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
@patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True)
@patch(
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.get_throttle_limits",
return_value=(1, 10 * 60),
)
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.get_throttle_limits", return_value=(1, 10 * 60))
@pytest.mark.django_db
def test_phone_number_verification_flow_ratelimit_per_user(
mock_verification_start,
mocked_verification_check,
mocked_get_phone_verification_code_get_throttle_limits,
mocked_get_phone_verify_phone_number_limits,
make_organization_and_user_with_plugin_token,
make_user_auth_headers,
):
Expand All @@ -1794,40 +1787,44 @@ def test_phone_number_verification_flow_ratelimit_per_user(
client = APIClient()
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})

# first get_verification_code request is succesfull
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK
with patch(
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerUser.rate",
new_callable=PropertyMock,
) as mocked_rate:
mocked_rate.return_value = "1/10m"
# first get_verification_code request is succesfull
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK

# second get_verification_code request is ratelimited
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
# second get_verification_code request is ratelimited
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS

url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})

# first verify_number request is succesfull, because it uses different ratelimit scope
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK
with patch(
"apps.api.throttlers.VerifyPhoneNumberThrottlerPerUser.rate",
new_callable=PropertyMock,
) as mocked_rate:
mocked_rate.return_value = "1/10m"

url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
# first verify_number request is succesfull, because it uses different ratelimit scope
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK

# second verify_number request is succesfull, because it ratelimited
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})

# second verify_number request is succesfull, because it ratelimited
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS


@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
@patch("apps.phone_notifications.phone_backend.PhoneBackend.verify_phone_number", return_value=True)
@patch(
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.get_throttle_limits",
return_value=(1, 10 * 60),
)
@patch("apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.get_throttle_limits", return_value=(1, 10 * 60))
@pytest.mark.django_db
def test_phone_number_verification_flow_ratelimit_per_org(
mock_verification_start,
mocked_verification_check,
mocked_get_phone_verification_code_get_throttle_limits,
mocked_get_phone_verify_phone_number_limits,
make_organization_and_user_with_plugin_token,
make_user_auth_headers,
make_user_for_organization,
Expand All @@ -1841,21 +1838,33 @@ def test_phone_number_verification_flow_ratelimit_per_org(

client = APIClient()

url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK
with patch(
"apps.api.throttlers.GetPhoneVerificationCodeThrottlerPerOrg.rate",
new_callable=PropertyMock,
) as mocked_rate:
mocked_rate.return_value = "1/10m"

url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key})
response = client.get(url, format="json", **make_user_auth_headers(second_user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": user.public_primary_key})
response = client.get(url, format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK

url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK
url = reverse("api-internal:user-get-verification-code", kwargs={"pk": second_user.public_primary_key})
response = client.get(url, format="json", **make_user_auth_headers(second_user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS

url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key})
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
with patch(
"apps.api.throttlers.VerifyPhoneNumberThrottlerPerOrg.rate",
new_callable=PropertyMock,
) as mocked_rate:
mocked_rate.return_value = "1/10m"

url = reverse("api-internal:user-verify-number", kwargs={"pk": user.public_primary_key})
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(user, token))
assert response.status_code == status.HTTP_200_OK

url = reverse("api-internal:user-verify-number", kwargs={"pk": second_user.public_primary_key})
response = client.put(f"{url}?token=12345", format="json", **make_user_auth_headers(second_user, token))
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS


@patch("apps.phone_notifications.phone_backend.PhoneBackend.send_verification_sms", return_value=Mock())
Expand Down
4 changes: 2 additions & 2 deletions engine/apps/api/throttlers/demo_alert_throttler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from rest_framework.throttling import UserRateThrottle
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler


class DemoAlertThrottler(UserRateThrottle):
class DemoAlertThrottler(CustomRateUserThrottler):
scope = "send_demo_alert"
rate = "30/m"
54 changes: 13 additions & 41 deletions engine/apps/api/throttlers/phone_verification_throttler.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,21 @@
from common.api_helpers.custom_rate_scoped_throttler import CustomRateScopedThrottler
from common.api_helpers.custom_rate_scoped_throttler import CustomRateOrganizationThrottler, CustomRateUserThrottler


class GetPhoneVerificationCodeThrottlerPerUser(CustomRateScopedThrottler):
def get_scope(self):
return "get_phone_verification_code_per_user"
class GetPhoneVerificationCodeThrottlerPerUser(CustomRateUserThrottler):
rate = "5/10m"
scope = "get_phone_verification_code_per_user"

def get_throttle_limits(self):
return 5, 10 * 60

class VerifyPhoneNumberThrottlerPerUser(CustomRateUserThrottler):
rate = "50/10m"
scope = "verify_phone_number_per_user"

class VerifyPhoneNumberThrottlerPerUser(CustomRateScopedThrottler):
def get_scope(self):
return "verify_phone_number_per_user"

def get_throttle_limits(self):
return 50, 10 * 60
class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateOrganizationThrottler):
rate = "50/10m"
scope = "get_phone_verification_code_per_org"


class GetPhoneVerificationCodeThrottlerPerOrg(CustomRateScopedThrottler):
def get_scope(self):
return "get_phone_verification_code_per_org"

def get_throttle_limits(self):
return 50, 10 * 60

def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.organization.pk
else:
ident = self.get_ident(request)

return self.cache_format % {"scope": self.scope, "ident": ident}


class VerifyPhoneNumberThrottlerPerOrg(CustomRateScopedThrottler):
def get_scope(self):
return "verify_phone_number_per_org"

def get_throttle_limits(self):
return 50, 10 * 60

def get_cache_key(self, request, view):
if request.user.is_authenticated:
ident = request.user.organization.pk
else:
ident = self.get_ident(request)

return self.cache_format % {"scope": self.scope, "ident": ident}
class VerifyPhoneNumberThrottlerPerOrg(CustomRateOrganizationThrottler):
rate = "50/10m"
scope = "verify_phone_number_per_org"
6 changes: 3 additions & 3 deletions engine/apps/api/throttlers/test_call_throttler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from rest_framework.throttling import UserRateThrottle
from common.api_helpers.custom_rate_scoped_throttler import CustomRateUserThrottler


class TestCallThrottler(UserRateThrottle):
class TestCallThrottler(CustomRateUserThrottler):
"""
set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following:
PytestCollectionWarning: cannot collect test class 'TestCallThrottler' because it has a __init__ constructor
Expand All @@ -13,7 +13,7 @@ class TestCallThrottler(UserRateThrottle):
rate = "5/m"


class TestPushThrottler(UserRateThrottle):
class TestPushThrottler(CustomRateUserThrottler):
"""
set a __test__ = False attribute in classes that pytest should ignore otherwise we end up getting the following:
PytestCollectionWarning: cannot collect test class 'TestPushThrottler' because it has a __init__ constructor
Expand Down
1 change: 1 addition & 0 deletions engine/apps/integrations/mixins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .ratelimit_mixin import ( # noqa: F401
IntegrationHeartBeatRateLimitMixin,
IntegrationRateLimitMixin,
RateLimitMixin,
is_ratelimit_ignored,
)
38 changes: 33 additions & 5 deletions engine/apps/integrations/mixins/ratelimit_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from functools import wraps

from django.conf import settings
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse
from django.views import View
Expand All @@ -16,6 +17,8 @@

RATELIMIT_INTEGRATION = "300/5m"
RATELIMIT_TEAM = "900/5m"
RATELIMIT_INTEGRATION_GROUP_NAME = "integration"
RATELIMIT_TEAM_GROUP_NAME = "team"
RATELIMIT_REASON_INTEGRATION = "channel"
RATELIMIT_REASON_TEAM = "team"
INTEGRATION_TOKEN_TO_IGNORE_KEY = "integration_tokens_to_ignore_ratelimit"
Expand All @@ -30,13 +33,30 @@ def get_rate_limit_per_channel_key(_, request):
return str(request.alert_receive_channel.pk)


def get_rate_limit_per_team_key(_, request):
def get_rate_limit_per_organization_key(_, request):
"""
Rate limiting based on AlertReceiveChannel's team PK
"""
return str(request.alert_receive_channel.organization_id)


def get_rate_limit(group, request):
custom_ratelimits = settings.CUSTOM_RATELIMITS

organization_id = str(request.alert_receive_channel.organization_id)

if group == RATELIMIT_INTEGRATION_GROUP_NAME:
if organization_id in custom_ratelimits:
return custom_ratelimits[organization_id]["integration"]
return RATELIMIT_INTEGRATION
elif group == RATELIMIT_TEAM_GROUP_NAME:
if organization_id in custom_ratelimits:
return custom_ratelimits[organization_id]["organization"]
return RATELIMIT_TEAM
else:
raise Exception("Unknown group")


def ratelimit(group=None, key=None, rate=None, method=ALL, block=False, reason=None):
"""
This decorator is an updated version of:
Expand Down Expand Up @@ -171,7 +191,11 @@ def notify(self):
block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit
)
@ratelimit(
key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True
key=get_rate_limit_per_organization_key,
rate=RATELIMIT_TEAM,
group="team",
reason=RATELIMIT_REASON_TEAM,
block=True,
)
def execute_rate_limit(self, *args, **kwargs):
pass
Expand Down Expand Up @@ -201,13 +225,17 @@ class IntegrationRateLimitMixin(RateLimitMixin, View):

@ratelimit(
key=get_rate_limit_per_channel_key,
rate=RATELIMIT_INTEGRATION,
group="integration",
rate=get_rate_limit,
group=RATELIMIT_INTEGRATION_GROUP_NAME,
reason=RATELIMIT_REASON_INTEGRATION,
block=True, # use block=True so integration rate limit 429s are not counted towards the team rate limit
)
@ratelimit(
key=get_rate_limit_per_team_key, rate=RATELIMIT_TEAM, group="team", reason=RATELIMIT_REASON_TEAM, block=True
key=get_rate_limit_per_organization_key,
rate=get_rate_limit,
group=RATELIMIT_TEAM_GROUP_NAME,
reason=RATELIMIT_REASON_TEAM,
block=True,
)
def execute_rate_limit(self, *args, **kwargs):
pass
Expand Down
Loading

0 comments on commit c718863

Please sign in to comment.