feat: Implement multi-voice support and management
Refactor the TTS handling to support multiple, user-selectable voices. This replaces the previous single-voice system. Key changes: - Introduce VoiceManager to handle loading and managing voices from a dedicated oices/ directory. - Add slash commands (/voice list, /set, /current, /refresh) for users to manage their personal TTS voice. - Implement on-demand voice loading to improve startup time and memory usage. - Remove the old ts_handler.py and single voice .wav files in favor of the new system. - Update configuration to specify a voices directory instead of a single file path.
This commit is contained in:
232
bot.py
232
bot.py
@@ -1,9 +1,15 @@
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from config import Config
|
||||
from tts_handler import TTSHandler
|
||||
from voice_manager import VoiceManager
|
||||
|
||||
|
||||
class TTSBot(commands.Bot):
|
||||
@@ -15,18 +21,195 @@ class TTSBot(commands.Bot):
|
||||
intents.voice_states = True
|
||||
super().__init__(command_prefix="!", intents=intents)
|
||||
|
||||
self.tts_handler = TTSHandler(Config.VOICE_WAV_PATH)
|
||||
self.voice_manager = VoiceManager(Config.VOICES_DIR, Config.DEFAULT_VOICE)
|
||||
self.message_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue()
|
||||
|
||||
self._setup_slash_commands()
|
||||
|
||||
def _setup_slash_commands(self) -> None:
|
||||
"""Set up slash commands for voice management."""
|
||||
|
||||
@self.tree.command(name="voice", description="Manage your TTS voice")
|
||||
@app_commands.describe(
|
||||
action="What to do",
|
||||
voice_name="Name of the voice (for 'set' action)"
|
||||
)
|
||||
@app_commands.choices(action=[
|
||||
app_commands.Choice(name="list", value="list"),
|
||||
app_commands.Choice(name="set", value="set"),
|
||||
app_commands.Choice(name="current", value="current"),
|
||||
app_commands.Choice(name="refresh", value="refresh"),
|
||||
])
|
||||
async def voice_command(
|
||||
interaction: discord.Interaction,
|
||||
action: app_commands.Choice[str],
|
||||
voice_name: str | None = None
|
||||
):
|
||||
if action.value == "list":
|
||||
await self._handle_voice_list(interaction)
|
||||
elif action.value == "set":
|
||||
await self._handle_voice_set(interaction, voice_name)
|
||||
elif action.value == "current":
|
||||
await self._handle_voice_current(interaction)
|
||||
elif action.value == "refresh":
|
||||
await self._handle_voice_refresh(interaction)
|
||||
|
||||
@voice_command.autocomplete("voice_name")
|
||||
async def voice_name_autocomplete(
|
||||
interaction: discord.Interaction,
|
||||
current: str
|
||||
) -> list[app_commands.Choice[str]]:
|
||||
voices = self.voice_manager.get_available_voices()
|
||||
return [
|
||||
app_commands.Choice(name=v, value=v)
|
||||
for v in voices
|
||||
if current.lower() in v.lower()
|
||||
][:25]
|
||||
|
||||
async def _handle_voice_list(self, interaction: discord.Interaction) -> None:
|
||||
"""Handle /voice list command."""
|
||||
voices = self.voice_manager.get_available_voices()
|
||||
loaded = self.voice_manager.get_loaded_voices()
|
||||
user_voice = self.voice_manager.get_user_voice(interaction.user.id)
|
||||
|
||||
if not voices:
|
||||
await interaction.response.send_message(
|
||||
"❌ No voices available. Add .wav files to the voices directory.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
lines = ["**Available Voices:**\n"]
|
||||
for voice in voices:
|
||||
status = []
|
||||
if voice == user_voice:
|
||||
status.append("✅ your voice")
|
||||
if voice in loaded:
|
||||
status.append("📦 loaded")
|
||||
status_str = f" ({', '.join(status)})" if status else ""
|
||||
lines.append(f"• `{voice}`{status_str}")
|
||||
|
||||
lines.append(f"\n*Use `/voice set <name>` to change your voice.*")
|
||||
|
||||
await interaction.response.send_message(
|
||||
"\n".join(lines),
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
async def _handle_voice_set(self, interaction: discord.Interaction, voice_name: str | None) -> None:
|
||||
"""Handle /voice set command."""
|
||||
if not voice_name:
|
||||
await interaction.response.send_message(
|
||||
"❌ Please provide a voice name. Use `/voice list` to see available voices.",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
voice_name = voice_name.lower()
|
||||
|
||||
if not self.voice_manager.is_voice_available(voice_name):
|
||||
voices = self.voice_manager.get_available_voices()
|
||||
await interaction.response.send_message(
|
||||
f"❌ Voice `{voice_name}` not found.\n"
|
||||
f"Available voices: {', '.join(f'`{v}`' for v in voices)}",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
# Check if voice needs to be loaded
|
||||
needs_loading = not self.voice_manager.is_voice_loaded(voice_name)
|
||||
|
||||
if needs_loading:
|
||||
await interaction.response.send_message(
|
||||
f"⏳ Loading voice `{voice_name}` for the first time... This may take a moment.",
|
||||
ephemeral=True
|
||||
)
|
||||
try:
|
||||
await asyncio.to_thread(self.voice_manager.get_voice_state, voice_name)
|
||||
except Exception as e:
|
||||
await interaction.followup.send(
|
||||
f"❌ Failed to load voice `{voice_name}`: {e}",
|
||||
ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
self.voice_manager.set_user_voice(interaction.user.id, voice_name)
|
||||
|
||||
if needs_loading:
|
||||
await interaction.followup.send(
|
||||
f"✅ Voice changed to `{voice_name}`!",
|
||||
ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
f"✅ Voice changed to `{voice_name}`!",
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
async def _handle_voice_current(self, interaction: discord.Interaction) -> None:
|
||||
"""Handle /voice current command."""
|
||||
voice = self.voice_manager.get_user_voice(interaction.user.id)
|
||||
if voice:
|
||||
loaded = "(loaded)" if self.voice_manager.is_voice_loaded(voice) else "(not yet loaded)"
|
||||
await interaction.response.send_message(
|
||||
f"🎤 Your current voice: `{voice}` {loaded}",
|
||||
ephemeral=True
|
||||
)
|
||||
else:
|
||||
await interaction.response.send_message(
|
||||
"❌ No voice set. Use `/voice set <name>` to choose a voice.",
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
async def _handle_voice_refresh(self, interaction: discord.Interaction) -> None:
|
||||
"""Handle /voice refresh command."""
|
||||
await interaction.response.send_message(
|
||||
"🔄 Scanning for new voices...",
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
added, removed = await asyncio.to_thread(self.voice_manager.refresh_voices)
|
||||
|
||||
lines = []
|
||||
if added:
|
||||
lines.append(f"✅ **New voices found:** {', '.join(f'`{v}`' for v in added)}")
|
||||
if removed:
|
||||
lines.append(f"❌ **Voices removed:** {', '.join(f'`{v}`' for v in removed)}")
|
||||
if not added and not removed:
|
||||
lines.append("No changes detected.")
|
||||
|
||||
total = len(self.voice_manager.get_available_voices())
|
||||
lines.append(f"\n*Total voices available: {total}*")
|
||||
|
||||
await interaction.followup.send(
|
||||
"\n".join(lines),
|
||||
ephemeral=True
|
||||
)
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Called when the bot is starting up."""
|
||||
print("Initializing TTS...")
|
||||
await asyncio.to_thread(self.tts_handler.load)
|
||||
print("Discovering available voices...")
|
||||
await asyncio.to_thread(self.voice_manager.discover_voices)
|
||||
await asyncio.to_thread(self.voice_manager.load_model)
|
||||
|
||||
# Pre-load the default voice if one is set
|
||||
default = self.voice_manager.default_voice
|
||||
if default:
|
||||
print(f"Pre-loading default voice: {default}")
|
||||
await asyncio.to_thread(self.voice_manager.get_voice_state, default)
|
||||
|
||||
self.loop.create_task(self.process_queue())
|
||||
|
||||
# Sync slash commands
|
||||
print("Syncing slash commands...")
|
||||
await self.tree.sync()
|
||||
print("Slash commands synced!")
|
||||
|
||||
async def on_ready(self) -> None:
|
||||
print(f"Logged in as {self.user}")
|
||||
print(f"Monitoring channel ID: {Config.TEXT_CHANNEL_ID}")
|
||||
print(f"Available voices: {', '.join(self.voice_manager.get_available_voices())}")
|
||||
print("Bot is ready!")
|
||||
|
||||
async def on_message(self, message: discord.Message) -> None:
|
||||
@@ -75,7 +258,24 @@ class TTSBot(commands.Bot):
|
||||
return
|
||||
|
||||
print(f"Generating TTS for: {text[:50]}...")
|
||||
wav_bytes = await asyncio.to_thread(self.tts_handler.generate_wav_bytes, text)
|
||||
|
||||
# Get user's voice (loads on-demand if needed)
|
||||
user_id = message.author.id
|
||||
try:
|
||||
voice_state = await asyncio.to_thread(
|
||||
self.voice_manager.get_user_voice_state, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error loading voice for user {user_id}: {e}")
|
||||
await message.channel.send(
|
||||
f"{message.author.mention}, failed to load your voice. Use `/voice set` to choose a voice.",
|
||||
delete_after=5
|
||||
)
|
||||
return
|
||||
|
||||
wav_bytes = await asyncio.to_thread(
|
||||
self._generate_wav_bytes, voice_state, text
|
||||
)
|
||||
|
||||
audio_source = discord.FFmpegPCMAudio(
|
||||
io.BytesIO(wav_bytes),
|
||||
@@ -88,7 +288,7 @@ class TTSBot(commands.Bot):
|
||||
|
||||
play_complete = asyncio.Event()
|
||||
|
||||
def after_playing(error):
|
||||
def after_playing(error: Exception | None) -> None:
|
||||
if error:
|
||||
print(f"Playback error: {error}")
|
||||
self.loop.call_soon_threadsafe(play_complete.set)
|
||||
@@ -98,6 +298,28 @@ class TTSBot(commands.Bot):
|
||||
|
||||
await play_complete.wait()
|
||||
|
||||
def _generate_wav_bytes(self, voice_state: Any, text: str) -> bytes:
|
||||
"""Generate audio and return as WAV file bytes."""
|
||||
model = self.voice_manager.model
|
||||
if model is None:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
audio = model.generate_audio(voice_state, text)
|
||||
audio_np = audio.numpy()
|
||||
|
||||
if audio_np.ndim == 1:
|
||||
audio_np = audio_np.reshape(-1, 1)
|
||||
|
||||
max_val = np.max(np.abs(audio_np))
|
||||
if max_val > 0:
|
||||
audio_np = audio_np / max_val
|
||||
audio_int16 = (audio_np * 32767).astype(np.int16)
|
||||
|
||||
wav_buffer = io.BytesIO()
|
||||
wavfile.write(wav_buffer, model.sample_rate, audio_int16)
|
||||
wav_buffer.seek(0)
|
||||
return wav_buffer.read()
|
||||
|
||||
async def ensure_voice_connection(self, channel: discord.VoiceChannel) -> discord.VoiceClient | None:
|
||||
"""Ensure we're connected to the specified voice channel."""
|
||||
guild = channel.guild
|
||||
|
||||
Reference in New Issue
Block a user