/*******************************************************************************
* Copyright 2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*       This example demonstrates use of oneapi::mkl::lapack::gels_batch
*       to perform batched calculation of least squares.
*
*       The supported floating point data types for matrix data are:
*           float
*           double
*           std::complex<float>
*           std::complex<double>
*******************************************************************************/

#include <oneapi/mkl.hpp>
#include "common_for_examples.hpp"

template <typename data_t, typename real_t = decltype(std::real((data_t)0)), bool is_real = std::is_same_v<data_t,real_t>>
int run_gels_batch_example(sycl::device &dev)
{
    oneapi::mkl::transpose nontrans = oneapi::mkl::transpose::nontrans;
    const int64_t m = 5, n = 5, nrhs = 1, lda = m, stride_a = n*lda, ldb = m, stride_b = nrhs*ldb, batch_size = 2;

    auto v = [] (real_t arg) { if constexpr (is_real) return arg; else return data_t{0, arg}; };

    data_t A[] = {
        v( 1.0), v( 0.0), v( 0.0), v( 0.0), v( 0.0),
        v( 1.0), v( 0.2), v(-0.4), v(-0.4), v(-0.8),
        v( 1.0), v( 0.6), v(-0.2), v( 0.4), v(-1.2),
        v( 1.0), v( 1.0), v(-1.0), v( 0.6), v(-0.8),
        v( 1.0), v( 1.8), v(-0.6), v( 0.2), v(-0.6)
                                                   ,
        v( 0.2), v(-0.4), v(-0.4), v(-0.8), v( 0.0),
        v( 0.4), v( 0.2), v( 0.8), v(-0.4), v( 0.0),
        v( 0.4), v(-0.8), v( 0.2), v( 0.4), v( 0.0),
        v( 0.8), v( 0.4), v(-0.4), v( 0.2), v( 0.0),
        v( 0.0), v( 0.0), v( 0.0), v( 0.0), v( 1.0)
    };

    data_t B[] = {
        v( 5.0), v( 3.6), v(-2.2), v( 0.8), v(-3.4),
        v( 1.8), v(-0.6), v( 0.2), v(-0.6), v( 1.0),
    };

    data_t X[] = {
           1.0 ,    1.0 ,    1.0 ,    1.0 ,    1.0 ,
           1.0 ,    1.0 ,    1.0 ,    1.0 ,    1.0 ,
    };

    sycl::queue que { dev };

    data_t *A_dev = sycl::aligned_alloc_device<data_t>(64, stride_a*batch_size, que);
    data_t *B_dev = sycl::aligned_alloc_device<data_t>(64, stride_b*batch_size, que);
    que.copy(A, A_dev, stride_a*batch_size).wait();
    que.copy(B, B_dev, stride_b*batch_size).wait();

    int64_t scratchpad_size = oneapi::mkl::lapack::gels_batch_scratchpad_size<data_t>(que, nontrans, m, n, nrhs, lda, stride_a, ldb, stride_b, batch_size);
    data_t *scratchpad = sycl::aligned_alloc_device<data_t>(64, scratchpad_size, que);

    oneapi::mkl::lapack::gels_batch(que, nontrans, m, n, nrhs, A_dev, lda, stride_a, B_dev, ldb, stride_b, batch_size, scratchpad, scratchpad_size).wait_and_throw();

    que.copy(B_dev, B, stride_b*batch_size).wait();

    const real_t bound = std::is_same_v<real_t, float> ? 1e-6 : 1e-8;
    bool passed = true;

    printf("Results:\n");
    auto print = [](data_t &v) { if constexpr (is_real) printf("%6.2f", v); else printf("<%6.2f,%6.2f> ", v.real(), v.imag()); };
    for (int i=0; i<batch_size; i++) {
        for (int j=0; j<n; j++) {
            data_t result = B[i*stride_b+j];
            data_t residual = result - X[j+i*m];
            passed = passed and (result == result) and (sqrt(std::abs(std::real(residual*residual))) < bound);
            print(result);
        }
        printf("\n");
    }

    if (passed) {
        printf("Calculations successfully finished\n");
    } else {
        printf("ERROR: results mismatch!\n");
        printf("Expected:\n");
        for (int i=0; i<batch_size; i++) {
            for (int j=0; j<m; j++) {
                print(X[j+i*m]);
            }
            printf("\n");
        }
        return 1;
    }

    return 0;
}

//
// Description of example setup, APIs used and supported floating point type precisions
//
void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Batched strided GELS example:" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Computes least squares of a batch of matrices and right hand sides." << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU device
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU device
// -DSYCL_DEVICES_all (default) -- runs on all: CPU and GPU devices
//
//  For each device selected and each data type supported, gels_batch example
//  runs with all supported data types
//
int main(int argc, char **argv) {

    print_example_banner();

    // Find list of devices
    std::list<my_sycl_device_types> listOfDevices;
    set_list_of_devices(listOfDevices);

    bool failed = false;

    for (auto &deviceType: listOfDevices) {
        sycl::device myDev;
        bool myDevIsFound = false;
        get_sycl_device(myDev, myDevIsFound, deviceType);

        if (myDevIsFound) {
          std::cout << std::endl << "Running gels_batch examples on " << sycl_device_names[deviceType] << "." << std::endl;

          std::cout << "Running with single precision real data type:" << std::endl;
          failed |= run_gels_batch_example<float>(myDev);

          std::cout << "Running with single precision complex data type:" << std::endl;
          failed |= run_gels_batch_example<std::complex<float>>(myDev);

          if (isDoubleSupported(myDev)) {
              std::cout << "Running with double precision real data type:" << std::endl;
              failed |= run_gels_batch_example<double>(myDev);

              std::cout << "Running with double precision complex data type:" << std::endl;
              failed |= run_gels_batch_example<std::complex<double>>(myDev);
          } else {
              std::cout << "Double precision not supported on this device " << std::endl;
              std::cout << std::endl;
          }

        } else {
#ifdef FAIL_ON_MISSING_DEVICES
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; Fail on missing devices is enabled.\n";
          return 1;
#else
          std::cout << "No " << sycl_device_names[deviceType] << " devices found; skipping " << sycl_device_names[deviceType] << " tests.\n";
#endif
        }
    }
    return failed;
}
