import os
import json
import asyncio

from flask import Flask, jsonify, request
from dotenv import load_dotenv

# ——————————————————————————————————————————————————————————————————
# Use aiohttp for the async login & conversation calls
# ——————————————————————————————————————————————————————————————————
import aiohttp

# Use websocket-client for a persistent WebSocket in synchronous code
from websocket import create_connection, WebSocketException, WebSocketConnectionClosedException
from websocket import create_connection, WebSocketException

# ——————————————————————————————————————————————————————————————————
# 1) Load environment variables from .env
# ——————————————————————————————————————————————————————————————————
load_dotenv()

CTRL_BASE     = os.getenv("CTRL_BASE",     "http://44.204.7.58")
CTRL_EMAIL    = os.getenv("CTRL_EMAIL",    None)
CTRL_PASSWORD = os.getenv("CTRL_PASSWORD", None)
ACCOUNT_ID    = os.getenv("ACCOUNT_ID",    None)
AGENT_ID      = os.getenv("AGENT_ID",      None)

if not all([CTRL_BASE, CTRL_EMAIL, CTRL_PASSWORD, ACCOUNT_ID, AGENT_ID]):
    raise RuntimeError(
        "Please set CTRL_BASE, CTRL_EMAIL, CTRL_PASSWORD, ACCOUNT_ID, AGENT_ID in your .env"
    )

app = Flask(__name__)

# ——————————————————————————————————————————————————————————————————
# 2) In‐memory store for active “stream_id → { token, conv_id, ws_url, ws_obj }”
#    Once initialized, each stream_id keeps one live WebSocket
# ——————————————————————————————————————————————————————————————————
#
# connections = {
#   "some_stream": {
#       "access_token": "...",
#       "conversation_id": "...",
#       "ws_url": "...",
#       "ws": <websocket-client.WebSocket>
#   },
#   ...
# }
#
connections: dict[str, dict] = {}


# ——————————————————————————————————————————————————————————————————
# 3) Async helper: log in to CtrlAgent, return access_token
# ——————————————————————————————————————————————————————————————————
async def login_ctrlagent(email: str, password: str) -> str:
    """
    POST to {CTRL_BASE}/ctrlagent/auth/login
    with { "email": ..., "password": ... }
    → returns { "access_token": "..." }
    """
    login_url = f"{CTRL_BASE}/ctrlagent/auth/login"
    async with aiohttp.ClientSession() as sess:
        resp = await sess.post(
            login_url,
            json={"email": email, "password": password},
            ssl=False,
        )
        resp.raise_for_status()
        payload = await resp.json()
        return payload["access_token"]


# ——————————————————————————————————————————————————————————————————
# 4) Async helper: create a new conversation, return conversation_id
# ——————————————————————————————————————————————————————————————————
async def new_conversation(token: str) -> str:
    """
    POST to {CTRL_BASE}/api/v1/conversation/new
    Header: Authorization: Bearer {token}
    Body: { "account_id": ACCOUNT_ID, "agent_id": AGENT_ID, "initial_variables": {} }
    → returns { "conversation_id": "..." }  (or a field called "id"/"conversationId")
    """
    conv_url = f"{CTRL_BASE}/api/v1/conversation/new"
    headers = {"Authorization": f"Bearer {token}"}

    async with aiohttp.ClientSession(headers=headers) as sess:
        resp = await sess.post(
            conv_url,
            json={
                "account_id": ACCOUNT_ID,
                "agent_id": AGENT_ID,
                "initial_variables": {},
            },
            ssl=False,
        )
        resp.raise_for_status()
        data = await resp.json()

        # Some responses use "conversation_id", others "conversationId" or "id"
        return data.get("conversation_id") or data.get("conversationId") or data.get("id")


# ——————————————————————————————————————————————————————————————————
# 5) Helper (sync): wrap the async functions above via asyncio.run()
# ——————————————————————————————————————————————————————————————————
def login_sync() -> str:
    return asyncio.run(login_ctrlagent(CTRL_EMAIL, CTRL_PASSWORD))


def new_conversation_sync(token: str) -> str:
    return asyncio.run(new_conversation(token))


# ——————————————————————————————————————————————————————————————————
# 6) Build the raw WebSocket URL for a given token + conv_id
#    (You could also inline this, but we keep it separate for clarity.)
# ——————————————————————————————————————————————————————————————————
def build_ws_url(token: str, conv_id: str) -> str:
    """
    According to your snippet, the pattern is:
      ws://35.171.18.197/api/v1/ws/conversation/{conv_id}?token={token}&channel=voice
    """
    return (
        f"ws://{CTRL_BASE.split('://')[-1]}/api/v1/ws/conversation/"
        f"{conv_id}?token={token}&channel=voice"
    )


# ——————————————————————————————————————————————————————————————————
# 7) Route: Initialize a WebSocket session for a given stream_id
#     - If not already present, perform:
#         a) login → token
#         b) new_conversation(token) → conv_id
#         c) build ws_url
#         d) open a persistent WebSocket and store it in‐memory
# ——————————————————————————————————————————————————————————————————
@app.route("/initialize/<stream_id>", methods=["POST"])
def initialize_stream(stream_id):
    """
    1) If stream_id already in connections, return existing ws_url (socket is still open).
    2) Otherwise, perform login → new_conversation → build ws_url → open ws → store in connections.
    """
    if stream_id in connections:
        return jsonify(
            {
                "message": "Already initialized (WebSocket is open)",
                "stream_id": stream_id,
                "ws_url": connections[stream_id]["ws_url"],
            }
        ), 200

    # 2.a) Log in (sync wrapper around your async function)
    try:
        token = login_sync()
    except Exception as e:
        return jsonify({"error": "Login to CtrlAgent failed", "details": str(e)}), 502

    # 2.b) Create new conversation
    try:
        conv_id = new_conversation_sync(token)
    except Exception as e:
        return jsonify({"error": "Creating conversation failed", "details": str(e)}), 502

    # 2.c) Build WS URL
    ws_url = build_ws_url(token, conv_id)

    # 2.d) Open a persistent WebSocket (blocking call) and store it
    try:
        ws = create_connection(ws_url, timeout=5)
    except WebSocketException as e:
        return jsonify({"error": "Failed to open WebSocket", "details": str(e)}), 502

    # 2.e) Store everything in‐memory
    connections[stream_id] = {
        "access_token": token,
        "conversation_id": conv_id,
        "ws_url": ws_url,
        "ws": ws,
    }

    return jsonify({"stream_id": stream_id, "ws_url": ws_url}), 201


# ——————————————————————————————————————————————————————————————————
# 8) Route: Retrieve the WS URL (for debugging / clients)
# ——————————————————————————————————————————————————————————————————
@app.route("/ws_url/<stream_id>", methods=["GET"])
def get_ws_url(stream_id):
    entry = connections.get(stream_id)
    if not entry:
        return jsonify({"error": "Stream ID not found / not initialized"}), 404

    return jsonify({"stream_id": stream_id, "ws_url": entry["ws_url"]}), 200


# ——————————————————————————————————————————————————————————————————
# 9) Route: Send a JSON‐payload over the already‐open WebSocket
#    - Expect a POST body: { "text": "<some‐string>" }
#    - Use the same ws for that stream_id (no new connection)
# ——————————————————————————————————————————————————————————————————
@app.route("/send/<stream_id>", methods=["POST"])
def send_over_websocket(stream_id):
    """
    1) Look up connections[stream_id]["ws"].
    2) Read the JSON body, ensure it has "text".
    3) ws.send(json.dumps({ "text": ... })) and ws.recv() for a single response.
    4) Return that response to the caller (parsed as JSON if possible).
    """
    entry = connections.get(stream_id)
    if not entry:
        return jsonify({"error": "Stream ID not found / not initialized"}), 404

    ws = entry.get("ws")
    if ws is None:
        return jsonify({"error": "WebSocket object missing for this stream_id"}), 500

    # 2) Parse request JSON
    payload = request.get_json(force=True)
    if not payload or "text" not in payload:
        return jsonify({"error": "Request JSON must include a 'text' field"}), 400

    try:
        # 3.a) Send the payload over WS
        ws.send(json.dumps({"text": payload["text"]}))

        # Set a maximum number of attempts to prevent infinite loops
        max_attempts = 50
        attempts = 0
        
        ithelpdesk_output = None
        disfluency_output = None
        
        # Keep receiving responses until we get both "ithelpdesk" and "disfluency" messages, or reach max attempts
        while attempts < max_attempts and not (ithelpdesk_output and disfluency_output):
            attempts += 1
            
            # Block until a response frame arrives
            raw_response = ws.recv()
            print(f"Raw response (attempt {attempts}/{max_attempts}): {raw_response}")
            try:
                parsed = json.loads(raw_response)
                from_agent = parsed.get("from_agent")
                output_value = parsed.get("output")
                
                print(f"Received agent type: {from_agent}, output: {output_value} (attempt {attempts}/{max_attempts})")
                
                if from_agent == "ithelpdesk":
                    ithelpdesk_output = output_value
                    print(f"Captured ithelpdesk_output: {ithelpdesk_output}")
                
                # Check for disfluency message based on message_type
                elif parsed.get("message_type") == "disfluency":
                    disfluency_output = output_value
                    print(f"Captured disfluency_output (from_agent: {from_agent}, message_type: disfluency): {disfluency_output}")
                
                # If both are now captured, we can break early, or let the while condition handle it
                if ithelpdesk_output and disfluency_output:
                    print("Both ithelpdesk and disfluency outputs captured.")
                    break
                    
            except json.JSONDecodeError:
                print(f"Received non-JSON response (attempt {attempts}/{max_attempts}): {raw_response}")
                # If we can't parse as JSON, just continue to the next response
                continue
            except WebSocketConnectionClosedException:
                print(f"WebSocket connection closed by server (attempt {attempts}/{max_attempts})")
                break # Exit the loop as no more messages can be received
                
        # After the loop, check what we've gathered
        if ithelpdesk_output and disfluency_output:
            return jsonify({
                "ithelpdesk_response": ithelpdesk_output,
                "disfluency_response": disfluency_output
            }), 200
        else:
            missing = []
            if not ithelpdesk_output:
                missing.append("ithelpdesk")
            if not disfluency_output:
                missing.append("disfluency")
            return jsonify({
                "error": f"Timeout waiting for: {', '.join(missing)} response(s)", 
                "attempts": attempts,
                "ithelpdesk_response_received": bool(ithelpdesk_output),
                "disfluency_response_received": bool(disfluency_output)
            }), 408
        
    except WebSocketException as e:
        return jsonify({"error": "WebSocket send/recv failed", "details": str(e)}), 500


# ——————————————————————————————————————————————————————————————————
# 10) Route: Disconnect (close) the WS for this stream_id and clean up
# ——————————————————————————————————————————————————————————————————
@app.route("/disconnect/<stream_id>", methods=["POST"])
def disconnect_stream(stream_id):
    """
    1) Pop connections[stream_id].
    2) Close the underlying WebSocket (if still open).
    3) (Optional) You could call CtrlAgent’s “end conversation” endpoint here,
       but for now we just close the socket and drop it.
    """
    entry = connections.pop(stream_id, None)
    if not entry:
        return jsonify({"error": "Stream ID not found / already disconnected"}), 404

    ws = entry.get("ws")
    if ws:
        try:
            ws.close()
        except Exception:
            pass  # ignore if already closed

    return jsonify({"message": f"Stream '{stream_id}' disconnected and WS closed."}), 200


# ——————————————————————————————————————————————————————————————————
# 11) Run the Flask app
# ——————————————————————————————————————————————————————————————————
if __name__ == "__main__":
    # By default, it runs on http://127.0.0.1:5050
    app.run(host="0.0.0.0", port=5000, debug=True)
