#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""AI4EVER Visualization Backend

This script generates plots (scatter, histogram, manhattan) from a CSV table.
It is intentionally separate from ml_tabular.py so training/prediction stays stable.

SwiftUI should call this script and then render the saved PNG.

Outputs one-line JSON events to stdout so the Swift progress parser can pick them up.
"""

from __future__ import annotations

import argparse
import json
import os
from dataclasses import dataclass
from typing import Optional

import numpy as np
import pandas as pd

import matplotlib
matplotlib.use("Agg")  # headless
import matplotlib.pyplot as plt


def _emit(event: str, **kwargs):
    print(json.dumps({"event": event, **kwargs}), flush=True)


def _ensure_outdir(out_dir: str) -> str:
    os.makedirs(out_dir, exist_ok=True)
    return out_dir


def _read_csv(path: str) -> pd.DataFrame:
    # Robust-ish defaults for your typical GWAS/tabular exports
    return pd.read_csv(path, sep=None, engine="python")


def _to_numeric(series: pd.Series) -> pd.Series:
    return pd.to_numeric(series, errors="coerce")


def plot_scatter(
        df: pd.DataFrame,
        x_col: str,
        y_col: str,
        out_path: str,
        title: Optional[str] = None
):
    x = _to_numeric(df[x_col])
    y = _to_numeric(df[y_col])
    m = x.notna() & y.notna()
    x = x[m].to_numpy()
    y = y[m].to_numpy()

    plt.figure(figsize=(7.5, 5.0), dpi=160)
    plt.scatter(x, y, s=14, alpha=0.8)
    plt.xlabel(x_col)
    plt.ylabel(y_col)
    if title:
        plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


def plot_hist(
        df: pd.DataFrame,
        col: str,
        out_path: str,
        bins: int = 30,
        title: Optional[str] = None
):
    x = _to_numeric(df[col])
    x = x[x.notna()].to_numpy()

    plt.figure(figsize=(7.5, 5.0), dpi=160)
    plt.hist(x, bins=bins)
    plt.xlabel(col)
    plt.ylabel("Count")
    if title:
        plt.title(title)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


@dataclass
class ManhattanColumns:
    chr_col: str
    pos_col: str
    p_col: str


def plot_manhattan(
        df: pd.DataFrame,
        cols: ManhattanColumns,
        out_path: str,
        title: Optional[str] = None,
        sig_line: Optional[float] = None,
        transform: str = "raw",
        by: str = "chr",
        trait_col: Optional[str] = None
):
    """Manhattan plot.

    transform:
      - raw      : y = p
      - neg      : y = -p
      - inv      : y = 1/p
      - neglog10 : y = -log10(p)  (common Manhattan)

    by:
      - chr      : cumulative position by chromosome (supported)
      - trait    : reserved for future multi-trait overlays (NOT implemented yet)
    """

    if by == "trait" and not trait_col:
        raise SystemExit("manhattan --by trait requires --trait-col")

    # coerce
    chr_raw = df[cols.chr_col].astype(str)
    pos = _to_numeric(df[cols.pos_col])
    p = _to_numeric(df[cols.p_col])

    m = pos.notna() & p.notna() & chr_raw.notna()
    chr_raw = chr_raw[m].to_numpy()
    pos = pos[m].to_numpy()
    p = p[m].to_numpy()

    trait_raw = None
    if by == "trait":
        trait_raw = df[trait_col].astype(str)
        trait_raw = trait_raw[m].to_numpy()

    # normalize chromosome labels to sortable keys
    def chr_key(c: str):
        c2 = c.strip().lower().replace("chr", "")
        if c2.isdigit():
            return (0, int(c2))
        return (1, c2)

    order = np.array(sorted(range(len(chr_raw)), key=lambda i: chr_key(chr_raw[i])))
    chr_raw = chr_raw[order]
    pos = pos[order]
    p = p[order]

    if trait_raw is not None:
        trait_raw = trait_raw[order]

    # Unique chromosomes in order (after sorting)
    uniq: list[str] = []
    seen = set()
    for c in chr_raw:
        if c not in seen:
            uniq.append(c)
            seen.add(c)

    # Build cumulative x offsets per chromosome
    offsets: dict[str, float] = {}
    current = 0.0
    tick_positions: list[float] = []
    tick_labels: list[str] = []

    for c in uniq:
        mask = (chr_raw == c)
        max_pos = float(np.nanmax(pos[mask])) if np.any(mask) else 0.0
        offsets[c] = current
        tick_positions.append(current + max_pos / 2.0)
        tick_labels.append(c)
        current += max_pos + 1.0  # +1 between chromosomes

    x = np.array([offsets[c] for c in chr_raw]) + pos

    # Transform y
    p_safe = np.clip(p, 1e-300, np.inf)
    if transform == "raw":
        y = p
        y_label = cols.p_col
        sig_y = sig_line
    elif transform == "neg":
        y = -p
        y_label = f"-({cols.p_col})"
        sig_y = (-sig_line) if sig_line is not None else None
    elif transform == "inv":
        y = 1.0 / p_safe
        y_label = f"1/({cols.p_col})"
        sig_y = (1.0 / sig_line) if (sig_line is not None and sig_line > 0) else None
    elif transform == "neglog10":
        y = -np.log10(np.clip(p, 1e-300, 1.0))
        y_label = f"-log10({cols.p_col})"
        sig_y = (-np.log10(sig_line)) if (sig_line is not None and sig_line > 0) else None
    else:
        raise SystemExit(f"Unknown transform: {transform}")

    plt.figure(figsize=(11.5, 4.8), dpi=160)

    # Plot points
    if by == "trait":
        # Traits get different marker shapes; chromosomes keep different colors.
        if trait_raw is None:
            raise SystemExit("Internal error: trait_raw is None while by==trait")

        markers = ["s", "^", "o", "D", "P", "X", "v", "<", ">"]

        # Preserve trait appearance order
        uniq_traits: list[str] = []
        seen_t = set()
        for t in trait_raw:
            if t not in seen_t:
                uniq_traits.append(t)
                seen_t.add(t)

        # Stable colors per chromosome from matplotlib default cycle
        cycle = plt.rcParams.get("axes.prop_cycle", None)
        cycle_colors = cycle.by_key().get("color", None) if cycle is not None else None
        if not cycle_colors:
            cycle_colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9"]
        chr_color = {c: cycle_colors[i % len(cycle_colors)] for i, c in enumerate(uniq)}

        for i, t in enumerate(uniq_traits):
            m_t = (trait_raw == t)
            marker = markers[i % len(markers)]
            # Plot trait points chromosome-by-chromosome to keep chromosome colors
            for c in uniq:
                m_tc = m_t & (chr_raw == c)
                if np.any(m_tc):
                    plt.scatter(
                        x[m_tc], y[m_tc],
                        s=10, alpha=0.85,
                        marker=marker,
                        c=chr_color[c],
                    )

        # Legend showing trait -> marker mapping (keep colors out of legend for readability)
        from matplotlib.lines import Line2D
        handles = [
            Line2D([0], [0], marker=markers[i % len(markers)], linestyle="None",
                   color="black", markersize=6, label=t)
            for i, t in enumerate(uniq_traits)
        ]
        plt.legend(handles=handles, frameon=False, fontsize=8, title="Trait", title_fontsize=8)

    else:
        # Chromosome blocks have different colors (matplotlib default cycle)
        for c in uniq:
            mask = (chr_raw == c)
            plt.scatter(x[mask], y[mask], s=10, alpha=0.85)

    # Significance line (transformed already into sig_y above)
    if sig_y is not None and np.isfinite(sig_y):
        plt.axhline(sig_y, linestyle="--", linewidth=1.2)

    plt.xticks(tick_positions, tick_labels, rotation=0, fontsize=8)
    plt.xlabel("Genomic position (cumulative)")
    plt.ylabel(y_label)
    if title:
        plt.title(title)

    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--table-file", required=True, help="CSV file to plot")
    ap.add_argument("--plot", required=True, choices=["scatter", "hist", "manhattan"], help="plot type")
    ap.add_argument("--out-dir", required=True)
    ap.add_argument("--out-name", default=None, help="optional fixed filename")

    # scatter
    ap.add_argument("--x-col", default=None)
    ap.add_argument("--y-col", default=None)

    # hist
    ap.add_argument("--col", default=None)
    ap.add_argument("--bins", type=int, default=30)

    # manhattan
    ap.add_argument("--chr-col", default=None)
    ap.add_argument("--pos-col", default=None)
    ap.add_argument("--p-col", default=None)
    ap.add_argument("--sig", type=float, default=None)

    ap.add_argument(
        "--transform",
        default="raw",
        choices=["raw", "neg", "inv", "neglog10"],
        help=(
            "Transform applied to p-values for Manhattan y-axis: "
            "raw=p, neg=-p, inv=1/p, neglog10=-log10(p)."
        )
    )
    ap.add_argument(
        "--by",
        default="chr",
        choices=["chr", "trait"],
        help=(
            "Manhattan x-axis grouping. Currently supports 'chr' (cumulative by chromosome). "
            "'trait' reserved for future multi-trait plots."
        )
    )

    ap.add_argument(
    "--trait-col",
    default=None,
    help="Trait column name (required when --by trait)"
    )


    ap.add_argument("--title", default=None)

    args = ap.parse_args()
    by = args.by
    trait_col = args.trait_col


    out_dir = _ensure_outdir(args.out_dir)
    if args.out_name:
        out_path = os.path.join(out_dir, args.out_name)
    else:
        base = os.path.splitext(os.path.basename(args.table_file))[0]
        out_path = os.path.join(out_dir, f"{base}_{args.plot}.png")

    _emit("progress", pct=0.05, phase="Loading table")
    df = _read_csv(args.table_file)
    _emit("progress", pct=0.25, phase="Preparing plot")

    if args.plot == "scatter":
        if not args.x_col or not args.y_col:
            raise SystemExit("scatter requires --x-col and --y-col")
        plot_scatter(df, args.x_col, args.y_col, out_path, title=args.title)

    elif args.plot == "hist":
        if not args.col:
            raise SystemExit("hist requires --col")
        plot_hist(df, args.col, out_path, bins=args.bins, title=args.title)

    elif args.plot == "manhattan":
        if not (args.chr_col and args.pos_col and args.p_col):
            raise SystemExit("manhattan requires --chr-col --pos-col --p-col")
        plot_manhattan(
            df,
            ManhattanColumns(args.chr_col, args.pos_col, args.p_col),
            out_path,
            title=args.title,
            sig_line=args.sig,
            transform=args.transform,
            by=args.by,
            trait_col=args.trait_col
        )

    _emit("progress", pct=1.0, phase="Done")
    _emit("plot", type=args.plot, path=os.path.abspath(out_path))


if __name__ == "__main__":
    main()
