/*******************************************************************************
* Copyright (C) 2020 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates usage of DPC++ buffer-based API for oneMKL RNG
*       continious distributions with oneapi::mkl::rng::default_engine
*       random number generator
*
*       Continious distribution list:
*           oneapi::mkl::rng::beta                     (available for CPU and GPU)
*           oneapi::mkl::rng::cauchy                   (available for CPU and GPU)
*           oneapi::mkl::rng::chi_square               (available for CPU and GPU)
*           oneapi::mkl::rng::exponential              (available for CPU and GPU)
*           oneapi::mkl::rng::gamma                    (available for CPU and GPU)
*           oneapi::mkl::rng::gaussian                 (available for CPU and GPU)
*           oneapi::mkl::rng::gaussian_mv              (available for CPU and GPU)
*           oneapi::mkl::rng::gumbel                   (available for CPU and GPU)
*           oneapi::mkl::rng::laplace                  (available for CPU and GPU)
*           oneapi::mkl::rng::lognormal                (available for CPU and GPU)
*           oneapi::mkl::rng::rayleigh                 (available for CPU and GPU)
*           oneapi::mkl::rng::weibull                  (available for CPU and GPU)
*
*       The supported floating point data types for random numbers are:
*           float
*           double
*
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <vector>
#include <string>

#include <sycl/sycl.hpp>
#include "oneapi/mkl.hpp"

// local includes
#include "../include/common_for_rng_examples.hpp"

// example parameters defines
constexpr std::size_t n = 2000;
constexpr std::size_t n_print = 10;

template <typename RealType, typename Distribution, bool is_validation_needed = true>
bool perform_generation_validation(sycl::queue& queue, const Distribution& distr,
                                   const std::string& distr_name, std::size_t buf_size = n) {
    // create a default generate
    oneapi::mkl::rng::default_engine engine(queue);

    // prepare array for random numbers
    sycl::buffer<RealType> r_buffer(buf_size);

    try {
        // call oneMKL generation
        oneapi::mkl::rng::generate(distr, engine, n, r_buffer);
        queue.wait_and_throw();
    }
    catch (sycl::exception const& e) {
        std::cout << "\t\tSYCL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl
                  << "Error code: " << e.code().value() << std::endl;
        return false;
    }
    catch (oneapi::mkl::exception const& e) {
        std::cout << "\toneMKL exception during oneapi::mkl::rng::generate() call\n"
                  << e.what() << std::endl;
        return false;
    }

    // print output
    std::cout << "\n\t\tOutput of generator with the " << distr_name <<" distribution:" << std::endl;
    print_output(r_buffer, n_print);

    // validation
    if constexpr (is_validation_needed) {
        auto r_acc = sycl::host_accessor(r_buffer, sycl::read_only);
        return check_statistics(r_acc.get_pointer(), buf_size, distr);
    }
    else
        return true;
}

template <typename RealType>
bool run_beta_example(sycl::queue& queue) {
    // set shape p
    RealType p(2.0);
    // set shape q
    RealType q(10.0);
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType b(1.0);

    oneapi::mkl::rng::beta<RealType> distribution(p, q, a, b);

    return perform_generation_validation<RealType>(queue, distribution, "beta");
}

template <typename RealType>
bool run_cauchy_example(sycl::queue& queue) {
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType b(1.0);

    oneapi::mkl::rng::cauchy<RealType> distribution(a, b);

    return perform_generation_validation<RealType, oneapi::mkl::rng::cauchy<RealType>,
                                         /*is_validate*/ false>(queue, distribution, "cauchy");
}

template <typename RealType>
bool run_chi_square_example(sycl::queue& queue) {
    // set degrees of freedom
    std::int32_t n = 10;

    oneapi::mkl::rng::chi_square<RealType> distribution(n);

    return perform_generation_validation<RealType>(queue, distribution, "chi square");
}

template <typename RealType>
bool run_exponential_example(sycl::queue& queue) {
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType beta(1.0);

    oneapi::mkl::rng::exponential<RealType> distribution(a, beta);

    return perform_generation_validation<RealType>(queue, distribution, "exponential");
}

template <typename RealType>
bool run_gamma_example(sycl::queue& queue) {
    // set shape
    RealType alpha(1.0);
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType beta(1.0);

    oneapi::mkl::rng::gamma<RealType> distribution(alpha, a, beta);

    return perform_generation_validation<RealType>(queue, distribution, "gamma");
}

template <typename RealType>
bool run_gaussian_example(sycl::queue& queue) {
    // set mean value
    RealType mean(0.0);
    // set standard deviation
    RealType stddev(1.0);

    oneapi::mkl::rng::gaussian<RealType> distribution(mean, stddev);

    return perform_generation_validation<RealType>(queue, distribution, "gaussian");
}

template <typename RealType>
bool run_gaussian_mv_example(sycl::queue& queue) {
    // set quantity of dimensions
    std::uint32_t dimen = 3;
    // set mean values for each dimension
    std::vector<RealType> mean = { 3.0, 5.0, 2.0 };
    // set lower triangular matrix T (by default packed layout is used), it can be
    // computed using Cholesky factorization to the source symmetric matrix.
    // Source symmetric matrix can be computed as T * T' where T' is transposed matrix
    // for this case source symmetric matrix is:
    //   16.0     8.0     4.0
    //    8.0    13.0    17.0
    //    4.0    17.0    62.0
    // where diagonal elements are dispersions and other elements are cross covariation coefficients
    std::vector<RealType> lower_triang_matrix = { 4.0, 2.0, 1.0, 3.0, 5.0, 6.0 };

    oneapi::mkl::rng::gaussian_mv<RealType> distribution(
        dimen, sycl::span{ mean.data(), mean.size() },
        sycl::span{ lower_triang_matrix.data(), lower_triang_matrix.size() });

    return perform_generation_validation<RealType>(queue, distribution, "gaussian_mv", n * dimen);
}

template <typename RealType>
bool run_gumbel_example(sycl::queue& queue) {
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType b(1.0);

    oneapi::mkl::rng::gumbel<RealType> distribution(a, b);

    return perform_generation_validation<RealType>(queue, distribution, "gumbel");
}

template <typename RealType>
bool run_laplace_example(sycl::queue& queue) {
    // set mean value
    RealType a(0.0);
    // set scalefactor
    RealType b(1.0);

    oneapi::mkl::rng::laplace<RealType> distribution(a, b);

    return perform_generation_validation<RealType>(queue, distribution, "laplace");
}

template <typename RealType>
bool run_lognormal_example(sycl::queue& queue) {
    // set average of the subject normal distribution
    RealType m(0.0);
    // set standard deviation of the subject normal distribution
    RealType s(1.0);
    // set displacement
    RealType displ(0.0);
    // set scalefactor
    RealType scale(1.0);

    oneapi::mkl::rng::lognormal<RealType> distribution(m, s, displ, scale);

    return perform_generation_validation<RealType>(queue, distribution, "lognormal");
}

template <typename RealType>
bool run_rayleigh_example(sycl::queue& queue) {
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType b(1.0);

    oneapi::mkl::rng::rayleigh<RealType> distribution(a, b);

    return perform_generation_validation<RealType>(queue, distribution, "rayleigh");
}

template <typename RealType>
bool run_weibull_example(sycl::queue& queue) {
    // set shape
    RealType alpha(1.0);
    // set displacement
    RealType a(0.0);
    // set scalefactor
    RealType beta(1.0);

    oneapi::mkl::rng::weibull<RealType> distribution(alpha, a, beta);

    return perform_generation_validation<RealType>(queue, distribution, "weibull");
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << "# Generate random numbers with continious rng distributions:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   oneapi::mkl::rng::beta" << std::endl;
    std::cout << "#   oneapi::mkl::rng::cauchy" << std::endl;
    std::cout << "#   oneapi::mkl::rng::chi_square" << std::endl;
    std::cout << "#   oneapi::mkl::rng::exponential" << std::endl;
    std::cout << "#   oneapi::mkl::rng::gamma" << std::endl;
    std::cout << "#   oneapi::mkl::rng::gaussian" << std::endl;
    std::cout << "#   oneapi::mkl::rng::gumbel" << std::endl;
    std::cout << "#   oneapi::mkl::rng::laplace" << std::endl;
    std::cout << "#   oneapi::mkl::rng::lognormal" << std::endl;
    std::cout << "#   oneapi::mkl::rng::rayleigh" << std::endl;
    std::cout << "#   oneapi::mkl::rng::weibull" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float double" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//

int main(int argc, char** argv) {
    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";
            sycl::queue queue(my_dev, exception_handler);
            std::cout << "\tRunning with single precision real data type:" << std::endl;
            if (!run_beta_example<float>(queue) || !run_cauchy_example<float>(queue) ||
                !run_chi_square_example<float>(queue) || !run_exponential_example<float>(queue) ||
                !run_gamma_example<float>(queue) || !run_gaussian_example<float>(queue) ||
                !run_gaussian_mv_example<float>(queue) || !run_gumbel_example<float>(queue) ||
                !run_laplace_example<float>(queue) || !run_lognormal_example<float>(queue) ||
                !run_rayleigh_example<float>(queue) || !run_weibull_example<float>(queue)) {
                std::cout << "FAILED" << std::endl;
                return 1;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices is enabled.\n";
            std::cout << "FAILED" << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }
    std::cout << "PASSED" << std::endl;
    return 0;
}
