/*******************************************************************************
* Copyright 2021-2022 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*  Content:
*      POTRF/POTRS OpenMP GPU Offload Example
*******************************************************************************/

#include <stdio.h>
#include <math.h>
#include <omp.h>
#include "mkl.h"
#include "mkl_omp_offload.h"


int main()
{
    char L = 'L';
    MKL_INT m    = 3;
    MKL_INT n    = 3;
    MKL_INT lda  = 4;
    MKL_INT info = 0;
    MKL_INT matrix_size = lda * n;

    double matrix[] = {  4,  12, -16,  -1,
                        12,  37, -43,  -1,
                       -16, -43,  98,  -1 };

    double rhs[] =    {  4,  12, -16,  -1,
                        12,  37, -43,  -1,
                       -16, -43,  98,  -1 };

    double result[] = {  1,   0,   0,  -1,
                         0,   1,   0,  -1,
                         0,   0,   1,  -1 };

    printf("Input:\n");
    for (int i=0; i < m; i++) {
        for (int j=0; j <= i; j++) {
            printf("%6.2f ", matrix[i + j * lda]);
        }
        printf("\n");
    }

    MKL_INT *info_ptr  = &info;
    double *matrix_ptr = &matrix[0];
    double *rhs_ptr    = &rhs[0];

    #pragma omp target data map(matrix_ptr[0:matrix_size], rhs_ptr[0:matrix_size], info_ptr[0:1])
    {
        #pragma omp dispatch
        dpotrf(&L, &n, matrix_ptr, &lda, info_ptr);
        #pragma omp target update from(info_ptr)
        if (info == 0) {
            #pragma omp dispatch
            dpotrs(&L, &n, &n, matrix_ptr, &lda, rhs_ptr, &lda, info_ptr);
        }
    }

    if (info != 0) {
        printf("ERROR: Calculations failed with info = %d!\n", (int)info);
        return 1;
    }

    int num_errors = 0;
    printf("Result:\n");
    for (int i=0; i < m; i++) {
        for (int j=0; j <= i; j++) {
            printf("%6.2f ", rhs[i + j * lda]);
            num_errors += fabs(rhs[i + j * lda] - result[i + j * lda]) > 1e-7;
        }
        printf("\n");
    }
    if (num_errors != 0) {
        printf("ERROR: result mismatches!\n");
        return 1;
    }

    return 0;
}
