! (C) Copyright 2000- ECMWF. ! ! This software is licensed under the terms of the Apache Licence Version 2.0 ! which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. ! In applying this licence, ECMWF does not waive the privileges and immunities ! granted to it by virtue of its status as an intergovernmental organisation ! nor does it submit to any jurisdiction. ! #if defined CUDAGPU #define ACC_GET_HIP_STREAM ACC_GET_CUDA_STREAM #define OPENACC_LIB OPENACC #endif MODULE HICBLAS_MOD USE EC_PARKIND, ONLY: JPIM, JPRM, JPRD, JPIB USE GROWING_ALLOCATOR_MOD, ONLY: GROWING_ALLOCATION_TYPE #ifdef ACCGPU USE OPENACC_LIB, ONLY: ACC_GET_HIP_STREAM #endif #ifdef OMPGPU #endif IMPLICIT NONE INTERFACE SUBROUTINE HIP_DGEMM_BATCHED( & & CTA, CTB, & & M, N, K, & & ALPHA, & & A, LDA, TDA, & & B, LDB, TDB, & & BETA, & & C, LDC, TDC, & & BATCHCOUNT, STREAM, ALLOC & &) BIND(C, NAME='hipblas_dgemm_wrapper') USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_DOUBLE, C_SIZE_T, C_PTR CHARACTER(1,C_CHAR), VALUE :: CTA, CTB INTEGER(C_INT), VALUE :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT REAL(C_DOUBLE), VALUE :: ALPHA,BETA REAL(C_DOUBLE), DIMENSION(LDA,*) :: A REAL(C_DOUBLE), DIMENSION(LDB,*) :: B REAL(C_DOUBLE), DIMENSION(LDC,*) :: C INTEGER(KIND=C_SIZE_T) :: STREAM TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC END SUBROUTINE HIP_DGEMM_BATCHED SUBROUTINE HIP_SGEMM_BATCHED( & & CTA, CTB, & & M, N, K, & & ALPHA, & & A, LDA, TDA, & & B, LDB, TDB, & & BETA, & & C, LDC, TDC, & & BATCHCOUNT, STREAM, ALLOC & &) BIND(C, NAME='hipblas_sgemm_wrapper') USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_FLOAT, C_SIZE_T, C_PTR CHARACTER(1,C_CHAR), VALUE :: CTA, CTB INTEGER(C_INT), VALUE :: M, N, K, LDA, LDB, LDC, TDA, TDB, TDC, BATCHCOUNT REAL(C_FLOAT), VALUE :: ALPHA, BETA REAL(C_FLOAT), DIMENSION(LDA,*) :: A REAL(C_FLOAT), DIMENSION(LDB,*) :: B REAL(C_FLOAT), DIMENSION(LDC,*) :: C INTEGER(KIND=C_SIZE_T) :: STREAM TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC END SUBROUTINE HIP_SGEMM_BATCHED END INTERFACE INTERFACE SUBROUTINE CLEAN_GEMM(RESOL_ID) BIND(C, NAME="clean_gemm") USE ISO_C_BINDING INTEGER(KIND=C_INT), INTENT(IN), VALUE :: RESOL_ID END SUBROUTINE END INTERFACE INTERFACE SUBROUTINE HIP_DGEMM_GROUPED( & & RESOL_ID, BLAS_ID, CTA, CTB, & & M, N, K, & & ALPHA, & & A, LDA, OFFSETA, & & B, LDB, OFFSETB, & & BETA, & & C, LDC, OFFSETC, & & BATCHCOUNT, STREAM, ALLOC & &) BIND(C, NAME='hipblas_dgemm_wrapper_grouped') USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_DOUBLE, C_SIZE_T, C_PTR, C_INT64_T CHARACTER(1,C_CHAR), VALUE :: CTA, CTB INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDC, BATCHCOUNT INTEGER(C_INT) :: N(*), K(*), LDB(*) INTEGER(C_INT64_T) :: OFFSETA(*), OFFSETB(*), OFFSETC(*) REAL(C_DOUBLE), VALUE :: ALPHA,BETA REAL(C_DOUBLE) :: A(*), B(*), C(*) INTEGER(KIND=C_SIZE_T) :: STREAM TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC END SUBROUTINE HIP_DGEMM_GROUPED SUBROUTINE HIP_SGEMM_GROUPED( & & RESOL_ID, BLAS_ID, CTA, CTB, & & M, N, K, & & ALPHA, & & A, LDA, OFFSETA, & & B, LDB, OFFSETB, & & BETA, & & C, LDC, OFFSETC, & & BATCHCOUNT, STREAM, ALLOC & &) BIND(C, NAME='hipblas_sgemm_wrapper_grouped') USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_FLOAT, C_SIZE_T, C_PTR, C_INT64_T CHARACTER(1,C_CHAR), VALUE :: CTA, CTB INTEGER(C_INT), VALUE :: RESOL_ID, BLAS_ID, M, LDA, LDC, BATCHCOUNT INTEGER(C_INT) :: N(*), K(*), LDB(*) INTEGER(C_INT64_T) :: OFFSETA(*), OFFSETB(*), OFFSETC(*) REAL(C_FLOAT), VALUE :: ALPHA,BETA REAL(C_FLOAT) :: A(*), B(*), C(*) INTEGER(KIND=C_SIZE_T) :: STREAM TYPE(C_PTR), INTENT(IN), VALUE :: ALLOC END SUBROUTINE HIP_SGEMM_GROUPED END INTERFACE CONTAINS SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD( & & TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, STRIDEA, & & BARRAY, LDB, STRIDEB, & & BETA, & & CARRAY, LDC, STRIDEC, & & BATCHCOUNT, STREAM, ALLOC) USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB INTEGER(KIND=JPIM) :: M INTEGER(KIND=JPIM) :: N INTEGER(KIND=JPIM) :: K REAL(KIND=JPRD) :: ALPHA REAL(KIND=JPRD), DIMENSION(:) :: AARRAY INTEGER(KIND=JPIM) :: LDA INTEGER(KIND=JPIM) :: STRIDEA REAL(KIND=JPRD), DIMENSION(:,:) :: BARRAY INTEGER(KIND=JPIM) :: LDB INTEGER(KIND=JPIM) :: STRIDEB REAL(KIND=JPRD) :: BETA REAL(KIND=JPRD), DIMENSION(:) :: CARRAY INTEGER(KIND=JPIM) :: LDC INTEGER(KIND=JPIM) :: STRIDEC INTEGER(KIND=JPIM) :: BATCHCOUNT INTEGER(KIND=C_INT) :: STREAM TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC INTEGER(KIND=C_LONG) :: HIP_STREAM #ifdef ACCGPU HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG) #endif #ifdef OMPGPU #endif #if defined(_CRAYFTN) #ifdef ACCGPU !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY) #endif #endif CALL HIP_DGEMM_BATCHED( & & TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, STRIDEA, & & BARRAY, LDB, STRIDEB, & & BETA, & & CARRAY, LDC, STRIDEC, & & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC)) #if defined(_CRAYFTN) #ifdef ACCGPU !$ACC END HOST_DATA #endif #endif END SUBROUTINE HIP_DGEMM_BATCHED_OVERLOAD SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD( & & TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, STRIDEA, & & BARRAY, LDB, STRIDEB, & & BETA, & & CARRAY, LDC, STRIDEC, & & BATCHCOUNT, STREAM, ALLOC) USE ISO_C_BINDING, ONLY: C_CHAR, C_INT, C_LONG, C_LOC CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB INTEGER(KIND=JPIM) :: M INTEGER(KIND=JPIM) :: N INTEGER(KIND=JPIM) :: K REAL(KIND=JPRM) :: ALPHA REAL(KIND=JPRM), DIMENSION(:) :: AARRAY INTEGER(KIND=JPIM) :: LDA INTEGER(KIND=JPIM) :: STRIDEA REAL(KIND=JPRM), DIMENSION(*) :: BARRAY INTEGER(KIND=JPIM) :: LDB INTEGER(KIND=JPIM) :: STRIDEB REAL(KIND=JPRM) :: BETA REAL(KIND=JPRM), DIMENSION(:) :: CARRAY INTEGER(KIND=JPIM) :: LDC INTEGER(KIND=JPIM) :: STRIDEC INTEGER(KIND=JPIM) :: BATCHCOUNT INTEGER(KIND=C_INT) :: STREAM TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC INTEGER(KIND=C_LONG) :: HIP_STREAM #ifdef ACCGPU HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG) #endif #ifdef OMPGPU #endif CALL HIP_SGEMM_BATCHED( & & TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, STRIDEA, & & BARRAY, LDB, STRIDEB, & & BETA, & & CARRAY, LDC, STRIDEC, & & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC)) END SUBROUTINE HIP_SGEMM_BATCHED_OVERLOAD SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD( & & RESOL_ID, BLAS_ID, TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, OFFSETA, & & BARRAY, LDB, OFFSETB, & & BETA, & & CARRAY, LDC, OFFSETC, & & BATCHCOUNT, STREAM, ALLOC) USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB INTEGER(KIND=JPIM) :: M INTEGER(KIND=JPIM) :: N(:) INTEGER(KIND=JPIM) :: K(:) REAL(KIND=JPRD) :: ALPHA REAL(KIND=JPRD), DIMENSION(:) :: AARRAY INTEGER(KIND=JPIM) :: LDA INTEGER(KIND=JPIB) :: OFFSETA(:) REAL(KIND=JPRD), DIMENSION(*) :: BARRAY INTEGER(KIND=JPIM) :: LDB(:) INTEGER(KIND=JPIB) :: OFFSETB(:) REAL(KIND=JPRD) :: BETA REAL(KIND=JPRD), DIMENSION(:) :: CARRAY INTEGER(KIND=JPIM) :: LDC INTEGER(KIND=JPIB) :: OFFSETC(:) INTEGER(KIND=JPIM) :: BATCHCOUNT INTEGER(KIND=C_INT) :: STREAM TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC INTEGER(KIND=C_LONG) :: HIP_STREAM #ifdef ACCGPU HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG) #endif #ifdef OMPGPU #endif CALL HIP_DGEMM_GROUPED( & & RESOL_ID, BLAS_ID, TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, OFFSETA, & & BARRAY, LDB, OFFSETB, & & BETA, & & CARRAY, LDC, OFFSETC, & & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC)) END SUBROUTINE HIP_DGEMM_GROUPED_OVERLOAD SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD(& & RESOL_ID, BLAS_ID, TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, OFFSETA, & & BARRAY, LDB, OFFSETB, & & BETA, & & CARRAY, LDC, OFFSETC, & & BATCHCOUNT, STREAM, ALLOC) USE ISO_C_BINDING, ONLY: C_INT, C_CHAR, C_LONG, C_LOC INTEGER(KIND=C_INT), INTENT(IN) :: RESOL_ID INTEGER(KIND=C_INT), INTENT(IN) :: BLAS_ID CHARACTER(1,C_CHAR), VALUE :: TRANSA, TRANSB INTEGER(KIND=JPIM) :: M INTEGER(KIND=JPIM) :: N(:) INTEGER(KIND=JPIM) :: K(:) REAL(KIND=JPRM) :: ALPHA REAL(KIND=JPRM), DIMENSION(:) :: AARRAY INTEGER(KIND=JPIM) :: LDA INTEGER(KIND=JPIB) :: OFFSETA(:) REAL(KIND=JPRM), DIMENSION(*) :: BARRAY INTEGER(KIND=JPIM) :: LDB(:) INTEGER(KIND=JPIB) :: OFFSETB(:) REAL(KIND=JPRM) :: BETA REAL(KIND=JPRM), DIMENSION(:) :: CARRAY INTEGER(KIND=JPIM) :: LDC INTEGER(KIND=JPIB) :: OFFSETC(:) INTEGER(KIND=JPIM) :: BATCHCOUNT INTEGER(KIND=C_INT) :: STREAM TYPE(GROWING_ALLOCATION_TYPE), INTENT(IN), POINTER :: ALLOC INTEGER(KIND=C_LONG) :: HIP_STREAM #ifdef ACCGPU HIP_STREAM = INT(ACC_GET_HIP_STREAM(STREAM), C_LONG) #endif #ifdef OMPGPU #endif #if defined(_CRAYFTN) #ifdef ACCGPU !$ACC HOST_DATA USE_DEVICE(AARRAY,BARRAY,CARRAY) #endif #endif CALL HIP_SGEMM_GROUPED( & & RESOL_ID, BLAS_ID, TRANSA, TRANSB, & & M, N, K, & & ALPHA, & & AARRAY, LDA, OFFSETA, & & BARRAY, LDB, OFFSETB, & & BETA, & & CARRAY, LDC, OFFSETC, & & BATCHCOUNT, HIP_STREAM, C_LOC(ALLOC)) #if defined(_CRAYFTN) #ifdef ACCGPU !$ACC END HOST_DATA #endif #endif END SUBROUTINE HIP_SGEMM_GROUPED_OVERLOAD END MODULE HICBLAS_MOD