diff --git a/bot.py b/bot.py index 9f3de008a1..2037bf7074 100644 --- a/bot.py +++ b/bot.py @@ -1309,9 +1309,18 @@ async def get_contexts(self, message, *, cls=commands.Context): # Check if a snippet is being called. # This needs to be done before checking for aliases since # snippets can have multiple words. + snippet_invoked = False try: # Use removeprefix once PY3.9+ - snippet_text = self.snippets[message.content[len(invoked_prefix) :]] + snippet_data = self.snippets[message.content[len(invoked_prefix) :]] + # Extract text from snippet (handle both old string format and new dict format) + if isinstance(snippet_data, str): + snippet_text = snippet_data + elif isinstance(snippet_data, dict): + snippet_text = snippet_data.get("text", "") + else: + snippet_text = None + snippet_invoked = True except KeyError: snippet_text = None @@ -1327,9 +1336,43 @@ async def get_contexts(self, message, *, cls=commands.Context): for alias in aliases: command = None try: - snippet_text = self.snippets[alias] + snippet_data = self.snippets[alias] + # Extract text from snippet (handle both old string format and new dict format) + if isinstance(snippet_data, str): + snippet_text = snippet_data + elif isinstance(snippet_data, dict): + snippet_text = snippet_data.get("text", "") + # Download attachment if present + if snippet_data.get("file_id"): + try: + import io + + file_data, metadata = await self.api.download_snippet_attachment( + snippet_data["file_id"] + ) + + class AttachmentWrapper: + def __init__(self, file_data, metadata): + self.file_data = file_data + self.id = 0 + self.url = f"attachment://{metadata['filename']}" + self.filename = metadata["filename"] + self.size = metadata["length"] + self.width = None + + async def to_file(self): + return discord.File( + io.BytesIO(self.file_data), filename=self.filename + ) + + message.attachments = [AttachmentWrapper(file_data, metadata)] + except Exception as e: + logger.warning("Failed to download snippet attachment: %s", e) + else: + snippet_text = None except KeyError: command_invocation_text = alias + snippet_text = None else: command = self._get_snippet_command() command_invocation_text = f"{invoked_prefix}{command} {snippet_text}" @@ -1346,11 +1389,38 @@ async def get_contexts(self, message, *, cls=commands.Context): if snippet_text is not None: # Process snippets + snippet_name = message.content[len(invoked_prefix) :] + snippet_data = self.snippets.get(snippet_name) + # Download attachment if present + if isinstance(snippet_data, dict) and snippet_data.get("file_id"): + try: + import io + + file_data, metadata = await self.api.download_snippet_attachment(snippet_data["file_id"]) + + # Create a list-like object that mimics message.attachments + class AttachmentWrapper: + def __init__(self, file_data, metadata): + self.file_data = file_data + self.id = 0 + self.url = f"attachment://{metadata['filename']}" + self.filename = metadata["filename"] + self.size = metadata["length"] + self.width = None + + async def to_file(self): + return discord.File(io.BytesIO(self.file_data), filename=self.filename) + + message.attachments = [AttachmentWrapper(file_data, metadata)] + except Exception as e: + logger.warning("Failed to download snippet attachment: %s", e) ctx.command = self._get_snippet_command() reply_view = StringView(f"{invoked_prefix}{ctx.command} {snippet_text}") discord.utils.find(reply_view.skip_string, prefixes) ctx.invoked_with = reply_view.get_word().lower() ctx.view = reply_view + # Mark that a snippet was invoked so we can delete the command message + ctx.snippet_invoked = snippet_invoked else: ctx.command = self.all_commands.get(invoker) ctx.invoked_with = invoker diff --git a/cogs/modmail.py b/cogs/modmail.py index 0e39da920c..2e19c6e788 100644 --- a/cogs/modmail.py +++ b/cogs/modmail.py @@ -112,6 +112,30 @@ def _resolve_user(self, user_str): return int(match.group(1)) return None + def _get_snippet_text(self, snippet_data) -> str: + """ + Extract text from a snippet, handling both old string format and new dict format. + + Parameters + ---------- + snippet_data : str or dict + The snippet data (either old string format or new dict format). + + Returns + ------- + str + The text content of the snippet. + """ + if isinstance(snippet_data, str): + return snippet_data + elif isinstance(snippet_data, dict): + return snippet_data.get("text", "") + return "" + + def _has_snippet_attachment(self, snippet_data) -> bool: + """Check if a snippet has an attachment.""" + return isinstance(snippet_data, dict) and bool(snippet_data.get("file_id")) + @commands.command() @trigger_typing @checks.has_permissions(PermissionLevel.OWNER) @@ -246,10 +270,17 @@ async def snippet(self, ctx, *, name: str.lower = None): if snippet_name is None: embed = create_not_found_embed(name, self.bot.snippets.keys(), "Snippet") else: - val = self.bot.snippets[snippet_name] + snippet_data = self.bot.snippets[snippet_name] + snippet_text = self._get_snippet_text(snippet_data) + has_attachment = self._has_snippet_attachment(snippet_data) + + description = snippet_text if snippet_text else "(No text content)" + if has_attachment: + description += "\n\nšŸ“Ž *This snippet has an attachment.*" + embed = discord.Embed( title=f'Snippet - "{snippet_name}":', - description=val, + description=description, color=self.bot.main_color, ) return await ctx.send(embed=embed) @@ -270,10 +301,15 @@ async def snippet(self, ctx, *, name: str.lower = None): for embed in embeds: embed.set_author(name="Snippets", icon_url=self.bot.get_guild_icon(guild=ctx.guild, size=128)) - for i, snippet in enumerate(sorted(self.bot.snippets.items())): - embeds[i // 10].add_field( - name=snippet[0], value=return_or_truncate(snippet[1], 350), inline=False - ) + for i, (snippet_name, snippet_data) in enumerate(sorted(self.bot.snippets.items())): + snippet_text = self._get_snippet_text(snippet_data) + has_attachment = self._has_snippet_attachment(snippet_data) + + display_value = return_or_truncate(snippet_text, 350) if snippet_text else "(No text)" + if has_attachment: + display_value = "šŸ“Ž " + display_value + + embeds[i // 10].add_field(name=snippet_name, value=display_value, inline=False) session = EmbedPaginatorSession(ctx, *embeds) await session.run() @@ -288,10 +324,18 @@ async def snippet_raw(self, ctx, *, name: str.lower): if snippet_name is None: embed = create_not_found_embed(name, self.bot.snippets.keys(), "Snippet") else: - val = truncate(escape_code_block(self.bot.snippets[snippet_name]), 2048 - 7) + snippet_data = self.bot.snippets[snippet_name] + snippet_text = self._get_snippet_text(snippet_data) + has_attachment = self._has_snippet_attachment(snippet_data) + + val = truncate(escape_code_block(snippet_text), 2048 - 7) if snippet_text else "(No text content)" + description = f"```\n{val}```" + if has_attachment: + description += "\n\nšŸ“Ž *This snippet has an attachment.*" + embed = discord.Embed( title=f'Raw snippet - "{snippet_name}":', - description=f"```\n{val}```", + description=description, color=self.bot.main_color, ) @@ -299,9 +343,9 @@ async def snippet_raw(self, ctx, *, name: str.lower): @snippet.command(name="add", aliases=["create", "make"]) @checks.has_permissions(PermissionLevel.SUPPORTER) - async def snippet_add(self, ctx, name: str.lower, *, value: commands.clean_content): + async def snippet_add(self, ctx, name: str.lower, *, value: commands.clean_content = None): """ - Add a snippet. + Add a snippet with an optional attachment. Simply to add a snippet, do: ``` {prefix}snippet add hey hello there :) @@ -311,6 +355,8 @@ async def snippet_add(self, ctx, name: str.lower, *, value: commands.clean_conte To add a multi-word snippet name, use quotes: ``` {prefix}snippet add "two word" this is a two word snippet. ``` + + You can also attach a file (max 10 MB) to include with the snippet. """ if self.bot.get_command(name): embed = discord.Embed( @@ -343,14 +389,122 @@ async def snippet_add(self, ctx, name: str.lower, *, value: commands.clean_conte ) return await ctx.send(embed=embed) - self.bot.snippets[name] = value + # Handle optional attachment + file_id = None + attachment_info = None + if ctx.message.attachments: + attachment = ctx.message.attachments[0] + + # Validate file size + max_size_mb = self.bot.config.get("snippet_attachment_max_size") + max_size_bytes = max_size_mb * 1024 * 1024 + if attachment.size > max_size_bytes: + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description=f"Attachment exceeds the maximum file size of {max_size_mb} MB. " + f"Your file is {attachment.size / (1024 * 1024):.2f} MB.", + ) + return await ctx.send(embed=embed) + + # Confirmation for attachments 2MB or higher + confirm_msg = None + if attachment.size >= 2 * 1024 * 1024: + view = discord.ui.View(timeout=30) + confirmed = None + + async def confirm_callback(interaction: discord.Interaction): + nonlocal confirmed + if interaction.user.id != ctx.author.id: + return await interaction.response.send_message( + "Only the command author can confirm.", ephemeral=True + ) + confirmed = True + await interaction.response.defer() + view.stop() + + async def cancel_callback(interaction: discord.Interaction): + nonlocal confirmed + if interaction.user.id != ctx.author.id: + return await interaction.response.send_message( + "Only the command author can cancel.", ephemeral=True + ) + confirmed = False + await interaction.response.edit_message( + content="āŒ Cancelled. Snippet not created.", view=None, embed=None + ) + view.stop() + + confirm_button = discord.ui.Button(label="āœ“ Confirm", style=discord.ButtonStyle.green) + cancel_button = discord.ui.Button(label="āœ— Cancel", style=discord.ButtonStyle.red) + confirm_button.callback = confirm_callback + cancel_button.callback = cancel_callback + view.add_item(confirm_button) + view.add_item(cancel_button) + + embed = discord.Embed( + title="Confirm Large Attachment", + description=f"The attachment is {attachment.size / (1024 * 1024):.2f} MB (≄2 MB).\n" + f"Do you want to create the snippet `{name}` with this attachment?", + color=self.bot.main_color, + ) + confirm_msg = await ctx.send(embed=embed, view=view) + await view.wait() + + if confirmed is None or not confirmed: + if confirmed is None: + await confirm_msg.edit( + content="ā±ļø Timed out. Snippet not created.", view=None, embed=None + ) + return + + # Download and upload to GridFS + try: + file_data = await attachment.read() + file_id = await self.bot.api.upload_snippet_attachment( + file_data, + attachment.filename, + attachment.content_type or "application/octet-stream", + ) + attachment_info = attachment.filename + except Exception as e: + logger.error("Failed to upload snippet attachment: %s", e) + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description="Failed to upload attachment. Please try again.", + ) + return await ctx.send(embed=embed) + + # Require at least text or attachment + if not value and not file_id: + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description="You must provide either text content or an attachment for the snippet.", + ) + return await ctx.send(embed=embed) + + # Store snippet as dict with text and optional file_id + snippet_data = {"text": value or ""} + if file_id: + snippet_data["file_id"] = file_id + + self.bot.snippets[name] = snippet_data await self.bot.config.update() + description = "Successfully created snippet." + if attachment_info: + description += f"\nšŸ“Ž Attachment: `{attachment_info}`" + embed = discord.Embed( title="Added snippet", color=self.bot.main_color, - description="Successfully created snippet.", + description=description, ) + + if confirm_msg: + return await confirm_msg.edit(content=None, embed=embed, view=None) return await ctx.send(embed=embed) def _fix_aliases(self, snippet_being_deleted: str) -> Tuple[List[str]]: @@ -408,6 +562,14 @@ def _fix_aliases(self, snippet_being_deleted: str) -> Tuple[List[str]]: async def snippet_remove(self, ctx, *, name: str.lower): """Remove a snippet.""" if name in self.bot.snippets: + # Delete GridFS attachment if present + snippet_data = self.bot.snippets[name] + if isinstance(snippet_data, dict) and snippet_data.get("file_id"): + try: + await self.bot.api.delete_snippet_attachment(snippet_data["file_id"]) + except Exception as e: + logger.warning("Failed to delete snippet attachment for %s: %s", name, e) + deleted_aliases, edited_aliases = self._fix_aliases(name) deleted_aliases_string = ",".join(f"`{alias}`" for alias in deleted_aliases) @@ -459,22 +621,103 @@ async def snippet_remove(self, ctx, *, name: str.lower): @snippet.command(name="edit") @checks.has_permissions(PermissionLevel.SUPPORTER) - async def snippet_edit(self, ctx, name: str.lower, *, value): + async def snippet_edit(self, ctx, name: str.lower, *, value: commands.clean_content = None): """ - Edit a snippet. + Edit a snippet's text and/or attachment. To edit a multi-word snippet name, use quotes: ``` {prefix}snippet edit "two word" this is a new two word snippet. ``` + + Attach a new file to replace the existing attachment. + Provide text without attachment to keep the existing attachment. """ if name in self.bot.snippets: - self.bot.snippets[name] = value + snippet_data = self.bot.snippets[name] + + # Handle old string format + if isinstance(snippet_data, str): + old_text = snippet_data + old_file_id = None + else: + old_text = snippet_data.get("text", "") + old_file_id = snippet_data.get("file_id") + + # Handle new attachment if provided + new_file_id = old_file_id + attachment_info = None + if ctx.message.attachments: + attachment = ctx.message.attachments[0] + + # Validate file size + max_size_mb = self.bot.config.get("snippet_attachment_max_size") + max_size_bytes = max_size_mb * 1024 * 1024 + if attachment.size > max_size_bytes: + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description=f"Attachment exceeds the maximum file size of {max_size_mb} MB. " + f"Your file is {attachment.size / (1024 * 1024):.2f} MB.", + ) + return await ctx.send(embed=embed) + + # Delete old attachment if present + if old_file_id: + try: + await self.bot.api.delete_snippet_attachment(old_file_id) + except Exception as e: + logger.warning("Failed to delete old attachment for %s: %s", name, e) + + # Upload new attachment + try: + file_data = await attachment.read() + new_file_id = await self.bot.api.upload_snippet_attachment( + file_data, + attachment.filename, + attachment.content_type or "application/octet-stream", + ) + attachment_info = attachment.filename + except Exception as e: + logger.error("Failed to upload snippet attachment: %s", e) + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description="Failed to upload attachment. Please try again.", + ) + return await ctx.send(embed=embed) + + # Use new text if provided, otherwise keep old text + new_text = value if value is not None else old_text + + # Require at least text or attachment + if not new_text and not new_file_id: + embed = discord.Embed( + title="Error", + color=self.bot.error_color, + description="Snippet must have either text content or an attachment.", + ) + return await ctx.send(embed=embed) + + # Update snippet + updated_snippet = {"text": new_text or ""} + if new_file_id: + updated_snippet["file_id"] = new_file_id + + self.bot.snippets[name] = updated_snippet await self.bot.config.update() + description = f"`{name}` has been updated." + if value: + description += f'\nText: "{truncate(value, 100)}"' + if attachment_info: + description += f"\nšŸ“Ž New attachment: `{attachment_info}`" + elif new_file_id: + description += f"\nšŸ“Ž Attachment kept." + embed = discord.Embed( title="Edited snippet", color=self.bot.main_color, - description=f'`{name}` will now send "{value}".', + description=description, ) else: embed = create_not_found_embed(name, self.bot.snippets.keys(), "Snippet") @@ -1541,6 +1784,13 @@ async def freply(self, ctx, *, msg: str = ""): async with safe_typing(ctx): await ctx.thread.reply(ctx.message, msg) + # Delete the snippet command message if it was invoked via snippet + if getattr(ctx, "snippet_invoked", False): + try: + await ctx.message.delete() + except Exception as e: + logger.warning("Failed to delete snippet command message: %s", e) + @commands.command(aliases=["formatanonreply"]) @checks.has_permissions(PermissionLevel.SUPPORTER) @checks.thread_only() diff --git a/core/clients.py b/core/clients.py index 90f09b3b48..a30b237cb5 100644 --- a/core/clients.py +++ b/core/clients.py @@ -1,14 +1,15 @@ import secrets import sys from json import JSONDecodeError -from typing import Any, Dict, Union, Optional +from typing import Any, Dict, Union, Optional, Tuple import discord from discord import Member, DMChannel, TextChannel, Message from discord.ext import commands from aiohttp import ClientResponseError, ClientResponse -from motor.motor_asyncio import AsyncIOMotorClient +from bson import ObjectId +from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket from pymongo.errors import ConfigurationError from core.models import InvalidConfigError, getLogger @@ -460,6 +461,7 @@ def __init__(self, bot): sys.exit(0) super().__init__(bot, db) + self.fs = AsyncIOMotorGridFSBucket(db, bucket_name="snippet_attachments") async def setup_indexes(self): """Setup text indexes so we can use the $search operator""" @@ -779,6 +781,86 @@ async def get_user_info(self) -> Optional[dict]: } } + # ==================== GridFS Methods for Snippet Attachments ==================== + + async def upload_snippet_attachment( + self, file_data: bytes, filename: str, content_type: str = "application/octet-stream" + ) -> str: + """ + Upload a file to GridFS for snippet attachments. + + Parameters + ---------- + file_data : bytes + The raw file data to upload. + filename : str + The original filename. + content_type : str + The MIME type of the file. + + Returns + ------- + str + The string representation of the GridFS file ID. + """ + file_id = await self.fs.upload_from_stream( + filename, + file_data, + metadata={"content_type": content_type, "filename": filename}, + ) + logger.debug("Uploaded snippet attachment %s with file_id %s.", filename, file_id) + return str(file_id) + + async def download_snippet_attachment(self, file_id: str) -> Tuple[bytes, Dict[str, Any]]: + """ + Download a file from GridFS. + + Parameters + ---------- + file_id : str + The string representation of the GridFS file ID. + + Returns + ------- + Tuple[bytes, Dict[str, Any]] + A tuple of (file_data, metadata) where metadata includes filename and content_type. + """ + grid_out = await self.fs.open_download_stream(ObjectId(file_id)) + file_data = await grid_out.read() + metadata = { + "filename": grid_out.filename, + "content_type": ( + grid_out.metadata.get("content_type", "application/octet-stream") + if grid_out.metadata + else "application/octet-stream" + ), + "length": grid_out.length, + } + logger.debug("Downloaded snippet attachment with file_id %s.", file_id) + return file_data, metadata + + async def delete_snippet_attachment(self, file_id: str) -> bool: + """ + Delete a file from GridFS. + + Parameters + ---------- + file_id : str + The string representation of the GridFS file ID. + + Returns + ------- + bool + True if deletion was successful. + """ + try: + await self.fs.delete(ObjectId(file_id)) + logger.debug("Deleted snippet attachment with file_id %s.", file_id) + return True + except Exception as e: + logger.warning("Failed to delete snippet attachment %s: %s", file_id, e) + return False + class PluginDatabaseClient: def __init__(self, bot): diff --git a/core/config.py b/core/config.py index 0e45b00175..df0f5e4503 100644 --- a/core/config.py +++ b/core/config.py @@ -164,6 +164,8 @@ class ConfigManager: "thread_creation_menu_embed_large_image": False, "thread_creation_menu_embed_footer_icon_url": None, "thread_creation_menu_embed_color": str(discord.Color.green()), + # snippet attachments + "snippet_attachment_max_size": 10, # in MB } private_keys = { @@ -242,6 +244,8 @@ class ConfigManager: duration_seconds = {"snooze_default_duration"} + megabytes = {"snippet_attachment_max_size"} + booleans = { "use_user_id_channel_name", "use_timestamp_channel_name", @@ -421,6 +425,14 @@ def get(self, key: str, *, convert: bool = True) -> typing.Any: logger.warning("Invalid %s %s.", key, value) value = self.remove(key) + elif key in self.megabytes: + if not isinstance(value, int): + try: + value = int(value) + except (ValueError, TypeError): + logger.warning("Invalid %s %s.", key, value) + value = self.remove(key) + elif key in self.force_str: # Temporary: as we saved in int previously, leading to int32 overflow, # this is transitioning IDs to strings diff --git a/core/config_help.json b/core/config_help.json index fedf9279ed..8d6c825c38 100644 --- a/core/config_help.json +++ b/core/config_help.json @@ -856,6 +856,18 @@ "See also: `anonymous_snippets`." ] }, + "snippet_attachment_max_size": { + "default": "10 (MB)", + "description": "Maximum file size in megabytes (MB) for attachments when creating or editing snippets.", + "examples": [ + "`{prefix}config set snippet_attachment_max_size 5` (5 MB)", + "`{prefix}config set snippet_attachment_max_size 20` (20 MB)" + ], + "notes": [ + "Attachments larger than this size will be rejected when adding or editing snippets.", + "Value is specified in megabytes (MB)." + ] + }, "require_close_reason": { "default" : "No", "description": "Require a reason to close threads.", diff --git a/core/thread.py b/core/thread.py index 45a6cb9c71..0df9e96a61 100644 --- a/core/thread.py +++ b/core/thread.py @@ -1957,11 +1957,16 @@ async def send( images = [] attachments = [] - for attachment in ext: + files_to_upload = [] + for i, a in enumerate(message.attachments): + attachment = ext[i] if is_image_url(attachment[0]): images.append(attachment) else: - attachments.append(attachment) + if hasattr(a, "to_file") and callable(a.to_file): + files_to_upload.append(a) + else: + attachments.append(attachment) image_urls = re.findall( r"http[s]?:\/\/(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*(),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+", @@ -2176,6 +2181,13 @@ def lottie_to_png(data): else: mentions = None + discord_files = [] + for att in files_to_upload: + try: + discord_files.append(await att.to_file()) + except Exception: + logger.warning("Failed to convert AttachmentWrapper to file.", exc_info=True) + if plain: if from_mod and not isinstance(destination, discord.TextChannel): # Plain to user (DM) @@ -2187,12 +2199,13 @@ def lottie_to_png(data): body = embed.description or "" plain_message = f"{prefix}{embed.author.name}:** {body}" - files = [] + files = discord_files[:] for att in message.attachments: - try: - files.append(await att.to_file()) - except Exception: - logger.warning("Failed to attach file in plain DM.", exc_info=True) + if not (hasattr(att, "to_file") and callable(att.to_file)): + try: + files.append(await att.to_file()) + except Exception: + logger.warning("Failed to attach file in plain DM.", exc_info=True) msg = await destination.send(plain_message, files=files or None) else: @@ -2200,10 +2213,14 @@ def lottie_to_png(data): footer_text = embed.footer.text if embed.footer else "" embed.set_footer(text=f"[PLAIN] {footer_text}".strip()) msg = await destination.send(mentions, embed=embed) + if discord_files: + await destination.send(files=discord_files) else: try: msg = await destination.send(mentions, embed=embed) + if discord_files: + await destination.send(files=discord_files) except discord.NotFound: if ( isinstance(destination, discord.TextChannel) @@ -2214,6 +2231,8 @@ def lottie_to_png(data): await self.restore_from_snooze() destination = self.channel or destination msg = await destination.send(mentions, embed=embed) + if discord_files: + await destination.send(files=discord_files) else: logger.warning("Channel not found during send.") raise