From 7a68d9569ab272f8b829dac1fc7d977c286a7f1d Mon Sep 17 00:00:00 2001 From: RuslanUC Date: Fri, 27 Oct 2023 17:55:17 +0300 Subject: [PATCH] add slash commands responses get/edit/delete endpoints --- src/rest_api/models/channels.py | 16 ++-- src/rest_api/models/interactions.py | 7 +- src/rest_api/routes/channels.py | 5 +- src/rest_api/routes/webhooks.py | 76 ++++++++++++++----- src/rest_api/utils.py | 60 ++++++++++++++- src/yepcord/models/attachment.py | 4 +- src/yepcord/models/audit_log_entry.py | 6 +- src/yepcord/models/channel.py | 6 +- src/yepcord/models/connected_account.py | 4 +- src/yepcord/models/emoji.py | 2 +- src/yepcord/models/frecency_settings.py | 6 +- src/yepcord/models/guild.py | 8 +- src/yepcord/models/guild_ban.py | 2 +- src/yepcord/models/guild_event.py | 6 +- src/yepcord/models/guild_member.py | 10 +-- src/yepcord/models/guild_template.py | 10 +-- src/yepcord/models/hidden_dm_channel.py | 2 +- src/yepcord/models/invite.py | 10 +-- src/yepcord/models/mfa_code.py | 2 +- src/yepcord/models/permission_overwrite.py | 2 +- src/yepcord/models/reaction.py | 2 +- src/yepcord/models/readstate.py | 4 +- src/yepcord/models/relationship.py | 6 +- src/yepcord/models/remote_auth_session.py | 2 +- src/yepcord/models/role.py | 2 +- src/yepcord/models/session.py | 4 +- src/yepcord/models/sticker.py | 2 +- src/yepcord/models/thread_member.py | 4 +- src/yepcord/models/thread_metadata.py | 4 +- src/yepcord/models/user.py | 8 +- src/yepcord/models/user_note.py | 2 +- src/yepcord/models/user_settings.py | 9 ++- src/yepcord/models/userdata.py | 2 +- ...interactions.py => test_slash_commands.py} | 59 ++++++++++++++ tests/api/test_webhooks.py | 46 +++++++++++ 35 files changed, 298 insertions(+), 102 deletions(-) rename tests/api/{test_interactions.py => test_slash_commands.py} (84%) diff --git a/src/rest_api/models/channels.py b/src/rest_api/models/channels.py index 19d984b..4cc9fdf 100644 --- a/src/rest_api/models/channels.py +++ b/src/rest_api/models/channels.py @@ -19,7 +19,7 @@ from __future__ import annotations from time import mktime -from typing import Optional, List +from typing import Optional from dateutil.parser import parse as dparse from pydantic import BaseModel, Field, field_validator @@ -237,7 +237,7 @@ class EmbedModel(BaseModel): thumbnail: Optional[EmbedImage] = None video: Optional[EmbedImage] = None author: Optional[EmbedAuthor] = None - fields: List[EmbedField] = Field(default_factory=list) + fields: list[EmbedField] = Field(default_factory=list) @field_validator("title") def validate_title(cls, value: str): @@ -302,7 +302,7 @@ def validate_author(cls, value: Optional[EmbedAuthor]): return value @field_validator("fields") - def validate_fields(cls, value: List[EmbedField]): + def validate_fields(cls, value: list[EmbedField]): if len(value) > 25: value = value[:25] return value @@ -327,8 +327,8 @@ def model_dump(self, *args, **kwargs): class MessageCreate(BaseModel): content: Optional[str] = None nonce: Optional[str] = None - embeds: List[EmbedModel] = Field(default_factory=list) - sticker_ids: List[int] = Field(default_factory=list) + embeds: list[EmbedModel] = Field(default_factory=list) + sticker_ids: list[int] = Field(default_factory=list) message_reference: Optional[MessageReferenceModel] = None flags: Optional[int] = None @@ -343,7 +343,7 @@ def validate_content(cls, value: Optional[str]): return value @field_validator("embeds") - def validate_embeds(cls, value: List[EmbedModel]): + def validate_embeds(cls, value: list[EmbedModel]): if len(value) > 10: raise InvalidDataErr(400, Errors.make(50035, {"embeds": { "code": "BASE_TYPE_BAD_LENGTH", "message": "Must be between 1 and 10 in length." @@ -351,7 +351,7 @@ def validate_embeds(cls, value: List[EmbedModel]): return value @field_validator("sticker_ids") - def validate_sticker_ids(cls, value: List[int]): + def validate_sticker_ids(cls, value: list[int]): if len(value) > 3: raise InvalidDataErr(400, Errors.make(50035, {"sticker_ids": { "code": "BASE_TYPE_BAD_LENGTH", "message": "Must be between 1 and 3 in length." @@ -392,7 +392,7 @@ def to_json(self) -> dict: # noinspection PyMethodParameters class MessageUpdate(BaseModel): content: Optional[str] = None - embeds: List[EmbedModel] = Field(default_factory=list) + embeds: list[EmbedModel] = Field(default_factory=list) @field_validator("content") def validate_content(cls, value: Optional[str]): diff --git a/src/rest_api/models/interactions.py b/src/rest_api/models/interactions.py index 24dd947..ab8e936 100644 --- a/src/rest_api/models/interactions.py +++ b/src/rest_api/models/interactions.py @@ -1,8 +1,9 @@ from typing import Literal, Optional, Any -from pydantic import BaseModel, create_model, field_validator +from pydantic import BaseModel, create_model, field_validator, Field from pydantic_core.core_schema import ValidationInfo +from .channels import EmbedModel from ...yepcord.enums import ApplicationCommandOptionType OPTION_MODELS = { @@ -54,9 +55,9 @@ class InteractionCreate(BaseModel): class InteractionRespondData(BaseModel): content: Optional[str] = None - #embeds: Optional[list] = None + embeds: list[EmbedModel] = Field(default_factory=list) flags: int = 0 - #components: Optional[list] = None # components validation are not supported now :( + #components: Optional[list] = None # components validation are not supported yet :( class InteractionRespond(BaseModel): diff --git a/src/rest_api/routes/channels.py b/src/rest_api/routes/channels.py index ae21a2e..06cb251 100644 --- a/src/rest_api/routes/channels.py +++ b/src/rest_api/routes/channels.py @@ -35,7 +35,7 @@ DMChannelDeleteEvent, MessageReactionAddEvent, MessageReactionRemoveEvent, ChannelUpdateEvent, ChannelDeleteEvent, \ WebhooksUpdateEvent, ThreadCreateEvent, ThreadMemberUpdateEvent, MessageAckEvent, GuildAuditLogEntryCreateEvent from ...yepcord.ctx import getCore, getCDNStorage, getGw -from ...yepcord.enums import GuildPermissions, MessageType, ChannelType, WebhookType, GUILD_CHANNELS +from ...yepcord.enums import GuildPermissions, MessageType, ChannelType, WebhookType, GUILD_CHANNELS, MessageFlags from ...yepcord.errors import InvalidDataErr, Errors from ...yepcord.models import User, Channel, Message, ReadState, Emoji, PermissionOverwrite, Webhook, ThreadMember, \ ThreadMetadata, AuditLogEntry, Relationship, ApplicationCommand, Integration, Bot @@ -553,8 +553,7 @@ async def create_thread(data: CreateThread, user: User, channel: Channel, messag await getGw().dispatch(ThreadCreateEvent(await thread.ds_json() | {"newly_created": True}), guild_id=channel.guild.id) await getGw().dispatch(ThreadMemberUpdateEvent(thread_member.ds_json()), guild_id=channel.guild.id) - message.thread = thread - await message.save(update_fields=["thread"]) + await message.update(thread=thread, flags=message.flags | MessageFlags.HAS_THREAD) await getCore().sendMessage(thread_message) await getCore().sendMessage(thread_create_message) diff --git a/src/rest_api/routes/webhooks.py b/src/rest_api/routes/webhooks.py index ab51cc8..9332890 100644 --- a/src/rest_api/routes/webhooks.py +++ b/src/rest_api/routes/webhooks.py @@ -15,21 +15,21 @@ You should have received a copy of the GNU Affero General Public License along with this program. If not, see . """ -from time import time +from datetime import datetime from typing import Optional from quart import Blueprint, request from quart_schema import validate_request, validate_querystring +from ..models.channels import MessageUpdate from ..models.webhooks import WebhookUpdate, WebhookMessageCreate, WebhookMessageCreateQuery from ..utils import getUser, multipleDecorators, allowWithoutUser, processMessageData, allowBots, process_stickers, \ - validate_reply, processMessage + validate_reply, processMessage, getWebhook, getMessage, getInteractionW from ...gateway.events import MessageCreateEvent, WebhooksUpdateEvent, MessageUpdateEvent from ...yepcord.ctx import getCore, getCDNStorage, getGw -from ...yepcord.enums import GuildPermissions, MessageType, MessageFlags +from ...yepcord.enums import GuildPermissions, MessageFlags from ...yepcord.errors import InvalidDataErr, Errors -from ...yepcord.models import User, Channel, Message, Interaction -from ...yepcord.snowflake import Snowflake +from ...yepcord.models import User, Channel, Message, Interaction, Webhook from ...yepcord.utils import getImage # Base path is /api/vX/webhooks @@ -39,7 +39,7 @@ @webhooks.delete("/") @webhooks.delete("//") @multipleDecorators(allowWithoutUser, allowBots, getUser) -async def api_webhooks_webhook_delete(user: Optional[User], webhook: int, token: Optional[str]=None): +async def delete_webhook(user: Optional[User], webhook: int, token: Optional[str]=None): if webhook := await getCore().getWebhook(webhook): if webhook.token != token: guild = webhook.channel.guild @@ -56,7 +56,7 @@ async def api_webhooks_webhook_delete(user: Optional[User], webhook: int, token: @webhooks.patch("/") @webhooks.patch("//") @multipleDecorators(validate_request(WebhookUpdate), allowWithoutUser, allowBots, getUser) -async def api_webhooks_webhook_patch(data: WebhookUpdate, user: Optional[User], webhook: int, token: Optional[str]=None): +async def edit_webhook(data: WebhookUpdate, user: Optional[User], webhook: int, token: Optional[str]=None): if not (webhook := await getCore().getWebhook(webhook)): raise InvalidDataErr(404, Errors.make(10015)) channel = webhook.channel @@ -93,7 +93,7 @@ async def api_webhooks_webhook_patch(data: WebhookUpdate, user: Optional[User], @webhooks.get("/") @webhooks.get("//") @multipleDecorators(allowWithoutUser, allowBots, getUser) -async def api_webhooks_webhook_get(user: Optional[User], webhook: int, token: Optional[str]=None): +async def get_webhook(user: Optional[User], webhook: int, token: Optional[str]=None): if not (webhook := await getCore().getWebhook(webhook)): raise InvalidDataErr(404, Errors.make(10015)) if webhook.token != token: @@ -107,7 +107,7 @@ async def api_webhooks_webhook_get(user: Optional[User], webhook: int, token: Op @webhooks.post("//") @validate_querystring(WebhookMessageCreateQuery) -async def api_webhooks_webhook_post(query_args: WebhookMessageCreateQuery, webhook: int, token: str): +async def post_webhook_message(query_args: WebhookMessageCreateQuery, webhook: int, token: str): if not (webhook := await getCore().getWebhook(webhook)): raise InvalidDataErr(404, Errors.make(10015)) if webhook.token != token: @@ -126,16 +126,31 @@ async def api_webhooks_webhook_post(query_args: WebhookMessageCreateQuery, webho return "", 204 +@webhooks.get("///messages/") +@multipleDecorators(getWebhook, getMessage) +async def get_webhook_message(webhook: Webhook, message: Message): + return await message.ds_json() + + +@webhooks.delete("///messages/") +@multipleDecorators(getWebhook, getMessage) +async def delete_webhook_message(webhook: Webhook, message: Message): + await message.delete() + return "", 204 + + +@webhooks.patch("///messages/") +@multipleDecorators(validate_request(MessageUpdate), getWebhook, getMessage) +async def edit_webhook_message(data: MessageUpdate, webhook: Webhook, message: Message): + await message.update(**data.to_json(), edit_timestamp=datetime.now()) + await getGw().dispatch(MessageUpdateEvent(await message.ds_json()), channel_id=webhook.channel.id) + return await message.ds_json() + + @webhooks.post("//int___") -async def interaction_followup_create(application_id: int, token: str): - if not (inter := await Interaction.from_token(f"int___{token}")) or inter.application.id != application_id: - raise InvalidDataErr(404, Errors.make(10002)) - message = await Message.get_or_none(interaction=inter, id__gt=Snowflake.fromTimestamp(time() - 15 * 60))\ - .select_related(*Message.DEFAULT_RELATED) - if message is None: - raise InvalidDataErr(404, Errors.make(10008)) - - channel = inter.channel +@getInteractionW +async def interaction_followup_create(interaction: Interaction, message: Message): + channel = interaction.channel data = await request.get_json() data, attachments = await processMessageData(data, channel) @@ -151,8 +166,29 @@ async def interaction_followup_create(application_id: int, token: str): message_obj = await message.ds_json() if message.ephemeral: - await getGw().dispatch(MessageUpdateEvent(message_obj), users=[inter.user.id, inter.application.id]) + await getGw().dispatch(MessageUpdateEvent(message_obj), users=[interaction.user.id, interaction.application.id]) else: - await getGw().dispatch(MessageUpdateEvent(message_obj), channel_id=inter.channel.id) + await getGw().dispatch(MessageUpdateEvent(message_obj), channel_id=interaction.channel.id) return message_obj + + +@webhooks.get("//int___/messages/") +@getInteractionW +async def get_interaction_message(interaction: Interaction, message: Message): + return await message.ds_json() + + +@webhooks.delete("//int___/messages/") +@getInteractionW +async def delete_interaction_message(interaction: Interaction, message: Message): + await message.delete() + return "", 204 + + +@webhooks.patch("//int___/messages/") +@multipleDecorators(validate_request(MessageUpdate), getInteractionW) +async def edit_interaction_message(data: MessageUpdate, interaction: Interaction, message: Message): + await message.update(**data.to_json(), edit_timestamp=datetime.now()) + await getGw().dispatch(MessageUpdateEvent(await message.ds_json()), channel_id=interaction.channel.id) + return await message.ds_json() diff --git a/src/rest_api/utils.py b/src/rest_api/utils.py index d4fe676..07ffcf4 100644 --- a/src/rest_api/utils.py +++ b/src/rest_api/utils.py @@ -18,6 +18,7 @@ from __future__ import annotations from functools import wraps from json import loads +from time import time from typing import Optional, Union, TYPE_CHECKING from PIL import Image @@ -29,12 +30,13 @@ from ..yepcord.ctx import Ctx, getCore, getCDNStorage from ..yepcord.enums import MessageType from ..yepcord.errors import Errors, InvalidDataErr -from ..yepcord.models import Session, User, Channel, Attachment, Application, Authorization, Bot, Interaction, Webhook +from ..yepcord.models import Session, User, Channel, Attachment, Application, Authorization, Bot, Interaction, Webhook, \ + Message from ..yepcord.snowflake import Snowflake from ..yepcord.utils import b64decode import src.yepcord.models as models -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from .models.channels import MessageCreate @@ -137,7 +139,7 @@ async def wrapped(*args, **kwargs): return wrapped -async def _getMessage(user: User, channel: Channel, message_id: int): +async def _getMessage(user: User, channel: Channel, message_id: int) -> Message: if not channel: raise InvalidDataErr(404, Errors.make(10003)) if not user: @@ -147,10 +149,22 @@ async def _getMessage(user: User, channel: Channel, message_id: int): return message +async def _getMessageWebhook(webhook_id: int, message_id: int) -> Message: + if not message_id or not webhook_id: + raise InvalidDataErr(404, Errors.make(10008)) + message = await Message.get_or_none(id=message_id, webhook_id=webhook_id).select_related(*Message.DEFAULT_RELATED) + if message is None: + raise InvalidDataErr(404, Errors.make(10008)) + return message + + def getMessage(f): @wraps(f) async def wrapped(*args, **kwargs): - kwargs["message"] = await _getMessage(kwargs.get("user"), kwargs.get("channel"), kwargs.get("message")) + if "webhook" in kwargs: + kwargs["message"] = await _getMessageWebhook(kwargs["webhook"].id, kwargs.get("message")) + else: + kwargs["message"] = await _getMessage(kwargs.get("user"), kwargs.get("channel"), kwargs.get("message")) return await f(*args, **kwargs) return wrapped @@ -264,6 +278,42 @@ async def wrapped(*args, **kwargs): return wrapped +def getWebhook(f): + @wraps(f) + async def wrapped(*args, **kwargs): + if not (webhook_id := kwargs.get("webhook")) or not (token := kwargs.get("token")): + raise InvalidDataErr(404, Errors.make(10015)) + webhook = await (Webhook.get_or_none(id=webhook_id).select_related("channel")) + if webhook is None or webhook.token != token: + raise InvalidDataErr(404, Errors.make(10015)) + del kwargs["webhook"] + del kwargs["token"] + kwargs["webhook"] = webhook + return await f(*args, **kwargs) + + return wrapped + + +def getInteractionW(f): + @wraps(f) + async def wrapped(*args, **kwargs): + if not (application_id := kwargs.get("application_id")) or not (token := kwargs.get("token")): + raise InvalidDataErr(404, Errors.make(10002)) + if not (inter := await Interaction.from_token(f"int___{token}")) or inter.application.id != application_id: + raise InvalidDataErr(404, Errors.make(10002)) + message = await Message.get_or_none(interaction=inter, id__gt=Snowflake.fromTimestamp(time() - 15 * 60)) \ + .select_related(*Message.DEFAULT_RELATED) + if message is None: + raise InvalidDataErr(404, Errors.make(10008)) + del kwargs["application_id"] + del kwargs["token"] + kwargs["interaction"] = inter + kwargs["message"] = message + return await f(*args, **kwargs) + + return wrapped + + getGuildWithMember = getGuild(with_member=True) getGuildWithoutMember = getGuild(with_member=False) getGuildWM = getGuildWithMember @@ -413,6 +463,8 @@ async def processMessage(data: dict, channel: Channel, author: Optional[User], v raise InvalidDataErr(400, Errors.make(50006)) data_json = data.to_json() + if webhook is not None: + data_json["webhook_id"] = webhook.id message = await models.Message.create( id=Snowflake.makeId(), channel=channel, author=author, **data_json, **stickers_data, type=message_type, guild=channel.guild, webhook_author=w_author) diff --git a/src/yepcord/models/attachment.py b/src/yepcord/models/attachment.py index 2505898..160fc80 100644 --- a/src/yepcord/models/attachment.py +++ b/src/yepcord/models/attachment.py @@ -20,9 +20,9 @@ from tortoise import fields -from src.yepcord.config import Config +from ..config import Config import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Attachment(Model): diff --git a/src/yepcord/models/audit_log_entry.py b/src/yepcord/models/audit_log_entry.py index 4ab8096..3f9205e 100644 --- a/src/yepcord/models/audit_log_entry.py +++ b/src/yepcord/models/audit_log_entry.py @@ -22,9 +22,9 @@ from tortoise import fields -from src.yepcord.enums import AuditLogEntryType -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ..enums import AuditLogEntryType +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake import src.yepcord.models as models diff --git a/src/yepcord/models/channel.py b/src/yepcord/models/channel.py index e93eb4a..a61c672 100644 --- a/src/yepcord/models/channel.py +++ b/src/yepcord/models/channel.py @@ -24,10 +24,10 @@ from tortoise.expressions import Q from tortoise.fields import SET_NULL -from src.yepcord.ctx import getCore -from src.yepcord.enums import ChannelType +from ..ctx import getCore +from ..enums import ChannelType import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Channel(Model): diff --git a/src/yepcord/models/connected_account.py b/src/yepcord/models/connected_account.py index cc69dc4..6d251dc 100644 --- a/src/yepcord/models/connected_account.py +++ b/src/yepcord/models/connected_account.py @@ -20,8 +20,8 @@ from tortoise import fields -from src.yepcord.models._utils import ChoicesValidator, SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ._utils import ChoicesValidator, SnowflakeField, Model +from ..snowflake import Snowflake import src.yepcord.models as models diff --git a/src/yepcord/models/emoji.py b/src/yepcord/models/emoji.py index befd69f..71c0fa6 100644 --- a/src/yepcord/models/emoji.py +++ b/src/yepcord/models/emoji.py @@ -21,7 +21,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Emoji(Model): diff --git a/src/yepcord/models/frecency_settings.py b/src/yepcord/models/frecency_settings.py index 68a1d7a..35d36f8 100644 --- a/src/yepcord/models/frecency_settings.py +++ b/src/yepcord/models/frecency_settings.py @@ -19,9 +19,9 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.proto import FrecencyUserSettings -from src.yepcord.utils import b64decode +from ._utils import SnowflakeField, Model +from ..proto import FrecencyUserSettings +from ..utils import b64decode class FrecencySettings(Model): diff --git a/src/yepcord/models/guild.py b/src/yepcord/models/guild.py index 770feaa..904f7af 100644 --- a/src/yepcord/models/guild.py +++ b/src/yepcord/models/guild.py @@ -21,9 +21,10 @@ from tortoise import fields -from src.yepcord.ctx import getCore +from ..ctx import getCore import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ..enums import Locales +from ._utils import SnowflakeField, Model, ChoicesValidator class Guild(Model): @@ -47,7 +48,8 @@ class Guild(Model): system_channel_flags: int = fields.BigIntField(default=0) max_members: int = fields.IntField(default=100) vanity_url_code: Optional[str] = fields.CharField(max_length=64, null=True, default=None) - preferred_locale: str = fields.CharField(max_length=8, default="en-US") + preferred_locale: str = fields.CharField(max_length=8, default="en-US", + validators=[ChoicesValidator(Locales.values_set())]) premium_progress_bar_enabled: bool = fields.BooleanField(default=False) nsfw: bool = fields.BooleanField(default=False) nsfw_level: int = fields.IntField(default=0) diff --git a/src/yepcord/models/guild_ban.py b/src/yepcord/models/guild_ban.py index bae784e..23e8dee 100644 --- a/src/yepcord/models/guild_ban.py +++ b/src/yepcord/models/guild_ban.py @@ -19,7 +19,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class GuildBan(Model): diff --git a/src/yepcord/models/guild_event.py b/src/yepcord/models/guild_event.py index 0d2962d..998cebc 100644 --- a/src/yepcord/models/guild_event.py +++ b/src/yepcord/models/guild_event.py @@ -22,9 +22,9 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.ctx import getCore -from src.yepcord.enums import ScheduledEventEntityType -from src.yepcord.models._utils import SnowflakeField, Model +from ..ctx import getCore +from ..enums import ScheduledEventEntityType +from ._utils import SnowflakeField, Model class GuildEvent(Model): diff --git a/src/yepcord/models/guild_member.py b/src/yepcord/models/guild_member.py index b51c9f6..44f2d68 100644 --- a/src/yepcord/models/guild_member.py +++ b/src/yepcord/models/guild_member.py @@ -23,11 +23,11 @@ from tortoise import fields -from src.yepcord.ctx import getCore -from src.yepcord.enums import GuildPermissions -from src.yepcord.errors import InvalidDataErr, Errors -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ..ctx import getCore +from ..enums import GuildPermissions +from ..errors import InvalidDataErr, Errors +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake import src.yepcord.models as models diff --git a/src/yepcord/models/guild_template.py b/src/yepcord/models/guild_template.py index 07a757c..9c23e36 100644 --- a/src/yepcord/models/guild_template.py +++ b/src/yepcord/models/guild_template.py @@ -21,11 +21,11 @@ from tortoise import fields -from src.yepcord.ctx import getCore -from src.yepcord.enums import ChannelType -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake -from src.yepcord.utils import b64encode, int_size, NoneType +from ..ctx import getCore +from ..enums import ChannelType +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake +from ..utils import b64encode, int_size, NoneType import src.yepcord.models as models diff --git a/src/yepcord/models/hidden_dm_channel.py b/src/yepcord/models/hidden_dm_channel.py index 9593c94..9802abb 100644 --- a/src/yepcord/models/hidden_dm_channel.py +++ b/src/yepcord/models/hidden_dm_channel.py @@ -19,7 +19,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class HiddenDmChannel(Model): diff --git a/src/yepcord/models/invite.py b/src/yepcord/models/invite.py index 07b1bda..b62ff0f 100644 --- a/src/yepcord/models/invite.py +++ b/src/yepcord/models/invite.py @@ -21,11 +21,11 @@ from tortoise import fields -from src.yepcord.ctx import getCore -from src.yepcord.enums import ChannelType -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake -from src.yepcord.utils import b64encode, int_size +from ..ctx import getCore +from ..enums import ChannelType +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake +from ..utils import b64encode, int_size import src.yepcord.models as models diff --git a/src/yepcord/models/mfa_code.py b/src/yepcord/models/mfa_code.py index 5ab3d7c..6140b5b 100644 --- a/src/yepcord/models/mfa_code.py +++ b/src/yepcord/models/mfa_code.py @@ -19,7 +19,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class MfaCode(Model): diff --git a/src/yepcord/models/permission_overwrite.py b/src/yepcord/models/permission_overwrite.py index da75cde..27f7a98 100644 --- a/src/yepcord/models/permission_overwrite.py +++ b/src/yepcord/models/permission_overwrite.py @@ -19,7 +19,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class PermissionOverwrite(Model): diff --git a/src/yepcord/models/reaction.py b/src/yepcord/models/reaction.py index 02771a4..ec86d99 100644 --- a/src/yepcord/models/reaction.py +++ b/src/yepcord/models/reaction.py @@ -21,7 +21,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Reaction(Model): diff --git a/src/yepcord/models/readstate.py b/src/yepcord/models/readstate.py index fe6092d..69c247d 100644 --- a/src/yepcord/models/readstate.py +++ b/src/yepcord/models/readstate.py @@ -18,9 +18,9 @@ from tortoise import fields -from src.yepcord.ctx import getCore +from ..ctx import getCore import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class ReadState(Model): diff --git a/src/yepcord/models/relationship.py b/src/yepcord/models/relationship.py index 7094337..1842f5d 100644 --- a/src/yepcord/models/relationship.py +++ b/src/yepcord/models/relationship.py @@ -22,9 +22,9 @@ from tortoise.expressions import Q from tortoise import fields -from src.yepcord.enums import RelationshipType, RelTypeDiscord -from src.yepcord.errors import InvalidDataErr, Errors -from src.yepcord.models._utils import ChoicesValidator, SnowflakeField, Model +from ..enums import RelationshipType, RelTypeDiscord +from ..errors import InvalidDataErr, Errors +from ._utils import ChoicesValidator, SnowflakeField, Model import src.yepcord.models as models diff --git a/src/yepcord/models/remote_auth_session.py b/src/yepcord/models/remote_auth_session.py index 4317390..65cbc99 100644 --- a/src/yepcord/models/remote_auth_session.py +++ b/src/yepcord/models/remote_auth_session.py @@ -22,7 +22,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model def time_plus_150s(): diff --git a/src/yepcord/models/role.py b/src/yepcord/models/role.py index 080dd9d..c64ac74 100644 --- a/src/yepcord/models/role.py +++ b/src/yepcord/models/role.py @@ -21,7 +21,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Role(Model): diff --git a/src/yepcord/models/session.py b/src/yepcord/models/session.py index 0aa6eb4..35807e1 100644 --- a/src/yepcord/models/session.py +++ b/src/yepcord/models/session.py @@ -22,8 +22,8 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.utils import b64encode, int_size, b64decode +from ._utils import SnowflakeField, Model +from ..utils import b64encode, int_size, b64decode class Session(Model): diff --git a/src/yepcord/models/sticker.py b/src/yepcord/models/sticker.py index ded8b29..b128cd8 100644 --- a/src/yepcord/models/sticker.py +++ b/src/yepcord/models/sticker.py @@ -21,7 +21,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class Sticker(Model): diff --git a/src/yepcord/models/thread_member.py b/src/yepcord/models/thread_member.py index ef0e0da..690839b 100644 --- a/src/yepcord/models/thread_member.py +++ b/src/yepcord/models/thread_member.py @@ -22,8 +22,8 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake class ThreadMember(Model): diff --git a/src/yepcord/models/thread_metadata.py b/src/yepcord/models/thread_metadata.py index a767150..c1ff166 100644 --- a/src/yepcord/models/thread_metadata.py +++ b/src/yepcord/models/thread_metadata.py @@ -21,8 +21,8 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ._utils import SnowflakeField, Model +from ..snowflake import Snowflake class ThreadMetadata(Model): diff --git a/src/yepcord/models/user.py b/src/yepcord/models/user.py index 0ac55bc..50e7861 100644 --- a/src/yepcord/models/user.py +++ b/src/yepcord/models/user.py @@ -24,10 +24,10 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.classes.other import MFA -from src.yepcord.ctx import getCore -from src.yepcord.models._utils import SnowflakeField, Model -from src.yepcord.snowflake import Snowflake +from ._utils import SnowflakeField, Model +from ..classes.other import MFA +from ..ctx import getCore +from ..snowflake import Snowflake class UserUtils: diff --git a/src/yepcord/models/user_note.py b/src/yepcord/models/user_note.py index fb2555e..6f3a031 100644 --- a/src/yepcord/models/user_note.py +++ b/src/yepcord/models/user_note.py @@ -19,7 +19,7 @@ from tortoise import fields import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class UserNote(Model): diff --git a/src/yepcord/models/user_settings.py b/src/yepcord/models/user_settings.py index e285f8e..5aea38a 100644 --- a/src/yepcord/models/user_settings.py +++ b/src/yepcord/models/user_settings.py @@ -24,12 +24,13 @@ from protobuf_to_dict import protobuf_to_dict from tortoise import fields -from src.yepcord.models._utils import ChoicesValidator, SnowflakeField, Model +from ..enums import Locales +from ._utils import ChoicesValidator, SnowflakeField, Model import src.yepcord.models as models -from src.yepcord.proto import PreloadedUserSettings, Versions, UserContentSettings, VoiceAndVideoSettings, \ +from ..proto import PreloadedUserSettings, Versions, UserContentSettings, VoiceAndVideoSettings, \ TextAndImagesSettings, PrivacySettings, StatusSettings, CustomStatus, LocalizationSettings, AppearanceSettings, \ GuildFolders, GuildFolder, Theme -from src.yepcord.utils import dict_get, freeze, unfreeze +from ..utils import dict_get, freeze, unfreeze class UserSettings(Model): @@ -68,7 +69,7 @@ class UserSettings(Model): friend_discovery_flags: int = fields.IntField(default=0) animate_stickers: int = fields.IntField(default=0) theme: str = fields.CharField(max_length=8, default="dark", validators=[ChoicesValidator({"dark", "light"})]) - locale: str = fields.CharField(max_length=8, default="en-US") # TODO: add choices validator + locale: str = fields.CharField(max_length=8, default="en-US", validators=[ChoicesValidator(Locales.values_set())]) mfa: str = fields.CharField(max_length=64, null=True, default=None) render_spoilers: str = fields.CharField(max_length=16, default="ON_CLICK", validators=[ChoicesValidator({"ALWAYS", "ON_CLICK", "IF_MODERATOR"})]) diff --git a/src/yepcord/models/userdata.py b/src/yepcord/models/userdata.py index 440da78..12fd8fc 100644 --- a/src/yepcord/models/userdata.py +++ b/src/yepcord/models/userdata.py @@ -24,7 +24,7 @@ from tortoise.validators import MinValueValidator, MaxValueValidator import src.yepcord.models as models -from src.yepcord.models._utils import SnowflakeField, Model +from ._utils import SnowflakeField, Model class UserData(Model): diff --git a/tests/api/test_interactions.py b/tests/api/test_slash_commands.py similarity index 84% rename from tests/api/test_interactions.py rename to tests/api/test_slash_commands.py index f270e05..0b0df96 100644 --- a/tests/api/test_interactions.py +++ b/tests/api/test_slash_commands.py @@ -317,3 +317,62 @@ async def _resolved(type_: int, name: str, value: str, resolved_types: set[str]= await _resolved(7, "channel", channel2["id"]) await _resolved(8, "role", str(Snowflake.makeId())) await _resolved(8, "role", guild2["id"]) + + +@pt.mark.asyncio +async def test_get_update_delete_slash_command_response(): + client: TestClientType = app.test_client() + user = (await create_users(client, 1))[0] + guild = await create_guild(client, user, "Test") + application = await create_application(client, user, "testApp") + await add_bot_to_guild(client, user, guild, application) + headers = {"Authorization": user["token"]} + bot_token_ = await bot_token(client, user, application) + bot_headers = {"Authorization": f"Bot {bot_token_}"} + channel = [channel for channel in guild["channels"] if channel["type"] == ChannelType.GUILD_TEXT][0] + + resp = await client.post(f"/api/v9/applications/{application['id']}/commands", headers=bot_headers, json={ + "type": 1, "name": "test", "description": "test"}) + assert resp.status_code == 200 + command = await resp.get_json() + + payload = generate_slash_command_payload(application, guild, channel, command, []) + async with gateway_cm(gw_app): + gw_client = gw_app.test_client() + cl = GatewayClient(bot_token_) + async with gw_client.websocket('/') as ws: + event_coro = await cl.awaitable_wait_for(GatewayOp.DISPATCH, "INTERACTION_CREATE") + await cl.run(ws) + + resp = await client.post(f"/api/v9/interactions", headers=headers, form={"payload_json": dumps(payload)}) + assert resp.status_code == 204 + + event = await event_coro + + int_id = event["id"] + int_token = event["token"] + + resp = await client.post(f"/api/v9/interactions/{int_id}/{int_token}/callback", + json={"type": 4, "data": {"content": "test", "flags": 64}}) + assert resp.status_code == 204 + + resp = await client.get(f"/api/v9/webhooks/{application['id']}/{int_token}/messages/@original") + assert resp.status_code == 200, await resp.get_json() + json = await resp.get_json() + assert json["application_id"] == application["id"] + assert json["webhook_id"] == int_id + assert json["content"] == "test" + assert json["flags"] == 64 + + resp = await client.patch(f"/api/v9/webhooks/{application['id']}/{int_token}/messages/@original", + json={"content": "changed"}) + assert resp.status_code == 200 + json = await resp.get_json() + assert json["content"] == "changed" + assert json["flags"] == 64 + + resp = await client.delete(f"/api/v9/webhooks/{application['id']}/{int_token}/messages/@original") + assert resp.status_code == 204 + + resp = await client.get(f"/api/v9/webhooks/{application['id']}/{int_token}/messages/@original") + assert resp.status_code == 404 diff --git a/tests/api/test_webhooks.py b/tests/api/test_webhooks.py index 48e4ce8..a980e29 100644 --- a/tests/api/test_webhooks.py +++ b/tests/api/test_webhooks.py @@ -268,3 +268,49 @@ async def test_get_webhook_fail(): resp = await client.get(f"/api/v9/webhooks/{Snowflake.makeId()}/wrong-token") assert resp.status_code == 404 + + +@pt.mark.asyncio +async def test_get_delete_webhook_message(): + client: TestClientType = app.test_client() + user = (await create_users(client, 1))[0] + guild = await create_guild(client, user, "Test Guild") + channel = await create_guild_channel(client, user, guild, 'test_text_channel') + webhook = await create_webhook(client, user, channel["id"]) + + resp = await client.post(f"/api/webhooks/{webhook['id']}/{webhook['token']}?wait=true", + json={'content': 'test message sent from webhook'}) + assert resp.status_code == 200 + message = await resp.get_json() + + resp = await client.get(f"/api/webhooks/{webhook['id']}/{webhook['token']}/messages/{message['id']}") + assert resp.status_code == 200 + json = await resp.get_json() + assert json == message + + resp = await client.delete(f"/api/webhooks/{webhook['id']}/{webhook['token']}/messages/{message['id']}") + assert resp.status_code == 204 + + resp = await client.get(f"/api/webhooks/{webhook['id']}/{webhook['token']}/messages/{message['id']}") + assert resp.status_code == 404 + + +@pt.mark.asyncio +async def test_edit_webhook_message(): + client: TestClientType = app.test_client() + user = (await create_users(client, 1))[0] + guild = await create_guild(client, user, "Test Guild") + channel = await create_guild_channel(client, user, guild, 'test_text_channel') + webhook = await create_webhook(client, user, channel["id"]) + + resp = await client.post(f"/api/webhooks/{webhook['id']}/{webhook['token']}?wait=true", + json={'content': 'test message sent from webhook'}) + assert resp.status_code == 200 + message = await resp.get_json() + + resp = await client.patch(f"/api/webhooks/{webhook['id']}/{webhook['token']}/messages/{message['id']}", + json={"content": "test changed"}) + assert resp.status_code == 200 + json = await resp.get_json() + assert json["id"] == message["id"] + assert json["content"] == "test changed"