feat: Implement multi-voice support and management

Refactor the TTS handling to support multiple, user-selectable voices. This replaces the previous single-voice system.

Key changes:
- Introduce VoiceManager to handle loading and managing voices from a dedicated oices/ directory.
- Add slash commands (/voice list, /set, /current, /refresh) for users to manage their personal TTS voice.
- Implement on-demand voice loading to improve startup time and memory usage.
- Remove the old 	ts_handler.py and single voice .wav files in favor of the new system.
- Update configuration to specify a voices directory instead of a single file path.
This commit is contained in:
2026-01-18 17:24:12 -06:00
parent ae1c2a65d3
commit 92dfcb1d39
14 changed files with 463 additions and 85 deletions

232
bot.py
View File

@@ -1,9 +1,15 @@
import asyncio import asyncio
import io import io
from typing import Any
import discord import discord
import numpy as np
import scipy.io.wavfile as wavfile
from discord import app_commands
from discord.ext import commands from discord.ext import commands
from config import Config from config import Config
from tts_handler import TTSHandler from voice_manager import VoiceManager
class TTSBot(commands.Bot): class TTSBot(commands.Bot):
@@ -15,18 +21,195 @@ class TTSBot(commands.Bot):
intents.voice_states = True intents.voice_states = True
super().__init__(command_prefix="!", intents=intents) super().__init__(command_prefix="!", intents=intents)
self.tts_handler = TTSHandler(Config.VOICE_WAV_PATH) self.voice_manager = VoiceManager(Config.VOICES_DIR, Config.DEFAULT_VOICE)
self.message_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue() self.message_queue: asyncio.Queue[tuple[discord.Message, str]] = asyncio.Queue()
self._setup_slash_commands()
def _setup_slash_commands(self) -> None:
"""Set up slash commands for voice management."""
@self.tree.command(name="voice", description="Manage your TTS voice")
@app_commands.describe(
action="What to do",
voice_name="Name of the voice (for 'set' action)"
)
@app_commands.choices(action=[
app_commands.Choice(name="list", value="list"),
app_commands.Choice(name="set", value="set"),
app_commands.Choice(name="current", value="current"),
app_commands.Choice(name="refresh", value="refresh"),
])
async def voice_command(
interaction: discord.Interaction,
action: app_commands.Choice[str],
voice_name: str | None = None
):
if action.value == "list":
await self._handle_voice_list(interaction)
elif action.value == "set":
await self._handle_voice_set(interaction, voice_name)
elif action.value == "current":
await self._handle_voice_current(interaction)
elif action.value == "refresh":
await self._handle_voice_refresh(interaction)
@voice_command.autocomplete("voice_name")
async def voice_name_autocomplete(
interaction: discord.Interaction,
current: str
) -> list[app_commands.Choice[str]]:
voices = self.voice_manager.get_available_voices()
return [
app_commands.Choice(name=v, value=v)
for v in voices
if current.lower() in v.lower()
][:25]
async def _handle_voice_list(self, interaction: discord.Interaction) -> None:
"""Handle /voice list command."""
voices = self.voice_manager.get_available_voices()
loaded = self.voice_manager.get_loaded_voices()
user_voice = self.voice_manager.get_user_voice(interaction.user.id)
if not voices:
await interaction.response.send_message(
"❌ No voices available. Add .wav files to the voices directory.",
ephemeral=True
)
return
lines = ["**Available Voices:**\n"]
for voice in voices:
status = []
if voice == user_voice:
status.append("✅ your voice")
if voice in loaded:
status.append("📦 loaded")
status_str = f" ({', '.join(status)})" if status else ""
lines.append(f"• `{voice}`{status_str}")
lines.append(f"\n*Use `/voice set <name>` to change your voice.*")
await interaction.response.send_message(
"\n".join(lines),
ephemeral=True
)
async def _handle_voice_set(self, interaction: discord.Interaction, voice_name: str | None) -> None:
"""Handle /voice set command."""
if not voice_name:
await interaction.response.send_message(
"❌ Please provide a voice name. Use `/voice list` to see available voices.",
ephemeral=True
)
return
voice_name = voice_name.lower()
if not self.voice_manager.is_voice_available(voice_name):
voices = self.voice_manager.get_available_voices()
await interaction.response.send_message(
f"❌ Voice `{voice_name}` not found.\n"
f"Available voices: {', '.join(f'`{v}`' for v in voices)}",
ephemeral=True
)
return
# Check if voice needs to be loaded
needs_loading = not self.voice_manager.is_voice_loaded(voice_name)
if needs_loading:
await interaction.response.send_message(
f"⏳ Loading voice `{voice_name}` for the first time... This may take a moment.",
ephemeral=True
)
try:
await asyncio.to_thread(self.voice_manager.get_voice_state, voice_name)
except Exception as e:
await interaction.followup.send(
f"❌ Failed to load voice `{voice_name}`: {e}",
ephemeral=True
)
return
self.voice_manager.set_user_voice(interaction.user.id, voice_name)
if needs_loading:
await interaction.followup.send(
f"✅ Voice changed to `{voice_name}`!",
ephemeral=True
)
else:
await interaction.response.send_message(
f"✅ Voice changed to `{voice_name}`!",
ephemeral=True
)
async def _handle_voice_current(self, interaction: discord.Interaction) -> None:
"""Handle /voice current command."""
voice = self.voice_manager.get_user_voice(interaction.user.id)
if voice:
loaded = "(loaded)" if self.voice_manager.is_voice_loaded(voice) else "(not yet loaded)"
await interaction.response.send_message(
f"🎤 Your current voice: `{voice}` {loaded}",
ephemeral=True
)
else:
await interaction.response.send_message(
"❌ No voice set. Use `/voice set <name>` to choose a voice.",
ephemeral=True
)
async def _handle_voice_refresh(self, interaction: discord.Interaction) -> None:
"""Handle /voice refresh command."""
await interaction.response.send_message(
"🔄 Scanning for new voices...",
ephemeral=True
)
added, removed = await asyncio.to_thread(self.voice_manager.refresh_voices)
lines = []
if added:
lines.append(f"✅ **New voices found:** {', '.join(f'`{v}`' for v in added)}")
if removed:
lines.append(f"❌ **Voices removed:** {', '.join(f'`{v}`' for v in removed)}")
if not added and not removed:
lines.append("No changes detected.")
total = len(self.voice_manager.get_available_voices())
lines.append(f"\n*Total voices available: {total}*")
await interaction.followup.send(
"\n".join(lines),
ephemeral=True
)
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Called when the bot is starting up.""" """Called when the bot is starting up."""
print("Initializing TTS...") print("Initializing TTS...")
await asyncio.to_thread(self.tts_handler.load) print("Discovering available voices...")
await asyncio.to_thread(self.voice_manager.discover_voices)
await asyncio.to_thread(self.voice_manager.load_model)
# Pre-load the default voice if one is set
default = self.voice_manager.default_voice
if default:
print(f"Pre-loading default voice: {default}")
await asyncio.to_thread(self.voice_manager.get_voice_state, default)
self.loop.create_task(self.process_queue()) self.loop.create_task(self.process_queue())
# Sync slash commands
print("Syncing slash commands...")
await self.tree.sync()
print("Slash commands synced!")
async def on_ready(self) -> None: async def on_ready(self) -> None:
print(f"Logged in as {self.user}") print(f"Logged in as {self.user}")
print(f"Monitoring channel ID: {Config.TEXT_CHANNEL_ID}") print(f"Monitoring channel ID: {Config.TEXT_CHANNEL_ID}")
print(f"Available voices: {', '.join(self.voice_manager.get_available_voices())}")
print("Bot is ready!") print("Bot is ready!")
async def on_message(self, message: discord.Message) -> None: async def on_message(self, message: discord.Message) -> None:
@@ -75,7 +258,24 @@ class TTSBot(commands.Bot):
return return
print(f"Generating TTS for: {text[:50]}...") print(f"Generating TTS for: {text[:50]}...")
wav_bytes = await asyncio.to_thread(self.tts_handler.generate_wav_bytes, text)
# Get user's voice (loads on-demand if needed)
user_id = message.author.id
try:
voice_state = await asyncio.to_thread(
self.voice_manager.get_user_voice_state, user_id
)
except Exception as e:
print(f"Error loading voice for user {user_id}: {e}")
await message.channel.send(
f"{message.author.mention}, failed to load your voice. Use `/voice set` to choose a voice.",
delete_after=5
)
return
wav_bytes = await asyncio.to_thread(
self._generate_wav_bytes, voice_state, text
)
audio_source = discord.FFmpegPCMAudio( audio_source = discord.FFmpegPCMAudio(
io.BytesIO(wav_bytes), io.BytesIO(wav_bytes),
@@ -88,7 +288,7 @@ class TTSBot(commands.Bot):
play_complete = asyncio.Event() play_complete = asyncio.Event()
def after_playing(error): def after_playing(error: Exception | None) -> None:
if error: if error:
print(f"Playback error: {error}") print(f"Playback error: {error}")
self.loop.call_soon_threadsafe(play_complete.set) self.loop.call_soon_threadsafe(play_complete.set)
@@ -98,6 +298,28 @@ class TTSBot(commands.Bot):
await play_complete.wait() await play_complete.wait()
def _generate_wav_bytes(self, voice_state: Any, text: str) -> bytes:
"""Generate audio and return as WAV file bytes."""
model = self.voice_manager.model
if model is None:
raise RuntimeError("Model not loaded")
audio = model.generate_audio(voice_state, text)
audio_np = audio.numpy()
if audio_np.ndim == 1:
audio_np = audio_np.reshape(-1, 1)
max_val = np.max(np.abs(audio_np))
if max_val > 0:
audio_np = audio_np / max_val
audio_int16 = (audio_np * 32767).astype(np.int16)
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, model.sample_rate, audio_int16)
wav_buffer.seek(0)
return wav_buffer.read()
async def ensure_voice_connection(self, channel: discord.VoiceChannel) -> discord.VoiceClient | None: async def ensure_voice_connection(self, channel: discord.VoiceChannel) -> discord.VoiceClient | None:
"""Ensure we're connected to the specified voice channel.""" """Ensure we're connected to the specified voice channel."""
guild = channel.guild guild = channel.guild

View File

@@ -7,7 +7,8 @@ load_dotenv()
class Config: class Config:
DISCORD_TOKEN: str = os.getenv("DISCORD_TOKEN", "") DISCORD_TOKEN: str = os.getenv("DISCORD_TOKEN", "")
TEXT_CHANNEL_ID: int = int(os.getenv("TEXT_CHANNEL_ID", "0")) TEXT_CHANNEL_ID: int = int(os.getenv("TEXT_CHANNEL_ID", "0"))
VOICE_WAV_PATH: str = os.getenv("VOICE_WAV_PATH", "./voice.wav") VOICES_DIR: str = os.getenv("VOICES_DIR", "./voices")
DEFAULT_VOICE: str | None = os.getenv("DEFAULT_VOICE", None)
@classmethod @classmethod
def validate(cls) -> list[str]: def validate(cls) -> list[str]:
@@ -17,6 +18,6 @@ class Config:
errors.append("DISCORD_TOKEN is not set") errors.append("DISCORD_TOKEN is not set")
if cls.TEXT_CHANNEL_ID == 0: if cls.TEXT_CHANNEL_ID == 0:
errors.append("TEXT_CHANNEL_ID is not set") errors.append("TEXT_CHANNEL_ID is not set")
if not os.path.exists(cls.VOICE_WAV_PATH): if not os.path.exists(cls.VOICES_DIR):
errors.append(f"Voice WAV file not found: {cls.VOICE_WAV_PATH}") errors.append(f"Voices directory not found: {cls.VOICES_DIR}")
return errors return errors

37
pockettts.service Normal file
View File

@@ -0,0 +1,37 @@
[Unit]
Description=Pocket TTS Discord Bot
After=network-online.target
Wants=network-online.target
[Service]
# Replace with your username
User=YOUR_USERNAME
Group=YOUR_USERNAME
# Replace with the actual path to your bot directory
WorkingDirectory=/home/YOUR_USERNAME/PocketTTSBot
# Use the Python from the virtual environment
ExecStart=/home/YOUR_USERNAME/PocketTTSBot/venv/bin/python bot.py
# Restart on failure
Restart=on-failure
RestartSec=10
# Give the bot time to gracefully shutdown
TimeoutStopSec=30
# Logging
StandardOutput=journal
StandardError=journal
SyslogIdentifier=pockettts
# Security hardening (optional but recommended)
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=read-only
ReadWritePaths=/home/YOUR_USERNAME/PocketTTSBot/voices
PrivateTmp=true
[Install]
WantedBy=multi-user.target

View File

@@ -1,77 +0,0 @@
import io
import numpy as np
import scipy.io.wavfile as wavfile
from typing import Any
from pocket_tts import TTSModel
from audio_preprocessor import (
AudioPreprocessor,
PreprocessingConfig,
print_audio_analysis,
)
class TTSHandler:
"""Handles text-to-speech generation using Pocket TTS."""
DISCORD_SAMPLE_RATE = 48000
def __init__(self, voice_wav_path: str, preprocess_audio: bool = True):
self.voice_wav_path = voice_wav_path
self.preprocess_audio = preprocess_audio
self.model: TTSModel | None = None
self.voice_state: Any = None
self._preprocessed_path: str | None = None
def load(self) -> None:
"""Load the TTS model and voice state from the WAV file."""
print("Loading Pocket TTS model...")
self.model = TTSModel.load_model()
voice_path = self.voice_wav_path
# Analyze and preprocess the audio if enabled
if self.preprocess_audio:
print("\nAnalyzing original audio...")
print_audio_analysis(self.voice_wav_path)
print("Preprocessing audio for optimal voice cloning...")
config = PreprocessingConfig(
target_sample_rate=22050,
normalize=True,
trim_silence=True,
trim_top_db=20,
reduce_noise=True,
target_length_seconds=15.0, # Limit to 15 seconds for best results
)
preprocessor = AudioPreprocessor(config)
voice_path = preprocessor.preprocess_file(self.voice_wav_path)
self._preprocessed_path = voice_path
print("")
print(f"Loading voice state from: {voice_path}")
self.voice_state = self.model.get_state_for_audio_prompt(voice_path)
print("TTS handler ready!")
def generate_wav_bytes(self, text: str) -> bytes:
"""Generate audio and return as WAV file bytes (for FFmpeg)."""
if self.model is None or self.voice_state is None:
raise RuntimeError("TTS handler not loaded. Call load() first.")
audio = self.model.generate_audio(self.voice_state, text)
audio_np = audio.numpy()
if audio_np.ndim == 1:
audio_np = audio_np.reshape(-1, 1)
max_val = np.max(np.abs(audio_np))
if max_val > 0:
audio_np = audio_np / max_val
audio_int16 = (audio_np * 32767).astype(np.int16)
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, self.model.sample_rate, audio_int16)
wav_buffer.seek(0)
return wav_buffer.read()

190
voice_manager.py Normal file
View File

@@ -0,0 +1,190 @@
"""Voice management for per-user voice selection and on-demand loading."""
import json
from pathlib import Path
from typing import Any
from pocket_tts import TTSModel
from audio_preprocessor import (
AudioPreprocessor,
PreprocessingConfig,
print_audio_analysis,
)
class VoiceManager:
"""Manages available voices, per-user preferences, and on-demand voice loading."""
def __init__(self, voices_dir: str, default_voice: str | None = None):
self.voices_dir = Path(voices_dir)
self.default_voice = default_voice
self.model: TTSModel | None = None
self.preferences_file = self.voices_dir / "preferences.json"
# Cache of loaded voice states: voice_name -> voice_state
self._voice_states: dict[str, Any] = {}
# Per-user voice preferences: user_id -> voice_name
self._user_voices: dict[int, str] = {}
# Available voices: voice_name -> file_path
self._available_voices: dict[str, Path] = {}
def discover_voices(self) -> dict[str, Path]:
"""Discover all available voice WAV files in the voices directory."""
old_voices = set(self._available_voices.keys())
self._available_voices = {}
if not self.voices_dir.exists():
print(f"Voices directory not found: {self.voices_dir}")
return self._available_voices
for wav_file in self.voices_dir.glob("*.wav"):
voice_name = wav_file.stem.lower()
self._available_voices[voice_name] = wav_file
print(f" Found voice: {voice_name} ({wav_file.name})")
# Set default voice if not specified
if self.default_voice is None and self._available_voices:
self.default_voice = next(iter(self._available_voices.keys()))
# Load saved preferences
self._load_preferences()
return self._available_voices
def refresh_voices(self) -> tuple[list[str], list[str]]:
"""Re-scan for voices and return (new_voices, removed_voices)."""
old_voices = set(self._available_voices.keys())
self._available_voices = {}
if self.voices_dir.exists():
for wav_file in self.voices_dir.glob("*.wav"):
voice_name = wav_file.stem.lower()
self._available_voices[voice_name] = wav_file
new_voices = set(self._available_voices.keys())
added = sorted(new_voices - old_voices)
removed = sorted(old_voices - new_voices)
# Update default if needed
if self.default_voice not in self._available_voices and self._available_voices:
self.default_voice = next(iter(self._available_voices.keys()))
return added, removed
def load_model(self) -> None:
"""Load the TTS model (does not load any voices yet)."""
print("Loading Pocket TTS model...")
self.model = TTSModel.load_model()
print("TTS model loaded!")
def get_available_voices(self) -> list[str]:
"""Get list of available voice names."""
return sorted(self._available_voices.keys())
def is_voice_available(self, voice_name: str) -> bool:
"""Check if a voice is available."""
return voice_name.lower() in self._available_voices
def get_voice_state(self, voice_name: str) -> Any:
"""Get or load a voice state on-demand."""
if self.model is None:
raise RuntimeError("Model not loaded. Call load_model() first.")
voice_name = voice_name.lower()
if voice_name not in self._available_voices:
raise ValueError(f"Voice '{voice_name}' not found")
# Return cached state if already loaded
if voice_name in self._voice_states:
return self._voice_states[voice_name]
# Load the voice on-demand
voice_path = self._available_voices[voice_name]
print(f"Loading voice '{voice_name}' from {voice_path}...")
# Preprocess the audio
print(f" Analyzing audio...")
print_audio_analysis(str(voice_path))
print(f" Preprocessing audio...")
config = PreprocessingConfig(
target_sample_rate=22050,
normalize=True,
trim_silence=True,
trim_top_db=20,
reduce_noise=True,
target_length_seconds=15.0,
)
preprocessor = AudioPreprocessor(config)
processed_path = preprocessor.preprocess_file(str(voice_path))
# Load voice state
voice_state = self.model.get_state_for_audio_prompt(processed_path)
self._voice_states[voice_name] = voice_state
print(f" Voice '{voice_name}' loaded and cached!")
return voice_state
def is_voice_loaded(self, voice_name: str) -> bool:
"""Check if a voice is already loaded in cache."""
return voice_name.lower() in self._voice_states
def get_user_voice(self, user_id: int) -> str:
"""Get the voice preference for a user, or default voice."""
return self._user_voices.get(user_id, self.default_voice or "")
def set_user_voice(self, user_id: int, voice_name: str) -> None:
"""Set the voice preference for a user."""
voice_name = voice_name.lower()
if voice_name not in self._available_voices:
raise ValueError(f"Voice '{voice_name}' not found")
self._user_voices[user_id] = voice_name
self._save_preferences()
def get_user_voice_state(self, user_id: int) -> Any:
"""Get the voice state for a user (loads on-demand if needed)."""
voice_name = self.get_user_voice(user_id)
if not voice_name:
raise RuntimeError("No default voice available")
return self.get_voice_state(voice_name)
def get_loaded_voices(self) -> list[str]:
"""Get list of currently loaded voice names."""
return list(self._voice_states.keys())
def _load_preferences(self) -> None:
"""Load user voice preferences from JSON file."""
if not self.preferences_file.exists():
return
try:
with open(self.preferences_file, "r") as f:
data = json.load(f)
# Load user preferences (convert string keys back to int)
for user_id_str, voice_name in data.get("user_voices", {}).items():
user_id = int(user_id_str)
# Only load if voice still exists
if voice_name.lower() in self._available_voices:
self._user_voices[user_id] = voice_name.lower()
print(f" Loaded {len(self._user_voices)} user voice preferences")
except Exception as e:
print(f" Warning: Failed to load preferences: {e}")
def _save_preferences(self) -> None:
"""Save user voice preferences to JSON file."""
try:
# Ensure directory exists
self.preferences_file.parent.mkdir(parents=True, exist_ok=True)
data = {
"user_voices": {str(k): v for k, v in self._user_voices.items()}
}
with open(self.preferences_file, "w") as f:
json.dump(data, f, indent=2)
except Exception as e:
print(f"Warning: Failed to save preferences: {e}")

5
voices/preferences.json Normal file
View File

@@ -0,0 +1,5 @@
{
"user_voices": {
"122139828182712322": "hankhill"
}
}