/*******************************************************************************
* Copyright (C) 2024 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file BenchCustomKernel.cpp
 HPCG routine
 */


#include <sycl/sycl.hpp>
#include <iomanip>
#include <fstream>
#include <iostream>
#include <vector>
#include "hpcg.hpp"
#include "WriteProblem.hpp"
#include "mytimer.hpp"

#include "BenchCustomKernels.hpp"
#include "kernels/axpby_kernel.hpp"

#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"
#include <mpi.h>
#include <cstdlib>
#endif

#include <cmath>  // for convert_to_gflops()
#include <cstdio> // for convert_to_gflops()

#include "CustomKernels.hpp"
#include "ComputeDotProduct.hpp"
#include "ComputeSPMV.hpp"
#include "ComputeSYMGS.hpp"
#include "PrefixSum.hpp"

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif


// =====================================================================
// ===========  Master Switches for Performance Test Runs  =============
// =====================================================================

// perf warmup/run counts
#define WARMUP_RUNS 2
#define TIMED_RUNS 10

// select which functionality to run performance testing on:

#define TEST_AXPBY
#define TEST_DOT
#define TEST_PREFIX_SUM
#define TEST_SPGEMV
#define TEST_SPGEMV_DOT
#define TEST_SPTRMVL
#define TEST_SPTRMVU
#define TEST_SPTRSVL
#define TEST_SPTRSVU
#define TEST_SPTRSVL_FUSED
#define TEST_SPTRSVU_FUSED
#define TEST_SYMGS
#define TEST_SYMGS_MV

// =====================================================================
// =====================================================================


void print_header(const int rank)
{
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
    if (rank == 0) {
        printf("\t%4s %-4s%-13s%-6s %-10s %-7s %-5s\n", "rank", "", "Functionality", "", "ave_time", "GFLOPs", "GiB/s");
        printf("\t=========================================================\n");
    }
#ifndef HPCG_NO_MPI
        MPI_Barrier(MPI_COMM_WORLD);
#endif
}

void print_gflops(const int rank, const int size, const std::string &func, const double ave_time, const double gflops)
{
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
    for (int r = 0; r < size; ++r) {
        if (rank == r) {
            printf("\t %2d: %-22s %0.8f, %3.2f\n", rank, func.c_str(), ave_time, gflops); // with headers
        }
#ifndef HPCG_NO_MPI
        MPI_Barrier(MPI_COMM_WORLD);
#endif
    }
}

void print_gflops_gibs(const int rank, const int size, const std::string &func, const double ave_time, const double gflops, const double gibs)
{
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
    for (int r = 0; r < size; ++r) {
        if (rank == r) {
            //printf("\trank %2d: %-30s ave_time = %0.8f, GFLOPS = %3.2f, GiB/sec = %3.2f\n", rank, func.c_str(), ave_time, gflops, gibs);
            printf("\t %2d: %-22s %0.8f, %3.2f, %3.2f\n", rank, func.c_str(), ave_time, gflops, gibs);
        }
#ifndef HPCG_NO_MPI
        MPI_Barrier(MPI_COMM_WORLD);
#endif
    }
}


int BenchCustomKernels(SparseMatrix &A, Vector &b, Vector &x, const int rank, const int size, sycl::queue &queue)
{
#ifdef USE_PRINTF
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
    printf("[rank %d]: bench_custom_kernels begin \n", rank);fflush(0);
#endif

    sycl::event ev_test, ev_ref, ev_update, ev_run;

    double start_time = 0.0, wall_time = 0.0, ave_time = 0.0, gflops = 0.0, gmemops = 0.0;

    const local_int_t nRows = A.localNumberOfRows;
    const local_int_t nCols = A.localNumberOfColumns;
    struct optData *optData = (struct optData *)A.optimizationData;

    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;

    Vector r, w, y, y1, z, Ay;

    // vectors:  nCols x 1 
    InitializeVectorDevice(r, nCols, queue);
    InitializeVectorDevice(w, nCols, queue);
    InitializeVectorDevice(y, nCols, queue);
    InitializeVectorDevice(y1, nCols, queue);
    InitializeVectorDevice(z, nCols, queue);

    // vectors: nRows x 1     = [nRows x nCols] [nCols x 1]
    // use nCols as nCols >= nRows to make sure we have room
    InitializeVectorDevice(Ay, nCols, queue);

    double *rv = r.values;
    double *wv = w.values;
    double *yv = y.values;
    double *y1v = y1.values;
    double *zv = z.values;
    double *Ayv = Ay.values;


#ifdef TEST_PREFIX_SUM
    // for prefix_sum
    local_int_t *int_nrows_dev = (local_int_t *)sparse_malloc_device(nRows * sizeof(local_int_t), queue);
    local_int_t *int_nrowsp1_dev = (local_int_t *)sparse_malloc_device((nRows + 1) * sizeof(local_int_t), queue);
    if (int_nrows_dev == NULL || int_nrowsp1_dev == NULL ) {
        std::cerr << "rank " << rank << ": error in TestCustomKernels local_int_t array allocation" << std::endl;
        return 1;
    }
#endif

    int ierr;
    double *fp_dev = (double *)sparse_malloc_device(1 * sizeof(double), queue);
    double *fp_host = (double *)sparse_malloc_host(1 * sizeof(double), queue);
    local_int_t *tmp_dev = (local_int_t *)sparse_malloc_device(1 * sizeof(local_int_t), queue);
    local_int_t *tmp_host = (local_int_t *)sparse_malloc_host(1 * sizeof(local_int_t), queue);
    double *tmp2_dev = (double *)sparse_malloc_device(nRows * sizeof(double), queue);
    if ( fp_dev == NULL || tmp_dev == NULL || tmp_host == NULL || tmp2_dev == NULL){
        std::cerr << "rank " << rank << ": error in TestCustomKernels allocation" << std::endl;
        return 1;
    }

    queue.fill(rv, 0.01, nCols).wait();
    queue.fill(wv, 0.01, nCols).wait();
    queue.fill(yv, 0.0, nCols).wait();
    queue.fill(y1v, 0.0, nCols).wait();
    queue.fill(zv, 0.0, nCols).wait();
    queue.fill(Ayv, 0.01, nCols).wait();

#ifdef TEST_PREFIX_SUM
    queue.fill(int_nrows_dev, 1, nRows).wait();
    queue.fill(int_nrowsp1_dev, 1, nRows).wait();
#endif

    queue.fill(tmp_dev, 0, 1).wait();

    // base nnz size for perf measurements
    const double nnz = static_cast<double>(nRows * 27.0); // use 27 point stencil on each row even though boundaries would have less...

    // generic scalars for both functional and performance tests
    double alpha = 1.23, beta = 4.56;


    // ============================================================================================================
    //  Custom Kernel Performance suites
    // ============================================================================================================


    double n_flops = 0;
    double n_memops = 0;
    double n_kflops = 0;
    double n_mflops = 0;
    double GFLOPS = 1e-9;
    double MFLOPS = 1e-6;
    double KFLOPS = 1e-3;

    auto convert_to_gflops = [=](double n_flops, double ave_time_in_seconds) {
        double ticks = n_flops / ave_time_in_seconds;
//        return ticks * 1e-9;
        return fma(0x1.12e0be826d695p-30, ticks, fma(-0x1.34674bfabb83bp-84, ticks, 0)); // divide more accurately by 1e9
    };

    if (rank == 0) {
        std::cout << "Starting Performance Suites: reporting ave_time (sec) for " << TIMED_RUNS << " runs after " << WARMUP_RUNS << " warmup runs" << std::endl;
    }


    print_header(rank);

    //
    // ----------------------------   AXPBY performance ---------------------------------------------
    //


#ifdef TEST_AXPBY
    //
    // Test AXPBY custom
    // yv = alpha * wv + beta * yv

    // n_flops for AXPBY (3 * nRows flops)
    n_flops = 3 * nRows;

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_run);
            // AXPBY esimd kernel parameters
            constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
            const local_int_t nWG = 8;
            constexpr local_int_t uroll = 4;
            local_int_t nBlocks = ceil_div(nCols, uroll * block_size);

            auto kernel = [=] (sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                axpby_body<block_size, uroll>(item, wv, yv, alpha, beta, nCols, nBlocks);
            };
            cgh.parallel_for<class axbpy_esimd_warmup>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
        });

    }
    ev_run.wait();

    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = queue.submit([&](sycl::handler &cgh) {
            cgh.depends_on(ev_run);
            // AXPBY esimd kernel parameters
            constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
            const local_int_t nWG = 8;
            constexpr local_int_t uroll = 4;
            local_int_t nBlocks = ceil_div(nCols, uroll * block_size);

            auto kernel = [=] (sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                axpby_body<block_size, uroll>(item, wv, yv, alpha, beta, nCols, nBlocks);
            };
            cgh.parallel_for<class axbpy_esimd_run>(sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), kernel);
        });

    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "AXPBY", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
#endif // TEST_AXPBY


    //
    // ----------------------------   DOT performance ---------------------------------------------
    //

#ifdef TEST_DOT
    //
    // Test Dot custom
    //

    // n_flops for Dot (2 * nRows - 1 flops)
    n_flops = 2 * nRows - 1;

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = ComputeDotProductLocal(nRows, r, w, fp_dev, queue, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = ComputeDotProductLocal(nRows, r, w, fp_dev, queue, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "DotProduct", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
#endif // TEST_DOT


    //
    // ----------------------------  Prefix Sum performance ---------------------------------------
    //

#ifdef TEST_PREFIX_SUM
    //
    // Test Prefix Sum custom
    //
    
    // set array values to 0 so we don't get overflow on repeated calls, kernel can't
    // exploit 0 values as it doesn't know that we are passing in 0's
    queue.fill(int_nrowsp1_dev, 0, nRows+1).wait();

    // n_memops for Prefix Sum of length N is (2 * N - 1 memops)
    n_memops = 2 * (nRows+1) - 1;

    // n_flops for Prefix Sum of length N is ( N - 1 flops)
    n_flops = (nRows+1) - 1;
    

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = prefix_sum(queue, nRows+1, int_nrowsp1_dev, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = prefix_sum(queue, nRows+1, int_nrowsp1_dev, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    gmemops = convert_to_gflops(n_memops, ave_time);
    print_gflops_gibs(rank, size, "PrefixSum", ave_time, gflops, gmemops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

    // clean up here to avoid excess memory usage
    sycl::free(int_nrows_dev, queue);
    sycl::free(int_nrowsp1_dev, queue);

#endif // TEST_PREFIX_SUM



    //
    // ----------------------------   SpGEMV performance ---------------------------------------------
    //

#ifdef TEST_SPGEMV

    //
    // Test SpGEMV on A + B custom  y = M * w
    //

    // n_flops for SpGEMV
    n_flops = 2 * nnz;

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpGEMV(queue, sparseM, wv, yv, {ev_run});
        //ev_run = ComputeSPMV(A, w, y, queue, ierr, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpGEMV(queue, sparseM, wv, yv, {ev_run});
        //ev_run = ComputeSPMV(A, w, y, queue, ierr, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpGEMV A+B", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPGEMV

    //
    // ----------------------------   SpGEMV+DOT performance ---------------------------------------------
    //

#ifdef TEST_SPGEMV_DOT

    //
    // Test SpGEMV Dot custom
    //

    // n_flops for SpGEMV + Dot (2 * nRows - 1 flops)
    n_flops = 2 * nnz + 2 * nRows - 1;

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = ComputeSPMV_DOT(A, w, y, fp_dev, queue, ierr, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = ComputeSPMV_DOT(A, w, y, fp_dev, queue, ierr, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;
 
    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpGEMV Dot", ave_time, gflops);


#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
#endif // TEST_SPGEMV_DOT

    //
    // ----------------------------   SpTRMV performance ---------------------------------------------
    //


#ifdef TEST_SPTRMVL
    //
    // Test custom::SpTRMV() lower_update
    //

    // n_flops for trmv lower diagonal
    n_flops = 2.0 * (nnz - nRows) / 2.0 + nRows;  // y = (L+B)*x+y

    ZeroVector(y, queue, {}).wait();

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRMV(queue, sparseM, custom::uplo::lower_update, wv, rv, yv, y1v, {ev_run});
    }
    ev_run.wait();

    ZeroVector(y, queue, {}).wait();

    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRMV(queue, sparseM, custom::uplo::lower_update, wv, rv, yv, y1v, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRMV L+B update", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRMVL



#ifdef TEST_SPTRMVU
    //
    // Test custom::SpTRMV() upper_nonlocal  fused  y=r-(U+B)*w   fused with y1=U*w
    //

    // n_flops for trmv upper no diagonal
    //n_kflops = (2 * ((nnz - nRows) / 2 ) + nRows) / 1e3;
    n_flops = 2.0 * (nnz - nRows) / 2.0  + nRows;   // y=r-(U+B)*w fused with y1=U*w 


    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRMV(queue, sparseM, custom::uplo::upper_nonlocal, wv, rv, yv, y1v, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRMV(queue, sparseM, custom::uplo::upper_nonlocal, wv, rv, yv, y1v, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    //gflops = static_cast<double>(n_kflops) * MFLOPS  / ave_time;
    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRMV r-(U+B)", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif


#endif // TEST_SPTRMVU



    //
    // ----------------------------   SpTRSV performance ---------------------------------------------
    //

    // n_flops for trsv
    //n_kflops = 2 * ((nnz - nRows) / 2 + nRows) / 1e3;
    n_flops = 2.0 * (nnz + nRows) / 2.0 ; // (L+D)*y = w


#ifdef TEST_SPTRSVL
    //
    // Test SpTRSV Lower_diagonal on A
    //
    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRSV(queue, sparseM, custom::uplo::lower_diagonal, wv, yv, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRSV(queue, sparseM, custom::uplo::lower_diagonal, wv, yv, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRSV FWD", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVL



#ifdef TEST_SPTRSVU
    //
    // Test SpTRSV Upper_diagonal
    //
    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRSV(queue, sparseM, custom::uplo::upper_diagonal, wv, yv, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRSV(queue, sparseM, custom::uplo::upper_diagonal, wv, yv, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRSV BWD", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVU




#ifdef TEST_SPTRSVL_FUSED

    //
    // Test TRSV Lower Fused   (L+D)*y_out = w_in;   fused with w_out = y_in+D*y_out
    //

    // n_flops for trsv fused 
    n_flops = 2.0 * (nnz + nRows) / 2.0 + nRows; // (L+D)*y = w

    queue.memcpy(Ayv, wv, sizeof(double) * nRows).wait();
    queue.fill(yv, 0.1, nCols).wait();

    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::lower_diagonal, Ayv, yv, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::lower_diagonal, Ayv, yv, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRSV_FUSED FWD", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVL_FUSED




#ifdef TEST_SPTRSVU_FUSED

    //
    // Test TRSV Upper Fused   (L+D)* y_out = w_in;   fused with w_out = D*w_in
    //
    
    // n_flops for trsv fused 
    n_flops = 2.0 * (nnz + nRows) / 2.0 + nRows; // (L+D)*y = w

    queue.memcpy(Ayv, wv, sizeof(double) * nRows).wait();

    // Ayv = wv 
    // solve (D+U) * yv_out = Ayv_in; and  Ayv_out = D*Ayv_in
    
    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::upper_diagonal, Ayv, yv, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = custom::SpTRSV_FUSED(queue, sparseM, custom::uplo::upper_diagonal, Ayv, yv, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SpTRSV_FUSED BWD", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SPTRSVU_FUSED





    //
    // ----------------------------   SYMGS performance ---------------------------------------------
    //


    // n_flops for SYMGS
    // optimized n_flops
    // n_flops = (3 * nnz + 5 * nRows);
    // reference n_flops
    n_flops = 4 * nnz + 2 * nRows;

#ifdef TEST_SYMGS
    //
    // Test SYMGS with custom kernels
    //
    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = run_SYMGS_custom(queue, A, optData, sparseM, w, y, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = run_SYMGS_custom(queue, A, optData, sparseM, w, y, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SYMGS", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SYMGS

    //
    // ----------------------------   SYMGS_MV performance ---------------------------------------------
    //

    // n_flops for SYMGS_MV
    // optimized n_flops
//    n_flops = (3 * nnz + 6 * nRows);
    // reference n_flops
    n_flops = 6 * nnz + 2 * nRows;

#ifdef TEST_SYMGS_MV
    //
    // Test SYMGS_MV with PTAP using custom kernels
    //
    ev_run = sycl::event();
    for (int run = 0; run < WARMUP_RUNS; ++run) {
        ev_run = run_SYMGS_MV_custom(queue, A, optData, sparseM, w, y, Ay, {ev_run});
    }
    ev_run.wait();


    start_time = mytimer();
    for (int run = 0; run < TIMED_RUNS; ++run) {
        ev_run = run_SYMGS_MV_custom(queue, A, optData, sparseM, w, y, Ay, {ev_run});
    }
    ev_run.wait();
    wall_time = mytimer() - start_time;
    ave_time = wall_time / TIMED_RUNS;

    gflops = convert_to_gflops(n_flops, ave_time);
    print_gflops(rank, size, "SYMGS_MV", ave_time, gflops);

#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif

#endif // TEST_SYMGS_MV

    // cleanup
    queue.wait();
    DeleteVector(r, queue);
    DeleteVector(w, queue);
    DeleteVector(y, queue);
    DeleteVector(y1, queue);
    DeleteVector(z, queue);
    DeleteVector(Ay, queue);

    sycl::free(fp_dev, queue);
    sycl::free(fp_host, queue);
    sycl::free(tmp_dev, queue);
    sycl::free(tmp_host, queue);
    sycl::free(tmp2_dev, queue);

#ifdef USE_PRINTF
#ifndef HPCG_NO_MPI
    MPI_Barrier(MPI_COMM_WORLD);
#endif
    printf("[rank %d]: bench_custom_kernels end \n", rank);fflush(0);
#endif

    return 0;
}
