Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add openAPI schema for some internal endpoints #5037

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/apps/alerts/models/escalation_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def sorted_users_queue(self):
return sorted(self.notify_to_users_queue.all(), key=lambda user: (user.username or "", user.pk))

@property
def slack_integration_required(self):
def slack_integration_required(self) -> bool:
if self.step in self.SLACK_INTEGRATION_REQUIRED_STEPS:
return True
else:
Expand Down
16 changes: 8 additions & 8 deletions engine/apps/api/serializers/alert_group_escalation_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ class Meta(EscalationPolicySerializer.Meta):
class AlertGroupEscalationSnapshotAPISerializer(serializers.Serializer):
"""Serializes AlertGroup escalation snapshot for API endpoint"""

escalation_chain = serializers.SerializerMethodField()
channel_filter = serializers.SerializerMethodField()
class EscalationChainSnapshotAPISerializer(serializers.Serializer):
name = serializers.CharField()

class ChannelFilterSnapshotAPISerializer(serializers.Serializer):
name = serializers.CharField(source="str_for_clients")

escalation_chain = EscalationChainSnapshotAPISerializer(read_only=True, source="escalation_chain_snapshot")
channel_filter = ChannelFilterSnapshotAPISerializer(read_only=True, source="channel_filter_snapshot")
escalation_policies = EscalationPolicySnapshotAPISerializer(
source="escalation_policies_snapshots", many=True, read_only=True
)

class Meta:
fields = ["escalation_chain", "channel_filter", "escalation_policies"]

def get_escalation_chain(self, obj):
return {"name": obj.escalation_chain_snapshot.name}

def get_channel_filter(self, obj):
return {"name": obj.channel_filter_snapshot.str_for_clients}
58 changes: 36 additions & 22 deletions engine/apps/api/serializers/channel_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
from common.utils import is_regex_valid


class SlackChannelDetails(typing.TypedDict):
display_name: str
slack_id: str
id: str


class ChannelFilterSerializer(EagerLoadingMixin, serializers.ModelSerializer):
class TelegramChannelDetailsSerializer(serializers.Serializer):
display_name = serializers.CharField(source="channel_name")
id = serializers.CharField(source="channel_chat_id")

id = serializers.CharField(read_only=True, source="public_primary_key")
alert_receive_channel = OrganizationFilteredPrimaryKeyRelatedField(queryset=AlertReceiveChannel.objects)
escalation_chain = OrganizationFilteredPrimaryKeyRelatedField(
Expand All @@ -28,7 +38,9 @@ class ChannelFilterSerializer(EagerLoadingMixin, serializers.ModelSerializer):
telegram_channel = OrganizationFilteredPrimaryKeyRelatedField(
queryset=TelegramToOrganizationConnector.objects, filter_field="organization", allow_null=True, required=False
)
telegram_channel_details = serializers.SerializerMethodField()
telegram_channel_details = TelegramChannelDetailsSerializer(
source="telegram_channel", read_only=True, allow_null=True
)
filtering_term_as_jinja2 = serializers.SerializerMethodField()
filtering_term = serializers.CharField(required=False, allow_null=True, allow_blank=True)
filtering_labels = LabelPairSerializer(many=True, required=False)
Expand Down Expand Up @@ -86,7 +98,7 @@ def validate(self, data):
raise serializers.ValidationError(["Expression type is incorrect"])
return data

def get_slack_channel(self, obj):
def get_slack_channel(self, obj) -> SlackChannelDetails | None:
if obj.slack_channel_id is None:
return None
# display_name and id appears via annotate in ChannelFilterView.get_queryset()
Expand All @@ -96,18 +108,6 @@ def get_slack_channel(self, obj):
"id": obj.slack_channel_pk,
}

def get_telegram_channel_details(self, obj) -> dict[str, typing.Any] | None:
if obj.telegram_channel_id is None:
return None
try:
telegram_channel = TelegramToOrganizationConnector.objects.get(pk=obj.telegram_channel_id)
return {
"display_name": telegram_channel.channel_name,
"id": telegram_channel.channel_chat_id,
}
except TelegramToOrganizationConnector.DoesNotExist:
return None

def validate_slack_channel(self, slack_channel_id):
from apps.slack.models import SlackChannel

Expand Down Expand Up @@ -182,21 +182,23 @@ class Meta:
]
read_only_fields = ["created_at", "is_default"]

def to_representation(self, obj):
"""add correct slack channel data to result after instance creation/update"""
result = super().to_representation(obj)
if obj.slack_channel_id is None:
result["slack_channel"] = None
else:
def _get_slack_channel(self, obj) -> SlackChannelDetails | None:
if obj.slack_channel_id is not None:
slack_team_identity = self.context["request"].auth.organization.slack_team_identity
if slack_team_identity is not None:
slack_channel = slack_team_identity.get_cached_channels(slack_id=obj.slack_channel_id).first()
if slack_channel:
result["slack_channel"] = {
return {
"display_name": slack_channel.name,
"slack_id": obj.slack_channel_id,
"slack_id": slack_channel.slack_id,
"id": slack_channel.public_primary_key,
}
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: invert IFs to reduce nesting

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


def to_representation(self, obj):
"""add correct slack channel data to result after instance creation/update"""
result = super().to_representation(obj)
result["slack_channel"] = self._get_slack_channel(obj)
return result

def create(self, validated_data):
Expand All @@ -218,3 +220,15 @@ def update(self, instance, validated_data):
raise BadRequest(detail="Filtering term of default channel filter cannot be changed")

return super().update(instance, validated_data)


class ChannelFilterRetrieveResponseSerializer(ChannelFilterUpdateSerializer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe ChannelFilterUpdateResponseSerializer would be a better name here? this way it indicates that it's a response from update/create, not retrieve

"""
This serializer is used in OpenAPI schema to show proper response structure,
as `slack_channel` field expects string on create/update and returns dict on response
"""

slack_channel = serializers.SerializerMethodField()

def get_slack_channel(self, obj) -> SlackChannelDetails | None:
return self._get_slack_channel(obj)
4 changes: 2 additions & 2 deletions engine/apps/api/serializers/escalation_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class EscalationChainListSerializer(EscalationChainSerializer):
class Meta(EscalationChainSerializer.Meta):
fields = [*EscalationChainSerializer.Meta.fields, "number_of_integrations", "number_of_routes"]

def get_number_of_integrations(self, obj):
def get_number_of_integrations(self, obj) -> int:
# num_integrations param added in queryset via annotate. Check EscalationChainViewSet.get_queryset
return getattr(obj, "num_integrations", 0)

def get_number_of_routes(self, obj):
def get_number_of_routes(self, obj) -> int:
# num_routes param added in queryset via annotate. Check EscalationChainViewSet.get_queryset
return getattr(obj, "num_routes", 0)

Expand Down
27 changes: 5 additions & 22 deletions engine/apps/api/views/alert_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from rest_framework.response import Response

from apps.alerts.constants import ActionSource
from apps.alerts.models import Alert, AlertGroup, AlertReceiveChannel, EscalationChain, ResolutionNote
from apps.alerts.models import Alert, AlertGroup, AlertReceiveChannel, ResolutionNote
from apps.alerts.paging import unpage_user
from apps.alerts.tasks import delete_alert_group, send_update_resolution_note_signal
from apps.api.errors import AlertGroupAPIError
Expand All @@ -35,32 +35,14 @@
DateRangeFilterMixin,
ModelFieldFilterMixin,
MultipleChoiceCharFilter,
get_escalation_chain_queryset,
get_integration_queryset,
get_user_queryset,
)
from common.api_helpers.mixins import PreviewTemplateMixin, PublicPrimaryKeyMixin, TeamFilteringMixin
from common.api_helpers.paginators import AlertGroupCursorPaginator


def get_integration_queryset(request):
if request is None:
return AlertReceiveChannel.objects.none()

return AlertReceiveChannel.objects_with_maintenance.filter(organization=request.user.organization)


def get_escalation_chain_queryset(request):
if request is None:
return EscalationChain.objects.none()

return EscalationChain.objects.filter(organization=request.user.organization)


def get_user_queryset(request):
if request is None:
return User.objects.none()

return User.objects.filter(organization=request.user.organization).distinct()


class AlertGroupFilter(DateRangeFilterMixin, ModelFieldFilterMixin, filters.FilterSet):
"""
Examples of possible date formats here https://docs.djangoproject.com/en/1.9/ref/settings/#datetime-input-formats
Expand Down Expand Up @@ -925,6 +907,7 @@ def bulk_action_options(self, request):
def get_alert_to_template(self, payload=None):
return self.get_object().alerts.first()

@extend_schema(responses=AlertGroupEscalationSnapshotAPISerializer)
@action(methods=["get"], detail=True)
def escalation_snapshot(self, request, pk=None):
alert_group = self.get_object()
Expand Down
38 changes: 32 additions & 6 deletions engine/apps/api/views/channel_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from django.db.models import OuterRef, Subquery
from django_filters import rest_framework as filters
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticated
Expand All @@ -8,12 +10,14 @@
from apps.api.permissions import RBACPermission
from apps.api.serializers.channel_filter import (
ChannelFilterCreateSerializer,
ChannelFilterRetrieveResponseSerializer,
ChannelFilterSerializer,
ChannelFilterUpdateSerializer,
)
from apps.auth_token.auth import PluginAuthentication
from apps.slack.models import SlackChannel
from common.api_helpers.exceptions import BadRequest
from common.api_helpers.filters import ModelFieldFilterMixin, MultipleChoiceCharFilter, get_integration_queryset
from common.api_helpers.mixins import (
CreateSerializerMixin,
PublicPrimaryKeyMixin,
Expand All @@ -24,13 +28,34 @@
from common.ordered_model.viewset import OrderedModelViewSet


class ChannelFilterFilter(ModelFieldFilterMixin, filters.FilterSet):
alert_receive_channel = MultipleChoiceCharFilter(
queryset=get_integration_queryset,
to_field_name="public_primary_key",
method=ModelFieldFilterMixin.filter_model_field.__name__,
)


@extend_schema_view(
list=extend_schema(responses=ChannelFilterSerializer),
retrieve=extend_schema(responses=ChannelFilterRetrieveResponseSerializer),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be just ChannelFilterSerializer?

create=extend_schema(request=ChannelFilterCreateSerializer, responses=ChannelFilterRetrieveResponseSerializer),
update=extend_schema(request=ChannelFilterUpdateSerializer, responses=ChannelFilterRetrieveResponseSerializer),
partial_update=extend_schema(
request=ChannelFilterUpdateSerializer, responses=ChannelFilterRetrieveResponseSerializer
),
)
class ChannelFilterView(
TeamFilteringMixin,
PublicPrimaryKeyMixin[ChannelFilter],
CreateSerializerMixin,
UpdateSerializerMixin,
OrderedModelViewSet,
):
"""
Internal API endpoints for channel filters (routes).
"""

authentication_classes = (PluginAuthentication,)
permission_classes = (IsAuthenticated, RBACPermission)
rbac_permissions = {
Expand All @@ -45,19 +70,19 @@ class ChannelFilterView(
"convert_from_regex_to_jinja2": [RBACPermission.Permissions.INTEGRATIONS_WRITE],
}

queryset = ChannelFilter.objects.none() # needed for drf-spectacular introspection

model = ChannelFilter
serializer_class = ChannelFilterSerializer
update_serializer_class = ChannelFilterUpdateSerializer
create_serializer_class = ChannelFilterCreateSerializer

filter_backends = (filters.DjangoFilterBackend,)
filterset_class = ChannelFilterFilter

TEAM_LOOKUP = "alert_receive_channel__team"

def get_queryset(self, ignore_filtering_by_available_teams=False):
alert_receive_channel_id = self.request.query_params.get("alert_receive_channel", None)
lookup_kwargs = {}
if alert_receive_channel_id:
lookup_kwargs = {"alert_receive_channel__public_primary_key": alert_receive_channel_id}

slack_channels_subq = SlackChannel.objects.filter(
slack_id=OuterRef("slack_channel_id"),
slack_team_identity=self.request.auth.organization.slack_team_identity,
Expand All @@ -66,7 +91,6 @@ def get_queryset(self, ignore_filtering_by_available_teams=False):
queryset = ChannelFilter.objects.filter(
alert_receive_channel__organization=self.request.auth.organization,
alert_receive_channel__deleted_at=None,
**lookup_kwargs,
).annotate(
slack_channel_name=Subquery(slack_channels_subq.values("name")[:1]),
slack_channel_pk=Subquery(slack_channels_subq.values("public_primary_key")[:1]),
Expand Down Expand Up @@ -109,6 +133,7 @@ def perform_update(self, serializer):
new_state=new_state,
)

@extend_schema(request=None, responses={status.HTTP_200_OK: None})
@action(detail=True, methods=["put"])
def move_to_position(self, request, pk):
instance = self.get_object()
Expand All @@ -117,6 +142,7 @@ def move_to_position(self, request, pk):

return super().move_to_position(request, pk)

@extend_schema(request=None, responses=ChannelFilterSerializer)
@action(detail=True, methods=["post"])
def convert_from_regex_to_jinja2(self, request, pk):
instance = self.get_object()
Expand Down
Loading
Loading