chatterbox-tts-cli/tts_service.py
dschlueter fe74b84360 Audio-Download-Endpunkt GET /audio/{job_id} hinzufügen
- SpeakRequest: keep_audio=true speichert WAV in ~/.cache/chatterbox-tts/
- SpeakJob: audio_path-Feld für gespeicherte WAV-Datei
- GET /audio/{job_id}: liefert WAV als FileResponse, löscht Datei danach
- mcp_adapter: keep_audio-Parameter in speak() weitergereicht
- Docstring: neuen Endpunkt dokumentiert

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-03 21:09:06 +02:00

378 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Chatterbox TTS lokaler HTTP-Service
Start:
uvicorn tts_service:app --host 0.0.0.0 --port 9999
Endpunkte:
POST /speak Text in Warteschlange einreihen
POST /stop laufende Ausgabe abbrechen, Queue leeren
POST /pause Ausgabe pausieren (ohne Datenverlust)
POST /resume pausierte Ausgabe fortsetzen
GET /audio/{job_id} fertige WAV herunterladen (nur wenn keep_audio=true)
GET /health Service-Status
GET /status aktueller Job + Queue-Länge
GET /voices unterstützte Sprachen
"""
from __future__ import annotations
import os
import queue
import sys
import threading
import uuid
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional
# CLI-Modul aus demselben Verzeichnis laden
sys.path.insert(0, str(Path(__file__).parent))
import chatterbox_cli_v4 as tts # noqa: E402
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
# Verzeichnis für temporäre Audio-Downloads (keep_audio=True)
_AUDIO_CACHE_DIR = Path.home() / ".cache" / "chatterbox-tts"
_AUDIO_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# ---------------------------------------------------------------------------
# Gerät einmalig bestimmen
# ---------------------------------------------------------------------------
_DEVICE = tts.get_device(None)
# ---------------------------------------------------------------------------
# Modell-Cache (lang, t3_model) → (model, model_kind, sr)
# ---------------------------------------------------------------------------
_model_cache: dict[tuple, tuple] = {}
_model_lock = threading.Lock()
def _get_or_load_model(lang: str, t3_model: str) -> tuple:
key = (lang, t3_model)
with _model_lock:
if key not in _model_cache:
_model_cache[key] = tts.load_model(lang, _DEVICE, t3_model=t3_model)
return _model_cache[key]
# Optionaler Warmup: TTS_PRELOAD_LANG=de lädt das Modell beim Service-Start,
# damit der erste Request keine Modell-Ladezeit hat.
_PRELOAD_LANG = __import__("os").environ.get("TTS_PRELOAD_LANG")
if _PRELOAD_LANG:
_preload_t3 = __import__("os").environ.get("TTS_PRELOAD_T3", "v3")
try:
_get_or_load_model(_PRELOAD_LANG, _preload_t3)
print(f"[chatterbox-tts] Modell vorgeladen: lang={_PRELOAD_LANG}, t3={_preload_t3}")
except Exception as _e:
print(f"[chatterbox-tts] Warnung: Warmup fehlgeschlagen: {_e}")
# ---------------------------------------------------------------------------
# Job-Datenmodell
# ---------------------------------------------------------------------------
class JobStatus(str, Enum):
pending = "pending"
running = "running"
done = "done"
cancelled = "cancelled"
error = "error"
@dataclass
class SpeakJob:
id: str
text: str
lang: str
t3_model: str
voice: Optional[str]
speed: float
audio_device: str
max_len: int
save_wav: bool
output_path: Optional[str]
pronunciation_dict: Optional[dict]
session_id: Optional[str]
keep_audio: bool = False
status: JobStatus = field(default=JobStatus.pending)
text_preview: str = field(default="")
chunks_total: int = 0
chunks_done: int = 0
error: Optional[str] = None
audio_path: Optional[str] = None # gesetzt wenn keep_audio=True und Job fertig
# ---------------------------------------------------------------------------
# Worker-Thread
# ---------------------------------------------------------------------------
_job_queue: queue.Queue[SpeakJob] = queue.Queue()
_current_job: Optional[SpeakJob] = None
_state_lock = threading.Lock()
_recent_jobs: list[SpeakJob] = []
_MAX_RECENT = 20
def _worker() -> None:
global _current_job
while True:
job = _job_queue.get()
with _state_lock:
_current_job = job
job.status = JobStatus.running
tts.clear_stop()
try:
model, model_kind, sr = _get_or_load_model(job.lang, job.t3_model)
raw = tts.clean_raw_text(job.text)
raw_chunks = tts.split_into_sentences(raw, max_len=job.max_len)
chunks = [
tts.preprocess_tts_text(c, lang=job.lang,
pronunciation_dict=job.pronunciation_dict)
for c in raw_chunks
]
chunks = [c for c in chunks if c.strip()]
job.chunks_total = len(chunks)
job.text_preview = job.text[:80]
playback = tts.PlaybackWorker(
sample_rate=sr,
device=job.audio_device or "pulse",
speed=job.speed,
stop_event=tts.STOP_REQUESTED,
)
playback.start()
wavs: list[torch.Tensor] = []
try:
for chunk in chunks:
if tts.stop_requested():
break
wav = tts.generate_chunk(model, model_kind, chunk, job.lang, job.voice)
wavs.append(wav)
playback.put(wav)
job.chunks_done += 1
finally:
playback.stop()
if wavs:
final = wavs[0] if len(wavs) == 1 else torch.cat(wavs, dim=-1)
if job.save_wav and job.output_path:
out = Path(job.output_path)
out.parent.mkdir(parents=True, exist_ok=True)
ta.save(str(out), final, sr)
if job.keep_audio:
cache_path = _AUDIO_CACHE_DIR / f"{job.id}.wav"
ta.save(str(cache_path), final, sr)
job.audio_path = str(cache_path)
job.status = (
JobStatus.cancelled if tts.stop_requested() else JobStatus.done
)
except Exception as exc: # noqa: BLE001
job.status = JobStatus.error
job.error = str(exc)
finally:
with _state_lock:
_current_job = None
_recent_jobs.append(job)
if len(_recent_jobs) > _MAX_RECENT:
_recent_jobs.pop(0)
_job_queue.task_done()
_worker_thread = threading.Thread(target=_worker, daemon=True, name="tts-worker")
_worker_thread.start()
# ---------------------------------------------------------------------------
# API-Modelle
# ---------------------------------------------------------------------------
class SpeakRequest(BaseModel):
text: str = Field(min_length=1, max_length=4000)
lang: str = "de"
voice: Optional[str] = None
interrupt: bool = False
speed: float = Field(default=1.0, ge=0.5, le=2.0)
t3_model: str = "v3"
audio_device: Optional[str] = None
max_len: int = Field(default=400, ge=100, le=1000)
save_wav: bool = False
output_path: Optional[str] = None
session_id: Optional[str] = None
pronunciation_dict: Optional[dict] = None
keep_audio: bool = False # WAV im Cache behalten für GET /audio/{job_id}
def _job_to_dict(j: SpeakJob) -> dict:
return {
"id": j.id,
"status": j.status,
"lang": j.lang,
"text_preview": j.text_preview,
"chunks_total": j.chunks_total,
"chunks_done": j.chunks_done,
"error": j.error,
}
def _drain_queue() -> None:
while not _job_queue.empty():
try:
_job_queue.get_nowait()
_job_queue.task_done()
except queue.Empty:
break
# ---------------------------------------------------------------------------
# FastAPI-App
# ---------------------------------------------------------------------------
app = FastAPI(title="Chatterbox TTS Service", version="1.0")
@app.get("/health")
def health():
return {"status": "ok", "device": _DEVICE}
@app.get("/voices")
def voices():
return {
"languages": sorted(tts.SUPPORTED_LANGS),
"note": "Voice cloning via 'voice' field (WAV-Pfad, 1030s Aufnahme)",
}
@app.post("/speak")
def speak(req: SpeakRequest):
if req.lang not in tts.SUPPORTED_LANGS:
raise HTTPException(status_code=422,
detail=f"Sprache nicht unterstützt: {req.lang}")
if req.voice and not Path(req.voice).exists():
raise HTTPException(status_code=422,
detail=f"Voice-Datei nicht gefunden: {req.voice}")
if req.interrupt:
tts.request_stop()
_drain_queue()
job = SpeakJob(
id=str(uuid.uuid4()),
text=req.text,
lang=req.lang,
t3_model=req.t3_model,
voice=req.voice,
speed=req.speed,
audio_device=req.audio_device,
max_len=req.max_len,
save_wav=req.save_wav,
output_path=req.output_path,
pronunciation_dict=req.pronunciation_dict,
session_id=req.session_id,
keep_audio=req.keep_audio,
)
_job_queue.put(job)
return {
"job_id": job.id,
"status": job.status,
"queue_position": _job_queue.qsize(),
}
@app.post("/stop")
def stop():
tts.request_stop()
_drain_queue()
return {"stopped": True}
@app.post("/pause")
def pause():
tts.request_pause()
return {"paused": True}
@app.post("/resume")
def resume():
tts.request_resume()
return {"resumed": True}
@app.get("/audio/{job_id}")
def download_audio(job_id: str):
"""Fertige WAV-Datei herunterladen (nur wenn speak mit keep_audio=true aufgerufen wurde).
Die Datei wird nach dem Download automatisch gelöscht.
Ist der Job noch nicht fertig, wird 202 zurückgegeben.
"""
with _state_lock:
cur = _current_job
recent = list(_recent_jobs)
# Laufenden Job prüfen
if cur and cur.id == job_id:
raise HTTPException(status_code=202, detail="Job läuft noch — bitte später erneut abrufen.")
# In den letzten Jobs suchen
job = next((j for j in recent if j.id == job_id), None)
if job is None:
raise HTTPException(status_code=404, detail=f"Job nicht gefunden: {job_id}")
if job.status == JobStatus.pending or job.status == JobStatus.running:
raise HTTPException(status_code=202, detail="Job läuft noch — bitte später erneut abrufen.")
if not job.audio_path or not Path(job.audio_path).exists():
if not job.keep_audio:
raise HTTPException(
status_code=404,
detail="Keine Audio-Datei vorhanden. Bitte /speak mit keep_audio=true aufrufen.",
)
raise HTTPException(status_code=404, detail="Audio-Datei nicht mehr vorhanden.")
audio_path = Path(job.audio_path)
def cleanup_after_send():
try:
os.unlink(audio_path)
job.audio_path = None
except OSError:
pass
response = FileResponse(
path=str(audio_path),
media_type="audio/wav",
filename=f"tts_{job_id[:8]}.wav",
background=None,
)
# Datei nach dem Senden löschen — via BackgroundTask
from starlette.background import BackgroundTask
response.background = BackgroundTask(cleanup_after_send)
return response
@app.get("/status")
def status():
with _state_lock:
cur = _current_job
recent = list(_recent_jobs)
return {
"current_job": _job_to_dict(cur) if cur else None,
"queue_length": _job_queue.qsize(),
"recent_jobs": [_job_to_dict(j) for j in reversed(recent)],
}