Skip to content

Commit

Permalink
Merge branch 'master' of github.com:Jittor/jittor
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Mar 21, 2022
2 parents e7bb254 + 1fa8977 commit 99d6d6b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 22 deletions.
30 changes: 20 additions & 10 deletions python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ void CublasBatchedMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
}

Expand Down Expand Up @@ -128,22 +128,32 @@ void CublasBatchedMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmStridedBatchedEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k, k * n,
batch_size,computeType,algo));
#else
checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
c->ptr<T>(), k, k * n,
batch_size));
#endif
// checkCudaErrors(cublas@op@@gemmStridedBatched(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m, k * m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, n * m, &beta,
// c->ptr<T>(), k, k * n,
// batch_size));
}
#endif
#endif // JIT
Expand Down
28 changes: 19 additions & 9 deletions python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void CublasMatmulOp::jit_prepare(JK& jk) {
jk << _CS("[T:") << a->dtype();
jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N');
jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N');
jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D');
jk << _CS("][op:") << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D'));
jk << ']';
}

Expand Down Expand Up @@ -85,22 +85,32 @@ void CublasMatmulOp::jit_run() {
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
cudaDataType_t computeType = CUDA_R_32F;
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
#endif
checkCudaErrors(cublasGemmEx(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(),get_dtype(c->dtype()), k,
computeType, algo));
#else
checkCudaErrors(cublas@op@@gemm(handle_,
CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
k, n, m, &alpha,
b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
c->ptr<T>(), k));
// checkCudaErrors(cublas@op@@gemm(handle_,
// CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a,
// k, n, m, &alpha,
// b->ptr<T>(), '@Trans_b' == 'N' ? k : m,
// a->ptr<T>(), '@Trans_a' == 'N' ? m : n, &beta,
// c->ptr<T>(), k));

#endif

}
#endif // JIT
Expand Down
6 changes: 3 additions & 3 deletions python/jittor/other/code_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def softmax_v1(a, log=False):
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++)
v1 += {"vy[i][j];" if log else "vx[i][j]*vy[i][j];"}
v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"}
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
Expand All @@ -114,8 +114,8 @@ def softmax_v1(a, log=False):
#pragma unroll
for (int j=0; j<{ILP}; j++)
vx[i][j] = {
"vy[i][j] - expf(vx[i][j]) * reduce_var;" if log
else "vx[i][j] * (vy[i][j] - reduce_var);"
"vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log
else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));"
}
{for_loop}
Expand Down

0 comments on commit 99d6d6b

Please sign in to comment.