From 8f96635efa44b43fde37a6a4799c6631704c505f Mon Sep 17 00:00:00 2001 From: cboss6 Date: Mon, 15 Apr 2024 13:48:44 +0800 Subject: [PATCH] [GPU][Fix] Fix softmax acc issue. (#2674) --- itex/core/kernels/gpu/softmax_op_functor.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/itex/core/kernels/gpu/softmax_op_functor.h b/itex/core/kernels/gpu/softmax_op_functor.h index 07e091d46..7c15e466b 100644 --- a/itex/core/kernels/gpu/softmax_op_functor.h +++ b/itex/core/kernels/gpu/softmax_op_functor.h @@ -531,7 +531,8 @@ inline Status SoftmaxWorkgroupSMemImpl(const GPUDevice& device, const int num_packs = cols / pack_size; stream->submit([&](sycl::handler& h) { - __shared__ scratch(sycl::range<1>(workgroup_size), h); + __shared__ scratch( + sycl::range<1>(cols * sizeof(ComputeType)), h); SoftmaxWorkgroupSMemImplKernel task(scratch, rows, cols, workgroup_size, num_packs, device_load, @@ -552,7 +553,10 @@ inline Status LaunchSoftmaxWorkGroupSMemImpl(const GPUDevice& device, STORE device_store, const int32 rows, const int32 cols) { - int workgroup_size = 128; + int workgroup_size = + device.stream() + ->get_device() + .template get_info(); sycl::range<1> local_range(workgroup_size); int num_wg; GetNumWorkGroups(device.stream()->get_device(), workgroup_size, rows, 32,