#!/usr/bin/env python3
"""SIP client with Piper TTS, real-time Moonshine transcription, and local LLM.

Registers with a SIP server, waits for incoming calls, answers with a
spoken welcome message via Piper TTS, transcribes the caller's speech
in real time with Moonshine, feeds it to a local LLM which decides when
it has enough context to respond, then speaks the response back via TTS.
"""

import os
import sys
import time
import threading
import argparse
import signal
import subprocess
import numpy as np
import pjsua2 as pj

# Add moonshine python source to path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(SCRIPT_DIR, "moonshine", "python", "src"))

from llama_cpp import Llama
from moonshine_voice import (
    TranscriptEventListener,
    TranscriptLine,
    get_model_for_language,
    Transcriber,
)

# SIP settings

SIP_DOMAIN = "sip.de.anveo.com:5010"
SIP_USER = "user"
SIP_PASS = "secret"

# Piper TTS settings
PIPER_BIN = os.path.join(SCRIPT_DIR, "piper", "piper")
PIPER_MODEL = os.path.join(SCRIPT_DIR, "piper", "en_US-lessac-medium.onnx")
PIPER_SAMPLE_RATE = 22050

# Audio settings
PJSIP_CLOCK_RATE = 16000
PJSIP_FRAME_TIME_MS = 20

# LLM settings
LLM_MODEL_PATH = os.path.join(SCRIPT_DIR, "TinyLlama-1.1B-Chat-v1.0.Q4_K_M.gguf")
LLM_MAX_TOKENS = 200
LLM_TEMPERATURE = 0.7
LLM_HISTORY_TURNS = 8  # keep last N exchanges

SYSTEM_PROMPT = (
    "You are a helpful phone assistant. You receive live speech transcription "
    "from a caller. The transcription may be incomplete or still arriving. "
    "If the caller's message is complete enough to respond to meaningfully, "
    "provide a concise, helpful spoken response. "
    "If the message seems incomplete, cut off, or you need more context, "
    "respond with exactly the single word WAIT and nothing else."
)

# Welcome message
WELCOME_TEXT = "Hello! Thank you for calling. Please go ahead and speak."

# Debounce interval for LLM queries on partial text (seconds)
DEBOUNCE_SECONDS = 1.0


def generate_tts(text):
    """Run Piper TTS and return audio as int16 numpy array at 16kHz."""
    env = os.environ.copy()
    piper_dir = os.path.dirname(PIPER_BIN)
    env["LD_LIBRARY_PATH"] = piper_dir + ":" + env.get("LD_LIBRARY_PATH", "")

    proc = subprocess.Popen(
        [PIPER_BIN, "--model", PIPER_MODEL, "--output_raw", "-q"],
        stdin=subprocess.PIPE,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        env=env,
    )
    raw_audio, _ = proc.communicate(text.encode())

    # Piper outputs signed 16-bit PCM at 22050 Hz
    samples_22k = np.frombuffer(raw_audio, dtype=np.int16).astype(np.float32)

    # Resample 22050 -> 16000 via linear interpolation
    ratio = PJSIP_CLOCK_RATE / PIPER_SAMPLE_RATE
    n_out = int(len(samples_22k) * ratio)
    indices = np.arange(n_out) / ratio
    idx_floor = np.minimum(indices.astype(np.intp), len(samples_22k) - 2)
    frac = indices - idx_floor
    resampled = samples_22k[idx_floor] * (1.0 - frac) + samples_22k[idx_floor + 1] * frac

    return resampled.astype(np.int16)


def build_prompt(history, current_input):
    """Build a chat prompt for TinyLlama with conversation history."""
    msgs = [("system", SYSTEM_PROMPT)] + history + [("user", current_input)]
    prompt = ""
    for role, content in msgs:
        if role == "system":
            prompt += f"### System:\n{content}\n\n"
        elif role == "user":
            prompt += f"### User:\n{content}\n\n"
        else:
            prompt += f"### Assistant:\n{content}\n\n"
    prompt += "### Assistant:\n"
    return prompt


class PlayerPort(pj.AudioMediaPort):
    """Streams pre-generated TTS audio into the SIP call."""

    def __init__(self, audio_int16, on_finished=None):
        super().__init__()
        self.audio = audio_int16
        self.pos = 0
        self.finished = False
        self.on_finished = on_finished

    def onFrameRequested(self, frame):
        n_bytes = frame.size
        n_samples = n_bytes // 2
        remaining = len(self.audio) - self.pos

        if remaining <= 0:
            frame.type = pj.PJMEDIA_FRAME_TYPE_NONE
            if not self.finished:
                self.finished = True
                if self.on_finished:
                    self.on_finished()
            return

        n = min(n_samples, remaining)
        chunk = self.audio[self.pos:self.pos + n]
        self.pos += n

        if n < n_samples:
            chunk = np.concatenate([chunk, np.zeros(n_samples - n, dtype=np.int16)])
            self.finished = True
            if self.on_finished:
                self.on_finished()

        frame.type = pj.PJMEDIA_FRAME_TYPE_AUDIO
        frame.buf = pj.ByteVector(chunk.tobytes())


class TranscriberPort(pj.AudioMediaPort):
    """Captures call audio and feeds it to Moonshine for transcription."""

    def __init__(self, moonshine_stream):
        super().__init__()
        self.moonshine_stream = moonshine_stream

    def onFrameReceived(self, frame):
        if frame.size == 0:
            return
        raw = bytes(frame.buf)
        audio = np.frombuffer(raw, dtype=np.int16).astype(np.float32) / 32768.0
        self.moonshine_stream.add_audio(audio, PJSIP_CLOCK_RATE)


class LLMListener(TranscriptEventListener):
    """Receives transcription events and queries the LLM when appropriate."""

    def __init__(self, llm, call_ref, debounce_seconds=DEBOUNCE_SECONDS):
        self.llm = llm
        self.call_ref = call_ref  # MyCall instance
        self.debounce_seconds = debounce_seconds
        self.lock = threading.Lock()
        self.llm_lock = threading.Lock()  # serialize LLM access (not thread-safe)
        self.completed_lines = []  # lines completed since last LLM response
        self.current_line_text = ""  # in-progress line text
        self.debounce_timer = None
        self.gen_thread = None
        self.cancel = threading.Event()
        self.speaking = False
        self.generating = False
        self.history = []  # conversation history: [(role, text), ...]
        self.stopped = False
        # snapshot of text sent to current LLM generation
        self._generating_text = ""

    def stop(self):
        self.stopped = True
        if self.debounce_timer:
            self.debounce_timer.cancel()
        self.cancel.set()

    def _get_accumulated_text(self):
        """Get all text since last LLM response."""
        parts = list(self.completed_lines)
        if self.current_line_text.strip():
            parts.append(self.current_line_text.strip())
        return " ".join(parts)

    def _interrupt_playback(self):
        """Stop TTS playback (barge-in) and trigger LLM with new input."""
        if not self.speaking:
            return
        print("\n[BARGE-IN] caller interrupted, stopping playback")
        self.speaking = False
        self.call_ref.stop_playback()
        # Trigger LLM with whatever has accumulated
        self._schedule_trigger()

    def _schedule_trigger(self):
        """Reset the debounce timer."""
        if self.debounce_timer:
            self.debounce_timer.cancel()
        self.debounce_timer = threading.Timer(self.debounce_seconds, self._trigger_llm)
        self.debounce_timer.start()

    def on_line_started(self, event):
        with self.lock:
            self.current_line_text = ""

    def on_line_text_changed(self, event):
        with self.lock:
            self.current_line_text = event.line.text
            text = self._get_accumulated_text()

        # Print to terminal for debugging
        print(f"\r>> {text}   ", end="", flush=True)

        if self.stopped:
            return

        # Barge-in: caller is speaking while TTS is playing
        if self.speaking:
            self._interrupt_playback()
            return

        if self.generating:
            return

        self._schedule_trigger()

    def on_line_completed(self, event):
        with self.lock:
            line_text = event.line.text.strip()
            if line_text:
                self.completed_lines.append(line_text)
            self.current_line_text = ""

        print(f"\r>> {line_text}")

        if self.stopped:
            return

        if self.speaking:
            self._interrupt_playback()
            return

        if self.generating:
            return

        # Cancel debounce and trigger immediately on completed line
        if self.debounce_timer:
            self.debounce_timer.cancel()
            self.debounce_timer = None
        self._trigger_llm()

    def _trigger_llm(self):
        """Start LLM generation in background thread."""
        if self.speaking or self.generating or self.stopped:
            return

        with self.lock:
            text = self._get_accumulated_text()

        if not text.strip():
            return

        self.generating = True
        self._generating_text = text
        self.gen_thread = threading.Thread(
            target=self._generate, args=(text,), daemon=True)
        self.gen_thread.start()

    def _generate(self, caller_text):
        """Run LLM generation. Called from background thread."""
        print(f"\n[LLM] thinking... (input: {caller_text!r})")

        prompt = build_prompt(self.history, caller_text)

        response = ""
        try:
            with self.llm_lock:
                for chunk in self.llm(
                    prompt,
                    max_tokens=LLM_MAX_TOKENS,
                    temperature=LLM_TEMPERATURE,
                    stop=["### User:", "### System:"],
                    stream=True,
                ):
                    if self.cancel.is_set() or self.stopped:
                        print("[LLM] cancelled")
                        return

                    token = chunk["choices"][0]["text"]
                    response += token

                    # Early WAIT detection
                    stripped = response.strip()
                    if stripped.upper() == "WAIT" or stripped.upper().startswith("WAIT"):
                        # Check it's really just WAIT, not a word starting with WAIT
                        if len(stripped) <= 5:
                            print("[LLM] WAIT - need more input")
                            return
        except Exception as e:
            print(f"[LLM] error: {e}")
            return
        finally:
            self.generating = False

        response = response.strip()
        if not response or response.upper() == "WAIT":
            print("[LLM] WAIT - need more input")
            return

        # Check if new text arrived while we were generating
        with self.lock:
            current_text = self._get_accumulated_text()
        if current_text != caller_text:
            # More speech arrived — don't play this response, retrigger
            print(f"[LLM] discarding, new input arrived")
            self._schedule_trigger()
            return

        print(f"[LLM] response: {response}")

        # Update conversation history
        with self.lock:
            self.history.append(("user", caller_text))
            self.history.append(("assistant", response))
            self.history = self.history[-(LLM_HISTORY_TURNS * 2):]
            self.completed_lines.clear()
            self.current_line_text = ""

        # Generate TTS and play response
        self.speaking = True
        try:
            print("[TTS] generating audio...")
            audio = generate_tts(response)
            print(f"[TTS] playing ({len(audio) / PJSIP_CLOCK_RATE:.1f}s)")
            self.call_ref.play_response(audio, on_finished=self._on_playback_done)
        except Exception as e:
            print(f"[TTS] error: {e}")
            self.speaking = False

    def _on_playback_done(self):
        """Called when TTS playback finishes."""
        print("[TTS] playback done, listening again")
        self.speaking = False
        # Check if speech arrived during playback
        with self.lock:
            text = self._get_accumulated_text()
        if text.strip():
            self._schedule_trigger()


class MyCall(pj.Call):
    def __init__(self, acc, transcriber, welcome_audio, llm,
                 debounce_seconds=DEBOUNCE_SECONDS,
                 call_id=pj.PJSUA_INVALID_ID):
        super().__init__(acc, call_id)
        self.transcriber = transcriber
        self.welcome_audio = welcome_audio
        self.llm = llm
        self.debounce_seconds = debounce_seconds
        self.moon_stream = None
        self.transcriber_port = None
        self.player_port = None
        self.llm_listener = None
        self.answer_at = None
        self.call_audio = None
        self.audio_fmt = None

    def onCallState(self, prm):
        try:
            ci = self.getInfo()
            print(f"[CALL] {ci.stateText}")
            if ci.state == pj.PJSIP_INV_STATE_DISCONNECTED:
                self._on_hangup()
        except Exception as e:
            print(f"[CALL] error: {e}")

    def onCallMediaState(self, prm):
        ci = self.getInfo()
        for i, mi in enumerate(ci.media):
            if (mi.type == pj.PJMEDIA_TYPE_AUDIO and
                    mi.status == pj.PJSUA_CALL_MEDIA_ACTIVE):
                self._start_media(i)
                return

    def _start_media(self, media_idx):
        try:
            self.call_audio = pj.AudioMedia.typecastFromMedia(
                self.getMedia(media_idx))

            self.audio_fmt = pj.MediaFormatAudio()
            self.audio_fmt.type = pj.PJMEDIA_TYPE_AUDIO
            self.audio_fmt.clockRate = PJSIP_CLOCK_RATE
            self.audio_fmt.channelCount = 1
            self.audio_fmt.bitsPerSample = 16
            self.audio_fmt.frameTimeUsec = PJSIP_FRAME_TIME_MS * 1000

            # Play welcome message into the call
            self.player_port = PlayerPort(self.welcome_audio)
            self.player_port.createPort("piper_tts", self.audio_fmt)
            self.player_port.startTransmit(self.call_audio)
            print("[CALL] playing welcome message...")

            # Start transcription of caller's audio
            self.moon_stream = self.transcriber.create_stream(0.5)
            self.llm_listener = LLMListener(self.llm, self, self.debounce_seconds)
            self.moon_stream.add_listener(self.llm_listener)
            self.moon_stream.start()

            self.transcriber_port = TranscriberPort(self.moon_stream)
            self.transcriber_port.createPort("moonshine", self.audio_fmt)
            self.call_audio.startTransmit(self.transcriber_port)
            print("[CALL] transcribing + LLM active...")

        except Exception as e:
            print(f"[CALL] media setup failed: {e}")
            import traceback
            traceback.print_exc()

    def stop_playback(self):
        """Stop current TTS playback (barge-in)."""
        if self.player_port and not self.player_port.finished:
            try:
                self.player_port.stopTransmit(self.call_audio)
            except Exception:
                pass
            self.player_port.finished = True
            print("[CALL] playback stopped")

    def play_response(self, audio_int16, on_finished=None):
        """Play a TTS response into the call. Called from LLM thread."""
        if self.call_audio is None or self.audio_fmt is None:
            print("[CALL] no active audio channel for playback")
            if on_finished:
                on_finished()
            return

        # Stop previous player if still active
        if self.player_port and not self.player_port.finished:
            try:
                self.player_port.stopTransmit(self.call_audio)
            except Exception:
                pass

        self.player_port = PlayerPort(audio_int16, on_finished=on_finished)
        self.player_port.createPort("llm_response", self.audio_fmt)
        self.player_port.startTransmit(self.call_audio)

    def _on_hangup(self):
        if self.llm_listener:
            self.llm_listener.stop()
            self.llm_listener = None
        if self.moon_stream:
            try:
                self.moon_stream.stop()
            except Exception:
                pass
            self.moon_stream = None
        self.transcriber_port = None
        self.player_port = None
        self.call_audio = None
        print("[CALL] ended")


class MyAccount(pj.Account):
    def __init__(self, transcriber, welcome_audio, llm, debounce_seconds=DEBOUNCE_SECONDS):
        super().__init__()
        self.transcriber = transcriber
        self.welcome_audio = welcome_audio
        self.llm = llm
        self.debounce_seconds = debounce_seconds
        self.current_call = None

    def onIncomingCall(self, prm):
        print("[SIP] incoming call")
        self.current_call = MyCall(
            self, self.transcriber, self.welcome_audio, self.llm,
            self.debounce_seconds, prm.callId)

        ring = pj.CallOpParam()
        ring.statusCode = 180
        self.current_call.answer(ring)

        self.current_call.answer_at = time.time() + 2.0


g_acc = None


def handle_sigint(sig, frame):
    print("\n[SIP] shutting down")
    os._exit(0)


def main():
    global g_acc

    parser = argparse.ArgumentParser(
        description="SIP phone with LLM-powered conversational response")
    parser.add_argument("--transport", choices=["udp", "tcp"], default="udp")
    parser.add_argument("--language", type=str, default="en")
    parser.add_argument("--model-arch", type=int, default=None,
                        help="Moonshine model arch: 0=tiny, 1=base, "
                             "2=tiny-streaming, 4=small-streaming, "
                             "5=medium-streaming")
    parser.add_argument("--welcome", type=str, default=WELCOME_TEXT,
                        help="Welcome message text")
    parser.add_argument("--llm-model", type=str, default=LLM_MODEL_PATH,
                        help="Path to GGUF LLM model")
    parser.add_argument("--debounce", type=float, default=DEBOUNCE_SECONDS,
                        help="Debounce interval in seconds for LLM queries")
    args = parser.parse_args()

    debounce = args.debounce

    signal.signal(signal.SIGINT, handle_sigint)

    # Load LLM
    print(f"[LLM] loading {os.path.basename(args.llm_model)}...")
    llm = Llama(
        model_path=args.llm_model,
        n_ctx=2048,
        n_threads=4,
        n_batch=256,
        f16_kv=True,
        use_mmap=True,
    )
    print("[LLM] ready")

    # Pre-generate welcome TTS audio
    print("[PIPER] generating welcome message...")
    welcome_audio = generate_tts(args.welcome)
    duration = len(welcome_audio) / PJSIP_CLOCK_RATE
    print(f"[PIPER] ready ({duration:.1f}s of audio)")

    # Load Moonshine model
    print("[MOONSHINE] loading model...")
    if args.model_arch is not None:
        model_path, model_arch = get_model_for_language(
            args.language, args.model_arch)
    else:
        model_path, model_arch = get_model_for_language(args.language)

    print(f"[MOONSHINE] model: {model_path} (arch={model_arch})")
    transcriber = Transcriber(model_path, model_arch,
                              options={"identify_speakers": "false"})
    print("[MOONSHINE] ready")

    # Start SIP
    ep = pj.Endpoint()
    ep.libCreate()

    ep_cfg = pj.EpConfig()
    ep_cfg.logConfig.level = 3
    ep_cfg.logConfig.consoleLevel = 3
    ep_cfg.medConfig.clockRate = PJSIP_CLOCK_RATE
    ep_cfg.medConfig.sndClockRate = PJSIP_CLOCK_RATE

    ep.libInit(ep_cfg)

    tcfg = pj.TransportConfig()
    tcfg.port = 5060
    transport_type = (pj.PJSIP_TRANSPORT_TCP if args.transport == "tcp"
                      else pj.PJSIP_TRANSPORT_UDP)
    ep.transportCreate(transport_type, tcfg)

    ep.libStart()
    ep.audDevManager().setNullDev()
    print(f"[SIP] started ({args.transport})")

    acc_cfg = pj.AccountConfig()
    if args.transport == "tcp":
        acc_cfg.idUri = f"sip:{SIP_USER}@{SIP_DOMAIN};transport=tcp"
        acc_cfg.regConfig.registrarUri = f"sip:{SIP_DOMAIN};transport=tcp"
    else:
        acc_cfg.idUri = f"sip:{SIP_USER}@{SIP_DOMAIN}"
        acc_cfg.regConfig.registrarUri = f"sip:{SIP_DOMAIN}"

    acc_cfg.sipConfig.authCreds.append(
        pj.AuthCredInfo("digest", "*", SIP_USER, 0, SIP_PASS)
    )

    g_acc = MyAccount(transcriber, welcome_audio, llm, debounce)
    g_acc.create(acc_cfg)
    print("[SIP] registered, waiting for calls...")

    while True:
        if g_acc.current_call and g_acc.current_call.answer_at:
            if time.time() >= g_acc.current_call.answer_at:
                try:
                    print("[SIP] answering")
                    ans = pj.CallOpParam()
                    ans.statusCode = 200
                    g_acc.current_call.answer(ans)
                except Exception as e:
                    print(f"[SIP] answer failed: {e}")
                g_acc.current_call.answer_at = None
        time.sleep(0.1)


if __name__ == "__main__":
    main()

