#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Detección de fallos con Autoencoder distribuido usando BICs
- Calcula T2/Q por grupo
- Combina con calcula_BICs(...)
- Calcula falsas alarmas, alarmas detectadas y tiempo de detección (10 consecutivas)
- Guarda resultados en Excel
@author: (generado)
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from keras.models import load_model
import pickle
from bic2 import calcula_BICs

# -------------------------
# Ajusta estas rutas si hace falta
# -------------------------
output_dir = '/Users/pablo/Documents/IEIA/TFG/Resultados/D2_LSTM'
os.makedirs(output_dir, exist_ok=True)
graficas_dir = '/Users/pablo/Documents/IEIA/TFG/Resultados/Graphics_BICs_D2'
data_path = '/Users/pablo/Documents/IEIA/TFG/_Autoencoders/Distributed2_Data.pkl'
model_dir = '/Users/pablo/Documents/IEIA/TFG/MODELOS/Distributed2_Models'
archivo_npz_path = '/Users/pablo/Documents/IEIA/TFG/_Autoencoders/AutoencoderLSTM_Data.npz'
# -------------------------

# Cargar datos precomputados por grupos
with open(data_path, 'rb') as f:
    results = pickle.load(f)


bic_data = results.get('BICs', None)
if bic_data is None:
    raise ValueError("No se encontró la clave 'BICs' en el archivo .pkl")

# Umbrales BIC precalculados MODELO SIN FALLOS
uBIC_T2 = bic_data['umbral_bic_t2']
uBIC_Q  = bic_data['umbral_bic_q']


grupo_nombres = sorted(
    [k for k in results.keys() if k.startswith("X") and k.endswith("n")],
    key=lambda x: int(x[1:-1])
)
n_grupos = len(grupo_nombres)
print("Grupos detectados:", grupo_nombres)

# cargar min/max globales usadas para normalizar
archivo_npz = np.load(archivo_npz_path)
Xmin = archivo_npz['Xmin']
Xmax = archivo_npz['Xmax']

# Matriz que indica si cada grupo detecta cada fallo (SI/NO)
matriz_detect = pd.DataFrame("NO",
    index=[f"d{str(i).zfill(2)}" for i in range(1,22)],
    columns=grupo_nombres)

# Resultados por fallo: DataFrame con las 6 métricas requeridas
cols = ['Falsas_T2 (%)', 'Detectadas_T2 (%)', 'TiempoDet_T2',
        'Falsas_Q (%)', 'Detectadas_Q (%)', 'TiempoDet_Q']
df_stats = pd.DataFrame(index=[f"d{str(i).zfill(2)}" for i in range(1,22)], columns=cols)

# crear secuencias para LSTM
def create_sequences(X, time_steps):
    return np.array([X[i:i + time_steps] for i in range(len(X) - time_steps)])

# Parámetro para la regla "10 consecutivas"
N_CONSEC = 10
# Número de muestras en la fase conocida como "normal" (tal como usas): 160
NORMAL_WINDOW = 160

# Recorremos los 21 fallos
for idx_fallo in range(1,22):
    tag = f"d{str(idx_fallo).zfill(2)}"
    print("\n" + "="*60)
    print("Procesando fallo:", tag)
    print("="*60)

    # cargar fichero de fallo
    X_test_path = f"/Users/pablo/Documents/IEIA/TFG/datos_csv/{tag}_te.csv"
    X_test = pd.read_csv(X_test_path)

    # normalizar con los min/max guardados
    Xn_test = (X_test - Xmin) / (Xmax - Xmin)
    Xn_test = pd.DataFrame(Xn_test)

    # listas para apilar T2/Q por grupo
    T2_all = []
    Q_all = []
    umbral_T2_list = []
    umbral_Q_list = []

    # procesar cada grupo
    for nombre in grupo_nombres:
        group_data = results[nombre]
        umbral_T2 = group_data['uT2']
        umbral_Q = group_data['uQ']
        hm = group_data['hm']
        hdesv = group_data['hdesv']
        time_steps = int(group_data['time_steps'])
        columnas = group_data['columnas']

        # extraer las columnas pertenecientes al grupo
        X_test_group = Xn_test.iloc[:, columnas]
        if len(X_test_group) <= time_steps:
            # no hay suficientes muestras para crear secuencias (salvar con arrays vacíos)
            Xn_seq = np.empty((0, time_steps, X_test_group.shape[1]))
        else:
            Xn_seq = create_sequences(X_test_group.values, time_steps)

        # cargar modelos del grupo
        autoencoder = load_model(os.path.join(model_dir, f"{nombre}_autoencoder.keras"))
        encoder = load_model(os.path.join(model_dir, f"{nombre}_encoder.keras"))

        # predicción y representación latente
        if Xn_seq.size == 0:
            # no hay datos: vectores vacíos
            T2 = np.array([])
            Q = np.array([])
        else:
            X_pred = autoencoder.predict(Xn_seq)
            h = encoder.predict(Xn_seq)

            # T2 (en espacio latente)
            covin = np.linalg.inv(hdesv)
            T2 = np.array([np.dot(np.dot((h[i] - hm), covin), (h[i] - hm).T) for i in range(h.shape[0])])

            # Q (residuo)
            res = Xn_seq - X_pred
            residuo = res.reshape(res.shape[0], -1)
            Q = np.array([np.dot(residuo[i], residuo[i].T) for i in range(len(residuo))])

        # almacenar
        T2_all.append(T2)
        Q_all.append(Q)
        umbral_T2_list.append(umbral_T2)
        umbral_Q_list.append(umbral_Q)

        # comprobar si este grupo por sí solo detecta el fallo (10 consecutivas)
        detect_T2_group = False
        detect_Q_group = False
        if T2.size > 0:
            idx_T2 = next((i for i in range(len(T2) - N_CONSEC + 1) if all(T2[i + j] > umbral_T2 for j in range(N_CONSEC))), None)
            if idx_T2 is not None:
                detect_T2_group = True
        if Q.size > 0:
            idx_Q = next((i for i in range(len(Q) - N_CONSEC + 1) if all(Q[i + j] > umbral_Q for j in range(N_CONSEC))), None)
            if idx_Q is not None:
                detect_Q_group = True

        if detect_T2_group or detect_Q_group:
            matriz_detect.loc[tag, nombre] = "SI"

    # Verificamos que todos los grupos devolvieron vectores de la misma longitud (necesario para apilar)
    # elegir la longitud mínima entre los T2s generados (para evitar problemas si algún grupo devolvió algo distinto)
    lengths_T2 = [len(v) for v in T2_all if v is not None]
    lengths_Q = [len(v) for v in Q_all if v is not None]
    if len(lengths_T2) == 0 or len(lengths_Q) == 0:
        print(f"⚠️ Fallo {tag}: no hay datos útiles en alguno de los grupos (longitudes T2/Q). Se saltará.")
        # rellenar NaNs en df_stats
        df_stats.loc[tag] = [np.nan]*6
        continue

    n_obs = min(lengths_T2 + lengths_Q)  # tomar la mínima longitud común
    # recortar cada vector al tamaño n_obs
    T2_aligned = [t2[:n_obs] if len(t2) >= n_obs else np.pad(t2, (0, n_obs-len(t2)), constant_values=np.nan) for t2 in T2_all]
    Q_aligned = [q[:n_obs] if len(q) >= n_obs else np.pad(q, (0, n_obs-len(q)), constant_values=np.nan) for q in Q_all]

    # apilar columnas: shape -> (n_obs, n_grupos)
    T2_grupos = np.column_stack(T2_aligned)
    Q_grupos = np.column_stack(Q_aligned)

    umbral_T2_grupos = np.array(umbral_T2_list)
    umbral_Q_grupos = np.array(umbral_Q_list)

    # calcular BICs 
    bic_res = calcula_BICs(T2_grupos, umbral_T2_grupos, Q_grupos, umbral_Q_grupos)
    # acepta tanto (bic_t2, bic_q, alpha) como (bic_t2, bic_q, uBIC_t2, uBIC_q)
    if len(bic_res) == 2:
        BIC_T2, BIC_Q = bic_res
        x = np.percentile(BIC_T2, 90)
        y = np.percentile(BIC_Q, 90)
    elif len(bic_res) == 3:
        BIC_T2, BIC_Q, _alpha = bic_res
        x = np.percentile(BIC_T2, 50)
        y = np.percentile(BIC_Q, 70)
    else:
        # si devuelve 4 items (por si lo modificaste antes)
        BIC_T2, BIC_Q, x, y = bic_res

    os.makedirs(graficas_dir, exist_ok=True)

    # -----------------------------
    # GUARDAR BIC_T2
    # -----------------------------
    plt.figure(figsize=(10,4))
    plt.plot(range(1, len(BIC_T2)+1), BIC_T2, label='BIC_T2')
    plt.axhline(y=uBIC_T2, color='red', linestyle='--', label='Umbral BIC_T2')
    plt.title(f"BIC_T2 - {tag}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(os.path.join(graficas_dir, f"{tag}_BIC_T2.png"), dpi=300)
    plt.close()

    # -----------------------------
    # GUARDAR BIC_Q
    # -----------------------------
    plt.figure(figsize=(10,4))
    plt.plot(range(1, len(BIC_Q)+1), BIC_Q, label='BIC_Q')
    plt.axhline(y=uBIC_Q, color='red', linestyle='--', label='Umbral BIC_Q')
    plt.title(f"BIC_Q - {tag}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    plt.savefig(os.path.join(graficas_dir, f"{tag}_BIC_Q.png"), dpi=300)
    plt.close()

    # -----------------------------
    # Ahora aplicamos las mismas métricas de detección que con LSTM pero sobre BIC_T2 y BIC_Q
    # -----------------------------
    # Falsas alarmas: contar en primeras NORMAL_WINDOW muestras
    # Nota: si n_obs < NORMAL_WINDOW, adaptamos al número disponible
    reference_window = min(NORMAL_WINDOW, len(BIC_T2))

    # T2
    false_count_T2 = np.sum(BIC_T2[:reference_window] > uBIC_T2)
    falsas_T2_pct = (false_count_T2 / reference_window) * 100

    detect_count_T2 = np.sum(BIC_T2[reference_window:] > uBIC_T2) if reference_window < len(BIC_T2) else 0
    detect_pct_T2 = (detect_count_T2 / max(1, (len(BIC_T2)-reference_window))) * 100

    # tiempo de detección: primer bloque de N_CONSEC consecutivas por encima de umbral
    idx_det_T2 = next((i for i in range(len(BIC_T2) - N_CONSEC + 1) if all(BIC_T2[i + j] > uBIC_T2 for j in range(N_CONSEC))), None)

    # Q
    false_count_Q = np.sum(BIC_Q[:reference_window] > uBIC_Q)
    falsas_Q_pct = (false_count_Q / reference_window) * 100

    detect_count_Q = np.sum(BIC_Q[reference_window:] > uBIC_Q) if reference_window < len(BIC_Q) else 0
    detect_pct_Q = (detect_count_Q / max(1, (len(BIC_Q)-reference_window))) * 100

    idx_det_Q = next((i for i in range(len(BIC_Q) - N_CONSEC + 1) if all(BIC_Q[i + j] > uBIC_Q for j in range(N_CONSEC))), None)

    # escribir resultados en tabla
    df_stats.loc[tag, 'Falsas_T2 (%)'] = float(f'{falsas_T2_pct:.3f}')
    df_stats.loc[tag, 'Detectadas_T2 (%)'] = float(f'{detect_pct_T2:.3f}')
    df_stats.loc[tag, 'TiempoDet_T2'] = int(idx_det_T2) if idx_det_T2 is not None else np.nan

    df_stats.loc[tag, 'Falsas_Q (%)'] = float(f'{falsas_Q_pct:.3f}')
    df_stats.loc[tag, 'Detectadas_Q (%)'] = float(f'{detect_pct_Q:.3f}')
    df_stats.loc[tag, 'TiempoDet_Q'] = int(idx_det_Q) if idx_det_Q is not None else np.nan

    print(f"-> {tag} | Falsas T2: {falsas_T2_pct:.2f}%  Detectadas T2: {detect_pct_T2:.2f}%  TiempoDet_T2: {idx_det_T2}")
    print(f"-> {tag} | Falsas Q:  {falsas_Q_pct:.2f}%  Detectadas Q:  {detect_pct_Q:.2f}%  TiempoDet_Q:  {idx_det_Q}")

# Guardar resultados
excel_stats = os.path.join(output_dir, "Stats_BIC_Distribuido.xlsx")
df_stats.to_excel(excel_stats)
print("\n✅ Estadísticas guardadas en:", excel_stats)

excel_mat = os.path.join(output_dir, "MatrizDetect_Grupos.xlsx")
matriz_detect.to_excel(excel_mat)
print("✅ Matriz detección grupos guardada en:", excel_mat)




