Files
Vox/tts_handler.py
2026-01-18 17:08:37 -06:00

78 lines
2.5 KiB
Python

import io
import numpy as np
import scipy.io.wavfile as wavfile
from typing import Any
from pocket_tts import TTSModel
from audio_preprocessor import (
AudioPreprocessor,
PreprocessingConfig,
print_audio_analysis,
)
class TTSHandler:
"""Handles text-to-speech generation using Pocket TTS."""
DISCORD_SAMPLE_RATE = 48000
def __init__(self, voice_wav_path: str, preprocess_audio: bool = True):
self.voice_wav_path = voice_wav_path
self.preprocess_audio = preprocess_audio
self.model: TTSModel | None = None
self.voice_state: Any = None
self._preprocessed_path: str | None = None
def load(self) -> None:
"""Load the TTS model and voice state from the WAV file."""
print("Loading Pocket TTS model...")
self.model = TTSModel.load_model()
voice_path = self.voice_wav_path
# Analyze and preprocess the audio if enabled
if self.preprocess_audio:
print("\nAnalyzing original audio...")
print_audio_analysis(self.voice_wav_path)
print("Preprocessing audio for optimal voice cloning...")
config = PreprocessingConfig(
target_sample_rate=22050,
normalize=True,
trim_silence=True,
trim_top_db=20,
reduce_noise=True,
target_length_seconds=15.0, # Limit to 15 seconds for best results
)
preprocessor = AudioPreprocessor(config)
voice_path = preprocessor.preprocess_file(self.voice_wav_path)
self._preprocessed_path = voice_path
print("")
print(f"Loading voice state from: {voice_path}")
self.voice_state = self.model.get_state_for_audio_prompt(voice_path)
print("TTS handler ready!")
def generate_wav_bytes(self, text: str) -> bytes:
"""Generate audio and return as WAV file bytes (for FFmpeg)."""
if self.model is None or self.voice_state is None:
raise RuntimeError("TTS handler not loaded. Call load() first.")
audio = self.model.generate_audio(self.voice_state, text)
audio_np = audio.numpy()
if audio_np.ndim == 1:
audio_np = audio_np.reshape(-1, 1)
max_val = np.max(np.abs(audio_np))
if max_val > 0:
audio_np = audio_np / max_val
audio_int16 = (audio_np * 32767).astype(np.int16)
wav_buffer = io.BytesIO()
wavfile.write(wav_buffer, self.model.sample_rate, audio_int16)
wav_buffer.seek(0)
return wav_buffer.read()