#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI4EVER Tabular Genomic Selection Module
-----------------------------------------
Unified script for training and evaluating ML models on genotype-phenotype data.

Supports:
    • ridge, rf, gbdt, mlp (scikit-learn)
    • keras (deep neural network)

Features:
    • Auto-detects regression vs classification based on the target column
    • Cross-validation or simple hold-out validation
    • Saves:
        - Trained model (.joblib or .keras)
        - summary.json
        - feature_importance.csv        (when available)
        - pcs.csv                       (PCs of genotype matrix)
        - viz_train.csv                 (ID + phenotype + predicted + PCs)
        - predictions.csv               (predictions for new genotypes)
        - viz_predict.csv               (ID + predicted + PCs for prediction set)
"""

import argparse
import json
import sys
import os
from pathlib import Path
from datetime import datetime

import joblib
import numpy as np
import pandas as pd

from sklearn.model_selection import KFold, train_test_split
from sklearn.metrics import r2_score, mean_squared_error, accuracy_score
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.ensemble import (
    RandomForestRegressor,
    RandomForestClassifier,
    GradientBoostingRegressor,
    GradientBoostingClassifier,
)
from sklearn.neural_network import MLPRegressor, MLPClassifier

# Keras / TensorFlow (only imported if used)
try:
    from tensorflow import keras
except Exception:
    keras = None


# ----------------------------------------------------------------------
# Logging helper: print JSON lines to stdout so the UI can parse them
# ----------------------------------------------------------------------
def log(event: str, **kwargs):
    rec = {"event": event}
    rec.update(kwargs)
    print(json.dumps(rec), flush=True)


# ----------------------------------------------------------------------
# Model factory
# ----------------------------------------------------------------------
def build_estimator(task: str, model_name: str, args):
    """Return an un-fitted estimator according to task & model type."""
    if model_name == "ridge":
        if task == "reg":
            # alpha in Ridge = L2 strength
            return Ridge(alpha=float(args.alpha), random_state=42)
        else:
            # Use logistic regression but keep "ridge" name at UI level
            C = 1.0 / max(float(args.alpha), 1e-6)
            return LogisticRegression(C=C, max_iter=200, n_jobs=-1)

    elif model_name == "rf":
        if task == "reg":
            return RandomForestRegressor(
                n_estimators=int(args.n_estimators),
                n_jobs=-1,
                random_state=42,
            )
        else:
            return RandomForestClassifier(
                n_estimators=int(args.n_estimators),
                n_jobs=-1,
                random_state=42,
            )

    elif model_name == "gbdt":
        if task == "reg":
            return GradientBoostingRegressor(
                n_estimators=int(args.n_estimators),
                random_state=42,
            )
        else:
            return GradientBoostingClassifier(
                n_estimators=int(args.n_estimators),
                random_state=42,
            )

    elif model_name == "mlp":
        if task == "reg":
            return MLPRegressor(
                hidden_layer_sizes=(int(args.hidden),),
                activation="relu",
                max_iter=int(args.epochs),
                random_state=42,
            )
        else:
            return MLPClassifier(
                hidden_layer_sizes=(int(args.hidden),),
                activation="relu",
                max_iter=int(args.epochs),
                random_state=42,
            )

    elif model_name == "keras":
        if keras is None:
            raise RuntimeError("Keras/TensorFlow is not available in this environment.")

        def make_keras_model(input_dim: int, n_classes: int = 1):
            model = keras.Sequential()
            model.add(keras.layers.Input(shape=(input_dim,)))
            model.add(keras.layers.Dense(int(args.hidden), activation="relu"))
            if args.dropout > 0:
                model.add(keras.layers.Dropout(float(args.dropout)))
            model.add(keras.layers.Dense(int(args.hidden), activation="relu"))

            if task == "reg":
                model.add(keras.layers.Dense(1, activation="linear"))
                model.compile(optimizer="adam", loss="mse")
            else:
                if n_classes <= 2:
                    model.add(keras.layers.Dense(1, activation="sigmoid"))
                    model.compile(
                        optimizer="adam",
                        loss="binary_crossentropy",
                        metrics=["accuracy"],
                    )
                else:
                    model.add(keras.layers.Dense(n_classes, activation="softmax"))
                    model.compile(
                        optimizer="adam",
                        loss="sparse_categorical_crossentropy",
                        metrics=["accuracy"],
                    )
            return model

        return make_keras_model

    else:
        raise ValueError(f"Unknown model: {model_name}")


# ----------------------------------------------------------------------
# Metrics
# ----------------------------------------------------------------------
def regression_metrics(y_true, y_pred):
    y_true = np.asarray(y_true, dtype=float)
    y_pred = np.asarray(y_pred, dtype=float)
    # correlation
    if np.std(y_true) == 0 or np.std(y_pred) == 0:
        corr = float("nan")
    else:
        corr = float(np.corrcoef(y_true, y_pred)[0, 1])
    r2 = float(r2_score(y_true, y_pred))
    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    return corr, r2, rmse


def classification_metrics(y_true, y_pred):
    acc = float(accuracy_score(y_true, y_pred))
    return acc


# ----------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------
def main():
    parser = argparse.ArgumentParser(description="AI4EVER Tabular Genomic Selection")

    parser.add_argument("--genotype", required=True, help="Genotype table (samples x markers)")
    parser.add_argument("--phenotype", required=True, help="Phenotype table with target column")
    parser.add_argument("--id-col", required=True, help="ID column present in both genotype and phenotype")
    parser.add_argument("--target", required=True, help="Phenotype column to predict")

    parser.add_argument("--model", choices=["ridge", "rf", "gbdt", "mlp", "keras"], default="ridge")
    parser.add_argument("--standardize", action="store_true", help="Standardize features")
    parser.add_argument("--cv", type=int, default=5, help="Number of CV folds (<=1 for no CV)")

    parser.add_argument(
        "--validation-mode",
        choices=["instant", "hold", "corrected"],
        default="hold",
        help="Validation strategy label (UI only; behavior is controlled by --cv)",
    )

    # Hyperparameters
    parser.add_argument("--alpha", type=float, default=1.0, help="Ridge / logistic regularization strength")
    parser.add_argument("--n-estimators", type=int, default=300, help="Number of trees (RF / GBDT)")
    parser.add_argument("--epochs", type=int, default=100, help="Epochs for MLP / Keras")
    parser.add_argument("--batch-size", type=int, default=32, help="Batch size for Keras")
    parser.add_argument("--dropout", type=float, default=0.3, help="Dropout for Keras hidden layers")
    parser.add_argument("--hidden", type=int, default=128, help="Hidden units for MLP / Keras")

    parser.add_argument(
        "--predict-file",
        type=str,
        default=None,
        help="Optional genotype file for prediction (same markers)",
    )

    parser.add_argument("--out-dir", type=str, default="./runs", help="Output directory")

    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    log("status", msg="starting_tabular_gs", model=args.model, target=args.target)

    # Globals used later for viz / prediction
    pcs_df = None
    pca_model = None
    class_mapping = None
    feature_cols = None

    # ------------------------------------------------------------------
    # Load & merge data
    # ------------------------------------------------------------------
    log("status", msg="loading_data")
    geno = pd.read_csv(args.genotype, sep=None, engine="python")
    pheno = pd.read_csv(args.phenotype, sep=None, engine="python")

    geno.columns = geno.columns.str.strip()
    pheno.columns = pheno.columns.str.strip()

    if args.id_col not in geno.columns or args.id_col not in pheno.columns:
        log("error", msg=f"id_col '{args.id_col}' not found in both files")
        sys.exit(1)
    if args.target not in pheno.columns:
        log("error", msg=f"target '{args.target}' not found in phenotype file")
        sys.exit(1)

    data = (
        pd.merge(
            pheno[[args.id_col, args.target]],
            geno,
            on=args.id_col,
            how="inner",
        ).dropna()
    )

    if data.empty:
        log("error", msg="no overlapping samples after merge")
        sys.exit(1)

    log("status", msg=f"merged_data shape {data.shape}")

    y_raw_col = data[args.target]
    X = data.drop(columns=[args.id_col, args.target])
    feature_cols = X.columns

    # ------------------------------------------------------------------
    # Task detection
    # ------------------------------------------------------------------
    numeric_y = pd.to_numeric(y_raw_col, errors="coerce")
    non_null_ratio = numeric_y.notnull().mean()
    unique_count = len(y_raw_col.unique())
    task = "reg" if (non_null_ratio > 0.95 and unique_count > 10) else "cls"
    log("task", kind=task, unique_count=unique_count, non_null_ratio=float(non_null_ratio))

    # ------------------------------------------------------------------
    # Preprocess X, handle y and class mapping
    # ------------------------------------------------------------------
    X = X.apply(pd.to_numeric, errors="coerce").fillna(0)

    if task == "cls":
        y_cat = y_raw_col.astype("category")
        class_mapping = {int(code): str(cat) for code, cat in enumerate(y_cat.cat.categories)}
        y_raw = y_cat.astype(str)
        y = y_cat.cat.codes.astype("int32")
    else:
        y_raw = numeric_y
        y = numeric_y.fillna(0).astype("float32")

    # ------------------------------------------------------------------
    # PCA for genotype (for visualization)
    # ------------------------------------------------------------------
    try:
        n_components = min(10, X.shape[1])
        if n_components > 0:
            pca_model = PCA(n_components=n_components, random_state=42)
            pcs = pca_model.fit_transform(X.values.astype("float32"))
            pc_cols = [f"PC{i+1}" for i in range(pcs.shape[1])]
            pcs_df = pd.DataFrame(pcs, columns=pc_cols)
            pcs_df.insert(0, args.id_col, data[args.id_col].values)
            log("status", msg=f"computed_pcs n_components={pcs.shape[1]}")
    except Exception as e:
        pcs_df = None
        pca_model = None
        log("warn", msg=f"pcs_failed: {e}")

    # ------------------------------------------------------------------
    # Cross-validation / hold-out evaluation
    # ------------------------------------------------------------------
    mean_corr = mean_r2 = mean_rmse = None
    mean_acc = None

    def run_single_fit(X_train, X_test, y_train, y_test):
        # Standardization per split if requested (and always for MLP/Keras)
        if args.standardize or args.model in {"mlp", "keras"}:
            scaler = StandardScaler()
            X_train_s = scaler.fit_transform(X_train)
            X_test_s = scaler.transform(X_test)
        else:
            scaler = None
            X_train_s = X_train
            X_test_s = X_test

        est = build_estimator(task, args.model, args)
        if args.model == "keras":
            # est is a factory
            model = est(
                input_dim=X_train_s.shape[1],
                n_classes=len(class_mapping) if (task == "cls" and class_mapping is not None) else 1,
            )
            callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
            model.fit(
                X_train_s,
                y_train,
                epochs=int(args.epochs),
                batch_size=int(args.batch_size),
                verbose=0,
                callbacks=callbacks,
                validation_split=0.2,
            )
            if task == "reg":
                y_pred = model.predict(X_test_s, verbose=0).ravel()
                corr, r2, rmse = regression_metrics(y_test, y_pred)
                return model, scaler, corr, r2, rmse, None
            else:
                if len(class_mapping or []) <= 2:
                    probs = model.predict(X_test_s, verbose=0).ravel()
                    y_pred = (probs >= 0.5).astype("int32")
                else:
                    probs = model.predict(X_test_s, verbose=0)
                    y_pred = probs.argmax(axis=1)
                acc = classification_metrics(y_test, y_pred)
                return model, scaler, None, None, None, acc
        else:
            est.fit(X_train_s, y_train)
            y_pred = est.predict(X_test_s)
            if task == "reg":
                corr, r2, rmse = regression_metrics(y_test, y_pred)
                return est, scaler, corr, r2, rmse, None
            else:
                acc = classification_metrics(y_test, y_pred)
                return est, scaler, None, None, None, acc

    # We will keep the last fitted estimator & scaler to retrain on full data later
    last_estimator = None
    last_scaler = None

    if args.cv and args.cv > 1:
        # K-fold cross-validation
        log("status", msg=f"starting_{args.cv}fold_cv", mode=args.validation_mode)
        kf = KFold(n_splits=args.cv, shuffle=True, random_state=42)

        cv_corrs = []
        cv_r2s = []
        cv_rmses = []
        cv_accs = []

        for fold, (tr_idx, te_idx) in enumerate(kf.split(X), start=1):
            X_tr, X_te = X.iloc[tr_idx].values, X.iloc[te_idx].values
            y_tr, y_te = y.iloc[tr_idx].values, y.iloc[te_idx].values

            est, scaler, corr, r2, rmse, acc = run_single_fit(X_tr, X_te, y_tr, y_te)
            last_estimator, last_scaler = est, scaler

            if task == "reg":
                cv_corrs.append(corr)
                cv_r2s.append(r2)
                cv_rmses.append(rmse)
                log("metric", fold=fold, corr=corr, r2=r2, rmse=rmse)
            else:
                cv_accs.append(acc)
                log("metric", fold=fold, acc=acc)

        if task == "reg" and cv_corrs:
            mean_corr = float(np.nanmean(cv_corrs))
            mean_r2 = float(np.nanmean(cv_r2s))
            mean_rmse = float(np.nanmean(cv_rmses))
            log("metric_summary", cv_corr=mean_corr, cv_r2=mean_r2, cv_rmse=mean_rmse)
        elif task == "cls" and cv_accs:
            mean_acc = float(np.nanmean(cv_accs))
            log("metric_summary", cv_acc=mean_acc)

    else:
        # Simple hold-out validation
        log("status", msg="starting_holdout_validation")
        X_tr, X_te, y_tr, y_te = train_test_split(
            X.values,
            y.values,
            test_size=0.2,
            random_state=42,
            stratify=y.values if task == "cls" else None,
        )
        est, scaler, corr, r2, rmse, acc = run_single_fit(X_tr, X_te, y_tr, y_te)
        last_estimator, last_scaler = est, scaler

        if task == "reg":
            mean_corr, mean_r2, mean_rmse = corr, r2, rmse
            log("metric_summary", corr=corr, r2=r2, rmse=rmse)
        else:
            mean_acc = acc
            log("metric_summary", acc=acc)

    # ------------------------------------------------------------------
    # Retrain on full dataset
    # ------------------------------------------------------------------
    log("status", msg=f"training_full_{args.model}")

    # Full-data scaling
    if args.standardize or args.model in {"mlp", "keras"}:
        full_scaler = StandardScaler()
        X_full = full_scaler.fit_transform(X.values)
    else:
        full_scaler = None
        X_full = X.values

    if args.model == "keras":
        est_factory = build_estimator(task, args.model, args)
        final_model = est_factory(
            input_dim=X_full.shape[1],
            n_classes=len(class_mapping) if (task == "cls" and class_mapping is not None) else 1,
        )
        callbacks = [keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)]
        final_model.fit(
            X_full,
            y.values,
            epochs=int(args.epochs),
            batch_size=int(args.batch_size),
            verbose=0,
            callbacks=callbacks,
            validation_split=0.1,
        )
    else:
        final_model = build_estimator(task, args.model, args)
        final_model.fit(X_full, y.values)

    # ------------------------------------------------------------------
    # Save model bundle in a run-specific folder
    # ------------------------------------------------------------------
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")
    model_dir = out_dir / f"{args.model}_{args.target}_{timestamp}"
    model_dir.mkdir(parents=True, exist_ok=True)

    if args.model == "keras":
        # Save keras model + scaler as separate artifacts
        model_path = model_dir / f"model_{args.target}_{task}.keras"
        final_model.save(model_path)
        log("artifact", path=str(model_path), kind="model")
        if full_scaler is not None:
            scaler_path = model_dir / "scaler.joblib"
            joblib.dump(full_scaler, scaler_path)
            log("artifact", path=str(scaler_path), kind="scaler")
    else:
        bundle = {
            "model": final_model,
            "scaler": full_scaler,
            "task": task,
            "id_col": args.id_col,
            "target": args.target,
            "class_mapping": class_mapping,
            "feature_cols": list(feature_cols),
        }
        model_path = model_dir / "model.joblib"
        joblib.dump(bundle, model_path)
        log("artifact", path=str(model_path), kind="model")

    log("done", msg=f"model_trained_on_all_{len(y)}samples")

    # ------------------------------------------------------------------
    # Save PCs
    # ------------------------------------------------------------------
    if pcs_df is not None:
        pcs_path = model_dir / "pcs.csv"
        pcs_df.to_csv(pcs_path, index=False)
        log("artifact", path=str(pcs_path), kind="pcs")

    # ------------------------------------------------------------------
    # Feature importance (where supported)
    # ------------------------------------------------------------------
    try:
        fi_df = None
        if hasattr(final_model, "feature_importances_"):
            importances = np.asarray(final_model.feature_importances_)
            fi_df = pd.DataFrame({"marker": feature_cols, "importance": importances})
        elif hasattr(final_model, "coef_"):
            coefs = np.asarray(final_model.coef_)
            if coefs.ndim == 2:
                coefs = coefs.mean(axis=0)
            fi_df = pd.DataFrame({"marker": feature_cols, "importance": coefs})

        if fi_df is not None:
            fi_path = model_dir / "feature_importance.csv"
            fi_df.to_csv(fi_path, index=False)
            log("artifact", path=str(fi_path), kind="feature_importance")
    except Exception as e:
        log("warn", msg=f"feature_importance_failed: {e}")

    # ------------------------------------------------------------------
    # Training visualization table
    # ------------------------------------------------------------------
    try:
        # Full-data predictions
        if args.model == "keras":
            y_pred_full = final_model.predict(X_full, verbose=0)
            y_pred_full = np.asarray(y_pred_full).ravel()
        else:
            if full_scaler is not None:
                X_for_pred = full_scaler.transform(X.values)
            else:
                X_for_pred = X.values
            y_pred_full = np.asarray(final_model.predict(X_for_pred)).ravel()

        viz_df = pd.DataFrame({args.id_col: data[args.id_col].values})
        viz_df["phenotype"] = np.asarray(y_raw)
        if task == "cls":
            viz_df["phenotype_code"] = np.asarray(y.values)
        viz_df["predicted"] = y_pred_full

        if pcs_df is not None:
            viz_df = viz_df.merge(pcs_df, on=args.id_col, how="left")

        viz_path = model_dir / "viz_train.csv"
        viz_df.to_csv(viz_path, index=False)
        log("artifact", path=str(viz_path), kind="viz_train")
    except Exception as e:
        log("warn", msg=f"viz_train_failed: {e}")

    # ------------------------------------------------------------------
    # Prediction mode: new genotype file
    # ------------------------------------------------------------------
    if args.predict_file:
        log("status", msg="predict_mode", file=args.predict_file)
        Gnew = pd.read_csv(args.predict_file, sep=None, engine="python")
        Gnew_cols = Gnew.columns

        # Try to align features with training markers
        if feature_cols is not None:
            missing = [c for c in feature_cols if c not in Gnew_cols]
            if missing:
                log("warn", msg=f"{len(missing)} markers missing in predict file; they will be filled with 0")
            for c in feature_cols:
                if c not in Gnew_cols:
                    Gnew[c] = 0
            Gnew = (
                Gnew[[args.id_col] + list(feature_cols)]
                if args.id_col in Gnew.columns
                else Gnew[list(feature_cols)]
            )
        else:
            if args.id_col in Gnew.columns:
                Gnew = Gnew.drop(columns=[args.id_col])

        Gnew_num = Gnew.drop(columns=[args.id_col], errors="ignore")
        Gnew_num = Gnew_num.apply(pd.to_numeric, errors="coerce").fillna(0)

        # Scale using full_scaler if present
        if full_scaler is not None:
            X_new = full_scaler.transform(Gnew_num.values)
        else:
            X_new = Gnew_num.values

        if args.model == "keras":
            preds = final_model.predict(X_new, verbose=0)
            preds = np.asarray(preds).ravel()
        else:
            preds = np.asarray(final_model.predict(X_new)).ravel()

        # Basic predictions file
        pred_df = pd.DataFrame({"pred_" + args.target: preds})
        pred_path = model_dir / "predictions.csv"
        pred_df.to_csv(pred_path, index=False)
        log("artifact", path=str(pred_path), kind="predictions")

        # Visualization table for prediction
        try:
            if args.id_col in Gnew.columns:
                viz_pred_df = pd.DataFrame({args.id_col: Gnew[args.id_col].values})
            else:
                viz_pred_df = pd.DataFrame(index=np.arange(len(preds)))

            viz_pred_df["predicted"] = preds

            if pca_model is not None and feature_cols is not None:
                feat_new = (
                    Gnew_num[feature_cols]
                    if all(c in Gnew_num.columns for c in feature_cols)
                    else Gnew_num
                )
                pcs_new = pca_model.transform(feat_new.values.astype("float32"))
                pc_cols = [f"PC{i+1}" for i in range(pcs_new.shape[1])]
                for i, col in enumerate(pc_cols):
                    viz_pred_df[col] = pcs_new[:, i]

            viz_pred_path = model_dir / "viz_predict.csv"
            viz_pred_df.to_csv(viz_pred_path, index=False)
            log("artifact", path=str(viz_pred_path), kind="viz_predict")
        except Exception as e:
            log("warn", msg=f"viz_predict_failed: {e}")

    # ------------------------------------------------------------------
    # Summary JSON
    # ------------------------------------------------------------------
    summary = {
        "task": task,
        "model": args.model,
        "target": args.target,
        "cv_folds": int(args.cv),
        "validation_mode": args.validation_mode,
        "cv_corr": float(mean_corr) if mean_corr is not None else None,
        "cv_r2": float(mean_r2) if mean_r2 is not None else None,
        "cv_rmse": float(mean_rmse) if mean_rmse is not None else None,
        "cv_acc": float(mean_acc) if mean_acc is not None else None,
        "run_dir": str(model_dir),
        "class_mapping": class_mapping,
    }

    summary_path = model_dir / "summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=4)
    log("artifact", path=str(summary_path), kind="summary")
    log("done", msg="tabular_pipeline_complete")


if __name__ == "__main__":
    main()
