diff --git a/bot.py b/bot.py index 17d7e13..07b5efa 100644 --- a/bot.py +++ b/bot.py @@ -1,9 +1,15 @@ import asyncio import io +from typing import Any + import discord +import numpy as np +import scipy.io.wavfile as wavfile +from discord import app_commands from discord.ext import commands + from config import Config -from tts_handler import TTSHandler +from voice_manager import VoiceManager class TTSBot(commands.Bot): @@ -15,18 +21,195 @@ class TTSBot(commands.Bot): intents.voice_states = True super().__init__(command_prefix="!", intents=intents) - self.tts_handler = TTSHandler(Config.VOICE_WAV_PATH) + self.voice_manager = VoiceManager(Config.VOICES_DIR, Config.DEFAULT_VOICE) self.message_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue() + + self._setup_slash_commands() + + def _setup_slash_commands(self) -> None: + """Set up slash commands for voice management.""" + + @self.tree.command(name="voice", description="Manage your TTS voice") + @app_commands.describe( + action="What to do", + voice_name="Name of the voice (for 'set' action)" + ) + @app_commands.choices(action=[ + app_commands.Choice(name="list", value="list"), + app_commands.Choice(name="set", value="set"), + app_commands.Choice(name="current", value="current"), + app_commands.Choice(name="refresh", value="refresh"), + ]) + async def voice_command( + interaction: discord.Interaction, + action: app_commands.Choice[str], + voice_name: str | None = None + ): + if action.value == "list": + await self._handle_voice_list(interaction) + elif action.value == "set": + await self._handle_voice_set(interaction, voice_name) + elif action.value == "current": + await self._handle_voice_current(interaction) + elif action.value == "refresh": + await self._handle_voice_refresh(interaction) + + @voice_command.autocomplete("voice_name") + async def voice_name_autocomplete( + interaction: discord.Interaction, + current: str + ) -> list[app_commands.Choice[str]]: + voices = self.voice_manager.get_available_voices() + return [ + app_commands.Choice(name=v, value=v) + for v in voices + if current.lower() in v.lower() + ][:25] + + async def _handle_voice_list(self, interaction: discord.Interaction) -> None: + """Handle /voice list command.""" + voices = self.voice_manager.get_available_voices() + loaded = self.voice_manager.get_loaded_voices() + user_voice = self.voice_manager.get_user_voice(interaction.user.id) + + if not voices: + await interaction.response.send_message( + "❌ No voices available. Add .wav files to the voices directory.", + ephemeral=True + ) + return + + lines = ["**Available Voices:**\n"] + for voice in voices: + status = [] + if voice == user_voice: + status.append("✅ your voice") + if voice in loaded: + status.append("📦 loaded") + status_str = f" ({', '.join(status)})" if status else "" + lines.append(f"• `{voice}`{status_str}") + + lines.append(f"\n*Use `/voice set ` to change your voice.*") + + await interaction.response.send_message( + "\n".join(lines), + ephemeral=True + ) + + async def _handle_voice_set(self, interaction: discord.Interaction, voice_name: str | None) -> None: + """Handle /voice set command.""" + if not voice_name: + await interaction.response.send_message( + "❌ Please provide a voice name. Use `/voice list` to see available voices.", + ephemeral=True + ) + return + + voice_name = voice_name.lower() + + if not self.voice_manager.is_voice_available(voice_name): + voices = self.voice_manager.get_available_voices() + await interaction.response.send_message( + f"❌ Voice `{voice_name}` not found.\n" + f"Available voices: {', '.join(f'`{v}`' for v in voices)}", + ephemeral=True + ) + return + + # Check if voice needs to be loaded + needs_loading = not self.voice_manager.is_voice_loaded(voice_name) + + if needs_loading: + await interaction.response.send_message( + f"⏳ Loading voice `{voice_name}` for the first time... This may take a moment.", + ephemeral=True + ) + try: + await asyncio.to_thread(self.voice_manager.get_voice_state, voice_name) + except Exception as e: + await interaction.followup.send( + f"❌ Failed to load voice `{voice_name}`: {e}", + ephemeral=True + ) + return + + self.voice_manager.set_user_voice(interaction.user.id, voice_name) + + if needs_loading: + await interaction.followup.send( + f"✅ Voice changed to `{voice_name}`!", + ephemeral=True + ) + else: + await interaction.response.send_message( + f"✅ Voice changed to `{voice_name}`!", + ephemeral=True + ) + + async def _handle_voice_current(self, interaction: discord.Interaction) -> None: + """Handle /voice current command.""" + voice = self.voice_manager.get_user_voice(interaction.user.id) + if voice: + loaded = "(loaded)" if self.voice_manager.is_voice_loaded(voice) else "(not yet loaded)" + await interaction.response.send_message( + f"🎤 Your current voice: `{voice}` {loaded}", + ephemeral=True + ) + else: + await interaction.response.send_message( + "❌ No voice set. Use `/voice set ` to choose a voice.", + ephemeral=True + ) + + async def _handle_voice_refresh(self, interaction: discord.Interaction) -> None: + """Handle /voice refresh command.""" + await interaction.response.send_message( + "🔄 Scanning for new voices...", + ephemeral=True + ) + + added, removed = await asyncio.to_thread(self.voice_manager.refresh_voices) + + lines = [] + if added: + lines.append(f"✅ **New voices found:** {', '.join(f'`{v}`' for v in added)}") + if removed: + lines.append(f"❌ **Voices removed:** {', '.join(f'`{v}`' for v in removed)}") + if not added and not removed: + lines.append("No changes detected.") + + total = len(self.voice_manager.get_available_voices()) + lines.append(f"\n*Total voices available: {total}*") + + await interaction.followup.send( + "\n".join(lines), + ephemeral=True + ) async def setup_hook(self) -> None: """Called when the bot is starting up.""" print("Initializing TTS...") - await asyncio.to_thread(self.tts_handler.load) + print("Discovering available voices...") + await asyncio.to_thread(self.voice_manager.discover_voices) + await asyncio.to_thread(self.voice_manager.load_model) + + # Pre-load the default voice if one is set + default = self.voice_manager.default_voice + if default: + print(f"Pre-loading default voice: {default}") + await asyncio.to_thread(self.voice_manager.get_voice_state, default) + self.loop.create_task(self.process_queue()) + + # Sync slash commands + print("Syncing slash commands...") + await self.tree.sync() + print("Slash commands synced!") async def on_ready(self) -> None: print(f"Logged in as {self.user}") print(f"Monitoring channel ID: {Config.TEXT_CHANNEL_ID}") + print(f"Available voices: {', '.join(self.voice_manager.get_available_voices())}") print("Bot is ready!") async def on_message(self, message: discord.Message) -> None: @@ -75,7 +258,24 @@ class TTSBot(commands.Bot): return print(f"Generating TTS for: {text[:50]}...") - wav_bytes = await asyncio.to_thread(self.tts_handler.generate_wav_bytes, text) + + # Get user's voice (loads on-demand if needed) + user_id = message.author.id + try: + voice_state = await asyncio.to_thread( + self.voice_manager.get_user_voice_state, user_id + ) + except Exception as e: + print(f"Error loading voice for user {user_id}: {e}") + await message.channel.send( + f"{message.author.mention}, failed to load your voice. Use `/voice set` to choose a voice.", + delete_after=5 + ) + return + + wav_bytes = await asyncio.to_thread( + self._generate_wav_bytes, voice_state, text + ) audio_source = discord.FFmpegPCMAudio( io.BytesIO(wav_bytes), @@ -88,7 +288,7 @@ class TTSBot(commands.Bot): play_complete = asyncio.Event() - def after_playing(error): + def after_playing(error: Exception | None) -> None: if error: print(f"Playback error: {error}") self.loop.call_soon_threadsafe(play_complete.set) @@ -98,6 +298,28 @@ class TTSBot(commands.Bot): await play_complete.wait() + def _generate_wav_bytes(self, voice_state: Any, text: str) -> bytes: + """Generate audio and return as WAV file bytes.""" + model = self.voice_manager.model + if model is None: + raise RuntimeError("Model not loaded") + + audio = model.generate_audio(voice_state, text) + audio_np = audio.numpy() + + if audio_np.ndim == 1: + audio_np = audio_np.reshape(-1, 1) + + max_val = np.max(np.abs(audio_np)) + if max_val > 0: + audio_np = audio_np / max_val + audio_int16 = (audio_np * 32767).astype(np.int16) + + wav_buffer = io.BytesIO() + wavfile.write(wav_buffer, model.sample_rate, audio_int16) + wav_buffer.seek(0) + return wav_buffer.read() + async def ensure_voice_connection(self, channel: discord.VoiceChannel) -> discord.VoiceClient | None: """Ensure we're connected to the specified voice channel.""" guild = channel.guild diff --git a/config.py b/config.py index 8bd6635..92175e4 100644 --- a/config.py +++ b/config.py @@ -7,7 +7,8 @@ load_dotenv() class Config: DISCORD_TOKEN: str = os.getenv("DISCORD_TOKEN", "") TEXT_CHANNEL_ID: int = int(os.getenv("TEXT_CHANNEL_ID", "0")) - VOICE_WAV_PATH: str = os.getenv("VOICE_WAV_PATH", "./voice.wav") + VOICES_DIR: str = os.getenv("VOICES_DIR", "./voices") + DEFAULT_VOICE: str | None = os.getenv("DEFAULT_VOICE", None) @classmethod def validate(cls) -> list[str]: @@ -17,6 +18,6 @@ class Config: errors.append("DISCORD_TOKEN is not set") if cls.TEXT_CHANNEL_ID == 0: errors.append("TEXT_CHANNEL_ID is not set") - if not os.path.exists(cls.VOICE_WAV_PATH): - errors.append(f"Voice WAV file not found: {cls.VOICE_WAV_PATH}") + if not os.path.exists(cls.VOICES_DIR): + errors.append(f"Voices directory not found: {cls.VOICES_DIR}") return errors diff --git a/pockettts.service b/pockettts.service new file mode 100644 index 0000000..3b57b3d --- /dev/null +++ b/pockettts.service @@ -0,0 +1,37 @@ +[Unit] +Description=Pocket TTS Discord Bot +After=network-online.target +Wants=network-online.target + +[Service] +# Replace with your username +User=YOUR_USERNAME +Group=YOUR_USERNAME + +# Replace with the actual path to your bot directory +WorkingDirectory=/home/YOUR_USERNAME/PocketTTSBot + +# Use the Python from the virtual environment +ExecStart=/home/YOUR_USERNAME/PocketTTSBot/venv/bin/python bot.py + +# Restart on failure +Restart=on-failure +RestartSec=10 + +# Give the bot time to gracefully shutdown +TimeoutStopSec=30 + +# Logging +StandardOutput=journal +StandardError=journal +SyslogIdentifier=pockettts + +# Security hardening (optional but recommended) +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=read-only +ReadWritePaths=/home/YOUR_USERNAME/PocketTTSBot/voices +PrivateTmp=true + +[Install] +WantedBy=multi-user.target diff --git a/tts_handler.py b/tts_handler.py deleted file mode 100644 index f6aec23..0000000 --- a/tts_handler.py +++ /dev/null @@ -1,77 +0,0 @@ -import io -import numpy as np -import scipy.io.wavfile as wavfile -from typing import Any -from pocket_tts import TTSModel - -from audio_preprocessor import ( - AudioPreprocessor, - PreprocessingConfig, - print_audio_analysis, -) - - -class TTSHandler: - """Handles text-to-speech generation using Pocket TTS.""" - - DISCORD_SAMPLE_RATE = 48000 - - def __init__(self, voice_wav_path: str, preprocess_audio: bool = True): - self.voice_wav_path = voice_wav_path - self.preprocess_audio = preprocess_audio - self.model: TTSModel | None = None - self.voice_state: Any = None - self._preprocessed_path: str | None = None - - def load(self) -> None: - """Load the TTS model and voice state from the WAV file.""" - print("Loading Pocket TTS model...") - self.model = TTSModel.load_model() - - voice_path = self.voice_wav_path - - # Analyze and preprocess the audio if enabled - if self.preprocess_audio: - print("\nAnalyzing original audio...") - print_audio_analysis(self.voice_wav_path) - - print("Preprocessing audio for optimal voice cloning...") - config = PreprocessingConfig( - target_sample_rate=22050, - normalize=True, - trim_silence=True, - trim_top_db=20, - reduce_noise=True, - target_length_seconds=15.0, # Limit to 15 seconds for best results - ) - preprocessor = AudioPreprocessor(config) - voice_path = preprocessor.preprocess_file(self.voice_wav_path) - self._preprocessed_path = voice_path - print("") - - print(f"Loading voice state from: {voice_path}") - self.voice_state = self.model.get_state_for_audio_prompt(voice_path) - print("TTS handler ready!") - - def generate_wav_bytes(self, text: str) -> bytes: - """Generate audio and return as WAV file bytes (for FFmpeg).""" - if self.model is None or self.voice_state is None: - raise RuntimeError("TTS handler not loaded. Call load() first.") - - audio = self.model.generate_audio(self.voice_state, text) - audio_np = audio.numpy() - - if audio_np.ndim == 1: - audio_np = audio_np.reshape(-1, 1) - - max_val = np.max(np.abs(audio_np)) - if max_val > 0: - audio_np = audio_np / max_val - audio_int16 = (audio_np * 32767).astype(np.int16) - - wav_buffer = io.BytesIO() - wavfile.write(wav_buffer, self.model.sample_rate, audio_int16) - wav_buffer.seek(0) - return wav_buffer.read() - - diff --git a/voice_manager.py b/voice_manager.py new file mode 100644 index 0000000..e52db69 --- /dev/null +++ b/voice_manager.py @@ -0,0 +1,190 @@ +"""Voice management for per-user voice selection and on-demand loading.""" + +import json +from pathlib import Path +from typing import Any + +from pocket_tts import TTSModel + +from audio_preprocessor import ( + AudioPreprocessor, + PreprocessingConfig, + print_audio_analysis, +) + + +class VoiceManager: + """Manages available voices, per-user preferences, and on-demand voice loading.""" + + def __init__(self, voices_dir: str, default_voice: str | None = None): + self.voices_dir = Path(voices_dir) + self.default_voice = default_voice + self.model: TTSModel | None = None + self.preferences_file = self.voices_dir / "preferences.json" + + # Cache of loaded voice states: voice_name -> voice_state + self._voice_states: dict[str, Any] = {} + # Per-user voice preferences: user_id -> voice_name + self._user_voices: dict[int, str] = {} + # Available voices: voice_name -> file_path + self._available_voices: dict[str, Path] = {} + + def discover_voices(self) -> dict[str, Path]: + """Discover all available voice WAV files in the voices directory.""" + old_voices = set(self._available_voices.keys()) + self._available_voices = {} + + if not self.voices_dir.exists(): + print(f"Voices directory not found: {self.voices_dir}") + return self._available_voices + + for wav_file in self.voices_dir.glob("*.wav"): + voice_name = wav_file.stem.lower() + self._available_voices[voice_name] = wav_file + print(f" Found voice: {voice_name} ({wav_file.name})") + + # Set default voice if not specified + if self.default_voice is None and self._available_voices: + self.default_voice = next(iter(self._available_voices.keys())) + + # Load saved preferences + self._load_preferences() + + return self._available_voices + + def refresh_voices(self) -> tuple[list[str], list[str]]: + """Re-scan for voices and return (new_voices, removed_voices).""" + old_voices = set(self._available_voices.keys()) + + self._available_voices = {} + if self.voices_dir.exists(): + for wav_file in self.voices_dir.glob("*.wav"): + voice_name = wav_file.stem.lower() + self._available_voices[voice_name] = wav_file + + new_voices = set(self._available_voices.keys()) + added = sorted(new_voices - old_voices) + removed = sorted(old_voices - new_voices) + + # Update default if needed + if self.default_voice not in self._available_voices and self._available_voices: + self.default_voice = next(iter(self._available_voices.keys())) + + return added, removed + + def load_model(self) -> None: + """Load the TTS model (does not load any voices yet).""" + print("Loading Pocket TTS model...") + self.model = TTSModel.load_model() + print("TTS model loaded!") + + def get_available_voices(self) -> list[str]: + """Get list of available voice names.""" + return sorted(self._available_voices.keys()) + + def is_voice_available(self, voice_name: str) -> bool: + """Check if a voice is available.""" + return voice_name.lower() in self._available_voices + + def get_voice_state(self, voice_name: str) -> Any: + """Get or load a voice state on-demand.""" + if self.model is None: + raise RuntimeError("Model not loaded. Call load_model() first.") + + voice_name = voice_name.lower() + + if voice_name not in self._available_voices: + raise ValueError(f"Voice '{voice_name}' not found") + + # Return cached state if already loaded + if voice_name in self._voice_states: + return self._voice_states[voice_name] + + # Load the voice on-demand + voice_path = self._available_voices[voice_name] + print(f"Loading voice '{voice_name}' from {voice_path}...") + + # Preprocess the audio + print(f" Analyzing audio...") + print_audio_analysis(str(voice_path)) + + print(f" Preprocessing audio...") + config = PreprocessingConfig( + target_sample_rate=22050, + normalize=True, + trim_silence=True, + trim_top_db=20, + reduce_noise=True, + target_length_seconds=15.0, + ) + preprocessor = AudioPreprocessor(config) + processed_path = preprocessor.preprocess_file(str(voice_path)) + + # Load voice state + voice_state = self.model.get_state_for_audio_prompt(processed_path) + self._voice_states[voice_name] = voice_state + print(f" Voice '{voice_name}' loaded and cached!") + + return voice_state + + def is_voice_loaded(self, voice_name: str) -> bool: + """Check if a voice is already loaded in cache.""" + return voice_name.lower() in self._voice_states + + def get_user_voice(self, user_id: int) -> str: + """Get the voice preference for a user, or default voice.""" + return self._user_voices.get(user_id, self.default_voice or "") + + def set_user_voice(self, user_id: int, voice_name: str) -> None: + """Set the voice preference for a user.""" + voice_name = voice_name.lower() + if voice_name not in self._available_voices: + raise ValueError(f"Voice '{voice_name}' not found") + self._user_voices[user_id] = voice_name + self._save_preferences() + + def get_user_voice_state(self, user_id: int) -> Any: + """Get the voice state for a user (loads on-demand if needed).""" + voice_name = self.get_user_voice(user_id) + if not voice_name: + raise RuntimeError("No default voice available") + return self.get_voice_state(voice_name) + + def get_loaded_voices(self) -> list[str]: + """Get list of currently loaded voice names.""" + return list(self._voice_states.keys()) + + def _load_preferences(self) -> None: + """Load user voice preferences from JSON file.""" + if not self.preferences_file.exists(): + return + + try: + with open(self.preferences_file, "r") as f: + data = json.load(f) + + # Load user preferences (convert string keys back to int) + for user_id_str, voice_name in data.get("user_voices", {}).items(): + user_id = int(user_id_str) + # Only load if voice still exists + if voice_name.lower() in self._available_voices: + self._user_voices[user_id] = voice_name.lower() + + print(f" Loaded {len(self._user_voices)} user voice preferences") + except Exception as e: + print(f" Warning: Failed to load preferences: {e}") + + def _save_preferences(self) -> None: + """Save user voice preferences to JSON file.""" + try: + # Ensure directory exists + self.preferences_file.parent.mkdir(parents=True, exist_ok=True) + + data = { + "user_voices": {str(k): v for k, v in self._user_voices.items()} + } + + with open(self.preferences_file, "w") as f: + json.dump(data, f, indent=2) + except Exception as e: + print(f"Warning: Failed to save preferences: {e}") diff --git a/Estinien.wav b/voices/Estinien.wav similarity index 100% rename from Estinien.wav rename to voices/Estinien.wav diff --git a/Gaius.wav b/voices/Gaius.wav similarity index 100% rename from Gaius.wav rename to voices/Gaius.wav diff --git a/Gibralter_funny.wav b/voices/Gibralter_funny.wav similarity index 100% rename from Gibralter_funny.wav rename to voices/Gibralter_funny.wav diff --git a/Gibralter_good.wav b/voices/Gibralter_good.wav similarity index 100% rename from Gibralter_good.wav rename to voices/Gibralter_good.wav diff --git a/HankHill.wav b/voices/HankHill.wav similarity index 100% rename from HankHill.wav rename to voices/HankHill.wav diff --git a/Johnny.wav b/voices/Johnny.wav similarity index 100% rename from Johnny.wav rename to voices/Johnny.wav diff --git a/MasterChief.wav b/voices/MasterChief.wav similarity index 100% rename from MasterChief.wav rename to voices/MasterChief.wav diff --git a/Trump.wav b/voices/Trump.wav similarity index 100% rename from Trump.wav rename to voices/Trump.wav diff --git a/voices/preferences.json b/voices/preferences.json new file mode 100644 index 0000000..6edac7e --- /dev/null +++ b/voices/preferences.json @@ -0,0 +1,5 @@ +{ + "user_voices": { + "122139828182712322": "hankhill" + } +} \ No newline at end of file