#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
AI4EVER Tabular Genomic Selection Module
-----------------------------------------
Unified script for training and evaluating ML models on genotype-phenotype data.

Supports:
    • ridge, rf, gbdt, mlp (scikit-learn)
    • keras (deep neural network)
Auto-detects regression vs classification.
Exports: model.joblib / .keras, summary.json, feature_importance.csv, predictions.csv
"""

import argparse, json, sys, os, joblib
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold, StratifiedKFold, cross_validate
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import Ridge
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.neural_network import MLPRegressor, MLPClassifier
from sklearn.metrics import r2_score, mean_squared_error, accuracy_score, f1_score
from sklearn.inspection import permutation_importance
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout

def log(event, **payload):
    """Stream structured logs for SwiftUI or CLI."""
    print(json.dumps({"event": event, **payload}), flush=True)

def is_numeric_series(s: pd.Series) -> bool:
    return pd.api.types.is_numeric_dtype(s)

# ---------- Argument parsing ----------
parser = argparse.ArgumentParser(description="Train genomic selection model on tabular genotype data.")
parser.add_argument("--genotype", required=True, help="Path to genotype file (CSV/TXT)")
parser.add_argument("--phenotype", required=True, help="Path to phenotype file (CSV/TXT)")
parser.add_argument("--id-col", required=True, help="Column name for sample IDs")
parser.add_argument("--target", required=True, help="Trait column name to predict")
parser.add_argument("--model", choices=["ridge","rf","gbdt","mlp","keras"], default="ridge", help="Model type")
parser.add_argument("--standardize", action="store_true", help="Apply feature scaling")
parser.add_argument("--cv", type=int, default=0, help="Number of CV folds (0 = skip)")
parser.add_argument(
    "--validation-mode",
    choices=["instant", "hold", "corrected"],
    default="hold",
    help="Validation accuracy calculation method: instant = mean of per-fold correlations, hold = overall correlation, corrected = Olkin & Pratt correction of instant."
)
parser.add_argument("--epochs", type=int, default=100, help="Epochs for neural network models")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for neural network models")
parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate for neural network models")
parser.add_argument("--alpha", type=float, default=1.0, help="Regularization strength for ridge")
parser.add_argument("--n-estimators", type=int, default=300, help="Number of trees for ensemble models")
parser.add_argument("--hidden", type=int, default=128, help="Hidden units for MLP/Keras")
parser.add_argument("--out-dir", default=".", help="Output directory for model & results")
parser.add_argument("--predict-file", help="Optional: Predict on new genotype file using trained model")
args = parser.parse_args()

os.makedirs(args.out_dir, exist_ok=True)

# ---------- Load & merge ----------
log("status", msg="loading_data")
geno = pd.read_csv(args.genotype, sep=None, engine="python")
pheno = pd.read_csv(args.phenotype, sep=None, engine="python")

geno.columns = geno.columns.str.strip()
pheno.columns = pheno.columns.str.strip()
if args.id_col not in geno.columns or args.id_col not in pheno.columns:
    log("error", msg=f"id_col '{args.id_col}' not found in both files")
    sys.exit(1)
if args.target not in pheno.columns:
    log("error", msg=f"target '{args.target}' not found in phenotype file")
    sys.exit(1)

data = pd.merge(pheno[[args.id_col, args.target]], geno, on=args.id_col, how="inner").dropna()
y = data[args.target]
X = data.drop(columns=[args.id_col, args.target])
log("status", msg=f"merged_data shape {data.shape}")

# ---------- Task detection ----------
numeric_y = pd.to_numeric(y, errors="coerce")
non_null_ratio = numeric_y.notnull().mean()
unique_count = len(y.unique())
task = "reg" if (non_null_ratio > 0.95 and unique_count > 10) else "cls"
log("task", kind=task)

# ---------- Preprocessing ----------
X = X.apply(pd.to_numeric, errors="coerce").fillna(0)
if task == "cls":
    y = y.astype("category").cat.codes.astype("int32")
else:
    y = numeric_y.fillna(0).astype("float32")

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42,
    stratify=(y if task=="cls" else None)
)

# ---------- Build estimator ----------
def build_estimator(task, model):
    if model == "ridge":
        est = Ridge(alpha=args.alpha)
    elif model == "rf":
        est = RandomForestRegressor(n_estimators=args.n_estimators, random_state=42, n_jobs=-1) if task=="reg" \
            else RandomForestClassifier(n_estimators=args.n_estimators, random_state=42, n_jobs=-1)
    elif model == "gbdt":
        est = GradientBoostingRegressor(random_state=42) if task=="reg" else GradientBoostingClassifier(random_state=42)
    elif model == "mlp":
        est = MLPRegressor(hidden_layer_sizes=(args.hidden,), max_iter=args.epochs, random_state=42) if task=="reg" \
            else MLPClassifier(hidden_layer_sizes=(args.hidden,), max_iter=args.epochs, random_state=42)
    elif model == "keras":
        input_dim = X_train.shape[1]
        if task == "cls":
            num_classes = len(np.unique(y_train))
            est = Sequential([
                Dense(args.hidden, activation="relu", input_shape=(input_dim,)),
                Dropout(args.dropout),
                Dense(64, activation="relu"),
                Dense(num_classes, activation="softmax" if num_classes > 2 else "sigmoid")
            ])
            est.compile(optimizer="adam",
                        loss="sparse_categorical_crossentropy" if num_classes > 2 else "binary_crossentropy",
                        metrics=["accuracy"])
        else:
            est = Sequential([
                Dense(args.hidden, activation="relu", input_shape=(input_dim,)),
                Dropout(args.dropout),
                Dense(64, activation="relu"),
                Dense(1)
            ])
            est.compile(optimizer="adam", loss="mse")
    else:
        raise ValueError(f"Unknown model type: {model}")

    if model == "keras":
        return est
    steps = []
    if args.standardize:
        steps.append(("scaler", StandardScaler()))
    steps.append(("est", est))
    return Pipeline(steps)

model = build_estimator(task, args.model)

# ---------- Unified Cross-validation + Full retrain ----------
if args.cv and args.cv > 1:
    log("status", msg=f"{args.validation_mode}_validation_{args.cv}fold")
    splitter = StratifiedKFold(n_splits=args.cv, shuffle=True, random_state=42) if task == "cls" \
        else KFold(n_splits=args.cv, shuffle=True, random_state=42)

    corr_list, r2_list, rmse_list = [], [], []
    y_all, y_pred_all = [], []

    for fold, (train_idx, test_idx) in enumerate(splitter.split(X, y), start=1):
        model_fold = build_estimator(task, args.model)
        X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
        y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]
        model_fold.fit(X_train, y_train)
        preds = model_fold.predict(X_test)

        if task == "cls":
            acc = accuracy_score(y_test, preds)
            corr_list.append(acc)
        else:
            # --- flatten only if predictions are 2D (e.g., Keras outputs shape (n,1)) ---
            if hasattr(preds, "shape") and len(preds.shape) > 1 and preds.shape[1] == 1:
                preds = preds[:, 0]

            if np.std(y_test) > 0 and np.std(preds) > 0:
                r = np.corrcoef(y_test, preds)[0, 1]
            else:
                r = np.nan


            n = len(y_test)
            if args.validation_mode == "corrected" and n > 4 and not np.isnan(r):
                r = r * (1 + (1 - r**2) / (2 * (n - 4)))  # Olkin & Pratt
            r2 = r2_score(y_test, preds)
            rmse = np.sqrt(mean_squared_error(y_test, preds))
            corr_list.append(r)
            r2_list.append(r2)
            rmse_list.append(rmse)

        y_all.extend(y_test)
        y_pred_all.extend(preds)

    # ---------- Aggregate results ----------
    if task == "cls":
        mean_corr = float(np.nanmean(corr_list))
        log("cv_metrics", accuracy=mean_corr)
    else:
        if args.validation_mode == "hold" and np.std(y_all) > 0 and np.std(y_pred_all) > 0:
            mean_corr = float(np.corrcoef(y_all, y_pred_all)[0, 1])
        else:
            mean_corr = float(np.nanmean(corr_list))
        mean_r2 = float(np.nanmean(r2_list))
        mean_rmse = float(np.nanmean(rmse_list))
        log("cv_metrics", corr=mean_corr, r2=mean_r2, rmse=mean_rmse)

else:
    # ---------- Single holdout evaluation (cv <= 1) ----------
    log("status", msg="single_holdout_validation")

    # Split once (80% train, 20% test)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42,
        stratify=(y if task == "cls" else None)
    )

    model_single = build_estimator(task, args.model)
    model_single.fit(X_train, y_train)
    preds = model_single.predict(X_test)

    if task == "cls":
        acc = accuracy_score(y_test, preds)
        log("cv_metrics", accuracy=float(acc))
        mean_corr = acc
        mean_r2 = None
        mean_rmse = None
    else:
        # Flatten if keras (n,1)
        if hasattr(preds, "shape") and len(preds.shape) > 1 and preds.shape[1] == 1:
            preds = preds[:, 0]

        if np.std(y_test) > 0 and np.std(preds) > 0:
            r = np.corrcoef(y_test, preds)[0, 1]
        else:
            r = np.nan

        n = len(y_test)
        if args.validation_mode == "corrected" and n > 4 and not np.isnan(r):
            r = r * (1 + (1 - r**2) / (2 * (n - 4)))  # Olkin & Pratt

        r2 = r2_score(y_test, preds)
        rmse = np.sqrt(mean_squared_error(y_test, preds))
        mean_corr = float(r)
        mean_r2 = float(r2)
        mean_rmse = float(rmse)
        log("cv_metrics", corr=mean_corr, r2=mean_r2, rmse=mean_rmse)


# ---------- Retrain on full dataset ----------
log("status", msg=f"training_full_{args.model}")

final_model = build_estimator(task, args.model)
if args.model == "keras":
    final_model.fit(X, y, epochs=args.epochs, batch_size=args.batch_size, verbose=0)
else:
    final_model.fit(X, y)

# ---------- Save final model ----------
from datetime import datetime

# Create unique timestamp for each run
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")

# e.g., ridge_EarHT_2025-11-06_15-42
model_dir = Path(args.out_dir) / f"{args.model}_{args.target}_{timestamp}"
model_dir.mkdir(parents=True, exist_ok=True)

#model_dir = Path(args.out_dir) / f"{args.model}_{args.target}"
#model_dir.mkdir(parents=True, exist_ok=True)

if args.model == "keras":
    model_path = model_dir / f"model_{args.target}_{task}.keras"
    final_model.save(model_path)
else:
    model_path = model_dir / "model.joblib"
    joblib.dump({"model": final_model, "task": task, "id_col": args.id_col, "target": args.target}, model_path)

log("artifact", path=str(model_path), kind="model")
log("done", msg=f"✅ {args.cv}-fold {args.validation_mode} CV complete; model retrained on all {len(y)} samples")



# ---------- Feature importance ----------
if args.model in ["ridge", "rf", "gbdt", "mlp"]:
    try:
        # some sklearn estimators (RF, GBDT) expose feature_importances_
        if hasattr(final_model, "feature_importances_"):
            imp = final_model.feature_importances_

        # others (ridge, linear) expose coef_
        elif hasattr(final_model, "coef_"):
            imp = np.abs(final_model.coef_).ravel()

        # if nothing available, use permutation importance as fallback
        else:
            from sklearn.inspection import permutation_importance
            pi = permutation_importance(final_model, X, y, n_repeats=5, random_state=42, n_jobs=-1)
            imp = pi.importances_mean

        imp_df = pd.DataFrame({
            "feature": X.columns,
            "importance": imp
        }).sort_values("importance", ascending=False)

        imp_path = model_dir / "feature_importance.csv"
        imp_df.to_csv(imp_path, index=False)
        log("artifact", path=str(imp_path), kind="feature_importance")

    except Exception as e:
        log("warn", msg=f"feature_importance_failed: {e}")


# ---------- Prediction mode ----------
if args.predict_file:
    log("status", msg="predict_mode")
    if args.model == "keras":
        Gnew = pd.read_csv(args.predict_file, sep=None, engine="python")
        Gnew = Gnew.apply(pd.to_numeric, errors="coerce").fillna(0)
        preds = model.predict(Gnew).ravel()
    else:
        bundle = joblib.load(model_path)
        mdl = bundle["model"]
        Gnew = pd.read_csv(args.predict_file, sep=None, engine="python")
        preds = mdl.predict(Gnew)
    pred_df = pd.DataFrame({"pred_"+args.target: preds})
    pred_path = Path(args.out_dir) / "predictions.csv"
    pred_df.to_csv(pred_path, index=False)
    log("artifact", path=str(pred_path), kind="predictions")

# ---------- Summary ----------
# ---------- Save summary ----------
summary = {
    "task": task,
    "model": args.model,
    "target": args.target,
    "cv_folds": args.cv,
    "validation_mode": args.validation_mode,
    "cv_corr": float(mean_corr) if "mean_corr" in locals() else None,
    "cv_r2": float(mean_r2) if "mean_r2" in locals() else None,
    "cv_rmse": float(mean_rmse) if "mean_rmse" in locals() else None,
}

summary_path = model_dir / "summary.json"
with open(summary_path, "w") as f:
    json.dump(summary, f, indent=4)

log("artifact", path=str(summary_path), kind="summary")



#you can run this in command line without UI
#python3 ml_tabular.py \
#  --genotype G.csv --phenotype P.csv \
#  --id-col SampleID --target EarHT \
#  --model ridge --standardize --cv 5 \
#  --out-dir ./runs
