Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cross entropy loss for xla autocast with FP32 precision (#7992) #8094

Merged
merged 1 commit into from
Sep 30, 2024

Conversation

avizon-aws
Copy link
Collaborator

There are many operators in XLA autocast that have been commented, but these operators are casted in the GPU, in order to maintain consistency, we need to support these operators as well. For cross_entropy_loss, it is currently commented in the xla autocast, so there will be no casting occuring, and it will execute based on its input’s dtype.

The output type is bf16, which is expected because linear layer is specified in xla autocast. loss dtype is fp32, which is correct, but there’s a catch, there was no autocasting done for the crossEntropyLoss, the reason the dtype is FP32 is because of the target’s dtype, which is FP32. There is a multiplication which happens in crossentropyloss between the generated output and the target, all the exponentiation/log etc. is done in BF16, but only because of the final multiplication, we get the result in FP32, because it casts to the higher precision (FP32). This is not the expected behavior, all the exponentiation/logs i.e. all ops related to crossentropyloss should be executed in FP32, the reason it is not happening is because crossentropyloss is not specified in xla autocast. This finding is based after detailed analysis of the HLO outputs which is attached below.

Before this change:
Exp1.

device = 'xla' # Get the XLA device (e.g., TPU or GPU)
model = torch.nn.Linear(10, 10).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
data = torch.randn(16, 10).to(torch.bfloat16).(to(device)
target = torch.randn(16, 10).to(device)
print(device, torch.__version__)
for epoch in range(1):
    optimizer.zero_grad()
    # debugpy.breakpoint() 
    with torch.autocast('xla'):
        output = model(data)
        loss = torch.nn.CrossEntropyLoss()(output, target)
        print(output.dtype, loss.dtype, target.dtype)
    # loss.backward()
    optimizer.step()
    print(f"Epoch {epoch}, Loss: {loss.item()}")

HLO:

ENTRY %SyncTensorsGraph.62 (p0.1: f32[], p1.2: f32[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (f32[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  **%constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %convert.50 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %subtract.49), metadata={op_type="aten__mul" op_name="aten__mul"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.51 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %convert.50, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}**
  %constant.52 = f32[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.58 = f32[] reduce(f32[16,10]{1,0} %multiply.51, f32[] %constant.52), dimensions={0,1}, to_apply=%AddComputation.54, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.59 = f32[] negate(f32[] %reduce.58), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.60 = f32[] divide(f32[] %negate.59, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.61 = (f32[]) tuple(f32[] %divide.60), frontend_attributes={neff_output_names="output0"}
}

Exp2.
The target dtype if also bf16 in this case. This experiment was done to prove that the dtype of the target was the true cause of the FP32 output as shown below.

Code change from previous experiment:
target = torch.randn(16, 10).to(torch.bfloat16).to(device)

HLO


ENTRY %SyncTensorsGraph.61 (p0.1: bf16[], p1.2: bf16[16,10], p2.3: f32[10], p3.12: f32[10,10], p4.22: f32[16,10]) -> (bf16[]) {
  %p4.22 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.23 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.22), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.12 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.2 = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %p3.12), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.20 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %custom-call.2), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.21 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.20), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.24 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.23, bf16[10,10]{0,1} %transpose.21), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %custom-call.3 = f32[10]{0} custom-call(f32[10]{0} %p2.3), custom_call_target="AwsNeuronTransferWithStaticRing", api_version=API_VERSION_UNSPECIFIED, metadata={op_type="xla___op_TransferWithStaticRingTransfer" op_name="xla___op_TransferWithStaticRingTransfer"}
  %convert.11 = bf16[10]{0} convert(f32[10]{0} %custom-call.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.28 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.11), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.29 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.24, bf16[16,10]{1,0} %broadcast.28), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %constant.32 = bf16[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.37 = bf16[16]{0} reduce(bf16[16,10]{1,0} %add.29, bf16[] %constant.32), dimensions={1}, to_apply=%MaxComputation.33, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.38 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %reduce.37), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.39 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %add.29, bf16[16,10]{1,0} %broadcast.38), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.40 = bf16[16,10]{1,0} exponential(bf16[16,10]{1,0} %subtract.39), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.41 = bf16[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.46 = bf16[16]{0} reduce(bf16[16,10]{1,0} %exponential.40, bf16[] %constant.41), dimensions={1}, to_apply=%AddComputation.42, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.47 = bf16[16]{0} log(bf16[16]{0} %reduce.46), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.48 = bf16[16,10]{1,0} broadcast(bf16[16]{0} %log.47), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.49 = bf16[16,10]{1,0} subtract(bf16[16,10]{1,0} %subtract.39, bf16[16,10]{1,0} %broadcast.48), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = bf16[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.50 = bf16[16,10]{1,0} multiply(bf16[16,10]{1,0} %subtract.49, bf16[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %constant.51 = bf16[] constant(0), metadata={op_type="aten__sum" op_name="aten__sum"}
  %reduce.57 = bf16[] reduce(bf16[16,10]{1,0} %multiply.50, bf16[] %constant.51), dimensions={0,1}, to_apply=%AddComputation.53, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.58 = bf16[] negate(bf16[] %reduce.57), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = bf16[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.59 = bf16[] divide(bf16[] %negate.58, bf16[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.60 = (bf16[]) tuple(bf16[] %divide.59), frontend_attributes={neff_output_names="output0"}
}

After uncommenting the CrossEntropyLoss in the XLA autocast as done in this PR:

Exp3:
The input and target are in FP32, so the output of the linear layer will be in BF16, and then it should be upcasted to FP32 for the Crossentropyloss as seen in the HLO.


  %p4.8 = f32[16,10]{1,0} parameter(4), frontend_attributes={neff_input_names="input4"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.9 = bf16[16,10]{1,0} convert(f32[16,10]{1,0} %p4.8), metadata={op_type="xla__cast" op_name="xla__cast"}
  %p3.5 = f32[10,10]{1,0} parameter(3), frontend_attributes={neff_input_names="input3"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.6 = bf16[10,10]{1,0} convert(f32[10,10]{1,0} %p3.5), metadata={op_type="xla__cast" op_name="xla__cast"}
  %transpose.7 = bf16[10,10]{0,1} transpose(bf16[10,10]{1,0} %convert.6), dimensions={1,0}, metadata={op_type="aten__permute" op_name="aten__permute"}
  %dot.10 = bf16[16,10]{1,0} dot(bf16[16,10]{1,0} %convert.9, bf16[10,10]{0,1} %transpose.7), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %p2.3 = f32[10]{0} parameter(2), frontend_attributes={neff_input_names="input2"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %convert.4 = bf16[10]{0} convert(f32[10]{0} %p2.3), metadata={op_type="xla__cast" op_name="xla__cast"}
  %broadcast.14 = bf16[16,10]{1,0} broadcast(bf16[10]{0} %convert.4), dimensions={1}, metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %add.15 = bf16[16,10]{1,0} add(bf16[16,10]{1,0} %dot.10, bf16[16,10]{1,0} %broadcast.14), metadata={op_type="aten__addmm" op_name="aten__addmm"}
  %convert.16 = f32[16,10]{1,0} convert(bf16[16,10]{1,0} %add.15), metadata={op_type="xla__cast" op_name="xla__cast"}
  %constant.17 = f32[] constant(-inf), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.22 = f32[16]{0} reduce(f32[16,10]{1,0} %convert.16, f32[] %constant.17), dimensions={1}, to_apply=%MaxComputation.18, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.23 = f32[16,10]{1,0} broadcast(f32[16]{0} %reduce.22), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.24 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %convert.16, f32[16,10]{1,0} %broadcast.23), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %exponential.25 = f32[16,10]{1,0} exponential(f32[16,10]{1,0} %subtract.24), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %constant.26 = f32[] constant(0), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %reduce.31 = f32[16]{0} reduce(f32[16,10]{1,0} %exponential.25, f32[] %constant.26), dimensions={1}, to_apply=%AddComputation.27, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %log.32 = f32[16]{0} log(f32[16]{0} %reduce.31), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %broadcast.33 = f32[16,10]{1,0} broadcast(f32[16]{0} %log.32), dimensions={0}, metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %subtract.34 = f32[16,10]{1,0} subtract(f32[16,10]{1,0} %subtract.24, f32[16,10]{1,0} %broadcast.33), metadata={op_type="aten__log_softmax" op_name="aten__log_softmax"}
  %p1.2 = f32[16,10]{1,0} parameter(1), frontend_attributes={neff_input_names="input1"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %multiply.35 = f32[16,10]{1,0} multiply(f32[16,10]{1,0} %subtract.34, f32[16,10]{1,0} %p1.2), metadata={op_type="aten__mul" op_name="aten__mul"}
  %reduce.42 = f32[] reduce(f32[16,10]{1,0} %multiply.35, f32[] %constant.26), dimensions={0,1}, to_apply=%AddComputation.38, metadata={op_type="aten__sum" op_name="aten__sum"}
  %negate.43 = f32[] negate(f32[] %reduce.42), metadata={op_type="aten__neg" op_name="aten__neg"}
  %p0.1 = f32[] parameter(0), frontend_attributes={neff_input_names="input0"}, metadata={op_type="xla__device_data" op_name="xla__device_data"}
  %divide.44 = f32[] divide(f32[] %negate.43, f32[] %p0.1), metadata={op_type="aten__div" op_name="aten__div"}
  ROOT %tuple.45 = (f32[]) tuple(f32[] %divide.44), frontend_attributes={neff_output_names="output0"}
}

@lsy323 lsy323 merged commit 940bee4 into master Sep 30, 2024
23 checks passed
@lsy323 lsy323 deleted the enable_autocast_cel branch September 30, 2024 20:41
@lsy323 lsy323 restored the enable_autocast_cel branch September 30, 2024 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants