78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
import io
|
|
import numpy as np
|
|
import scipy.io.wavfile as wavfile
|
|
from typing import Any
|
|
from pocket_tts import TTSModel
|
|
|
|
from audio_preprocessor import (
|
|
AudioPreprocessor,
|
|
PreprocessingConfig,
|
|
print_audio_analysis,
|
|
)
|
|
|
|
|
|
class TTSHandler:
|
|
"""Handles text-to-speech generation using Pocket TTS."""
|
|
|
|
DISCORD_SAMPLE_RATE = 48000
|
|
|
|
def __init__(self, voice_wav_path: str, preprocess_audio: bool = True):
|
|
self.voice_wav_path = voice_wav_path
|
|
self.preprocess_audio = preprocess_audio
|
|
self.model: TTSModel | None = None
|
|
self.voice_state: Any = None
|
|
self._preprocessed_path: str | None = None
|
|
|
|
def load(self) -> None:
|
|
"""Load the TTS model and voice state from the WAV file."""
|
|
print("Loading Pocket TTS model...")
|
|
self.model = TTSModel.load_model()
|
|
|
|
voice_path = self.voice_wav_path
|
|
|
|
# Analyze and preprocess the audio if enabled
|
|
if self.preprocess_audio:
|
|
print("\nAnalyzing original audio...")
|
|
print_audio_analysis(self.voice_wav_path)
|
|
|
|
print("Preprocessing audio for optimal voice cloning...")
|
|
config = PreprocessingConfig(
|
|
target_sample_rate=22050,
|
|
normalize=True,
|
|
trim_silence=True,
|
|
trim_top_db=20,
|
|
reduce_noise=True,
|
|
target_length_seconds=15.0, # Limit to 15 seconds for best results
|
|
)
|
|
preprocessor = AudioPreprocessor(config)
|
|
voice_path = preprocessor.preprocess_file(self.voice_wav_path)
|
|
self._preprocessed_path = voice_path
|
|
print("")
|
|
|
|
print(f"Loading voice state from: {voice_path}")
|
|
self.voice_state = self.model.get_state_for_audio_prompt(voice_path)
|
|
print("TTS handler ready!")
|
|
|
|
def generate_wav_bytes(self, text: str) -> bytes:
|
|
"""Generate audio and return as WAV file bytes (for FFmpeg)."""
|
|
if self.model is None or self.voice_state is None:
|
|
raise RuntimeError("TTS handler not loaded. Call load() first.")
|
|
|
|
audio = self.model.generate_audio(self.voice_state, text)
|
|
audio_np = audio.numpy()
|
|
|
|
if audio_np.ndim == 1:
|
|
audio_np = audio_np.reshape(-1, 1)
|
|
|
|
max_val = np.max(np.abs(audio_np))
|
|
if max_val > 0:
|
|
audio_np = audio_np / max_val
|
|
audio_int16 = (audio_np * 32767).astype(np.int16)
|
|
|
|
wav_buffer = io.BytesIO()
|
|
wavfile.write(wav_buffer, self.model.sample_rate, audio_int16)
|
|
wav_buffer.seek(0)
|
|
return wav_buffer.read()
|
|
|
|
|