/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
! An example of SINGLE-precision batch real-to-complex out-of-place 2D FFT on 
! a (GPU) device using the OpenMP target (offload) interface of oneMKL DFTI
!******************************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <float.h>
#include <omp.h>

#include "mkl_dfti_omp_offload.h"

static void init_r(float *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, MKL_LONG REAL_DIST, MKL_LONG H1, MKL_LONG H2);
static int verify_c(MKL_Complex8 *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, MKL_LONG CMPLX_DIST, MKL_LONG H1, MKL_LONG H2);
static void init_c(MKL_Complex8 *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, MKL_LONG CMPLX_DIST, MKL_LONG H1, MKL_LONG H2);
static int verify_r(float *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, MKL_LONG REAL_DIST, MKL_LONG H1, MKL_LONG H2);

// Define the format to printf MKL_LONG values
#if !defined(MKL_ILP64)
#define LI "%li"
#else
#define LI "%lli"
#endif

int main(void)
{
    const int devNum = 0;

    // Size of 2D FFT
    MKL_LONG N1 = 64, N2 = 32;
    MKL_LONG N[2] = {N2, N1};

    const MKL_LONG REAL_DIST = N2*N1;
    const MKL_LONG CMPLX_DIST = N2*(N1/2+1);
    
    // Number of transforms
    const MKL_LONG BATCH = 2;

    // Arbitrary harmonic used to verify FFT
    MKL_LONG H1 = -1, H2 = 2;

    MKL_LONG status = 0;

    // Pointers to input and output data
    float        *data_real     = NULL;
    MKL_Complex8 *data_cmplx    = NULL;

    DFTI_DESCRIPTOR_HANDLE descHandle = NULL;

    // Strides describe data layout in real and conjugate-even domain
    MKL_LONG rs[3] = {0, N1, 1};
    MKL_LONG cs[3] = {0, N1/2+1, 1};

    printf("DFTI_LENGTHS                  = {" LI ", " LI "}\n", N2, N1);
    printf("DFTI_PLACEMENT                = DFTI_NOT_INPLACE\n");
    printf("DFTI_CONJUGATE_EVEN_STORAGE   = DFTI_COMPLEX_COMPLEX\n");
    printf("DFTI_NUMBER_OF_TRANSFORMS     = {" LI "}\n", BATCH);

    printf("Create DFTI descriptor\n");
    status = DftiCreateDescriptor(&descHandle, DFTI_SINGLE, DFTI_REAL, 2, N);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set configuration: out-of-place\n");
    status = DftiSetValue(descHandle, DFTI_PLACEMENT, DFTI_NOT_INPLACE);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set configuration: CCE storage\n");
    status = DftiSetValue(descHandle,
                          DFTI_CONJUGATE_EVEN_STORAGE, DFTI_COMPLEX_COMPLEX);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set input strides = ");
    printf("{" LI ", " LI ", " LI "}\n", rs[0], rs[1], rs[2]);
    status = DftiSetValue(descHandle, DFTI_INPUT_STRIDES, rs);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set output strides = ");
    printf("{" LI ", " LI ", " LI "}\n", cs[0], cs[1], cs[2]);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, cs);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set input distance = " LI "\n", REAL_DIST);
    status = DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, REAL_DIST);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set output distance = " LI "\n", CMPLX_DIST);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, CMPLX_DIST);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set number of transforms = " LI "\n", BATCH);
    status = DftiSetValue(descHandle, DFTI_NUMBER_OF_TRANSFORMS, BATCH);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Commit descriptor\n");
#pragma omp dispatch device(devNum)
    status = DftiCommitDescriptor(descHandle);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Allocate data arrays\n");

    data_real = (float*)mkl_malloc(BATCH*REAL_DIST*sizeof(float), 64);
    data_cmplx = (MKL_Complex8*)mkl_malloc(BATCH*CMPLX_DIST*sizeof(MKL_Complex8), 64);
    if (data_real == NULL || data_cmplx == NULL) goto failed;

    printf("Initialize data for r2c FFT\n");
    init_r(data_real, BATCH, N1, N2, REAL_DIST, H1, H2);

    printf("Compute real-to-complex FFT\n");
#pragma omp target data map(to:data_real[0:BATCH*REAL_DIST]) \
                        map(from:data_cmplx[0:BATCH*CMPLX_DIST]) device(devNum)
    {
// Use need_device_ptr clause for out of place computation because
// DftiComputeForward is a variadic function where the out of place
// output is not explicit in the function declaration.
// The argument to the need_device_ptr clause is the one-based index
// of the pointer in the dispatched function's argument list.
// The input pointer is explicit in the function declaration, so the
// need_device_ptr clause is optional for it. That is, either
// need_device_ptr(2,3), referencing dataGPU_real and dataGPU_cmplx, or
// need_device_ptr(3), referencing just dataGPU_cmplx, will work.
#pragma omp dispatch device(devNum) need_device_ptr(2,3)
        status = DftiComputeForward(descHandle, data_real, data_cmplx);
    }
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Verify the result\n");
    status = verify_c(data_cmplx, BATCH, N1, N2, CMPLX_DIST, H1, H2);
    if (status != 0) goto failed;
    
    printf("Initialize data for c2r FFT\n");
    init_c(data_cmplx, BATCH, N1, N2, CMPLX_DIST, H1, H2);
    
    printf("Set input strides = ");
    printf("{" LI ", " LI ", " LI "}\n", cs[0], cs[1], cs[2]);
    status = DftiSetValue(descHandle, DFTI_INPUT_STRIDES, cs);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set output strides = ");
    printf("{" LI ", " LI ", " LI "}\n", rs[0], rs[1], rs[2]);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_STRIDES, rs);
    if (status != DFTI_NO_ERROR) goto failed;
    
    printf("Set input distance = " LI "\n", CMPLX_DIST);
    status = DftiSetValue(descHandle, DFTI_INPUT_DISTANCE, CMPLX_DIST);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Set output distance = " LI "\n", REAL_DIST);
    status = DftiSetValue(descHandle, DFTI_OUTPUT_DISTANCE, REAL_DIST);
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Commit descriptor\n");
#pragma omp dispatch device(devNum)
    status = DftiCommitDescriptor(descHandle);
    if (status != DFTI_NO_ERROR) goto failed;
    

    printf("Compute complex-to-real FFT\n");
#pragma omp target data map(to:data_cmplx[0:BATCH*CMPLX_DIST]) \
                        map(from:data_real[0:BATCH*REAL_DIST]) device(devNum)
    {
// Use need_device_ptr clause for out of place computation because
// DftiComputeBackward is a variadic function where the out of place
// output is not explicit in the function declaration.
// The argument to the need_device_ptr clause is the one-based index
// of the pointer in the dispatched function's argument list.
// The input pointer is explicit in the function declaration, so the
// need_device_ptr clause is optional for it. That is, either
// need_device_ptr(2,3), referencing dataGPU_real and dataGPU_cmplx, or
// need_device_ptr(3), referencing just dataGPU_cmplx, will work.
#pragma omp dispatch device(devNum) need_device_ptr(2,3)
        status = DftiComputeBackward(descHandle, data_cmplx, data_real);
    }
    if (status != DFTI_NO_ERROR) goto failed;

    printf("Verify the result\n");
    status = verify_r(data_real, BATCH, N1, N2, REAL_DIST, H1, H2);
    if (status != 0) goto failed;
    

 cleanup:

    printf("Free DFTI descriptor\n");
    DftiFreeDescriptor(&descHandle);

    printf("Free data arrays\n");
    mkl_free(data_real);
    mkl_free(data_cmplx);

    {
        printf("TEST %s\n", status == 0 ? "PASSED" : "FAILED");
        return status;
    }

 failed:
    printf(" ERROR, status = " LI "\n", status);
    goto cleanup;
}

// Compute (K*L)%M accurately
static float moda(MKL_LONG K, MKL_LONG L, MKL_LONG M)
{
    return (float)(((long long)K * L) % M);
}

// Initialize array data(N) to produce unit peaks at data(H) and data(N-H)
static void init_r(float *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, 
                   MKL_LONG REAL_DIST, MKL_LONG H1, MKL_LONG H2)
{
    float TWOPI = 6.2831853071795864769f, phase, factor;
    MKL_LONG m, n1, n2, index;

    // Generalized strides for row-major addressing of data
    MKL_LONG S1 = 1, S2 = N1;

    factor = (2*(N1-H1)%N1==0 && 2*(N2-H2)%N2==0) ? 1.0f : 2.0f;
    for (m = 0; m < M; m++) {
        for (n2 = 0; n2 < N2; n2++) {
            for (n1 = 0; n1 < N1; n1++) {
                phase  = moda(n1, H1, N1) / N1;
                phase += moda(n2, H2, N2) / N2;
                index = n2*S2 + n1*S1;
                data[m*REAL_DIST + index] = factor * cosf(TWOPI * phase) / (N2*N1);
            }
        }
    }
}

// Verify that x has unit peak at H
static int verify_c(MKL_Complex8 *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, 
                    MKL_LONG CMPLX_DIST, MKL_LONG H1, MKL_LONG H2)
{
    float err, errthr, maxerr;
    MKL_LONG m, n1, n2, index;

    // Generalized strides for row-major addressing of data
    MKL_LONG S1 = 1, S2 = N1/2+1;

    errthr = 2.5f * logf((float) N2*N1) / logf(2.0f) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    maxerr = 0.0f;
    for (m = 0; m < M; m++) {
        for (n2 = 0; n2 < N2; n2++) {
            for (n1 = 0; n1 < N1/2+1; n1++) {
                float re_exp = 0.0f, im_exp = 0.0f, re_got, im_got;

                if ((( n1-H1)%N1==0 && ( n2-H2)%N2==0) ||
                    ((-n1-H1)%N1==0 && (-n2-H2)%N2==0)
                ) {
                    re_exp = 1.0f;
                }

                index = m*CMPLX_DIST + n2*S2 + n1*S1;
                re_got = data[index].real;
                im_got = data[index].imag;
                err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
                if (err > maxerr) maxerr = err;
                if (!(err < errthr)) {
                    printf(" Batch # " LI " data[" LI "][" LI "]: ", m, n2, n1);
                    printf(" expected (%.7g,%.7g), ", re_exp, im_exp);
                    printf(" got (%.7g,%.7g), ", re_got, im_got);
                    printf(" err %.3lg\n", err);
                    printf(" Verification FAILED\n");
                    return 1;
                }
            }
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}

static void init_c(MKL_Complex8 *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2,
                   MKL_LONG CMPLX_DIST, MKL_LONG H1, MKL_LONG H2)
{
    float TWOPI = 6.2831853071795864769f, phase;
    MKL_LONG m, n1, n2, index;
    // Generalized strides for row-major addressing of data
    MKL_LONG S1 = 1, S2 = N1/2+1;
    
    for (m = 0; m < M; ++m){
        for (n2 = 0; n2 < N2; n2++) {
            for (n1 = 0; n1 < N1/2+1; n1++) {
                phase  = moda(n1, H1, N1) / N1 + moda(n2, H2, N2) / N2;
                index = m*CMPLX_DIST + n2*S2 + n1*S1;
                data[index].real =  cosf(TWOPI * phase) / (N2*N1);
                data[index].imag = -sinf(TWOPI * phase) / (N2*N1);
            }
        }
    }
}

static int verify_r(float *data, MKL_LONG M, MKL_LONG N1, MKL_LONG N2, 
                    MKL_LONG REAL_DIST, MKL_LONG H1, MKL_LONG H2)
{
    float err, errthr, maxerr;
    MKL_LONG m, n1, n2, index;
    // Generalized strides for row-major addressing of data
    MKL_LONG S1 = 1, S2 = N1;
    
    errthr = 2.5f * logf((float) N1*N2) / logf(2.0f) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3lg\n", errthr);

    maxerr = 0.0f;
    for (m = 0; m < M; ++m){
        for (n2 = 0; n2 < N2; n2++){
            for (n1 = 0; n1 < N1; n1++) {
                float re_exp = 0.0f, re_got;
    
                if (((n1 - H1) % N1 == 0) && ((n2 - H2) % N2 == 0)){
                    re_exp = 1.0f;
                }
                
                index = m*REAL_DIST + n2*S2 + n1*S1;
                re_got = data[index];
                err  = fabsf(re_got - re_exp);
                if (err > maxerr) maxerr = err;
                if (!(err < errthr)) {
                    printf(" Batch #" LI " data[" LI "][" LI "]: ", m, n2, n1);
                    printf(" expected %.7g, ", re_exp);
                    printf(" got %.7g, ", re_got);
                    printf(" err %.3lg\n", err);
                    printf(" Verification FAILED\n");
                    return 1;
                }
            }
        }
    }
    printf(" Verified,  maximum error was %.3lg\n", maxerr);
    return 0;
}
