diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc index 47874ef3..6f93e343 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -89,7 +89,7 @@ void CublasBatchedMatmulOp::jit_prepare(JK& jk) { jk << _CS("[T:") << a->dtype(); jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << ']'; } @@ -128,6 +128,18 @@ void CublasBatchedMatmulOp::jit_run() { || b->dtype() == ns_float16 || c->dtype() == ns_float16) { computeType = CUBLAS_COMPUTE_16F; } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = CUDA_R_32F; + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif checkCudaErrors(cublasGemmStridedBatchedEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -135,15 +147,13 @@ void CublasBatchedMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta, c->ptr(),get_dtype(c->dtype()), k, k * n, batch_size,computeType,algo)); - #else - checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, - CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - k, n, m, &alpha, - b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, - a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, - c->ptr(), k, k * n, - batch_size)); - #endif + // checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, + // c->ptr(), k, k * n, + // batch_size)); } #endif #endif // JIT diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index 95de20f7..a6708225 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -50,7 +50,7 @@ void CublasMatmulOp::jit_prepare(JK& jk) { jk << _CS("[T:") << a->dtype(); jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); - jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); jk << ']'; } @@ -85,6 +85,18 @@ void CublasMatmulOp::jit_run() { || b->dtype() == ns_float16 || c->dtype() == ns_float16) { computeType = CUBLAS_COMPUTE_16F; } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = CUDA_R_32F; + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif checkCudaErrors(cublasGemmEx(handle_, CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, k, n, m, &alpha, @@ -92,15 +104,13 @@ void CublasMatmulOp::jit_run() { a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta, c->ptr(),get_dtype(c->dtype()), k, computeType, algo)); - #else - checkCudaErrors(cublas@op@@gemm(handle_, - CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, - k, n, m, &alpha, - b->ptr(), '@Trans_b' == 'N' ? k : m, - a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, - c->ptr(), k)); + // checkCudaErrors(cublas@op@@gemm(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, + // c->ptr(), k)); - #endif } #endif // JIT diff --git a/python/jittor/other/code_softmax.py b/python/jittor/other/code_softmax.py index 8534f0cb..837bd648 100644 --- a/python/jittor/other/code_softmax.py +++ b/python/jittor/other/code_softmax.py @@ -100,7 +100,7 @@ def softmax_v1(a, log=False): {for_loop} #pragma unroll for (int j=0; j<{ILP}; j++) - v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"} + v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"} typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -114,8 +114,8 @@ def softmax_v1(a, log=False): #pragma unroll for (int j=0; j<{ILP}; j++) vx[i][j] = { - "vy[i][j] - expf(vx[i][j]) * reduce_var;" if log - else "vx[i][j] * (vy[i][j] - reduce_var);" + "vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log + else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));" } {for_loop}