#!/usr/bin/env python3 """ Chatterbox TTS – lokaler HTTP-Service Start: uvicorn tts_service:app --host 0.0.0.0 --port 9999 Endpunkte: POST /speak – Text in Warteschlange einreihen POST /stop – laufende Ausgabe abbrechen, Queue leeren GET /health – Service-Status GET /status – aktueller Job + Queue-Länge GET /voices – unterstützte Sprachen """ from __future__ import annotations import queue import sys import threading import uuid from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Optional # CLI-Modul aus demselben Verzeichnis laden sys.path.insert(0, str(Path(__file__).parent)) import chatterbox_cli_v4 as tts # noqa: E402 import torch import torchaudio as ta from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field # --------------------------------------------------------------------------- # Gerät einmalig bestimmen # --------------------------------------------------------------------------- _DEVICE = tts.get_device(None) # --------------------------------------------------------------------------- # Modell-Cache (lang, t3_model) → (model, model_kind, sr) # --------------------------------------------------------------------------- _model_cache: dict[tuple, tuple] = {} _model_lock = threading.Lock() def _get_or_load_model(lang: str, t3_model: str) -> tuple: key = (lang, t3_model) with _model_lock: if key not in _model_cache: _model_cache[key] = tts.load_model(lang, _DEVICE, t3_model=t3_model) return _model_cache[key] # Optionaler Warmup: TTS_PRELOAD_LANG=de lädt das Modell beim Service-Start, # damit der erste Request keine Modell-Ladezeit hat. _PRELOAD_LANG = __import__("os").environ.get("TTS_PRELOAD_LANG") if _PRELOAD_LANG: _preload_t3 = __import__("os").environ.get("TTS_PRELOAD_T3", "v3") try: _get_or_load_model(_PRELOAD_LANG, _preload_t3) print(f"[chatterbox-tts] Modell vorgeladen: lang={_PRELOAD_LANG}, t3={_preload_t3}") except Exception as _e: print(f"[chatterbox-tts] Warnung: Warmup fehlgeschlagen: {_e}") # --------------------------------------------------------------------------- # Job-Datenmodell # --------------------------------------------------------------------------- class JobStatus(str, Enum): pending = "pending" running = "running" done = "done" cancelled = "cancelled" error = "error" @dataclass class SpeakJob: id: str text: str lang: str t3_model: str voice: Optional[str] speed: float audio_device: str max_len: int save_wav: bool output_path: Optional[str] pronunciation_dict: Optional[dict] session_id: Optional[str] status: JobStatus = field(default=JobStatus.pending) text_preview: str = field(default="") chunks_total: int = 0 chunks_done: int = 0 error: Optional[str] = None # --------------------------------------------------------------------------- # Worker-Thread # --------------------------------------------------------------------------- _job_queue: queue.Queue[SpeakJob] = queue.Queue() _current_job: Optional[SpeakJob] = None _state_lock = threading.Lock() _recent_jobs: list[SpeakJob] = [] _MAX_RECENT = 20 def _worker() -> None: global _current_job while True: job = _job_queue.get() with _state_lock: _current_job = job job.status = JobStatus.running tts.clear_stop() try: model, model_kind, sr = _get_or_load_model(job.lang, job.t3_model) raw = tts.clean_raw_text(job.text) raw_chunks = tts.split_into_sentences(raw, max_len=job.max_len) chunks = [ tts.preprocess_tts_text(c, lang=job.lang, pronunciation_dict=job.pronunciation_dict) for c in raw_chunks ] chunks = [c for c in chunks if c.strip()] job.chunks_total = len(chunks) job.text_preview = job.text[:80] playback = tts.PlaybackWorker( sample_rate=sr, device=job.audio_device or "pulse", speed=job.speed, stop_event=tts.STOP_REQUESTED, ) playback.start() wavs: list[torch.Tensor] = [] try: for chunk in chunks: if tts.stop_requested(): break wav = tts.generate_chunk(model, model_kind, chunk, job.lang, job.voice) wavs.append(wav) playback.put(wav) job.chunks_done += 1 finally: playback.stop() if job.save_wav and job.output_path and wavs: out = Path(job.output_path) out.parent.mkdir(parents=True, exist_ok=True) final = wavs[0] if len(wavs) == 1 else torch.cat(wavs, dim=-1) ta.save(str(out), final, sr) job.status = ( JobStatus.cancelled if tts.stop_requested() else JobStatus.done ) except Exception as exc: # noqa: BLE001 job.status = JobStatus.error job.error = str(exc) finally: with _state_lock: _current_job = None _recent_jobs.append(job) if len(_recent_jobs) > _MAX_RECENT: _recent_jobs.pop(0) _job_queue.task_done() _worker_thread = threading.Thread(target=_worker, daemon=True, name="tts-worker") _worker_thread.start() # --------------------------------------------------------------------------- # API-Modelle # --------------------------------------------------------------------------- class SpeakRequest(BaseModel): text: str = Field(min_length=1, max_length=4000) lang: str = "de" voice: Optional[str] = None interrupt: bool = False speed: float = Field(default=1.0, ge=0.5, le=2.0) t3_model: str = "v3" audio_device: Optional[str] = None max_len: int = Field(default=400, ge=100, le=1000) save_wav: bool = False output_path: Optional[str] = None session_id: Optional[str] = None pronunciation_dict: Optional[dict] = None def _job_to_dict(j: SpeakJob) -> dict: return { "id": j.id, "status": j.status, "lang": j.lang, "text_preview": j.text_preview, "chunks_total": j.chunks_total, "chunks_done": j.chunks_done, "error": j.error, } def _drain_queue() -> None: while not _job_queue.empty(): try: _job_queue.get_nowait() _job_queue.task_done() except queue.Empty: break # --------------------------------------------------------------------------- # FastAPI-App # --------------------------------------------------------------------------- app = FastAPI(title="Chatterbox TTS Service", version="1.0") @app.get("/health") def health(): return {"status": "ok", "device": _DEVICE} @app.get("/voices") def voices(): return { "languages": sorted(tts.SUPPORTED_LANGS), "note": "Voice cloning via 'voice' field (WAV-Pfad, 10–30s Aufnahme)", } @app.post("/speak") def speak(req: SpeakRequest): if req.lang not in tts.SUPPORTED_LANGS: raise HTTPException(status_code=422, detail=f"Sprache nicht unterstützt: {req.lang}") if req.voice and not Path(req.voice).exists(): raise HTTPException(status_code=422, detail=f"Voice-Datei nicht gefunden: {req.voice}") if req.interrupt: tts.request_stop() _drain_queue() job = SpeakJob( id=str(uuid.uuid4()), text=req.text, lang=req.lang, t3_model=req.t3_model, voice=req.voice, speed=req.speed, audio_device=req.audio_device, max_len=req.max_len, save_wav=req.save_wav, output_path=req.output_path, pronunciation_dict=req.pronunciation_dict, session_id=req.session_id, ) _job_queue.put(job) return { "job_id": job.id, "status": job.status, "queue_position": _job_queue.qsize(), } @app.post("/stop") def stop(): tts.request_stop() _drain_queue() return {"stopped": True} @app.post("/pause") def pause(): tts.request_pause() return {"paused": True} @app.post("/resume") def resume(): tts.request_resume() return {"resumed": True} @app.get("/status") def status(): with _state_lock: cur = _current_job recent = list(_recent_jobs) return { "current_job": _job_to_dict(cur) if cur else None, "queue_length": _job_queue.qsize(), "recent_jobs": [_job_to_dict(j) for j in reversed(recent)], }