hicblas_mod.F90 Source File


This file depends on

sourcefile~~hicblas_mod.f90~~EfferentGraph sourcefile~hicblas_mod.f90 hicblas_mod.F90 sourcefile~growing_allocator_mod.f90 growing_allocator_mod.F90 sourcefile~hicblas_mod.f90->sourcefile~growing_allocator_mod.f90

Files dependent on this one

sourcefile~~hicblas_mod.f90~~AfferentGraph sourcefile~hicblas_mod.f90 hicblas_mod.F90 sourcefile~ledir_mod.f90 ledir_mod.F90 sourcefile~ledir_mod.f90->sourcefile~hicblas_mod.f90 sourcefile~leinv_mod.f90 leinv_mod.F90 sourcefile~leinv_mod.f90->sourcefile~hicblas_mod.f90 sourcefile~ltdir_mod.f90 ltdir_mod.F90 sourcefile~ltdir_mod.f90->sourcefile~ledir_mod.f90 sourcefile~ltdir_mod.f90~2 ltdir_mod.F90 sourcefile~ltdir_mod.f90~2->sourcefile~ledir_mod.f90 sourcefile~ltinv_mod.f90 ltinv_mod.F90 sourcefile~ltinv_mod.f90->sourcefile~leinv_mod.f90 sourcefile~ltinv_mod.f90~2 ltinv_mod.F90 sourcefile~ltinv_mod.f90~2->sourcefile~leinv_mod.f90 sourcefile~trltom_pack_unpack.f90 trltom_pack_unpack.F90 sourcefile~trltom_pack_unpack.f90->sourcefile~ledir_mod.f90 sourcefile~trmtol_pack_unpack.f90 trmtol_pack_unpack.F90 sourcefile~trmtol_pack_unpack.f90->sourcefile~leinv_mod.f90 sourcefile~dir_trans_ctl_mod.f90 dir_trans_ctl_mod.F90 sourcefile~dir_trans_ctl_mod.f90->sourcefile~ltdir_mod.f90 sourcefile~dir_trans_ctl_mod.f90->sourcefile~trltom_pack_unpack.f90 sourcefile~inv_trans_ctl_mod.f90 inv_trans_ctl_mod.F90 sourcefile~inv_trans_ctl_mod.f90->sourcefile~ltinv_mod.f90 sourcefile~inv_trans_ctl_mod.f90->sourcefile~trmtol_pack_unpack.f90 sourcefile~ltdir_ctl_mod.f90 ltdir_ctl_mod.F90 sourcefile~ltdir_ctl_mod.f90->sourcefile~ltdir_mod.f90 sourcefile~ltinv_ctl_mod.f90 ltinv_ctl_mod.F90 sourcefile~ltinv_ctl_mod.f90->sourcefile~ltinv_mod.f90 sourcefile~dir_trans.f90 dir_trans.F90 sourcefile~dir_trans.f90->sourcefile~dir_trans_ctl_mod.f90 sourcefile~dir_trans.f90~2 dir_trans.F90 sourcefile~dir_trans.f90~2->sourcefile~dir_trans_ctl_mod.f90 sourcefile~dir_trans_ctl_mod.f90~2 dir_trans_ctl_mod.F90 sourcefile~dir_trans_ctl_mod.f90~2->sourcefile~ltdir_ctl_mod.f90 sourcefile~inv_trans.f90 inv_trans.F90 sourcefile~inv_trans.f90->sourcefile~inv_trans_ctl_mod.f90 sourcefile~inv_trans.f90~2 inv_trans.F90 sourcefile~inv_trans.f90~2->sourcefile~inv_trans_ctl_mod.f90 sourcefile~inv_trans_ctl_mod.f90~2 inv_trans_ctl_mod.F90 sourcefile~inv_trans_ctl_mod.f90~2->sourcefile~ltinv_ctl_mod.f90

Source Code

! (C) Copyright 2000- ECMWF.
!
! This software is licensed under the terms of the Apache Licence Version 2.0
! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
! In applying this licence, ECMWF does not waive the privileges and immunities
! granted to it by virtue of its status as an intergovernmental organisation
! nor does it submit to any jurisdiction.
!

#if defined CUDAGPU
#define hipblasSgemm 'cublasSgemm'
#define hipblasDgemm 'cublasDgemm'
#define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM
#define OPENACC_LIB OPENACC
#endif

MODULE HICBLAS_MOD

USE PARKIND1, ONLY : JPIM, JPRM, JPRD
USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE
USE ISO_C_BINDING
USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM

IMPLICIT NONE

  INTERFACE HIP_GEMM_BATCHED
    MODULE PROCEDURE HIP_DGEMM_BATCHED_OVERLOAD
    MODULE PROCEDURE HIP_SGEMM_BATCHED_OVERLOAD
    MODULE PROCEDURE HIP_DGEMM_GROUPED_OVERLOAD
    MODULE PROCEDURE HIP_SGEMM_GROUPED_OVERLOAD
  END INTERFACE HIP_GEMM_BATCHED

!
! Define the interfaces to HIP/CUDA C code via a common wrapper interface
!
interface hip_gemm
!
! void hipblasSgemm (char transa, char transb, int m, int n,
! int k, float alpha, const float *A, int lda,
! const float *B, int ldb, float beta, float *C, int ldc)
!
SUBROUTINE HIP_SGEMM(CTA, CTB, M, N, K,&
ALPHA, A, LDA, B, LDB, BETA, C, LDC) BIND(C,NAME='hipblasSgemm')
USE ISO_C_BINDING
CHARACTER(1,C_CHAR),VALUE :: CTA, CTB
INTEGER(C_INT),     VALUE :: M,N,K,LDA,LDB,LDC
REAL(C_FLOAT),      VALUE :: ALPHA,BETA
REAL(C_FLOAT), DIMENSION(LDA,*) :: A
REAL(C_FLOAT), DIMENSION(LDB,*) :: B
REAL(C_FLOAT), DIMENSION(LDC,*) :: C
END SUBROUTINE HIP_SGEMM

!
! void hipblasDgemm (char transa, char transb, int m, int n,
! int k, double alpha, const double *A, int lda,
! const double *B, int ldb, double beta, double *C, int ldc)
!
SUBROUTINE HIP_DGEMM(CTA, CTB, M, N, K,&
ALPHA, A, LDA, B, LDB, BETA, C, LDC) BIND(C,NAME='hipblasDgemm')
USE ISO_C_BINDING
CHARACTER(1,C_CHAR),VALUE :: CTA, CTB
INTEGER(C_INT),     VALUE :: M,N,K,LDA,LDB,LDC
REAL(C_DOUBLE),     VALUE :: ALPHA,BETA
REAL(C_DOUBLE), DIMENSION(LDA,*) :: A
REAL(C_DOUBLE), DIMENSION(LDB,*) :: B
REAL(C_DOUBLE), DIMENSION(LDC,*) :: C
END SUBROUTINE HIP_DGEMM
END INTERFACE

INTERFACE
    SUBROUTINE HIP_DGEMM_BATCHED(   &
        & CTA, CTB,                 &
        & M, N, K,                  &
        & ALPHA,                    &
        & A, LDA, TDA,              &
        & B, LDB, TDB,              &
        & BETA,                     &
        & C, LDC, TDC,              &
        & BATCHCOUNT, STREAM, ALLOC &
    &) BIND(C, NAME='hipblas_dgemm_wrapper')
        USE ISO_C_BINDING
        CHARACTER(1,C_CHAR), VALUE            :: CTA, CTB
        INTEGER(C_INT),      VALUE            :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT
        REAL(C_DOUBLE),      VALUE            :: ALPHA,BETA
        REAL(C_DOUBLE),      DIMENSION(LDA,*) :: A
        REAL(C_DOUBLE),      DIMENSION(LDB,*) :: B
        REAL(C_DOUBLE),      DIMENSION(LDC,*) :: C
        INTEGER(KIND=C_SIZE_T) :: STREAM
        TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
    END SUBROUTINE HIP_DGEMM_BATCHED
END INTERFACE

INTERFACE
    SUBROUTINE HIP_DGEMM_STRIDED_BATCHED(&
        & CTA, CTB,               &
        & M, N, K,                &
        & ALPHA,                  &
        & A, LDA, TDA,            &
        & B, LDB, TDB,            &
        & BETA,                   &
        & C, LDC, TDC,            &
        & BATCHCOUNT, STREAM      &
    &) BIND(C, NAME='hipblasDgemmStridedBatched_wrapper')
        USE ISO_C_BINDING
        CHARACTER(1,C_CHAR),  VALUE            :: CTA, CTB
        INTEGER(C_INT),       VALUE            :: M, N, K, LDA, LDB, LDC, BATCHCOUNT
        INTEGER(C_INT),       VALUE            :: TDA,TDB,TDC
        REAL(C_DOUBLE),       VALUE            :: ALPHA, BETA
        REAL(C_DOUBLE),       DIMENSION(LDA,*) :: A
        REAL(C_DOUBLE),       DIMENSION(LDB,*) :: B
        REAL(C_DOUBLE),       DIMENSION(LDC,*) :: C
        INTEGER(KIND=C_SIZE_T) :: STREAM
    END SUBROUTINE HIP_DGEMM_STRIDED_BATCHED
END INTERFACE

INTERFACE
    SUBROUTINE HIP_DGEMM_BATCHED_FINALIZE() BIND(C,NAME='hipblasDgemmBatched_finalize')
    END SUBROUTINE HIP_DGEMM_BATCHED_FINALIZE
END INTERFACE

INTERFACE
    SUBROUTINE HIP_SGEMM_BATCHED(   &
        & CTA, CTB,                 &
        & M, N, K,                  &
        & ALPHA,                    &
        & A, LDA, TDA,              &
        & B, LDB, TDB,              &
        & BETA,                     &
        & C, LDC, TDC,              &
        & BATCHCOUNT, STREAM, ALLOC &
    &) BIND(C, NAME='hipblas_sgemm_wrapper')
        USE ISO_C_BINDING
        CHARACTER(1,C_CHAR), VALUE            :: CTA, CTB
        INTEGER(C_INT),      VALUE            :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT
        REAL(C_FLOAT),       VALUE            :: ALPHA, BETA
        REAL(C_FLOAT),       DIMENSION(LDA,*) :: A
        REAL(C_FLOAT),       DIMENSION(LDB,*) :: B
        REAL(C_FLOAT),       DIMENSION(LDC,*) :: C
        INTEGER(KIND=C_SIZE_T) :: STREAM
        TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
    END SUBROUTINE HIP_SGEMM_BATCHED
END INTERFACE

INTERFACE
    SUBROUTINE HIP_SGEMM_STRIDED_BATCHED(&
        & CTA, CTB,               &
        & M, N, K,                &
        & ALPHA,                  &
        & A, LDA, TDA,            &
        & B, LDB, TDB,            &
        & BETA,                   &
        & C, LDC, TDC,            &
        & BATCHCOUNT, STREAM      &
    &) BIND(C, NAME='hipblasSgemmStridedBatched_wrapper')
        USE ISO_C_BINDING
        CHARACTER(1,C_CHAR),  VALUE            :: CTA, CTB
        INTEGER(C_INT),       VALUE            :: M, N, K, LDA, LDB, LDC, BATCHCOUNT
        INTEGER(C_INT),       VALUE            :: TDA,TDB,TDC
        REAL(C_FLOAT),        VALUE            :: ALPHA, BETA
        REAL(C_FLOAT),        DIMENSION(LDA,*) :: A
        REAL(C_FLOAT),        DIMENSION(LDB,*) :: B
        REAL(C_FLOAT),        DIMENSION(LDC,*) :: C
        INTEGER(KIND=C_SIZE_T) :: STREAM
    END SUBROUTINE HIP_SGEMM_STRIDED_BATCHED
END INTERFACE

INTERFACE
    SUBROUTINE HIP_SGEMM_BATCHED_FINALIZE() BIND(C,NAME='hipblasSgemmBatched_finalize')
    END SUBROUTINE HIP_SGEMM_BATCHED_FINALIZE
END INTERFACE

INTERFACE
SUBROUTINE HIP_DGEMM_GROUPED(   &
    & BLAS_ID, CTA, CTB,        &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, OFFSETA,          &
    & B, LDB, OFFSETB,          &
    & BETA,                     &
    & C, LDC, OFFSETC,          &
    & BATCHCOUNT, STREAM, ALLOC &
&) BIND(C, NAME='blas_dgemm_wrapper_grouped')
    USE ISO_C_BINDING
    CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
    INTEGER(C_INT), VALUE  :: BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
    INTEGER(C_INT)         :: N(*), K(*), OFFSETA(*), OFFSETB(*), OFFSETC(*)
    REAL(C_DOUBLE), VALUE  :: ALPHA,BETA
    REAL(C_DOUBLE)         :: A(*), B(*), C(*)
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
END SUBROUTINE HIP_DGEMM_GROUPED
SUBROUTINE HIP_SGEMM_GROUPED(   &
    & BLAS_ID, CTA, CTB,        &
    & M, N, K,                  &
    & ALPHA,                    &
    & A, LDA, OFFSETA,          &
    & B, LDB, OFFSETB,          &
    & BETA,                     &
    & C, LDC, OFFSETC,          &
    & BATCHCOUNT, STREAM, ALLOC &
&) BIND(C, NAME='blas_sgemm_wrapper_grouped')
    USE ISO_C_BINDING
    CHARACTER(1,C_CHAR), VALUE :: CTA, CTB
    INTEGER(C_INT), VALUE :: BLAS_ID, M, LDA, LDB, LDC, BATCHCOUNT
    INTEGER(C_INT)        :: N(*), K(*), OFFSETA(*), OFFSETB(*), OFFSETC(*)
    REAL(C_FLOAT), VALUE  :: ALPHA,BETA
    REAL(C_FLOAT)         :: A(*), B(*), C(*)
    INTEGER(KIND=C_SIZE_T) :: STREAM
    TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC
END SUBROUTINE HIP_SGEMM_GROUPED
END INTERFACE

CONTAINS

SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( &
      & TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, STRIDEA, &
      & BARRAY, LDB, STRIDEB, &
      & BETA, &
      & CARRAY, LDC, STRIDEC, &
      & BATCHCOUNT, STREAM, ALLOC)
    CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
    INTEGER(KIND=JPIM) :: M
    INTEGER(KIND=JPIM) :: N
    INTEGER(KIND=JPIM) :: K
    REAL(KIND=JPRD) :: ALPHA
    REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
    INTEGER(KIND=JPIM) :: LDA
    INTEGER(KIND=JPIM) :: STRIDEA
    REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY
    INTEGER(KIND=JPIM) :: LDB
    INTEGER(KIND=JPIM) :: STRIDEB
    REAL(KIND=JPRD) :: BETA
    REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
    INTEGER(KIND=JPIM) :: LDC
    INTEGER(KIND=JPIM) :: STRIDEC
    INTEGER(KIND=JPIM) :: BATCHCOUNT
    INTEGER(KIND=C_INT) :: STREAM
    TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

    INTEGER(KIND=C_LONG) :: HIP_STREAM

    HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

#if defined(_CRAYFTN)
    !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
    CALL HIP_DGEMM_BATCHED( &
      & TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, STRIDEA, &
      & BARRAY, LDB, STRIDEB, &
      & BETA, &
      & CARRAY, LDC, STRIDEC, &
      & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
    !$ACC END HOST_DATA
#endif
  END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD

  SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( &
      & TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, STRIDEA, &
      & BARRAY, LDB, STRIDEB, &
      & BETA, &
      & CARRAY, LDC, STRIDEC, &
      & BATCHCOUNT, STREAM, ALLOC)
    CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
    INTEGER(KIND=JPIM) :: M
    INTEGER(KIND=JPIM) :: N
    INTEGER(KIND=JPIM) :: K
    REAL(KIND=JPRM) :: ALPHA
    REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
    INTEGER(KIND=JPIM) :: LDA
    INTEGER(KIND=JPIM) :: STRIDEA
    REAL(KIND=JPRM), DIMENSION(*) :: BARRAY
    INTEGER(KIND=JPIM) :: LDB
    INTEGER(KIND=JPIM) :: STRIDEB
    REAL(KIND=JPRM) :: BETA
    REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
    INTEGER(KIND=JPIM) :: LDC
    INTEGER(KIND=JPIM) :: STRIDEC
    INTEGER(KIND=JPIM) :: BATCHCOUNT
    INTEGER(KIND=C_INT) :: STREAM
    TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

    INTEGER(KIND=C_LONG) :: HIP_STREAM

    HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

    CALL HIP_SGEMM_BATCHED( &
      & TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, STRIDEA, &
      & BARRAY, LDB, STRIDEB, &
      & BETA, &
      & CARRAY, LDC, STRIDEC, &
      & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
  END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD


  SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( &
      & BLAS_ID, TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, OFFSETA, &
      & BARRAY, LDB, OFFSETB, &
      & BETA, &
      & CARRAY, LDC, OFFSETC, &
      & BATCHCOUNT, STREAM, ALLOC)
    INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
    CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
    INTEGER(KIND=JPIM) :: M
    INTEGER(KIND=JPIM) :: N(:)
    INTEGER(KIND=JPIM) :: K(:)
    REAL(KIND=JPRD) :: ALPHA
    REAL(KIND=JPRD), DIMENSION(:) :: AARRAY
    INTEGER(KIND=JPIM) :: LDA
    INTEGER(KIND=JPIM) :: OFFSETA(:)
    REAL(KIND=JPRD), DIMENSION(*) :: BARRAY
    INTEGER(KIND=JPIM) :: LDB
    INTEGER(KIND=JPIM) :: OFFSETB(:)
    REAL(KIND=JPRD) :: BETA
    REAL(KIND=JPRD), DIMENSION(:) :: CARRAY
    INTEGER(KIND=JPIM) :: LDC
    INTEGER(KIND=JPIM) :: OFFSETC(:)
    INTEGER(KIND=JPIM) :: BATCHCOUNT
    INTEGER(KIND=C_INT) :: STREAM
    TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

    INTEGER(KIND=C_LONG) :: HIP_STREAM

    HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

    CALL HIP_DGEMM_GROUPED( &
      & BLAS_ID, TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, OFFSETA, &
      & BARRAY, LDB, OFFSETB, &
      & BETA, &
      & CARRAY, LDC, OFFSETC, &
      & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))

  END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD

  SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(&
      & BLAS_ID, TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, OFFSETA, &
      & BARRAY, LDB, OFFSETB, &
      & BETA, &
      & CARRAY, LDC, OFFSETC, &
      & BATCHCOUNT, STREAM, ALLOC)
    INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID
    CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB
    INTEGER(KIND=JPIM) :: M
    INTEGER(KIND=JPIM) :: N(:)
    INTEGER(KIND=JPIM) :: K(:)
    REAL(KIND=JPRM) :: ALPHA
    REAL(KIND=JPRM), DIMENSION(:) :: AARRAY
    INTEGER(KIND=JPIM) :: LDA
    INTEGER(KIND=JPIM) :: OFFSETA(:)
    REAL(KIND=JPRM), DIMENSION(:,:,:) :: BARRAY
    INTEGER(KIND=JPIM) :: LDB
    INTEGER(KIND=JPIM) :: OFFSETB(:)
    REAL(KIND=JPRM) :: BETA
    REAL(KIND=JPRM), DIMENSION(:) :: CARRAY
    INTEGER(KIND=JPIM) :: LDC
    INTEGER(KIND=JPIM) :: OFFSETC(:)
    INTEGER(KIND=JPIM) :: BATCHCOUNT
    INTEGER(KIND=C_INT) :: STREAM
    TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN) :: ALLOC

    INTEGER(KIND=C_LONG) :: HIP_STREAM

    HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG)

#if defined(_CRAYFTN)
    !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY)
#endif
    CALL HIP_SGEMM_GROUPED( &
      & BLAS_ID, TRANSA, TRANSB, &
      & M, N, K, &
      & ALPHA, &
      & AARRAY, LDA, OFFSETA, &
      & BARRAY, LDB, OFFSETB, &
      & BETA, &
      & CARRAY, LDC, OFFSETC, &
      & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC))
#if defined(_CRAYFTN)
    !$ACC END HOST_DATA
#endif

  END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD

END MODULE HICBLAS_MOD