chatterbox-tts-cli/tts_service.py

282 lines
8 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""
Chatterbox TTS lokaler HTTP-Service
Start:
uvicorn tts_service:app --host 127.0.0.1 --port 8000
Endpunkte:
POST /speak Text in Warteschlange einreihen
POST /stop laufende Ausgabe abbrechen, Queue leeren
GET /health Service-Status
GET /status aktueller Job + Queue-Länge
GET /voices unterstützte Sprachen
"""
from __future__ import annotations
import queue
import sys
import threading
import uuid
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Optional
# CLI-Modul aus demselben Verzeichnis laden
sys.path.insert(0, str(Path(__file__).parent))
import chatterbox_cli_v4 as tts # noqa: E402
import torch
import torchaudio as ta
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
# ---------------------------------------------------------------------------
# Gerät einmalig bestimmen
# ---------------------------------------------------------------------------
_DEVICE = tts.get_device(None)
# ---------------------------------------------------------------------------
# Modell-Cache (lang, t3_model) → (model, model_kind, sr)
# ---------------------------------------------------------------------------
_model_cache: dict[tuple, tuple] = {}
_model_lock = threading.Lock()
def _get_or_load_model(lang: str, t3_model: str) -> tuple:
key = (lang, t3_model)
with _model_lock:
if key not in _model_cache:
_model_cache[key] = tts.load_model(lang, _DEVICE, t3_model=t3_model)
return _model_cache[key]
# ---------------------------------------------------------------------------
# Job-Datenmodell
# ---------------------------------------------------------------------------
class JobStatus(str, Enum):
pending = "pending"
running = "running"
done = "done"
cancelled = "cancelled"
error = "error"
@dataclass
class SpeakJob:
id: str
text: str
lang: str
t3_model: str
voice: Optional[str]
speed: float
audio_device: str
max_len: int
save_wav: bool
output_path: Optional[str]
pronunciation_dict: Optional[dict]
session_id: Optional[str]
status: JobStatus = field(default=JobStatus.pending)
text_preview: str = field(default="")
chunks_total: int = 0
chunks_done: int = 0
error: Optional[str] = None
# ---------------------------------------------------------------------------
# Worker-Thread
# ---------------------------------------------------------------------------
_job_queue: queue.Queue[SpeakJob] = queue.Queue()
_current_job: Optional[SpeakJob] = None
_state_lock = threading.Lock()
_recent_jobs: list[SpeakJob] = []
_MAX_RECENT = 20
def _worker() -> None:
global _current_job
while True:
job = _job_queue.get()
with _state_lock:
_current_job = job
job.status = JobStatus.running
tts.clear_stop()
try:
model, model_kind, sr = _get_or_load_model(job.lang, job.t3_model)
raw = tts.clean_raw_text(job.text)
raw_chunks = tts.split_into_sentences(raw, max_len=job.max_len)
chunks = [
tts.preprocess_tts_text(c, lang=job.lang,
pronunciation_dict=job.pronunciation_dict)
for c in raw_chunks
]
chunks = [c for c in chunks if c.strip()]
job.chunks_total = len(chunks)
job.text_preview = job.text[:80]
playback = tts.PlaybackWorker(
sample_rate=sr,
device=job.audio_device,
speed=job.speed,
stop_event=tts.STOP_REQUESTED,
)
playback.start()
wavs: list[torch.Tensor] = []
try:
for chunk in chunks:
if tts.stop_requested():
break
wav = tts.generate_chunk(model, model_kind, chunk, job.lang, job.voice)
wavs.append(wav)
playback.put(wav)
job.chunks_done += 1
finally:
playback.stop()
if job.save_wav and job.output_path and wavs:
out = Path(job.output_path)
out.parent.mkdir(parents=True, exist_ok=True)
final = wavs[0] if len(wavs) == 1 else torch.cat(wavs, dim=-1)
ta.save(str(out), final, sr)
job.status = (
JobStatus.cancelled if tts.stop_requested() else JobStatus.done
)
except Exception as exc: # noqa: BLE001
job.status = JobStatus.error
job.error = str(exc)
finally:
with _state_lock:
_current_job = None
_recent_jobs.append(job)
if len(_recent_jobs) > _MAX_RECENT:
_recent_jobs.pop(0)
_job_queue.task_done()
_worker_thread = threading.Thread(target=_worker, daemon=True, name="tts-worker")
_worker_thread.start()
# ---------------------------------------------------------------------------
# API-Modelle
# ---------------------------------------------------------------------------
class SpeakRequest(BaseModel):
text: str = Field(min_length=1, max_length=4000)
lang: str = "de"
voice: Optional[str] = None
interrupt: bool = False
speed: float = Field(default=1.0, ge=0.5, le=2.0)
t3_model: str = "v3"
audio_device: str = "pulse"
max_len: int = Field(default=400, ge=50, le=1000)
save_wav: bool = False
output_path: Optional[str] = None
session_id: Optional[str] = None
pronunciation_dict: Optional[dict] = None
def _job_to_dict(j: SpeakJob) -> dict:
return {
"id": j.id,
"status": j.status,
"lang": j.lang,
"text_preview": j.text_preview,
"chunks_total": j.chunks_total,
"chunks_done": j.chunks_done,
"error": j.error,
}
def _drain_queue() -> None:
while not _job_queue.empty():
try:
_job_queue.get_nowait()
_job_queue.task_done()
except queue.Empty:
break
# ---------------------------------------------------------------------------
# FastAPI-App
# ---------------------------------------------------------------------------
app = FastAPI(title="Chatterbox TTS Service", version="1.0")
@app.get("/health")
def health():
return {"status": "ok", "device": _DEVICE}
@app.get("/voices")
def voices():
return {
"languages": sorted(tts.SUPPORTED_LANGS),
"note": "Voice cloning via 'voice' field (WAV-Pfad, 1030s Aufnahme)",
}
@app.post("/speak")
def speak(req: SpeakRequest):
if req.lang not in tts.SUPPORTED_LANGS:
raise HTTPException(status_code=422,
detail=f"Sprache nicht unterstützt: {req.lang}")
if req.voice and not Path(req.voice).exists():
raise HTTPException(status_code=422,
detail=f"Voice-Datei nicht gefunden: {req.voice}")
if req.interrupt:
tts.request_stop()
_drain_queue()
job = SpeakJob(
id=str(uuid.uuid4()),
text=req.text,
lang=req.lang,
t3_model=req.t3_model,
voice=req.voice,
speed=req.speed,
audio_device=req.audio_device,
max_len=req.max_len,
save_wav=req.save_wav,
output_path=req.output_path,
pronunciation_dict=req.pronunciation_dict,
session_id=req.session_id,
)
_job_queue.put(job)
return {
"job_id": job.id,
"status": job.status,
"queue_position": _job_queue.qsize(),
}
@app.post("/stop")
def stop():
tts.request_stop()
_drain_queue()
return {"stopped": True}
@app.get("/status")
def status():
with _state_lock:
cur = _current_job
recent = list(_recent_jobs)
return {
"current_job": _job_to_dict(cur) if cur else None,
"queue_length": _job_queue.qsize(),
"recent_jobs": [_job_to_dict(j) for j in reversed(recent)],
}