Skip to content

Commit

Permalink
[GPU][Fix] Fix softmax acc issue. (#2674)
Browse files Browse the repository at this point in the history
  • Loading branch information
cboss6 committed Apr 15, 2024
1 parent a709138 commit 8f96635
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions itex/core/kernels/gpu/softmax_op_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ inline Status SoftmaxWorkgroupSMemImpl(const GPUDevice& device,
const int num_packs = cols / pack_size;

stream->submit([&](sycl::handler& h) {
__shared__<unsigned char> scratch(sycl::range<1>(workgroup_size), h);
__shared__<unsigned char> scratch(
sycl::range<1>(cols * sizeof(ComputeType)), h);
SoftmaxWorkgroupSMemImplKernel<LOAD, STORE, ComputeType, pack_size,
algorithm>
task(scratch, rows, cols, workgroup_size, num_packs, device_load,
Expand All @@ -552,7 +553,10 @@ inline Status LaunchSoftmaxWorkGroupSMemImpl(const GPUDevice& device,
STORE device_store,
const int32 rows,
const int32 cols) {
int workgroup_size = 128;
int workgroup_size =
device.stream()
->get_device()
.template get_info<sycl::info::device::max_work_group_size>();
sycl::range<1> local_range(workgroup_size);
int num_wg;
GetNumWorkGroups(device.stream()->get_device(), workgroup_size, rows, 32,
Expand Down

0 comments on commit 8f96635

Please sign in to comment.