/*******************************************************************************
* Copyright 2024 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*   Content : Intel(R) oneAPI Math Kernel Library (oneMKL) Sparse BLAS C OpenMP
*             offload example for mkl_sparse_s_trsm
*
********************************************************************************
*
* Consider the matrix A (see 'Sparse Storage Formats for Sparse BLAS Level 2
* and Level 3 in the  Intel oneMKL Reference Manual')
*
*                 |   1       -1      0   -3     0   |
*                 |  -2        5      0    0     0   |
*   A    =        |   0        0      4    6     4   |,
*                 |  -4        0      2    7     0   |
*                 |   0        8      0    0    -5   |
*
*  The matrix A is represented in a one-based compressed sparse row (CSR) storage
*  scheme with three arrays (see 'Sparse Matrix Storage Schemes' in the
*   Intel oneMKL Reference Manual) as follows:
*
*         values  =  ( 1 -1 -3 -2  5  4  6  4 -4  2  7  8 -5 )
*         columns =  ( 0  1  3  0  1  2  3  4  0  2  3  1  4 )
*         rowIndex = ( 0        3     5        8       11    13 )
*
*  The test performs the operation
*
*       alpha * L^{-1}*X = Y, using mkl_sparse_s_trsm
*
*  Here A is a general sparse matrix
*        L is the lower triangular part of matrix A with non-unit diagonal
*        X, Y and Z are dense matrices
*
********************************************************************************
*/
#include <stdio.h>
#include <assert.h>
#include <math.h>

#include "common_for_sparse_examples.h"
#include "mkl.h"
#include "mkl_omp_offload.h"

static void print_dense_matrix(const MKL_INT nRows,
                               const MKL_INT nCols,
                               const MKL_INT ld,
                               sparse_layout_t layout,
                               const float *mat)
{
    MKL_INT i = 0;
    MKL_INT j = 0;

    if (layout == SPARSE_LAYOUT_ROW_MAJOR) {
        for (i = 0; i < nRows; ++i) {
            for (j = 0; j < nCols; ++j) {
                printf("%7.1f, ", mat[i * ld + j]);
            }
            printf("\n");
        }
    }
    else {
        // column major layout
        for (i = 0; i < nRows; ++i) {
            for (j = 0; j < nCols; ++j) {
                printf("%7.1f, ", mat[i + j * ld]);
            }
            printf("\n");
        }
    }
}

int main() {
    //*******************************************************************************
    //     Declaration and initialization of parameters for sparse representation of
    //     the matrix A in the compressed sparse row format:
    //*******************************************************************************
#define M 5     // nRows of op(A) == nRows of Y
#define K 5     // nCols of op(A) == nRows of X
#define NNZ 13
#define NRHS 3  // nCols of X and Y

    // Descriptor of main sparse matrix properties
    struct matrix_descr descrA;

    // // Structure with sparse matrix stored in CSR format
    sparse_matrix_t csrA;
    //*******************************************************************************
    //    Sparse representation of the matrix A
    //*******************************************************************************
    float alpha = 2.0;
    MKL_INT i = 0;

    sparse_layout_t layout = SPARSE_LAYOUT_COLUMN_MAJOR;

    const MKL_INT ldx = K;
    const MKL_INT ldy = M;
    const MKL_INT x_size = ldx * NRHS;
    const MKL_INT y_size = ldy * NRHS;

    float *values      = (float *)mkl_malloc(sizeof(float) * NNZ, 64);
    MKL_INT *columns   = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * NNZ, 64);
    MKL_INT *row_index = (MKL_INT *)mkl_malloc(sizeof(MKL_INT) * (M + 1), 64);
    float *X           = (float *)mkl_malloc(sizeof(float) * x_size, 64);
    float *W           = (float *)mkl_malloc(sizeof(float) * x_size, 64);

    float *Y  = (float *)mkl_malloc(sizeof(float) * y_size, 64);
    float *Z1 = (float *)mkl_malloc(sizeof(float) * y_size, 64);
    float *Z2 = (float *)mkl_malloc(sizeof(float) * y_size, 64);

    const int num_pointers = 6;
    void *pointer_array[num_pointers];
    pointer_array[0] = values;
    pointer_array[1] = columns;
    pointer_array[2] = row_index;
    pointer_array[3] = X;
    pointer_array[4] = W;
    pointer_array[5] = Y;
    pointer_array[6] = Z1;
    pointer_array[7] = Z2;

    if (!values || !columns || !row_index || !X || !W || !Y || !Z1 || !Z2) {
        free_allocated_memories(pointer_array, num_pointers);
        return 1;
    }

    //*******************************************************************************
    //    Sparse representation of the matrix A
    //*******************************************************************************
    float init_values[NNZ] = {1.0f, -1.0f, -3.0f, -2.0f, 5.0f, 4.0f, 6.0f,
                              4.0f, -4.0f, 2.0f,  7.0f,  8.0f, -5.0f};
    MKL_INT init_columns[NNZ]     = {0, 1, 3, 0, 1, 2, 3, 4, 0, 2, 3, 1, 4};
    MKL_INT init_row_index[M + 1] = {0, 3, 5, 8, 11, 13};
    float init_x[K * NRHS]       = {1.0f, 5.0f, 1.0f, 4.0f, 1.0f,   // column-major layout.
                                    2.0f, 10.0f, 2.0f, 8.0f, 2.0f,
                                    3.0f, 15.0f, 3.0f, 12.0f, 3.0f};

    for (i = 0; i < NNZ; i++) {
        values[i]  = init_values[i];
        columns[i] = init_columns[i];
    }
    for (i = 0; i < M + 1; i++) {
        row_index[i] = init_row_index[i];
    }

    for (i   = 0; i < x_size; i++) {
        X[i] = init_x[i];
        W[i] = init_x[i];
    }

    for (i = 0; i < y_size; i++) {
        Y[i]  = 0.0f;
        Z1[i] = 0.0f;
        Z2[i] = 0.0f;
    }

    printf("\n EXAMPLE PROGRAM FOR mkl_sparse_s_trsm omp_offload \n");
    printf("-----------------------------------------------------\n");
    printf("\n" );
    printf("   INPUT DATA FOR mkl_sparse_s_trsm omp offload   \n");
    printf("   SPARSE_LAYOUT_COLUMN_MAJOR \n");
    printf("   UPPER TRIANGULAR SPARSE MATRIX \n");
    printf("   WITH UNIT DIAGONAL \n");
    printf("   ALPHA = %4.1f \n", alpha);
    printf("   SPARSE_OPERATION_NON_TRANSPOSE \n");
    printf("   Input matrix                   \n");
    print_dense_matrix(K, NRHS, ldx, layout, X);

    // Create matrix descriptor
    descrA.type = SPARSE_MATRIX_TYPE_TRIANGULAR;
    descrA.mode = SPARSE_FILL_MODE_UPPER;
    descrA.diag = SPARSE_DIAG_UNIT;

    sparse_status_t ie_status;

    // Create handle with matrix stored in CSR format
    ie_status = mkl_sparse_s_create_csr(&csrA, SPARSE_INDEX_BASE_ZERO,
                                        M, // number of rows
                                        K, // number of cols
                                        row_index, row_index + 1, columns, values);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_s_create_csr: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    ie_status = mkl_sparse_set_sm_hint(csrA, SPARSE_OPERATION_NON_TRANSPOSE, descrA, layout, NRHS, 1);
    if (ie_status != SPARSE_STATUS_SUCCESS && ie_status != SPARSE_STATUS_NOT_SUPPORTED) {
        printf(" Error in set_sm_hint: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    ie_status = mkl_sparse_optimize(csrA);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_optimize: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    // Compute Y = alpha * A^{-1} * X
    ie_status = mkl_sparse_s_trsm(SPARSE_OPERATION_NON_TRANSPOSE, alpha, csrA, descrA, layout, X,
                                  NRHS, ldx, Y, ldy);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_s_trsm: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    // Release matrix handle and deallocate matrix
    ie_status = mkl_sparse_destroy(csrA);
    if (ie_status != SPARSE_STATUS_SUCCESS) {
        printf(" Error in mkl_sparse_destroy: %d\n", ie_status);
        free_allocated_memories(pointer_array, num_pointers);
        return ie_status;
    }

    printf("\n");
    printf("   OUTPUT DATA FOR mkl_sparse_s_trsm \n");

    // Y should be equal { 36.0, 72.0, 108.0,
    //                     10.0, 20.0, 30.0,
    //                     -54.0, -108.0, -162.0,
    //                     8.0, 16.0, 24.0,
    //                     2.0, 4.0, 6.0 }
    print_dense_matrix(M, NRHS, ldy, layout, Y);
    printf("---------------------------------------------------\n");
    fflush(stdout);

    const int devNum = 0;

    sparse_matrix_t csrA_gpu;

    sparse_status_t status_create;
    sparse_status_t status_hint;
    sparse_status_t status_opt;
    sparse_status_t status_trsm1;
    sparse_status_t status_trsm2;
    sparse_status_t status_destroy;

// call create_csr/set_sm_hint/optimize/trsm/destroy via omp_offload.
#pragma omp target data map(to:row_index[0:M+1],columns[0:NNZ],values[0:NNZ],X[0:x_size],W[0:x_size]) \
                        map(tofrom:Z1[0:y_size],Z2[0:y_size]) device(devNum)
    {
        printf("Create CSR matrix via omp_offload\n"); fflush(0);

#pragma omp dispatch device(devNum)
        status_create = mkl_sparse_s_create_csr(&csrA_gpu, SPARSE_INDEX_BASE_ZERO, M, K,
                                                row_index, row_index + 1, columns, values);

        printf("mkl_sparse_set_sm_hint() ... \n"); fflush(0);
#pragma omp dispatch device(devNum) nowait
        status_hint = mkl_sparse_set_sm_hint(csrA_gpu, SPARSE_OPERATION_NON_TRANSPOSE, descrA, layout, NRHS, 1);
#pragma omp taskwait

        printf("mkl_sparse_optimize() ... \n"); fflush(0);
#pragma omp dispatch device(devNum) nowait
        status_opt = mkl_sparse_optimize(csrA_gpu);
#pragma omp taskwait

        printf("Compute mkl_sparse_s_trsm via omp_offload\n"); fflush(0);

#pragma omp dispatch device(devNum) nowait
        status_trsm1 = mkl_sparse_s_trsm(SPARSE_OPERATION_NON_TRANSPOSE, alpha, csrA_gpu, descrA, layout, X,
                                        NRHS, ldx, Z1, ldy);

#pragma omp dispatch device(devNum) nowait
        status_trsm2 = mkl_sparse_s_trsm(SPARSE_OPERATION_NON_TRANSPOSE, alpha, csrA_gpu, descrA, layout, W,
                                        NRHS, ldx, Z2, ldy);
#pragma omp taskwait
        printf("Destroy the CSR matrix via omp_offload\n"); fflush(0);

#pragma omp dispatch device(devNum)
        status_destroy = mkl_sparse_destroy(csrA_gpu);

        printf("End of mkl_sparse_destroy omp offload call.\n");
    }
    printf("End all offload function calls.\n"); fflush(stdout);

    int flps_per_value = 2 * ((NNZ / M) + 1);
    int status1        = 0;
    int status2        = 0;

    int status_offload = status_create | status_hint | status_opt | status_trsm1 | status_trsm2 | status_destroy;
    if (status_offload != 0) {
        printf("\tERROR: status_create = %d, status_hint = %d, status_opt = %d, "
               "status_trsm1 = %d, status_trsm2 = %d, status_destroy = %d\n",
               status_create, status_hint, status_opt, status_trsm1, status_trsm2, status_destroy);
        fflush(stdout);

        goto cleanup;
    }

    printf("   OUTPUT DATA FOR mkl_sparse_s_trsm_omp_offload \n");
    // Z1 & Z2 should be equal { 36.0, 72.0, 108.0,
    //                           10.0, 20.0, 30.0,
    //                           -54.0, -108.0, -162.0,
    //                           8.0, 16.0, 24.0,
    //                           2.0, 4.0, 6.0 }
    print_dense_matrix(M, NRHS, ldy, layout, Z1);
    printf("---------------------------------------------------\n");
    fflush(stdout);

    status1 = validation_result_float(Y, Z1, y_size, flps_per_value);
    status2 = validation_result_float(Y, Z2, y_size, flps_per_value);

cleanup:
    free_allocated_memories(pointer_array, num_pointers);

    const int status_all = status1 | status2 | status_offload;
    printf("Test %s\n", status_all == 0 ? "PASSED" : "FAILED");
    fflush(stdout);

    return status_all;
}
