import asyncio
import io
import logging
import re
import uuid
from datetime import datetime
from pathlib import Path
from typing import Optional

import httpx
from dotenv import load_dotenv
from fastapi import BackgroundTasks, FastAPI, HTTPException
from fastapi.responses import Response
from pydantic import BaseModel
from pydub import AudioSegment

load_dotenv()

import os

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    handlers=[
        logging.FileHandler("/var/www/tts-service/tts.log"),
        logging.StreamHandler(),
    ],
)
logger = logging.getLogger(__name__)

API_KEY = os.getenv("ELEVENLABS_API_KEY", "")
ELEVENLABS_BASE = "https://api.elevenlabs.io/v1/text-to-speech"
SEPARATORS_DIR = Path(__file__).parent / "separators"
JOBS_DIR = Path("/tmp/tts-jobs")
JOBS_DIR.mkdir(exist_ok=True)
ARCHIVE_DIR = Path("/var/www/tts-archive")
ARCHIVE_DIR.mkdir(exist_ok=True)
MAX_CHUNK_SIZE = 4500
AUDIO_TAG_PATTERN = re.compile(r"\[(?:pause|sighs|excited|clears throat)\]")
TRENNER_PATTERN = re.compile(r"---TRENNER:(.*?)---")

app = FastAPI(title="TTS Service")


class TTSRequest(BaseModel):
    script: str
    voice_id: str
    model_id: str = "eleven_v3"
    speed: float = 1.0
    stability: float = 0.5
    similarity_boost: Optional[float] = 0.75
    style: Optional[float] = 0.5
    use_speaker_boost: Optional[bool] = True
    greeting: Optional[str] = None
    callback_url: Optional[str] = None


def is_v3(model_id: str) -> bool:
    return model_id.startswith("eleven_v3")


def split_script(script: str) -> list[tuple[str, Optional[str]]]:
    """Split script at ---TRENNER:Name--- tags.

    Returns list of (text, separator_name) tuples.
    The separator_name is the one that follows the text segment (None for the last segment).
    """
    parts = TRENNER_PATTERN.split(script)
    # parts alternates: [text, separator_name, text, separator_name, ..., text]
    result = []
    for i in range(0, len(parts), 2):
        text = parts[i].strip()
        sep_name = parts[i + 1].strip() if i + 1 < len(parts) else None
        if text:
            result.append((text, sep_name))
        elif sep_name is not None:
            # Empty text but there's a separator — attach separator to next segment
            if result:
                # Replace previous separator
                prev_text, _ = result[-1]
                result[-1] = (prev_text, sep_name)
    return result


def clean_text(text: str, v3: bool) -> str:
    """Remove audio tags for non-v3 models and normalize whitespace."""
    if not v3:
        text = AUDIO_TAG_PATTERN.sub("", text)
    text = re.sub(r" {2,}", " ", text)
    return text.strip()


def chunk_text(text: str) -> list[str]:
    """Split text into chunks of max MAX_CHUNK_SIZE characters."""
    if len(text) <= MAX_CHUNK_SIZE:
        return [text]

    chunks = []
    # First split by paragraphs
    paragraphs = text.split("\n\n")
    current = ""

    for para in paragraphs:
        if len(current) + len(para) + 2 <= MAX_CHUNK_SIZE:
            current = current + "\n\n" + para if current else para
        else:
            if current:
                chunks.append(current)
            # If single paragraph exceeds limit, split at sentence boundaries
            if len(para) > MAX_CHUNK_SIZE:
                sentences = re.split(r"(?<=[.!?])\s+", para)
                current = ""
                for sentence in sentences:
                    if len(current) + len(sentence) + 1 <= MAX_CHUNK_SIZE:
                        current = current + " " + sentence if current else sentence
                    else:
                        if current:
                            chunks.append(current)
                        current = sentence
            else:
                current = para

    if current:
        chunks.append(current)

    return chunks


def build_request_body(
    text: str, req: TTSRequest, prev_text: str = "", next_text: str = ""
) -> dict:
    """Build the ElevenLabs API request body."""
    body: dict = {
        "text": text,
        "model_id": req.model_id,
        "voice_settings": {"stability": req.stability},
        "speed": req.speed,
    }

    if not is_v3(req.model_id):
        body["voice_settings"]["similarity_boost"] = req.similarity_boost
        body["voice_settings"]["style"] = req.style
        body["voice_settings"]["use_speaker_boost"] = req.use_speaker_boost
        if prev_text:
            body["previous_text"] = prev_text[-200:]
        if next_text:
            body["next_text"] = next_text[:200]

    return body


async def call_elevenlabs(client: httpx.AsyncClient, voice_id: str, body: dict) -> bytes:
    """Call ElevenLabs TTS API and return MP3 bytes."""
    url = f"{ELEVENLABS_BASE}/{voice_id}"
    headers = {
        "Accept": "audio/mpeg",
        "Content-Type": "application/json",
        "xi-api-key": API_KEY,
    }

    response = await client.post(url, json=body, headers=headers, timeout=120.0)

    if response.status_code != 200:
        raise HTTPException(
            status_code=502,
            detail=f"ElevenLabs API error {response.status_code}: {response.text[:500]}",
        )

    return response.content


async def tts_chunks_v3(
    client: httpx.AsyncClient, chunks: list[str], req: TTSRequest
) -> list[bytes]:
    """Process chunks in parallel for v3 models."""
    tasks = [
        call_elevenlabs(client, req.voice_id, build_request_body(chunk, req))
        for chunk in chunks
    ]
    return await asyncio.gather(*tasks)


async def tts_chunks_legacy(
    client: httpx.AsyncClient, chunks: list[str], req: TTSRequest
) -> list[bytes]:
    """Process chunks sequentially with stitching for non-v3 models."""
    results = []
    for i, chunk in enumerate(chunks):
        prev_text = chunks[i - 1] if i > 0 else ""
        next_text = chunks[i + 1] if i + 1 < len(chunks) else ""
        body = build_request_body(chunk, req, prev_text, next_text)
        result = await call_elevenlabs(client, req.voice_id, body)
        results.append(result)
    return results


def load_separator(name: str) -> AudioSegment:
    """Load a separator MP3 file by name. Falls back to 1s silence."""
    filepath = SEPARATORS_DIR / f"{name}.mp3"
    if filepath.exists():
        return AudioSegment.from_mp3(filepath)
    # Fallback: 1 second silence
    return AudioSegment.silent(duration=1000)


def combine_audio(segments: list[tuple[bytes, Optional[str]]]) -> bytes:
    """Combine TTS audio segments with separator audio files.

    segments: list of (mp3_bytes, separator_name) tuples
    """
    combined = AudioSegment.empty()

    for i, (audio_bytes, sep_name) in enumerate(segments):
        audio = AudioSegment.from_mp3(io.BytesIO(audio_bytes))
        combined += audio

        if sep_name is not None:
            separator = load_separator(sep_name)
            combined += separator

    # Export to MP3
    output = io.BytesIO()
    combined.export(output, format="mp3", bitrate="192k")
    return output.getvalue()


@app.post("/api/tts")
async def generate_tts(req: TTSRequest):
    logger.info(f"Request received: voice_id={req.voice_id}, model={req.model_id}, script_len={len(req.script)}")
    logger.info(f"Script preview: {req.script[:200]!r}")

    if not API_KEY:
        raise HTTPException(status_code=500, detail="ELEVENLABS_API_KEY not configured")

    if not req.script.strip():
        raise HTTPException(status_code=400, detail="Script is empty")

    # 1. Greeting
    script = req.script
    if req.greeting:
        script = req.greeting + " " + script

    # 2. Split at TRENNER tags
    parts = split_script(script)
    logger.info(f"Split into {len(parts)} parts: {[(len(t), s) for t, s in parts]}")
    if not parts:
        raise HTTPException(status_code=400, detail="Script has no content after splitting")

    v3 = is_v3(req.model_id)

    # 3-6. Process each part through TTS
    audio_segments: list[tuple[bytes, Optional[str]]] = []

    async with httpx.AsyncClient() as client:
        for text, sep_name in parts:
            cleaned = clean_text(text, v3)
            if not cleaned:
                continue

            chunks = chunk_text(cleaned)

            if v3:
                chunk_audio_list = await tts_chunks_v3(client, chunks, req)
            else:
                chunk_audio_list = await tts_chunks_legacy(client, chunks, req)

            # Concatenate all chunks of this part into one
            part_audio = b"".join(chunk_audio_list)
            audio_segments.append((part_audio, sep_name))

    if not audio_segments:
        raise HTTPException(status_code=400, detail="No audio generated")

    logger.info(f"Audio segments: {len(audio_segments)}, sizes: {[len(a) for a, _ in audio_segments]}")

    # 7-8. Combine with separators and export
    final_mp3 = combine_audio(audio_segments)
    logger.info(f"Final MP3 size: {len(final_mp3)} bytes")

    archive_path = ARCHIVE_DIR / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.mp3"
    archive_path.write_bytes(final_mp3)

    return Response(content=final_mp3, media_type="audio/mpeg")


async def _run_tts_job(job_id: str, req: TTSRequest):
    """Background task: generate TTS and write result to JOBS_DIR."""
    job_dir = JOBS_DIR / job_id
    job_dir.mkdir(exist_ok=True)
    try:
        script = req.script
        if req.greeting:
            script = req.greeting + " " + script

        parts = split_script(script)
        if not parts:
            (job_dir / "error.txt").write_text("Script has no content after splitting")
            return

        v3 = is_v3(req.model_id)
        audio_segments: list[tuple[bytes, Optional[str]]] = []

        async with httpx.AsyncClient() as client:
            for text, sep_name in parts:
                cleaned = clean_text(text, v3)
                if not cleaned:
                    continue
                chunks = chunk_text(cleaned)
                if v3:
                    chunk_audio_list = await tts_chunks_v3(client, chunks, req)
                else:
                    chunk_audio_list = await tts_chunks_legacy(client, chunks, req)
                audio_segments.append((b"".join(chunk_audio_list), sep_name))

        if not audio_segments:
            (job_dir / "error.txt").write_text("No audio generated")
            return

        final_mp3 = combine_audio(audio_segments)
        (job_dir / "result.mp3").write_bytes(final_mp3)
        archive_path = ARCHIVE_DIR / f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.mp3"
        archive_path.write_bytes(final_mp3)
        logger.info(f"Async job {job_id} done, MP3 size: {len(final_mp3)} bytes")

        if req.callback_url:
            try:
                async with httpx.AsyncClient() as cb_client:
                    await cb_client.post(
                        req.callback_url,
                        content=final_mp3,
                        headers={"Content-Type": "audio/mpeg", "X-Job-Id": job_id},
                        timeout=30.0,
                    )
                logger.info(f"Webhook delivered for job {job_id} to {req.callback_url}")
            except Exception as cb_err:
                logger.error(f"Webhook delivery failed for job {job_id}: {cb_err}")
    except Exception as e:
        logger.error(f"Async job {job_id} failed: {e}")
        (job_dir / "error.txt").write_text(str(e))
        if req.callback_url:
            try:
                async with httpx.AsyncClient() as cb_client:
                    await cb_client.post(
                        req.callback_url,
                        json={"error": str(e), "job_id": job_id},
                        headers={"Content-Type": "application/json", "X-Job-Id": job_id},
                        timeout=10.0,
                    )
            except Exception:
                pass


@app.post("/api/tts/async")
async def generate_tts_async(req: TTSRequest, background_tasks: BackgroundTasks):
    if not API_KEY:
        raise HTTPException(status_code=500, detail="ELEVENLABS_API_KEY not configured")
    if not req.script.strip():
        raise HTTPException(status_code=400, detail="Script is empty")

    job_id = str(uuid.uuid4())
    logger.info(f"Async job {job_id} queued: voice_id={req.voice_id}, script_len={len(req.script)}")
    background_tasks.add_task(_run_tts_job, job_id, req)
    return {"job_id": job_id, "status": "pending"}


@app.get("/api/tts/jobs/{job_id}")
async def get_tts_job(job_id: str):
    job_dir = JOBS_DIR / job_id
    if not job_dir.exists():
        raise HTTPException(status_code=404, detail="Job not found")

    result_file = job_dir / "result.mp3"
    error_file = job_dir / "error.txt"

    if result_file.exists():
        mp3 = result_file.read_bytes()
        return Response(content=mp3, media_type="audio/mpeg")
    elif error_file.exists():
        raise HTTPException(status_code=500, detail=error_file.read_text())
    else:
        return {"job_id": job_id, "status": "pending"}
