/*******************************************************************************
* Copyright 2024 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::syevd
*       to solve Symmetric Eigenvalue Problem on a SYCL device (CPU, GPU).
*
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*
*
*******************************************************************************/

// stl includes
#include <iostream>
#include <complex>
#include <vector>

#include <sycl/sycl.hpp>
#include "oneapi/mkl/lapack.hpp"
#include "mkl.h"

// local includes
#include "common_for_examples.hpp"


//
// SEP example consists of
// 1) initialization of symmetric dense A matrix
// 2) workspace query
// 3) problem solve: A = Q * Lambda * Q'
// Finally the results are post processed.
//
template <typename data_t>
void run_syevd_example(const sycl::device& dev) {
    // Matrix data sizes and leading dimension
    oneapi::mkl::job jobz = oneapi::mkl::job::vec;
    oneapi::mkl::uplo uplo = oneapi::mkl::uplo::lower;
    std::int64_t n   = 24;
    std::int64_t lda = 33;

    // Variable holding status of calculations
    std::int64_t info = 0;

    // Asynchronous error handler
    auto error_handler = [&] (sycl::exception_list exceptions) {
        for (auto const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(oneapi::mkl::lapack::exception const& e) {
                // Handle LAPACK related exceptions happened during asynchronous call
                info = e.info();
                std::cout << "Unexpected exception caught during asynchronous LAPACK operation:\n" << e.what() << "\ninfo: " << e.info() << std::endl;
            } catch(sycl::exception const& e) {
                // Handle not LAPACK related exceptions happened during asynchronous call
                std::cout << "Unexpected exception caught during asynchronous operation:\n" << e.what() << std::endl;
                info = -1;
            }
        }
    };

    // Create execution queue for selected device
    sycl::queue queue(dev, error_handler);

    // Allocate matrices
    std::int64_t A_size = lda * n;
    std::vector<data_t> A(A_size);
    std::int64_t w_size = n;
    std::vector<data_t> w(w_size);

    // Initialize matrix A
    rand_matrix(A, oneapi::mkl::transpose::nontrans, n, n, lda);

    try {
        // Get sizes of scratchpad for calculations
        std::int64_t syevd_scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size<data_t>(queue, jobz, uplo, n, lda);

        // Allocate memory on device
        data_t* A_dev = sycl::aligned_alloc_device<data_t>(64, A_size, queue);
        data_t* w_dev = sycl::aligned_alloc_device<data_t>(64, w_size, queue);
        data_t* syevd_scratchpad = sycl::aligned_alloc_device<data_t>(64, syevd_scratchpad_size, queue);
    
        // Copy input data on device
        queue.copy(A.data(), A_dev, A_size).wait();

        // Solve SEP
        oneapi::mkl::lapack::syevd(queue, jobz, uplo, n, A_dev, lda, w_dev, syevd_scratchpad, syevd_scratchpad_size).wait_and_throw();

        // Copy output data on host
        queue.copy(A_dev, A.data(), A_size).wait();
        queue.copy(w_dev, w.data(), w_size).wait();
    } catch(oneapi::mkl::lapack::exception const& e) {
        // Handle LAPACK related exceptions happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to LAPACK API:\nreason: " << e.what() << "\ninfo: " << e.info() << std::endl;
        info = e.info();
    } catch(sycl::exception const& e) {
        // Handle not LAPACK related exceptions happened during synchronous call
        std::cout << "Unexpected exception caught during synchronous call to SYCL API:\n" << e.what() << std::endl;
        info = -1;
    }

    std::cout << "syevd " << ((info == 0) ? "ran OK" : "FAILED") << std::endl;

    return;
}


//
// Description of example setup, apis used and supported floating point type precisions
//
void print_example_banner() {

    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Symmetric Eigenvalue Problem (Divide-and-Conquer) Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Computes Eigenvalues and Eigenvectors: A = Q * Lambda * Q'" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a symmetric dense matrix." << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   syevd" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;

}


//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, SEP example
//  runs with all supported data types
//
int main(int argc, char **argv) {

    print_example_banner();

    // find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    for(auto &deviceType : listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if(myDevIsFound) {
            std::cout << std::endl << "Running syevd examples on " << sycl_device_names[deviceType] << "." << std::endl;

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            run_syevd_example<float>(myDev);

            if (isDoubleSupported(myDev)) {
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                run_syevd_example<double>(myDev);
            } else {
                std::cout << "\tDouble precision not supported on this device " << std::endl;
                std::cout << std::endl;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
        }
    }
    mkl_free_buffers();
    return 0;
}
