#!/usr/bin/env python3
# Joe wrote this :P

import argparse
import json
import os
import re
import subprocess
import sys
import urllib.error
import urllib.request

try:
    from rich.console import Console
    from rich.markdown import Markdown

    RICH = True
except ImportError:
    RICH = False

FAST_MODEL = "gemma3:4b"
THINK_MODEL = "qwen3:4b"

ANSI_ESCAPE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")

SYSTEM_PROMPT = """
You are Computer.

Rules:
- Respond in under 100 words unless asked for detail.
- Prefer direct answers.
- Use technical language.
- Do not apologize.
- For Linux questions assume Arch Linux.
- For electronics questions assume the user is an engineer.
- If providing a command, put the command first.
- Do not explain your reasoning unless explicitly asked.
- Do not show chain-of-thought.
- Give final answers only.
""".strip()


def ollama_base_url():
    host = os.environ.get("OLLAMA_HOST", "127.0.0.1:11434").strip()
    if host.startswith(("http://", "https://")):
        return host.rstrip("/")
    return f"http://{host}".rstrip("/")


def strip_ansi(text):
    return ANSI_ESCAPE.sub("", text)


def copy_to_clipboard(text):
    commands = [
        ["wl-copy"],
        ["xclip", "-selection", "clipboard"],
        ["xsel", "--clipboard", "--input"],
    ]

    for cmd in commands:
        try:
            subprocess.run(
                cmd,
                input=text,
                text=True,
                stdout=subprocess.DEVNULL,
                stderr=subprocess.DEVNULL,
                check=True,
            )
            return True
        except Exception:
            pass

    return False


def run_ollama_api(model, prompt):
    url = f"{ollama_base_url()}/api/generate"
    payload = json.dumps(
        {
            "model": model,
            "prompt": prompt,
            "stream": False,
        }
    ).encode()

    req = urllib.request.Request(
        url,
        data=payload,
        headers={"Content-Type": "application/json"},
        method="POST",
    )

    with urllib.request.urlopen(req, timeout=600) as resp:
        data = json.load(resp)

    return strip_ansi(data.get("response", "")).strip()


def run_ollama_cli(model, prompt):
    result = subprocess.run(
        ["ollama", "run", model, prompt],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        err = (result.stderr or result.stdout or "ollama run failed").strip()
        raise RuntimeError(err)
    return strip_ansi(result.stdout).strip()


def run_ollama(model, prompt):
    try:
        return run_ollama_api(model, prompt)
    except (urllib.error.URLError, TimeoutError, json.JSONDecodeError) as exc:
        print(f"computer: API unavailable ({exc}), using CLI", file=sys.stderr)
        return run_ollama_cli(model, prompt)


def print_response(response, raw):
    if RICH and not raw:
        console = Console()
        console.print(Markdown(response))
    else:
        print(response)


def main():
    parser = argparse.ArgumentParser(
        prog="computer",
        description="Tiny local terminal AI",
    )

    parser.add_argument(
        "--think",
        action="store_true",
        help="Use the larger reasoning model",
    )

    parser.add_argument(
        "--copy",
        action="store_true",
        help="Copy response to clipboard",
    )

    parser.add_argument(
        "--raw",
        action="store_true",
        help="Disable rich formatting",
    )

    parser.add_argument(
        "prompt",
        nargs="*",
        help="Prompt text",
    )

    args = parser.parse_args()

    model = THINK_MODEL if args.think else FAST_MODEL

    user_prompt = " ".join(args.prompt)

    stdin_text = ""
    if not sys.stdin.isatty():
        stdin_text = sys.stdin.read()

    prompt = f"{SYSTEM_PROMPT}\n\nUser:\n{user_prompt}"

    if stdin_text.strip():
        prompt += f"\n\nInput:\n{stdin_text}"

    try:
        response = run_ollama(model, prompt)
    except RuntimeError as exc:
        print(f"computer: {exc}", file=sys.stderr)
        sys.exit(1)

    if args.copy:
        copy_to_clipboard(response)

    print_response(response, args.raw)


if __name__ == "__main__":
    main()
