@@ -69,13 +69,12 @@ infiniopStatus_t aclnnCreateMatmulDescriptor(AscendHandle_t handle,
6969 // aclnnGemm support C = alpha * A @ B + beta * C
7070 // see https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC3alpha003/apiref/aolapi/context/aclnnGemm.md
7171 ret = aclnnGemmGetWorkspaceSize (ta, tb, tc, (*desc_ptr)->alpha , (*desc_ptr)->beta , transA, transB, tc,
72- (*desc_ptr)->mt , &workspaceSize, &executor);
72+ (*desc_ptr)->mt , &workspaceSize, &executor);
7373 CHECK_RET (ret == ACL_SUCCESS,
74- LOG_PRINT (" aclnnGemmGetWorkspaceSize failed. ERROR: %d\n " , ret);
75- return STATUS_EXECUTION_FAILED);
74+ LOG_PRINT (" aclnnGemmGetWorkspaceSize failed. ERROR: %d\n " , ret);
75+ return STATUS_EXECUTION_FAILED);
7676 aclSetAclOpExecutorRepeatable (executor);
7777
78-
7978 return STATUS_SUCCESS;
8079}
8180
@@ -109,14 +108,14 @@ infiniopStatus_t aclnnMatmul(MatmulAclnnDescriptor_t desc,
109108 aclrtSetDevice (desc->device_id );
110109
111110 for (int i = 0 ; i < batch; i++) {
112- AclSetTensorAddr (executor, 0 , ta, (char *)(a) + i * desc->info ->a_matrix .stride * desc->dtype .size );
113- AclSetTensorAddr (executor, 1 , tb, (char *)(b) + i * desc->info ->b_matrix .stride * desc->dtype .size );
114- AclSetTensorAddr (executor, 2 , tc, (char *)(c) + i * desc->info ->c_matrix .stride * desc->dtype .size );
115- AclSetTensorAddr (executor, 3 , tc, (char *)(c) + i * desc->info ->c_matrix .stride * desc->dtype .size );
111+ AclSetTensorAddr (executor, 0 , ta, (char *) (a) + i * desc->info ->a_matrix .stride * desc->dtype .size );
112+ AclSetTensorAddr (executor, 1 , tb, (char *) (b) + i * desc->info ->b_matrix .stride * desc->dtype .size );
113+ AclSetTensorAddr (executor, 2 , tc, (char *) (c) + i * desc->info ->c_matrix .stride * desc->dtype .size );
114+ AclSetTensorAddr (executor, 3 , tc, (char *) (c) + i * desc->info ->c_matrix .stride * desc->dtype .size );
116115 aclnnStatus ret = aclnnGemm (workspace,
117- workspaceSize,
118- executor,
119- stream);
116+ workspaceSize,
117+ executor,
118+ stream);
120119 CHECK_RET (ret == ACL_SUCCESS,
121120 LOG_PRINT (" aclnnGemm failed. ERROR: %d\n " , ret);
122121 return STATUS_EXECUTION_FAILED);
0 commit comments