!===============================================================================
! Copyright (C) 2020 Intel Corporation
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
!       Example of using sfftw_plan_many_dft function on a
!       (GPU) device using the OpenMP target (offload) interface.
!
!*****************************************************************************

include "fftw/offload/fftw3_omp_offload.f90"

program sp_plan_many_dft_3d_outofplace

  use FFTW3_OMP_OFFLOAD
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING

  include 'fftw/fftw3.f'

  ! Sizes of 3D transform and the number of them
  integer, parameter :: N1 = 8
  integer, parameter :: N2 = 16
  integer, parameter :: N3 = 32
  integer, parameter :: M = 4

  ! Arbitrary harmonic used to verify FFT
  integer, parameter :: H1 = 1
  integer, parameter :: H2 = -N2/2
  integer, parameter :: H3 = -2

  ! Need single precision
  integer, parameter :: WP = selected_real_kind(6,37)

  ! Execution status
  integer :: statusf = 0, statusb = 0, status = 0

  ! Data arrays
  complex(WP), allocatable :: x(:)
  complex(WP), allocatable :: y(:)

  ! Distance, stride and embedding array
  integer :: dist, stride, nembed(3)

  ! FFTW plan
  integer*8 :: fwd = 0, bwd = 0

  stride = 1
  nembed(1) = stride * N1
  nembed(2) = N2
  nembed(3) = N3
  dist = nembed(1) * nembed(2) * nembed(3)

  print *,"Example sp_plan_many_dft_3d_outofplace"
  print *,"Forward and backward multiple 3D complex out-of-place FFT"
  print *,"Configuration parameters:"
  print '("  N = ["I0","I0","I0"]")', N1, N2, N3
  print '("  H = ["I0","I0","I0"]")', H1, H2, H3
  print '("  nembed = ["I0","I0","I0"]")', nembed(1), nembed(2), nembed(3)
  print '("  M = "I0)', M
  print '("  dist = "I0)', dist
  print '("  stride = "I0)', stride


  print *,"Allocate input array"
  allocate ( x(dist * M), STAT = status)
  if (0 /= status) goto 999
  
  print *,"Allocate output array"
  allocate ( y(dist * M), STAT = status)
  if (0 /= status) goto 999

  print *,"Initialize data for forward transform"
  call init(x, N1, N2, N3, H1, H2, H3, M, dist)

  print *,"Create FFTW forward transform plan"
  !$omp target data map(tofrom:x,y)
  !$omp dispatch
  call sfftw_plan_many_dft(fwd, 3, (/N1, N2, N3/), M, &
                           x, nembed, stride, dist,   &
                           y, nembed, stride, dist,   &
                           FFTW_FORWARD, FFTW_ESTIMATE)
  if (0 == fwd) print *, "Call to sfftw_plan_many_dft for forward transform &
                          &has failed"

  print *,"Create FFTW backward transform plan"
  !$omp dispatch
  call sfftw_plan_many_dft(bwd, 3, (/N1, N2, N3/), M, &
                           y, nembed, stride, dist,   &
                           x, nembed, stride, dist,   &
                           FFTW_BACKWARD, FFTW_ESTIMATE)
  if (0 == bwd) print *, "Call to sfftw_plan_many_dft for backward transform &
                          &has failed"

  print *,"Compute forward transform"
  !$omp dispatch
  call sfftw_execute_dft(fwd, x, y)

  ! Update the host with the results from forward FFT
  !$omp target update from(y)

  print *,"Verify the result of the forward transform"
  statusf = verify(y, N1, N2, N3, H1, H2, H3, M, dist)

  print *,"Initialize data for backward transform"
  call init(y, N1, N2, N3, -H1, -H2, -H3, M, dist)

  ! Update the device with input for backward FFT
  !$omp target update to(y)

  print *,"Compute backward transform"
  !$omp dispatch
  call sfftw_execute_dft(bwd, y, x)
  !$omp end target data

  print *,"Verify the result of the backward transform"
  statusb = verify(x, N1, N2, N3, H1, H2, H3, M, dist)
  if ((0 /= statusf) .or. (0 /= statusb)) goto 999

100 continue

  print *,"Destroy FFTW plans"
  call sfftw_destroy_plan(fwd)
  call sfftw_destroy_plan(bwd)

  print *,"Deallocate arrays"
  deallocate(x)

  if (status == 0) then
    print *, "TEST PASSED"
    call exit(0)
  else
    print *, "TEST FAILED"
    call exit(1)
  end if

999 print '("  Error, status forward = ",I0)', statusf
  print '(" Error, status backward = ",I0)', statusb
  status = 1
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure integer*8 function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = mod(k8*l,m)
  end function moda

  ! Initialize array with harmonic /H1, H2, H3/
  subroutine init(x, N1, N2, N3, H1, H2, H3, M, dist)
    integer N1, N2, N3, H1, H2, H3, M, dist
    complex(WP) :: x(:)

    integer mm, k1, k2, k3
    complex(WP), parameter :: I_TWOPI = (0.0_WP,6.2831853071795864769_WP)

    do mm = 1, M
      do k3 = 1, N3
        do k2 = 1, N2
          do k1 = 1, N1
            x(((mm-1)*dist)+(((k3-1)*N1*N2)+(k2-1)*N1)+k1) = &
            exp (I_TWOPI*&
            ( real(moda(k1-1,H1,N1), WP)/N1 &
            + real(moda(k2-1,H2,N2), WP)/N2 &
            + real(moda(k3-1,H3,N3), WP)/N3))/(N1*N2*N3)
          end do
        end do
      end do
    end do
  end subroutine init

  ! Verify that x(N1,N2) is unit peak at x(H1,H2)
  integer function verify(x, N1, N2, N3, H1, H2, H3, M, dist)
    integer N1, N2, N3, H1, H2, H3, M, dist
    complex(WP) :: x(:)

    integer mm, k1, k2, k3
    real(WP) err, errthr, maxerr
    complex(WP) :: res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 5.0 * log(real(N1*N2*N3,WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    maxerr = 0.0_WP
    do mm = 1, M
      do k3 = 1, N3
        do k2 = 1, N2
          do k1 = 1, N1
            if (mod(k1-1-H1, N1)==0 .AND. &
                mod(k2-1-H2, N2)==0 .AND. &
                mod(k3-1-H3, N3)==0) then
              res_exp = 1.0_WP
            else
              res_exp = 0.0_WP
            end if
            res_got = x(((mm-1)*dist)+(k3-1)*N1*N2+(k2-1)*N1+k1)
            err = abs(res_got - res_exp)
            maxerr = max(err,maxerr)
            if (.not.(err < errthr)) then
              print '("  x("I0","I0","I0"):"$)', k1, k2, k3
              print '(" expected ("G14.7","G14.7"),"$)', res_exp
              print '(" got ("G14.7","G14.7"),"$)', res_got
              print '(" err "G10.3)', err
              print *," Verification FAILED"
              verify = 1
              return
            end if
          end do
        end do
      end do
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verify = 0
  end function verify
end program sp_plan_many_dft_3d_outofplace
