#include <QString>
#include <er_ECLES_least_squares.h>

#include <iostream>
using namespace std;

#define DEBUG_LS TRUE



/** *****************************************************************************************************************/
void determineLSSystem(int n, int r, fmpz_mat_t M, fmpz_mat_t U, int Pc[], int fac[])
{
    double f;
    fmpz_t aux, aux_f;
    fmpz_mat_t W, U_pc, U_t, U_pci, Ut_pci;
    int rn = r+n;
    int Pc_inv[n];

    // inicializa fmpz_t
    fmpz_init(aux);
    fmpz_init(aux_f);

    // cria matriz M
    fmpz_mat_init(M, rn, rn);
    fmpz_mat_init(U_t, n, r);

    // cria matriz de pesos e inicializa com uma matriz identidade    
    determineWeightMatrix(n, W);

    // preenche o primeiro bloco nxn da matriz M (I)
    for(int i=0; i<n; i++){
        for(int j=0; j<n; j++){
            fmpz_set(fmpz_mat_entry(M, i, j), fmpz_mat_entry(W, i, j));
        }
    }

    if(r > 0){

        // calcula Pc * U;
        permuteColumns(r, n, Pc, U, U_pc);

        // calcula U_transposta
        fmpz_mat_transpose(U_t, U_pc);

        // preenche o segundo bloco nxm da matriz M (U transposta) e o terceiro bloco mxn (U)
        for(int i=0; i<n; i++){
            for(int j=n; j<rn; j++){

                f = 1; //fac[j-n];
                fmpz_set_d(aux_f, f);

                fmpz_divexact(aux, fmpz_mat_entry(U_pc, j-n, i), aux_f);
                fmpz_set(fmpz_mat_entry(M, j, i), aux); // bloco mxn (U)

                fmpz_divexact(aux, fmpz_mat_entry(U_t, i, j-n), aux_f);
                fmpz_set(fmpz_mat_entry(M, i, j), aux); // bloco nxm (U transposta)
            }
        }
    }


    if(DEBUG_LS){
        cout << "\n\n:::::: DETERMINE_LSS ======================================" << endl;
        flint_printf("\nLSM = \n"); fmpz_mat_print_pretty(M);
    }

}

/** *******************************************************************************************************************/
void determineWeightMatrix(int n, fmpz_mat_t W)
{
    fmpz_mat_init(W, n, n);
    fmpz_mat_one(W);
    fmpz_mat_scalar_mul_si(W, W, 2);

    if(DEBUG_LS){
        cout << "\n\n:::::: DETERMINE_WEIGHT_MATRIX ======================================" << endl;
        flint_printf("\n\nW = \n"); fmpz_mat_print_pretty(W);
    }
}


/** *******************************************************************************************************************/
void determineLSMatrix(int m, int r, int n, fmpz_mat_t W, fmpz_mat_t L, fmpz_mat_t DG, fmpz_mat_t U, int Pc[], fmpq_mat_t M1, fmpq_mat_t M2, fmpq_mat_t M3, fmpq_mat_t M4)
{
    fmpz_mat_t aux_22, M22, M2_aux;
    fmpq_mat_t DG_aux, DG_inv, aux_1, aux_2, aux_3, M1_inv, L_aux, U_aux;

    // define matriz M1 = 2 * W (quadrante superior esquerdo)
    fmpq_mat_init(M1, n, n);
    fmpq_mat_set_fmpz_mat(M1, W);

    // calcula inversa da Diagonal
    fmpq_mat_init(DG_aux, r, r);
    fmpq_mat_set_fmpz_mat(DG_aux, DG);
    fmpq_mat_init(DG_inv, r, r);

    // calculando M2 = L * D-1 * U * Pc
    if(fmpq_mat_inv(DG_inv, DG_aux) != 0){

        fmpq_mat_init(L_aux, m, r);
        fmpq_mat_set_fmpz_mat(L_aux, L);

        // L * D-1
        fmpq_mat_init(aux_1, m, r);
        fmpq_mat_mul(aux_1, L_aux, DG_inv);

        fmpq_mat_init(U_aux, r, n);
        fmpq_mat_set_fmpz_mat(U_aux, U);

        // L * D-1 * U
        fmpq_mat_init(aux_2, m, n);
        fmpq_mat_mul(aux_2, aux_1, U_aux);

        fmpz_mat_init(aux_22, m, n);
        fmpq_mat_get_fmpz_mat(aux_22, aux_2);

        // L * D-1 * U * Pc
        fmpz_mat_init(M22, m, n);
        permuteColumns(m, n, Pc, aux_22, M22);

        // M2 são as primeiras r linhas da matrix M22
        fmpz_mat_init(M2_aux, r, n);
        for(int i=0; i<r; i++){
            for(int j=0; j<n; j++){
                fmpz_set(fmpz_mat_entry(M2_aux, i, j), fmpz_mat_entry(M22, i, j));
            }
        }

        fmpq_mat_init(M2, r, n);
        fmpq_mat_set_fmpz_mat(M2, M2_aux);

    } else{
        cout << "\n\n >>> Determinando matriz diagonal :: DG não tem inversa! " << endl;
    }


    // define matriz M3 = M2_transposta(quadrante superior direito)
    fmpq_mat_init(M3, n, r);
    fmpq_mat_transpose(M3, M2);

    // define matriz M4 = M2 * M1-1 * M3
    fmpq_mat_init(M1_inv, n, n);
    if(fmpq_mat_inv(M1_inv, M1) != 0){

        // M2 * M1-1
        fmpq_mat_init(aux_3, r, n);
        fmpq_mat_mul(aux_3, M2, M1_inv);

        // M2 * M1-1 * M3
        fmpq_mat_init(M4, r, r);
        fmpq_mat_mul(M4, aux_3, M3);

    }else{
        cout << "\n\n >>> Determinando matriz M1 :: M1 não tem inversa! " << endl;
    }

    if(DEBUG_LS){
        cout << "\n\n:::::: DETERMINE_LSM ======================================" << endl;
        flint_printf("\nM1 = \n"); fmpq_mat_print(M1);
        flint_printf("\nM1_inv = \n"); fmpq_mat_print(M1_inv);
        flint_printf("\n\nM2 = \n"); fmpq_mat_print(M2);
        flint_printf("\n\nM3 = \n"); fmpq_mat_print(M3);
        flint_printf("\n\nM4 = \n"); fmpq_mat_print(M4);
    }

}


/** *******************************************************************************************************************/
void solveLSSystem(int r, int n, fmpq_mat_t M1, fmpq_mat_t M2, fmpq_mat_t M3, fmpq_mat_t M4, fmpz_mat_t B_pr, int p, fmpq_mat_t P1, fmpq_mat_t P2)
{
    fmpq_mat_t G;

    // inicializa fmpz_t
    fmpq_mat_init(G, r, p);

    // resolve primeira parte do sistema
    solveFirstPart(r, M2, M4, B_pr, p, P1, G);

    // resolve segunda parte do sistema    
    solveSecondPart(n, M1, M3, P1, G, p, P2);

}


/** *******************************************************************************************************************/
bool verifySolution(int m, int n, int p, fmpz_mat_t A, fmpz_mat_t B, fmpq_mat_t P2, double *Xd)
{
    fmpq_mat_t aux, Aaux, Baux;

    // converte matriz exata A para racional
    fmpq_mat_init(Aaux, m, n);
    fmpq_mat_set_fmpz_mat(Aaux, A);

    // multiplica A por P2
    fmpq_mat_init(aux, m, p);
    fmpq_mat_mul(aux, Aaux, P2);

    // converte matriz exata B para racional
    fmpq_mat_init(Baux, m, p);
    fmpq_mat_set_fmpz_mat(Baux, B);

    if(DEBUG_LS){
        cout << "\n\n:::::: VERIFICANDO SE A SOLUÇÃO SATIZFAZ O SISTEMA ======================================" << endl;
        flint_printf("\nA = \n"); fmpz_mat_print_pretty(A);
        flint_printf("\n\nP2 = \n"); fmpq_mat_print(P2);
        flint_printf("\n\nA * P2 = \n"); fmpq_mat_print(aux);
        flint_printf("\n\nB = \n"); fmpq_mat_print(Baux);
    }

    if(fmpq_mat_equal(aux, Baux)){
        cout << "\n\nSolução satisfaz o sistema!" << endl;
        convertFmpqToDoubleMatrix(n, p, P2, Xd);
        return true;
    }

    return false;
}


/** *******************************************************************************************************************/
void solveFirstPart(int r, fmpq_mat_t M2, fmpq_mat_t M4, fmpz_mat_t B_pr, int p, fmpq_mat_t P1, fmpq_mat_t G)
{
    fmpq_mat_t M4_inv, Baux, aux1, aux2;

    // multiplica M2 = L * Dg-1 * U * Pc por P1 = P'_d
    fmpq_mat_init(aux1, r, p);
    fmpq_mat_mul(aux1, M2, P1);

    // subtrai B = Pr-1 * b de aux1 = M2 * P'_d
    fmpq_mat_init(Baux, r, p);
    fmpq_mat_set_fmpz_mat(Baux, B_pr);

    fmpq_mat_init(aux2, r, p);
    fmpq_mat_sub(aux2, aux1, Baux);

    // calcula valor de gama
    fmpq_mat_init(M4_inv, r, r);
    if(fmpq_mat_inv(M4_inv, M4) != 0){
        fmpq_mat_init(G, r, p);
        fmpq_mat_mul(G, M4_inv, aux2);
    }else{
        cout << "\n\n >>> Resolve parte 1 do sistema LS :: M4 não tem inversa! " << endl;
    }

    if(DEBUG_LS){
        cout << "\n\n:::::: SOLVE_FIRST_PART_SYSTEM ======================================" << endl;
        flint_printf("\nP'_d = \n"); fmpq_mat_print(P1);
        flint_printf("\nM2 * P'_d = \n"); fmpq_mat_print(aux1);
        flint_printf("\n\nM2 * P'd - B_pr = \n"); fmpq_mat_print(aux2);
        flint_printf("\n\nM4_(-1) = \n"); fmpq_mat_print(M4_inv);
        flint_printf("\n\nGama = \n"); fmpq_mat_print(G);
    }
}


/** *******************************************************************************************************************/
void solveSecondPart(int n, fmpq_mat_t M1, fmpq_mat_t M3, fmpq_mat_t P1, fmpq_mat_t G, int p, fmpq_mat_t P2)
{
    fmpq_mat_t aux1, aux2, aux3, M1_inv;

    flint_printf("\n\M3= \n"); fmpq_mat_print(M3);

    // multiplica M1 por P1 = P'_d
    fmpq_mat_init(aux1, n, p);
    fmpq_mat_mul(aux1, M1, P1);

    // multiplica M3 por Gama(G)
    fmpq_mat_init(aux2, n, p);
    fmpq_mat_mul(aux2, M3, G);

    // subtrai aux2 = M3 * G de aux1 = M1 * P'_d
    fmpq_mat_init(aux3, n, p);
    fmpq_mat_sub(aux3, aux1, aux2);

    // calcula valor de P2 = P''_d
    fmpq_mat_init(M1_inv, n, n);
    if(fmpq_mat_inv(M1_inv, M1) != 0){
        fmpq_mat_init(P2, n, p);
        fmpq_mat_mul(P2, M1_inv, aux3);

    }else{
        cout << "\n\n >>> Resolve parte 2 do sistema LS :: M1 não tem inversa! " << endl;
    }

    if(DEBUG_LS){
        cout << "\n\n:::::: SOLVE_SECOND_PART_SYSTEM ======================================" << endl;
        flint_printf("\nP'_d = \n"); fmpq_mat_print(P1);
        flint_printf("\nM1 * P'_d = \n"); fmpq_mat_print(aux1);
        flint_printf("\n\nM3 * G = \n"); fmpq_mat_print(aux2);
        flint_printf("\n\nM1 * P'_d - M3 * G = \n"); fmpq_mat_print(aux3);
        flint_printf("\n\nP2 = \n"); fmpq_mat_print(P2);
    }
}
