Skip to content

Commit

Permalink
add slash commands responses get/edit/delete endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
RuslanUC committed Oct 27, 2023
1 parent 532b9a5 commit 7a68d95
Show file tree
Hide file tree
Showing 35 changed files with 298 additions and 102 deletions.
16 changes: 8 additions & 8 deletions src/rest_api/models/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import annotations

from time import mktime
from typing import Optional, List
from typing import Optional

from dateutil.parser import parse as dparse
from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -237,7 +237,7 @@ class EmbedModel(BaseModel):
thumbnail: Optional[EmbedImage] = None
video: Optional[EmbedImage] = None
author: Optional[EmbedAuthor] = None
fields: List[EmbedField] = Field(default_factory=list)
fields: list[EmbedField] = Field(default_factory=list)

@field_validator("title")
def validate_title(cls, value: str):
Expand Down Expand Up @@ -302,7 +302,7 @@ def validate_author(cls, value: Optional[EmbedAuthor]):
return value

@field_validator("fields")
def validate_fields(cls, value: List[EmbedField]):
def validate_fields(cls, value: list[EmbedField]):
if len(value) > 25:
value = value[:25]
return value
Expand All @@ -327,8 +327,8 @@ def model_dump(self, *args, **kwargs):
class MessageCreate(BaseModel):
content: Optional[str] = None
nonce: Optional[str] = None
embeds: List[EmbedModel] = Field(default_factory=list)
sticker_ids: List[int] = Field(default_factory=list)
embeds: list[EmbedModel] = Field(default_factory=list)
sticker_ids: list[int] = Field(default_factory=list)
message_reference: Optional[MessageReferenceModel] = None
flags: Optional[int] = None

Expand All @@ -343,15 +343,15 @@ def validate_content(cls, value: Optional[str]):
return value

@field_validator("embeds")
def validate_embeds(cls, value: List[EmbedModel]):
def validate_embeds(cls, value: list[EmbedModel]):
if len(value) > 10:
raise InvalidDataErr(400, Errors.make(50035, {"embeds": {
"code": "BASE_TYPE_BAD_LENGTH", "message": "Must be between 1 and 10 in length."
}}))
return value

@field_validator("sticker_ids")
def validate_sticker_ids(cls, value: List[int]):
def validate_sticker_ids(cls, value: list[int]):
if len(value) > 3:
raise InvalidDataErr(400, Errors.make(50035, {"sticker_ids": {
"code": "BASE_TYPE_BAD_LENGTH", "message": "Must be between 1 and 3 in length."
Expand Down Expand Up @@ -392,7 +392,7 @@ def to_json(self) -> dict:
# noinspection PyMethodParameters
class MessageUpdate(BaseModel):
content: Optional[str] = None
embeds: List[EmbedModel] = Field(default_factory=list)
embeds: list[EmbedModel] = Field(default_factory=list)

@field_validator("content")
def validate_content(cls, value: Optional[str]):
Expand Down
7 changes: 4 additions & 3 deletions src/rest_api/models/interactions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Literal, Optional, Any

from pydantic import BaseModel, create_model, field_validator
from pydantic import BaseModel, create_model, field_validator, Field
from pydantic_core.core_schema import ValidationInfo

from .channels import EmbedModel
from ...yepcord.enums import ApplicationCommandOptionType

OPTION_MODELS = {
Expand Down Expand Up @@ -54,9 +55,9 @@ class InteractionCreate(BaseModel):

class InteractionRespondData(BaseModel):
content: Optional[str] = None
#embeds: Optional[list] = None
embeds: list[EmbedModel] = Field(default_factory=list)
flags: int = 0
#components: Optional[list] = None # components validation are not supported now :(
#components: Optional[list] = None # components validation are not supported yet :(


class InteractionRespond(BaseModel):
Expand Down
5 changes: 2 additions & 3 deletions src/rest_api/routes/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
DMChannelDeleteEvent, MessageReactionAddEvent, MessageReactionRemoveEvent, ChannelUpdateEvent, ChannelDeleteEvent, \
WebhooksUpdateEvent, ThreadCreateEvent, ThreadMemberUpdateEvent, MessageAckEvent, GuildAuditLogEntryCreateEvent
from ...yepcord.ctx import getCore, getCDNStorage, getGw
from ...yepcord.enums import GuildPermissions, MessageType, ChannelType, WebhookType, GUILD_CHANNELS
from ...yepcord.enums import GuildPermissions, MessageType, ChannelType, WebhookType, GUILD_CHANNELS, MessageFlags
from ...yepcord.errors import InvalidDataErr, Errors
from ...yepcord.models import User, Channel, Message, ReadState, Emoji, PermissionOverwrite, Webhook, ThreadMember, \
ThreadMetadata, AuditLogEntry, Relationship, ApplicationCommand, Integration, Bot
Expand Down Expand Up @@ -553,8 +553,7 @@ async def create_thread(data: CreateThread, user: User, channel: Channel, messag
await getGw().dispatch(ThreadCreateEvent(await thread.ds_json() | {"newly_created": True}),
guild_id=channel.guild.id)
await getGw().dispatch(ThreadMemberUpdateEvent(thread_member.ds_json()), guild_id=channel.guild.id)
message.thread = thread
await message.save(update_fields=["thread"])
await message.update(thread=thread, flags=message.flags | MessageFlags.HAS_THREAD)

Check warning on line 556 in src/rest_api/routes/channels.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/routes/channels.py#L556

Added line #L556 was not covered by tests
await getCore().sendMessage(thread_message)
await getCore().sendMessage(thread_create_message)

Expand Down
76 changes: 56 additions & 20 deletions src/rest_api/routes/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from time import time
from datetime import datetime
from typing import Optional

from quart import Blueprint, request
from quart_schema import validate_request, validate_querystring

from ..models.channels import MessageUpdate
from ..models.webhooks import WebhookUpdate, WebhookMessageCreate, WebhookMessageCreateQuery
from ..utils import getUser, multipleDecorators, allowWithoutUser, processMessageData, allowBots, process_stickers, \
validate_reply, processMessage
validate_reply, processMessage, getWebhook, getMessage, getInteractionW
from ...gateway.events import MessageCreateEvent, WebhooksUpdateEvent, MessageUpdateEvent
from ...yepcord.ctx import getCore, getCDNStorage, getGw
from ...yepcord.enums import GuildPermissions, MessageType, MessageFlags
from ...yepcord.enums import GuildPermissions, MessageFlags
from ...yepcord.errors import InvalidDataErr, Errors
from ...yepcord.models import User, Channel, Message, Interaction
from ...yepcord.snowflake import Snowflake
from ...yepcord.models import User, Channel, Message, Interaction, Webhook
from ...yepcord.utils import getImage

# Base path is /api/vX/webhooks
Expand All @@ -39,7 +39,7 @@
@webhooks.delete("/<int:webhook>")
@webhooks.delete("/<int:webhook>/<string:token>")
@multipleDecorators(allowWithoutUser, allowBots, getUser)
async def api_webhooks_webhook_delete(user: Optional[User], webhook: int, token: Optional[str]=None):
async def delete_webhook(user: Optional[User], webhook: int, token: Optional[str]=None):
if webhook := await getCore().getWebhook(webhook):
if webhook.token != token:
guild = webhook.channel.guild
Expand All @@ -56,7 +56,7 @@ async def api_webhooks_webhook_delete(user: Optional[User], webhook: int, token:
@webhooks.patch("/<int:webhook>")
@webhooks.patch("/<int:webhook>/<string:token>")
@multipleDecorators(validate_request(WebhookUpdate), allowWithoutUser, allowBots, getUser)
async def api_webhooks_webhook_patch(data: WebhookUpdate, user: Optional[User], webhook: int, token: Optional[str]=None):
async def edit_webhook(data: WebhookUpdate, user: Optional[User], webhook: int, token: Optional[str]=None):
if not (webhook := await getCore().getWebhook(webhook)):
raise InvalidDataErr(404, Errors.make(10015))
channel = webhook.channel
Expand Down Expand Up @@ -93,7 +93,7 @@ async def api_webhooks_webhook_patch(data: WebhookUpdate, user: Optional[User],
@webhooks.get("/<int:webhook>")
@webhooks.get("/<int:webhook>/<string:token>")
@multipleDecorators(allowWithoutUser, allowBots, getUser)
async def api_webhooks_webhook_get(user: Optional[User], webhook: int, token: Optional[str]=None):
async def get_webhook(user: Optional[User], webhook: int, token: Optional[str]=None):
if not (webhook := await getCore().getWebhook(webhook)):
raise InvalidDataErr(404, Errors.make(10015))
if webhook.token != token:
Expand All @@ -107,7 +107,7 @@ async def api_webhooks_webhook_get(user: Optional[User], webhook: int, token: Op

@webhooks.post("/<int:webhook>/<string:token>")
@validate_querystring(WebhookMessageCreateQuery)
async def api_webhooks_webhook_post(query_args: WebhookMessageCreateQuery, webhook: int, token: str):
async def post_webhook_message(query_args: WebhookMessageCreateQuery, webhook: int, token: str):
if not (webhook := await getCore().getWebhook(webhook)):
raise InvalidDataErr(404, Errors.make(10015))
if webhook.token != token:
Expand All @@ -126,16 +126,31 @@ async def api_webhooks_webhook_post(query_args: WebhookMessageCreateQuery, webho
return "", 204


@webhooks.get("/<int:webhook>/<string:token>/messages/<int:message>")
@multipleDecorators(getWebhook, getMessage)
async def get_webhook_message(webhook: Webhook, message: Message):
return await message.ds_json()


@webhooks.delete("/<int:webhook>/<string:token>/messages/<int:message>")
@multipleDecorators(getWebhook, getMessage)
async def delete_webhook_message(webhook: Webhook, message: Message):
await message.delete()
return "", 204


@webhooks.patch("/<int:webhook>/<string:token>/messages/<int:message>")
@multipleDecorators(validate_request(MessageUpdate), getWebhook, getMessage)
async def edit_webhook_message(data: MessageUpdate, webhook: Webhook, message: Message):
await message.update(**data.to_json(), edit_timestamp=datetime.now())
await getGw().dispatch(MessageUpdateEvent(await message.ds_json()), channel_id=webhook.channel.id)
return await message.ds_json()


@webhooks.post("/<int:application_id>/int___<string:token>")
async def interaction_followup_create(application_id: int, token: str):
if not (inter := await Interaction.from_token(f"int___{token}")) or inter.application.id != application_id:
raise InvalidDataErr(404, Errors.make(10002))
message = await Message.get_or_none(interaction=inter, id__gt=Snowflake.fromTimestamp(time() - 15 * 60))\
.select_related(*Message.DEFAULT_RELATED)
if message is None:
raise InvalidDataErr(404, Errors.make(10008))

channel = inter.channel
@getInteractionW
async def interaction_followup_create(interaction: Interaction, message: Message):
channel = interaction.channel
data = await request.get_json()

data, attachments = await processMessageData(data, channel)
Expand All @@ -151,8 +166,29 @@ async def interaction_followup_create(application_id: int, token: str):
message_obj = await message.ds_json()

if message.ephemeral:
await getGw().dispatch(MessageUpdateEvent(message_obj), users=[inter.user.id, inter.application.id])
await getGw().dispatch(MessageUpdateEvent(message_obj), users=[interaction.user.id, interaction.application.id])

Check warning on line 169 in src/rest_api/routes/webhooks.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/routes/webhooks.py#L169

Added line #L169 was not covered by tests
else:
await getGw().dispatch(MessageUpdateEvent(message_obj), channel_id=inter.channel.id)
await getGw().dispatch(MessageUpdateEvent(message_obj), channel_id=interaction.channel.id)

return message_obj


@webhooks.get("/<int:application_id>/int___<string:token>/messages/<string:message>")
@getInteractionW
async def get_interaction_message(interaction: Interaction, message: Message):
return await message.ds_json()


@webhooks.delete("/<int:application_id>/int___<string:token>/messages/<string:message>")
@getInteractionW
async def delete_interaction_message(interaction: Interaction, message: Message):
await message.delete()
return "", 204


@webhooks.patch("/<int:application_id>/int___<string:token>/messages/<string:message>")
@multipleDecorators(validate_request(MessageUpdate), getInteractionW)
async def edit_interaction_message(data: MessageUpdate, interaction: Interaction, message: Message):
await message.update(**data.to_json(), edit_timestamp=datetime.now())
await getGw().dispatch(MessageUpdateEvent(await message.ds_json()), channel_id=interaction.channel.id)
return await message.ds_json()
60 changes: 56 additions & 4 deletions src/rest_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations
from functools import wraps
from json import loads
from time import time
from typing import Optional, Union, TYPE_CHECKING

from PIL import Image
Expand All @@ -29,12 +30,13 @@
from ..yepcord.ctx import Ctx, getCore, getCDNStorage
from ..yepcord.enums import MessageType
from ..yepcord.errors import Errors, InvalidDataErr
from ..yepcord.models import Session, User, Channel, Attachment, Application, Authorization, Bot, Interaction, Webhook
from ..yepcord.models import Session, User, Channel, Attachment, Application, Authorization, Bot, Interaction, Webhook, \
Message
from ..yepcord.snowflake import Snowflake
from ..yepcord.utils import b64decode
import src.yepcord.models as models

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from .models.channels import MessageCreate


Expand Down Expand Up @@ -137,7 +139,7 @@ async def wrapped(*args, **kwargs):
return wrapped


async def _getMessage(user: User, channel: Channel, message_id: int):
async def _getMessage(user: User, channel: Channel, message_id: int) -> Message:
if not channel:
raise InvalidDataErr(404, Errors.make(10003))
if not user:
Expand All @@ -147,10 +149,22 @@ async def _getMessage(user: User, channel: Channel, message_id: int):
return message


async def _getMessageWebhook(webhook_id: int, message_id: int) -> Message:
if not message_id or not webhook_id:
raise InvalidDataErr(404, Errors.make(10008))

Check warning on line 154 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L154

Added line #L154 was not covered by tests
message = await Message.get_or_none(id=message_id, webhook_id=webhook_id).select_related(*Message.DEFAULT_RELATED)
if message is None:
raise InvalidDataErr(404, Errors.make(10008))
return message


def getMessage(f):
@wraps(f)
async def wrapped(*args, **kwargs):
kwargs["message"] = await _getMessage(kwargs.get("user"), kwargs.get("channel"), kwargs.get("message"))
if "webhook" in kwargs:
kwargs["message"] = await _getMessageWebhook(kwargs["webhook"].id, kwargs.get("message"))
else:
kwargs["message"] = await _getMessage(kwargs.get("user"), kwargs.get("channel"), kwargs.get("message"))
return await f(*args, **kwargs)
return wrapped

Expand Down Expand Up @@ -264,6 +278,42 @@ async def wrapped(*args, **kwargs):
return wrapped


def getWebhook(f):
@wraps(f)
async def wrapped(*args, **kwargs):
if not (webhook_id := kwargs.get("webhook")) or not (token := kwargs.get("token")):
raise InvalidDataErr(404, Errors.make(10015))

Check warning on line 285 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L285

Added line #L285 was not covered by tests
webhook = await (Webhook.get_or_none(id=webhook_id).select_related("channel"))
if webhook is None or webhook.token != token:
raise InvalidDataErr(404, Errors.make(10015))

Check warning on line 288 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L288

Added line #L288 was not covered by tests
del kwargs["webhook"]
del kwargs["token"]
kwargs["webhook"] = webhook
return await f(*args, **kwargs)

return wrapped


def getInteractionW(f):
@wraps(f)
async def wrapped(*args, **kwargs):
if not (application_id := kwargs.get("application_id")) or not (token := kwargs.get("token")):
raise InvalidDataErr(404, Errors.make(10002))

Check warning on line 301 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L301

Added line #L301 was not covered by tests
if not (inter := await Interaction.from_token(f"int___{token}")) or inter.application.id != application_id:
raise InvalidDataErr(404, Errors.make(10002))

Check warning on line 303 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L303

Added line #L303 was not covered by tests
message = await Message.get_or_none(interaction=inter, id__gt=Snowflake.fromTimestamp(time() - 15 * 60)) \
.select_related(*Message.DEFAULT_RELATED)
if message is None:
raise InvalidDataErr(404, Errors.make(10008))
del kwargs["application_id"]
del kwargs["token"]
kwargs["interaction"] = inter
kwargs["message"] = message
return await f(*args, **kwargs)

return wrapped


getGuildWithMember = getGuild(with_member=True)
getGuildWithoutMember = getGuild(with_member=False)
getGuildWM = getGuildWithMember
Expand Down Expand Up @@ -413,6 +463,8 @@ async def processMessage(data: dict, channel: Channel, author: Optional[User], v
raise InvalidDataErr(400, Errors.make(50006))

Check warning on line 463 in src/rest_api/utils.py

View check run for this annotation

Codecov / codecov/patch

src/rest_api/utils.py#L463

Added line #L463 was not covered by tests

data_json = data.to_json()
if webhook is not None:
data_json["webhook_id"] = webhook.id
message = await models.Message.create(
id=Snowflake.makeId(), channel=channel, author=author, **data_json, **stickers_data, type=message_type,
guild=channel.guild, webhook_author=w_author)
Expand Down
4 changes: 2 additions & 2 deletions src/yepcord/models/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@

from tortoise import fields

from src.yepcord.config import Config
from ..config import Config
import src.yepcord.models as models
from src.yepcord.models._utils import SnowflakeField, Model
from ._utils import SnowflakeField, Model


class Attachment(Model):
Expand Down
6 changes: 3 additions & 3 deletions src/yepcord/models/audit_log_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

from tortoise import fields

from src.yepcord.enums import AuditLogEntryType
from src.yepcord.models._utils import SnowflakeField, Model
from src.yepcord.snowflake import Snowflake
from ..enums import AuditLogEntryType
from ._utils import SnowflakeField, Model
from ..snowflake import Snowflake

import src.yepcord.models as models

Expand Down
6 changes: 3 additions & 3 deletions src/yepcord/models/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from tortoise.expressions import Q
from tortoise.fields import SET_NULL

from src.yepcord.ctx import getCore
from src.yepcord.enums import ChannelType
from ..ctx import getCore
from ..enums import ChannelType
import src.yepcord.models as models
from src.yepcord.models._utils import SnowflakeField, Model
from ._utils import SnowflakeField, Model


class Channel(Model):
Expand Down
Loading

0 comments on commit 7a68d95

Please sign in to comment.