Initial commit
This commit is contained in:
77
tts_handler.py
Normal file
77
tts_handler.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import io
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
from typing import Any
|
||||
from pocket_tts import TTSModel
|
||||
|
||||
from audio_preprocessor import (
|
||||
AudioPreprocessor,
|
||||
PreprocessingConfig,
|
||||
print_audio_analysis,
|
||||
)
|
||||
|
||||
|
||||
class TTSHandler:
|
||||
"""Handles text-to-speech generation using Pocket TTS."""
|
||||
|
||||
DISCORD_SAMPLE_RATE = 48000
|
||||
|
||||
def __init__(self, voice_wav_path: str, preprocess_audio: bool = True):
|
||||
self.voice_wav_path = voice_wav_path
|
||||
self.preprocess_audio = preprocess_audio
|
||||
self.model: TTSModel | None = None
|
||||
self.voice_state: Any = None
|
||||
self._preprocessed_path: str | None = None
|
||||
|
||||
def load(self) -> None:
|
||||
"""Load the TTS model and voice state from the WAV file."""
|
||||
print("Loading Pocket TTS model...")
|
||||
self.model = TTSModel.load_model()
|
||||
|
||||
voice_path = self.voice_wav_path
|
||||
|
||||
# Analyze and preprocess the audio if enabled
|
||||
if self.preprocess_audio:
|
||||
print("\nAnalyzing original audio...")
|
||||
print_audio_analysis(self.voice_wav_path)
|
||||
|
||||
print("Preprocessing audio for optimal voice cloning...")
|
||||
config = PreprocessingConfig(
|
||||
target_sample_rate=22050,
|
||||
normalize=True,
|
||||
trim_silence=True,
|
||||
trim_top_db=20,
|
||||
reduce_noise=True,
|
||||
target_length_seconds=15.0, # Limit to 15 seconds for best results
|
||||
)
|
||||
preprocessor = AudioPreprocessor(config)
|
||||
voice_path = preprocessor.preprocess_file(self.voice_wav_path)
|
||||
self._preprocessed_path = voice_path
|
||||
print("")
|
||||
|
||||
print(f"Loading voice state from: {voice_path}")
|
||||
self.voice_state = self.model.get_state_for_audio_prompt(voice_path)
|
||||
print("TTS handler ready!")
|
||||
|
||||
def generate_wav_bytes(self, text: str) -> bytes:
|
||||
"""Generate audio and return as WAV file bytes (for FFmpeg)."""
|
||||
if self.model is None or self.voice_state is None:
|
||||
raise RuntimeError("TTS handler not loaded. Call load() first.")
|
||||
|
||||
audio = self.model.generate_audio(self.voice_state, text)
|
||||
audio_np = audio.numpy()
|
||||
|
||||
if audio_np.ndim == 1:
|
||||
audio_np = audio_np.reshape(-1, 1)
|
||||
|
||||
max_val = np.max(np.abs(audio_np))
|
||||
if max_val > 0:
|
||||
audio_np = audio_np / max_val
|
||||
audio_int16 = (audio_np * 32767).astype(np.int16)
|
||||
|
||||
wav_buffer = io.BytesIO()
|
||||
wavfile.write(wav_buffer, self.model.sample_rate, audio_int16)
|
||||
wav_buffer.seek(0)
|
||||
return wav_buffer.read()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user