#
#  ml_predict_only.py
#  AI4EVER
#
#  Created by Meijing Liang on 12/1/25.
#


#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import json
import sys
from pathlib import Path
from datetime import datetime

import joblib
import numpy as np
import pandas as pd


def log(event: str, **kwargs):
    rec = {"event": event}
    rec.update(kwargs)
    print(json.dumps(rec), flush=True)


def main():
    parser = argparse.ArgumentParser(description="AI4EVER predict-only (no retraining)")

    parser.add_argument(
        "--model-bundle",
        required=True,
        help="Path to model.joblib saved by ml_tabular.py",
    )
    parser.add_argument(
        "--predict-file",
        required=True,
        help="Genotype file for prediction (same markers as training)",
    )
    parser.add_argument(
        "--out-dir",
        type=str,
        default="./runs",
        help="Output directory (same root as training runs)",
    )

    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # Load bundle
    # ------------------------------------------------------------------
    log("status", msg="predict_only_start", model_bundle=args.model_bundle)

    try:
        bundle = joblib.load(args.model_bundle)
    except Exception as e:
        log("error", msg=f"failed_to_load_model_bundle: {e}")
        sys.exit(1)

    model = bundle.get("model", None)
    scaler = bundle.get("scaler", None)
    task = bundle.get("task", "reg")
    id_col = bundle.get("id_col", "Taxa")
    target = bundle.get("target", "trait")
    feature_cols = bundle.get("feature_cols", None)
    class_mapping = bundle.get("class_mapping", None)

    if model is None or feature_cols is None:
        log("error", msg="model bundle missing 'model' or 'feature_cols'")
        sys.exit(1)

    # ------------------------------------------------------------------
    # Load predict genotype file
    # ------------------------------------------------------------------
    log("status", msg="loading_predict_file", file=args.predict_file)
    Gnew = pd.read_csv(args.predict_file, sep=None, engine="python")
    Gnew_cols = Gnew.columns

    # Align markers with training features
    missing = [c for c in feature_cols if c not in Gnew_cols]
    if missing:
        log(
            "warn",
            msg=f"{len(missing)} markers missing in predict file; they will be filled with 0",
        )
        for c in missing:
            Gnew[c] = 0

    # Reorder columns to match training
    if id_col in Gnew.columns:
        Gnew = Gnew[[id_col] + list(feature_cols)]
    else:
        Gnew = Gnew[list(feature_cols)]

    Gnew_num = Gnew.drop(columns=[id_col], errors="ignore")
    Gnew_num = Gnew_num.apply(pd.to_numeric, errors="coerce").fillna(0)

    # ------------------------------------------------------------------
    # Scale and predict
    # ------------------------------------------------------------------
    if scaler is not None:
        X_new = scaler.transform(Gnew_num.values)
    else:
        X_new = Gnew_num.values

    log("status", msg="running_prediction", n_samples=len(Gnew_num))
    preds = np.asarray(model.predict(X_new)).ravel()

    # ------------------------------------------------------------------
    # Create run directory
    # ------------------------------------------------------------------
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M")
    run_dir = out_dir / f"predict_{target}_{timestamp}"
    run_dir.mkdir(parents=True, exist_ok=True)

    # ------------------------------------------------------------------
    # Save predictions.csv
    # ------------------------------------------------------------------
    pred_col = f"pred_{target}"
    pred_df = pd.DataFrame({pred_col: preds})
    pred_path = run_dir / "predictions.csv"
    pred_df.to_csv(pred_path, index=False)
    log("artifact", path=str(pred_path), kind="predictions")

    # ------------------------------------------------------------------
    # Save viz_predict.csv (ID + predicted)
    # ------------------------------------------------------------------
    try:
        if id_col in Gnew.columns:
            viz_pred_df = pd.DataFrame({id_col: Gnew[id_col].values})
        else:
            viz_pred_df = pd.DataFrame(
                {"ID": [f"sample_{i+1}" for i in range(len(preds))]}
            )

        viz_pred_df["predicted"] = preds
        viz_pred_path = run_dir / "viz_predict.csv"
        viz_pred_df.to_csv(viz_pred_path, index=False)
        log("artifact", path=str(viz_pred_path), kind="viz_predict")
    except Exception as e:
        log("warn", msg=f"viz_predict_failed: {e}")

    # ------------------------------------------------------------------
    # Save a small summary (optional)
    # ------------------------------------------------------------------
    summary = {
        "mode": "predict_only",
        "task": task,
        "target": target,
        "model_bundle": args.model_bundle,
        "run_dir": str(run_dir),
        "n_samples": int(len(Gnew_num)),
    }
    summary_path = run_dir / "summary.json"
    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=4)
    log("artifact", path=str(summary_path), kind="summary")

    log("done", msg="predict_only_complete")


if __name__ == "__main__":
    main()
