- SpeakRequest: keep_audio=true speichert WAV in ~/.cache/chatterbox-tts/
- SpeakJob: audio_path-Feld für gespeicherte WAV-Datei
- GET /audio/{job_id}: liefert WAV als FileResponse, löscht Datei danach
- mcp_adapter: keep_audio-Parameter in speak() weitergereicht
- Docstring: neuen Endpunkt dokumentiert
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
378 lines
11 KiB
Python
378 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Chatterbox TTS – lokaler HTTP-Service
|
||
|
||
Start:
|
||
uvicorn tts_service:app --host 0.0.0.0 --port 9999
|
||
|
||
Endpunkte:
|
||
POST /speak – Text in Warteschlange einreihen
|
||
POST /stop – laufende Ausgabe abbrechen, Queue leeren
|
||
POST /pause – Ausgabe pausieren (ohne Datenverlust)
|
||
POST /resume – pausierte Ausgabe fortsetzen
|
||
GET /audio/{job_id} – fertige WAV herunterladen (nur wenn keep_audio=true)
|
||
GET /health – Service-Status
|
||
GET /status – aktueller Job + Queue-Länge
|
||
GET /voices – unterstützte Sprachen
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
import queue
|
||
import sys
|
||
import threading
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from pathlib import Path
|
||
from typing import Optional
|
||
|
||
# CLI-Modul aus demselben Verzeichnis laden
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
import chatterbox_cli_v4 as tts # noqa: E402
|
||
|
||
import torch
|
||
import torchaudio as ta
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel, Field
|
||
|
||
# Verzeichnis für temporäre Audio-Downloads (keep_audio=True)
|
||
_AUDIO_CACHE_DIR = Path.home() / ".cache" / "chatterbox-tts"
|
||
_AUDIO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Gerät einmalig bestimmen
|
||
# ---------------------------------------------------------------------------
|
||
_DEVICE = tts.get_device(None)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Modell-Cache (lang, t3_model) → (model, model_kind, sr)
|
||
# ---------------------------------------------------------------------------
|
||
_model_cache: dict[tuple, tuple] = {}
|
||
_model_lock = threading.Lock()
|
||
|
||
|
||
def _get_or_load_model(lang: str, t3_model: str) -> tuple:
|
||
key = (lang, t3_model)
|
||
with _model_lock:
|
||
if key not in _model_cache:
|
||
_model_cache[key] = tts.load_model(lang, _DEVICE, t3_model=t3_model)
|
||
return _model_cache[key]
|
||
|
||
|
||
# Optionaler Warmup: TTS_PRELOAD_LANG=de lädt das Modell beim Service-Start,
|
||
# damit der erste Request keine Modell-Ladezeit hat.
|
||
_PRELOAD_LANG = __import__("os").environ.get("TTS_PRELOAD_LANG")
|
||
if _PRELOAD_LANG:
|
||
_preload_t3 = __import__("os").environ.get("TTS_PRELOAD_T3", "v3")
|
||
try:
|
||
_get_or_load_model(_PRELOAD_LANG, _preload_t3)
|
||
print(f"[chatterbox-tts] Modell vorgeladen: lang={_PRELOAD_LANG}, t3={_preload_t3}")
|
||
except Exception as _e:
|
||
print(f"[chatterbox-tts] Warnung: Warmup fehlgeschlagen: {_e}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Job-Datenmodell
|
||
# ---------------------------------------------------------------------------
|
||
class JobStatus(str, Enum):
|
||
pending = "pending"
|
||
running = "running"
|
||
done = "done"
|
||
cancelled = "cancelled"
|
||
error = "error"
|
||
|
||
|
||
@dataclass
|
||
class SpeakJob:
|
||
id: str
|
||
text: str
|
||
lang: str
|
||
t3_model: str
|
||
voice: Optional[str]
|
||
speed: float
|
||
audio_device: str
|
||
max_len: int
|
||
save_wav: bool
|
||
output_path: Optional[str]
|
||
pronunciation_dict: Optional[dict]
|
||
session_id: Optional[str]
|
||
keep_audio: bool = False
|
||
status: JobStatus = field(default=JobStatus.pending)
|
||
text_preview: str = field(default="")
|
||
chunks_total: int = 0
|
||
chunks_done: int = 0
|
||
error: Optional[str] = None
|
||
audio_path: Optional[str] = None # gesetzt wenn keep_audio=True und Job fertig
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Worker-Thread
|
||
# ---------------------------------------------------------------------------
|
||
_job_queue: queue.Queue[SpeakJob] = queue.Queue()
|
||
_current_job: Optional[SpeakJob] = None
|
||
_state_lock = threading.Lock()
|
||
_recent_jobs: list[SpeakJob] = []
|
||
_MAX_RECENT = 20
|
||
|
||
|
||
def _worker() -> None:
|
||
global _current_job
|
||
|
||
while True:
|
||
job = _job_queue.get()
|
||
|
||
with _state_lock:
|
||
_current_job = job
|
||
job.status = JobStatus.running
|
||
|
||
tts.clear_stop()
|
||
|
||
try:
|
||
model, model_kind, sr = _get_or_load_model(job.lang, job.t3_model)
|
||
|
||
raw = tts.clean_raw_text(job.text)
|
||
raw_chunks = tts.split_into_sentences(raw, max_len=job.max_len)
|
||
chunks = [
|
||
tts.preprocess_tts_text(c, lang=job.lang,
|
||
pronunciation_dict=job.pronunciation_dict)
|
||
for c in raw_chunks
|
||
]
|
||
chunks = [c for c in chunks if c.strip()]
|
||
|
||
job.chunks_total = len(chunks)
|
||
job.text_preview = job.text[:80]
|
||
|
||
playback = tts.PlaybackWorker(
|
||
sample_rate=sr,
|
||
device=job.audio_device or "pulse",
|
||
speed=job.speed,
|
||
stop_event=tts.STOP_REQUESTED,
|
||
)
|
||
playback.start()
|
||
|
||
wavs: list[torch.Tensor] = []
|
||
try:
|
||
for chunk in chunks:
|
||
if tts.stop_requested():
|
||
break
|
||
wav = tts.generate_chunk(model, model_kind, chunk, job.lang, job.voice)
|
||
wavs.append(wav)
|
||
playback.put(wav)
|
||
job.chunks_done += 1
|
||
finally:
|
||
playback.stop()
|
||
|
||
if wavs:
|
||
final = wavs[0] if len(wavs) == 1 else torch.cat(wavs, dim=-1)
|
||
|
||
if job.save_wav and job.output_path:
|
||
out = Path(job.output_path)
|
||
out.parent.mkdir(parents=True, exist_ok=True)
|
||
ta.save(str(out), final, sr)
|
||
|
||
if job.keep_audio:
|
||
cache_path = _AUDIO_CACHE_DIR / f"{job.id}.wav"
|
||
ta.save(str(cache_path), final, sr)
|
||
job.audio_path = str(cache_path)
|
||
|
||
job.status = (
|
||
JobStatus.cancelled if tts.stop_requested() else JobStatus.done
|
||
)
|
||
|
||
except Exception as exc: # noqa: BLE001
|
||
job.status = JobStatus.error
|
||
job.error = str(exc)
|
||
|
||
finally:
|
||
with _state_lock:
|
||
_current_job = None
|
||
_recent_jobs.append(job)
|
||
if len(_recent_jobs) > _MAX_RECENT:
|
||
_recent_jobs.pop(0)
|
||
_job_queue.task_done()
|
||
|
||
|
||
_worker_thread = threading.Thread(target=_worker, daemon=True, name="tts-worker")
|
||
_worker_thread.start()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# API-Modelle
|
||
# ---------------------------------------------------------------------------
|
||
class SpeakRequest(BaseModel):
|
||
text: str = Field(min_length=1, max_length=4000)
|
||
lang: str = "de"
|
||
voice: Optional[str] = None
|
||
interrupt: bool = False
|
||
speed: float = Field(default=1.0, ge=0.5, le=2.0)
|
||
t3_model: str = "v3"
|
||
audio_device: Optional[str] = None
|
||
max_len: int = Field(default=400, ge=100, le=1000)
|
||
save_wav: bool = False
|
||
output_path: Optional[str] = None
|
||
session_id: Optional[str] = None
|
||
pronunciation_dict: Optional[dict] = None
|
||
keep_audio: bool = False # WAV im Cache behalten für GET /audio/{job_id}
|
||
|
||
|
||
def _job_to_dict(j: SpeakJob) -> dict:
|
||
return {
|
||
"id": j.id,
|
||
"status": j.status,
|
||
"lang": j.lang,
|
||
"text_preview": j.text_preview,
|
||
"chunks_total": j.chunks_total,
|
||
"chunks_done": j.chunks_done,
|
||
"error": j.error,
|
||
}
|
||
|
||
|
||
def _drain_queue() -> None:
|
||
while not _job_queue.empty():
|
||
try:
|
||
_job_queue.get_nowait()
|
||
_job_queue.task_done()
|
||
except queue.Empty:
|
||
break
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# FastAPI-App
|
||
# ---------------------------------------------------------------------------
|
||
app = FastAPI(title="Chatterbox TTS Service", version="1.0")
|
||
|
||
|
||
@app.get("/health")
|
||
def health():
|
||
return {"status": "ok", "device": _DEVICE}
|
||
|
||
|
||
@app.get("/voices")
|
||
def voices():
|
||
return {
|
||
"languages": sorted(tts.SUPPORTED_LANGS),
|
||
"note": "Voice cloning via 'voice' field (WAV-Pfad, 10–30s Aufnahme)",
|
||
}
|
||
|
||
|
||
@app.post("/speak")
|
||
def speak(req: SpeakRequest):
|
||
if req.lang not in tts.SUPPORTED_LANGS:
|
||
raise HTTPException(status_code=422,
|
||
detail=f"Sprache nicht unterstützt: {req.lang}")
|
||
if req.voice and not Path(req.voice).exists():
|
||
raise HTTPException(status_code=422,
|
||
detail=f"Voice-Datei nicht gefunden: {req.voice}")
|
||
|
||
if req.interrupt:
|
||
tts.request_stop()
|
||
_drain_queue()
|
||
|
||
job = SpeakJob(
|
||
id=str(uuid.uuid4()),
|
||
text=req.text,
|
||
lang=req.lang,
|
||
t3_model=req.t3_model,
|
||
voice=req.voice,
|
||
speed=req.speed,
|
||
audio_device=req.audio_device,
|
||
max_len=req.max_len,
|
||
save_wav=req.save_wav,
|
||
output_path=req.output_path,
|
||
pronunciation_dict=req.pronunciation_dict,
|
||
session_id=req.session_id,
|
||
keep_audio=req.keep_audio,
|
||
)
|
||
_job_queue.put(job)
|
||
|
||
return {
|
||
"job_id": job.id,
|
||
"status": job.status,
|
||
"queue_position": _job_queue.qsize(),
|
||
}
|
||
|
||
|
||
@app.post("/stop")
|
||
def stop():
|
||
tts.request_stop()
|
||
_drain_queue()
|
||
return {"stopped": True}
|
||
|
||
|
||
@app.post("/pause")
|
||
def pause():
|
||
tts.request_pause()
|
||
return {"paused": True}
|
||
|
||
|
||
@app.post("/resume")
|
||
def resume():
|
||
tts.request_resume()
|
||
return {"resumed": True}
|
||
|
||
|
||
@app.get("/audio/{job_id}")
|
||
def download_audio(job_id: str):
|
||
"""Fertige WAV-Datei herunterladen (nur wenn speak mit keep_audio=true aufgerufen wurde).
|
||
|
||
Die Datei wird nach dem Download automatisch gelöscht.
|
||
Ist der Job noch nicht fertig, wird 202 zurückgegeben.
|
||
"""
|
||
with _state_lock:
|
||
cur = _current_job
|
||
recent = list(_recent_jobs)
|
||
|
||
# Laufenden Job prüfen
|
||
if cur and cur.id == job_id:
|
||
raise HTTPException(status_code=202, detail="Job läuft noch — bitte später erneut abrufen.")
|
||
|
||
# In den letzten Jobs suchen
|
||
job = next((j for j in recent if j.id == job_id), None)
|
||
if job is None:
|
||
raise HTTPException(status_code=404, detail=f"Job nicht gefunden: {job_id}")
|
||
|
||
if job.status == JobStatus.pending or job.status == JobStatus.running:
|
||
raise HTTPException(status_code=202, detail="Job läuft noch — bitte später erneut abrufen.")
|
||
|
||
if not job.audio_path or not Path(job.audio_path).exists():
|
||
if not job.keep_audio:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail="Keine Audio-Datei vorhanden. Bitte /speak mit keep_audio=true aufrufen.",
|
||
)
|
||
raise HTTPException(status_code=404, detail="Audio-Datei nicht mehr vorhanden.")
|
||
|
||
audio_path = Path(job.audio_path)
|
||
|
||
def cleanup_after_send():
|
||
try:
|
||
os.unlink(audio_path)
|
||
job.audio_path = None
|
||
except OSError:
|
||
pass
|
||
|
||
response = FileResponse(
|
||
path=str(audio_path),
|
||
media_type="audio/wav",
|
||
filename=f"tts_{job_id[:8]}.wav",
|
||
background=None,
|
||
)
|
||
# Datei nach dem Senden löschen — via BackgroundTask
|
||
from starlette.background import BackgroundTask
|
||
response.background = BackgroundTask(cleanup_after_send)
|
||
return response
|
||
|
||
|
||
@app.get("/status")
|
||
def status():
|
||
with _state_lock:
|
||
cur = _current_job
|
||
recent = list(_recent_jobs)
|
||
|
||
return {
|
||
"current_job": _job_to_dict(cur) if cur else None,
|
||
"queue_length": _job_queue.qsize(),
|
||
"recent_jobs": [_job_to_dict(j) for j in reversed(recent)],
|
||
}
|