/*******************************************************************************
* Copyright 2020-2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       DPCPP API oneapi::mkl::sparse::gemm to perform general sparse matrix-matrix
*       multiplication on a SYCL device (CPU, GPU). This example uses
*       row-major layout for the dense matrices B and C, and CSR format for
*       sparse matrix A.
*
*       C = alpha * op(A) * op(B) + beta * C
*
*       where op() is defined by one of
*           oneapi::mkl::transpose::{nontrans,trans,conjtrans}
*
*       The supported floating point data types for gemm are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*
*       The supported transpose operation op() for sparse matrix A are:
*           oneapi::mkl::transpose::nontrans
*           oneapi::mkl::transpose::trans
*           oneapi::mkl::transpose::conjtrans
*
*       The supported transpose operation op() for dense matrix B are:
*           oneapi::mkl::transpose::nontrans
*
*       The supported matrix formats for gemm are:
*           CSR
*           COO (currently only on CPU device)
*
*******************************************************************************/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"
#include "./include/reference_impls.hpp"

//
// Main example for Sparse Matrix-Dense Matrix Multiply consisting of
// initialization of A matrix, x and y vectors as well as
// scalars alpha and beta.  Then the product
//
// C = alpha * op(A) * op(B) + beta * C
//
// is performed and finally the results are post processed.
//
template <typename fp, typename intType>
int run_sparse_matrix_dense_matrix_multiply_example(const sycl::device &dev,
                                                    oneapi::mkl::transpose opA)
{
    // Initialize data for Sparse Matrix-Vector Multiply
    oneapi::mkl::transpose opB              = oneapi::mkl::transpose::nontrans;
    oneapi::mkl::layout dense_matrix_layout = oneapi::mkl::layout::row_major;
    oneapi::mkl::index_base index_base_val  = oneapi::mkl::index_base::zero;

    intType int_index = (index_base_val == oneapi::mkl::index_base::zero ? 0 : 1);

    // Matrix data size
    intType nrows_a      = 64;
    intType ncols_a      = 128;
    std::int64_t columns = 256;
    std::int64_t ldb     = columns;
    std::int64_t ldc     = columns;

    double density_val = 0.05;

    // Input matrix in CSR format
    std::vector<intType, mkl_allocator<intType, 64>> ia;
    std::vector<intType, mkl_allocator<intType, 64>> ja;
    std::vector<fp, mkl_allocator<fp, 64>> a;

    generate_random_sparse_matrix<fp, intType>(nrows_a, ncols_a, density_val, ia, ja, a, int_index);
    intType nnz_a = ia[nrows_a] - int_index;

    // Matrices b and c
    std::vector<fp, mkl_allocator<fp, 64>> b;
    std::vector<fp, mkl_allocator<fp, 64>> c;
    std::vector<fp, mkl_allocator<fp, 64>> c_ref;

    intType nrows_opa = (oneapi::mkl::transpose::nontrans == opA) ? nrows_a : ncols_a;
    intType ncols_opa = (oneapi::mkl::transpose::nontrans == opA) ? ncols_a : nrows_a;

    intType nrows_b = ncols_opa;
    intType ncols_b = columns;
    intType nrows_c = nrows_opa;
    intType ncols_c = columns;

    rand_matrix<std::vector<fp, mkl_allocator<fp, 64>>>(b, dense_matrix_layout,
                                                        nrows_b, ncols_b, ldb);
    b.resize(nrows_b * ldb);
    c.resize(nrows_c * ldc);
    c_ref.resize(nrows_c * ldc);

    // Init matrices c and d
    for (int i = 0; i < c.size(); i++) {
        c[i]     = set_fp_value(fp(0.0), fp(0.0));
        c_ref[i] = set_fp_value(fp(0.0), fp(0.0));
    }

    // Set scalar fp values
    fp alpha = set_fp_value(fp(2.0), fp(0.0));
    fp beta  = set_fp_value(fp(1.0), fp(0.0));

    // Catch asynchronous exceptions
    auto exception_handler = [](sycl::exception_list exceptions) {
        for (std::exception_ptr const &e : exceptions) {
            try {
                std::rethrow_exception(e);
            }
            catch (sycl::exception const &e) {
                std::cout << "Caught asynchronous SYCL "
                             "exception during sparse::gemm:\n"
                          << e.what() << std::endl;
            }
        }
    };

    //
    // Execute Matrix Multiply
    //

    std::cout << "\n\t\tsparse::gemm parameters:\n";
    std::cout << "\t\t\tdense_matrix_layout = " << dense_matrix_layout << std::endl;
    std::cout << "\t\t\topA                 = " << opA << std::endl;
    std::cout << "\t\t\topB                 = " << opB << std::endl;
    std::cout << "\t\t\tnrows               = " << nrows_a << std::endl;
    std::cout << "\t\t\tncols               = " << ncols_a << std::endl;
    std::cout << "\t\t\tnnz                 = " << nnz_a << std::endl;
    std::cout << "\t\t\tcolumns             = " << columns << std::endl;
    std::cout << "\t\t\tldb                 = " << ldb << std::endl;
    std::cout << "\t\t\tldc                 = " << ldc << std::endl;
    std::cout << "\t\t\talpha               = " << alpha << std::endl;
    std::cout << "\t\t\tbeta                = " << beta << std::endl;

    // create execution queue and buffers of matrix data
    sycl::queue main_queue(dev, exception_handler);

    sycl::buffer<intType, 1> ia_buffer(ia.data(), (nrows_a + 1));
    sycl::buffer<intType, 1> ja_buffer(ja.data(), nnz_a);
    sycl::buffer<fp, 1> a_buffer(a.data(), nnz_a);
    sycl::buffer<fp, 1> b_buffer(b.data(), b.size());
    sycl::buffer<fp, 1> c_buffer(c.data(), c.size());

    // create and initialize handle for a Sparse Matrix in CSR format
    oneapi::mkl::sparse::matrix_handle_t handle = nullptr;

    try {
        oneapi::mkl::sparse::init_matrix_handle(&handle);

        oneapi::mkl::sparse::set_csr_data(main_queue, handle, nrows_a, ncols_a, index_base_val,
                                          ia_buffer, ja_buffer, a_buffer);

        // add oneapi::mkl::sparse::gemm to execution queue
        oneapi::mkl::sparse::gemm(main_queue, dense_matrix_layout, opA, opB, alpha,
                                  handle, b_buffer, columns, ldb, beta, c_buffer, ldc);

        oneapi::mkl::sparse::release_matrix_handle(main_queue, &handle);

    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;

        main_queue.wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &handle);
        return 1;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;

        main_queue.wait();
        oneapi::mkl::sparse::release_matrix_handle(main_queue, &handle);
        return 1;
    }

    main_queue.wait();

    //
    // Post Processing
    //

    // Compute reference gemm result.
    // NOTE: Now we support only opB == nontrans case, so we don't pass it as argument.
    prepare_reference_gemm_data(ia.data(), ja.data(), a.data(), nrows_a, ncols_a, nnz_a, int_index,
            opA, alpha, beta, dense_matrix_layout, b.data(), ncols_b, ldb, c_ref.data(), ldc);

    fp diff  = set_fp_value(fp(0.0), fp(0.0));
    fp diff2 = set_fp_value(fp(0.0), fp(0.0));
    auto res = c_buffer.get_host_access(sycl::read_only);
    for (intType i = 0; i < nrows_c; i++) {
        intType flops_for_val = 2 * (ceil_div(nnz_a, nrows_c) + 2);
        if (opA == oneapi::mkl::transpose::nontrans) {
            flops_for_val = 2*(ia[i+1] - ia[i] + 2);
        }

        for (intType j = 0; j < ncols_c; j++) {
            intType index = i * ldc + j;
            if (!check_result(res[index], c_ref[index], flops_for_val, index))
                return 1;
            diff += (c_ref[index] - res[index]) * (c_ref[index] - res[index]);
            diff2 += c_ref[index] * c_ref[index];
        }
    }

    std::cout << "\n\t\t sparse::gemm residual:\n"
              << "\t\t\t" << diff / diff2 << "\n\tFinished" << std::endl;

    return 0;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Matrix-Dense Matrix Multiply Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# C = alpha * op(A) * op(B) + beta * C" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a sparse matrix in CSR format, B and C are "
                 "dense matrices"
              << std::endl;
    std::cout << "# and alpha, beta are floating point type precision scalars." << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   sparse::gemm" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type, MatrixMultiplyExample
//  runs is with all supported data types
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {

        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);

        if (my_dev_is_found) {
            std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

            oneapi::mkl::transpose opA = oneapi::mkl::transpose::nontrans;

            std::cout << "\tRunning with single precision real data type:" << std::endl;
            status = run_sparse_matrix_dense_matrix_multiply_example<float, std::int32_t>(my_dev, opA);
            if (status != 0)
                return status;

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                opA = oneapi::mkl::transpose::trans;
                std::cout << "\tRunning with double precision real data type:" << std::endl;
                status = run_sparse_matrix_dense_matrix_multiply_example<double, std::int32_t>(my_dev, opA);
                if (status != 0)
                    return status;
            }

            opA = oneapi::mkl::transpose::conjtrans;
            std::cout << "\tRunning with single precision complex data type:" << std::endl;
            status = run_sparse_matrix_dense_matrix_multiply_example<std::complex<float>, std::int32_t>(
                    my_dev, opA);
            if (status != 0)
                return status;

            if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                opA = oneapi::mkl::transpose::nontrans;
                std::cout << "\tRunning with double precision complex data type:" << std::endl;
                status = run_sparse_matrix_dense_matrix_multiply_example<std::complex<double>, std::int32_t>(
                        my_dev, opA);
                if (status != 0)
                    return status;
            }
        }
        else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it]
                      << " devices found; Fail on missing devices "
                         "is enabled.\n";
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                      << sycl_device_names[*it] << " tests.\n";
#endif
        }
    }

    mkl_free_buffers();
    return status;
}
