134 lines
4.1 KiB
Python
134 lines
4.1 KiB
Python
import asyncio
|
|
import io
|
|
import discord
|
|
from discord.ext import commands
|
|
from config import Config
|
|
from tts_handler import TTSHandler
|
|
|
|
|
|
class TTSBot(commands.Bot):
|
|
"""Discord bot that reads messages aloud using Pocket TTS."""
|
|
|
|
def __init__(self):
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
intents.voice_states = True
|
|
super().__init__(command_prefix="!", intents=intents)
|
|
|
|
self.tts_handler = TTSHandler(Config.VOICE_WAV_PATH)
|
|
self.message_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue()
|
|
|
|
async def setup_hook(self) -> None:
|
|
"""Called when the bot is starting up."""
|
|
print("Initializing TTS...")
|
|
await asyncio.to_thread(self.tts_handler.load)
|
|
self.loop.create_task(self.process_queue())
|
|
|
|
async def on_ready(self) -> None:
|
|
print(f"Logged in as {self.user}")
|
|
print(f"Monitoring channel ID: {Config.TEXT_CHANNEL_ID}")
|
|
print("Bot is ready!")
|
|
|
|
async def on_message(self, message: discord.Message) -> None:
|
|
if message.author.bot:
|
|
return
|
|
|
|
if message.channel.id != Config.TEXT_CHANNEL_ID:
|
|
return
|
|
|
|
if not message.content.strip():
|
|
return
|
|
|
|
if message.author.voice is None:
|
|
await message.channel.send(
|
|
f"{message.author.mention}, you need to be in a voice channel for me to speak!",
|
|
delete_after=5
|
|
)
|
|
return
|
|
|
|
await self.message_queue.put((message, message.content))
|
|
print(f"Queued message from {message.author}: {message.content[:50]}...")
|
|
|
|
await self.process_commands(message)
|
|
|
|
async def process_queue(self) -> None:
|
|
"""Process messages from the queue one at a time."""
|
|
while True:
|
|
message, text = await self.message_queue.get()
|
|
|
|
try:
|
|
await self.speak_message(message, text)
|
|
except Exception as e:
|
|
print(f"Error processing message: {e}")
|
|
finally:
|
|
self.message_queue.task_done()
|
|
|
|
async def speak_message(self, message: discord.Message, text: str) -> None:
|
|
"""Generate TTS and play it in the user's voice channel."""
|
|
if message.author.voice is None:
|
|
return
|
|
|
|
voice_channel = message.author.voice.channel
|
|
|
|
voice_client = await self.ensure_voice_connection(voice_channel)
|
|
if voice_client is None:
|
|
return
|
|
|
|
print(f"Generating TTS for: {text[:50]}...")
|
|
wav_bytes = await asyncio.to_thread(self.tts_handler.generate_wav_bytes, text)
|
|
|
|
audio_source = discord.FFmpegPCMAudio(
|
|
io.BytesIO(wav_bytes),
|
|
pipe=True,
|
|
options="-loglevel panic"
|
|
)
|
|
|
|
if voice_client.is_playing():
|
|
voice_client.stop()
|
|
|
|
play_complete = asyncio.Event()
|
|
|
|
def after_playing(error):
|
|
if error:
|
|
print(f"Playback error: {error}")
|
|
self.loop.call_soon_threadsafe(play_complete.set)
|
|
|
|
voice_client.play(audio_source, after=after_playing)
|
|
print(f"Playing audio in {voice_channel.name}")
|
|
|
|
await play_complete.wait()
|
|
|
|
async def ensure_voice_connection(self, channel: discord.VoiceChannel) -> discord.VoiceClient | None:
|
|
"""Ensure we're connected to the specified voice channel."""
|
|
guild = channel.guild
|
|
|
|
if guild.voice_client is not None:
|
|
if guild.voice_client.channel.id == channel.id:
|
|
return guild.voice_client
|
|
await guild.voice_client.move_to(channel)
|
|
return guild.voice_client
|
|
|
|
try:
|
|
voice_client = await channel.connect(timeout=10.0)
|
|
return voice_client
|
|
except Exception as e:
|
|
print(f"Failed to connect to voice channel: {e}")
|
|
return None
|
|
|
|
|
|
def main():
|
|
errors = Config.validate()
|
|
if errors:
|
|
print("Configuration errors:")
|
|
for error in errors:
|
|
print(f" - {error}")
|
|
print("\nPlease create a .env file based on .env.example")
|
|
return
|
|
|
|
bot = TTSBot()
|
|
bot.run(Config.DISCORD_TOKEN)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|