Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2.5 jax stable version #8060

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
release/2.5
6 changes: 3 additions & 3 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def test_tpu_custom_call_pallas_flash_attention(self):
# This payload is generated by the following Pallas code:
# https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# To be noted, set `jax.config.update('jax_default_matmul_precision', 'highest')`` before generating the payload.
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTkuMC4wZ2l0AAFBDQEDBQcJCwEDDQMFAw8FAxEHBxMVFwkLGRsdHyELAyMDrgI+AhsB8wcTCwsPEwsPDxMLCwsLkwsTCw8TDwsLCwsPCwsLDw8LCw8LDw8PDxcTE0MLGwvFC5MLCwsLGxsLGwsbCxsLGxsbGw8PDw8XDwsXDw8LFw8PCxcPDwsXDwsTCw8PFxMfCw8PFyMPEx8LDxcbDw8LDxcLDwsTHwsPFxsFCY15kWEHA1kJBV1JAR8PCxMTFxMTFxcfCxMXIwsBGw8HKx8bBxcjDwsbLy8CYg0fAwMNhwUlBScVj5UdOgJTBSkdI4kdI7UdIxYCBSsFLQUvBTEjEQlBAQAAAAAAAAABAAAAAAAAAIAAAAAAAAAABAAAAAAAAAANGQMDDYUFMxETAAMD4fsREQEFNQU3BTkFOx2/wQU9BT8FQR3PPRXRCQVDBUUBA9cFRx3bSRXdCR3rTRXtCR0GAgoCHSoCUxUuAgkDD1dZFVtfYWMpZSkXZ2lrBUkBCfPz8/cNF2FmZmluZV9tYXA8KGQwLCBkMSwgZDIsIGQzKSAtPiAoZDAsIGQxLCBkMiwgZDMpPgAFSyMRCUEDAAAAAAAAAAIAAAAAAAAAAQAAAAAAAAABAAAAAAAAAAVNBU8FUQVTAQltcXV5AwUZbxsdCSsDBRlzGx0JLQMFGXcbHQkvAwUZexsdCTEDBRUfFysDBRUfFy0DBRUfFy8DBRUfFzERAQERAwEViwkdB40XBRoIAR2RkwVVFwVKBQEVl50dmZsFVxcFqgsBFZ+lHaGjBVkXBWIDARWnrR2pqwVbFwUaAwEdr7EFXRezZQEFXxW3CR0HuRcFHggBAwMNvSUHCQAAAAAFYRXDCR0HxRcFIggBAwc19TclOckREwEDAw3NJQ0JAACA/wVjHQfTFwW2CAEDBT/9QUMREQUdRT0FZR0H3xcFuggBBWcd5UkFaQMDDeklDQkAAAAABWsdB+8XBb4IAQMFP/9BQyN0cHUuZGltZW5zaW9uX3NlbWFudGljczxwYXJhbGxlbD4AI3RwdS5jb250cmFjdF9wcmVjaXNpb248ZnAzMj4AI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPGFyYml0cmFyeT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI3ZlY3Rvci5raW5kPGFkZD4AHUVNBW0VDgIJHQcSAhcFwggBFRoCCR0HHgIXBd4IAQMDDSYCJQkJAAAAAAVvHQcyAhcF4ggBAwc19TclOSUFcQECAgMX+QkFBQIEEQtdJwUCBAIECycFAgQRCwsnAwIECycJBQUCBBELAQIEAQknBQIEBQsFEQEBAQEFBQUFAQUJAQEBAQkBAQEBBEIHBQEQAQcDARUDEQFVBwNhqxEBAQEBAQEBAQUBBQEFAQUBCQMPAwMDCQMPAwMDCQMPAwMDCQMPAwMDEQYPAw8LCRETFRcPBg8DCQMZCQMRAwMDCQMRAwMDCQMRAwMDCQMRAwMDEQYRAw8LCx0fISMPBhEDCQMlCQMzuwMHBwczxwMHBxsnKQkDO8sDDRMHO9UDDQUrLQ8G2QMVAy8VBkcDBwMxCwdHJwMHBSszGQfjJwMHAzUJA0vnAw0TB0vxAw0FNzkPBgICAxUDOxUGTwMHAz0NB08nAwcFNz8JAxMDAwMJAxMDAwMJAxMDAwMJAxMDAwMRBhMDDwsNQ0VHSQ8GEwMJA0sJA1EiAgMJBwdRNgIDCQdBTU8JAwsDAwMJAwsDAwMJAwsDAwMJAwsDAwMRBgsDDwsPU1VXWQ8GCwMJA1sPBgsDDwNRFwQLDV8PU1VXWQUAAQMRAX0HAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkDEQF/BwMLCwkBAQEBAQEBAQkDASEDAQUEAQkBAwcJAxEBgQcDCwsJAQEBAQEBAQEJAwEhAwEFBAEJAQMHCQMRAYMHAwsLCQEBAQEBAQEBCQMBIQMBBQQBCQEDBQkGAwEFAQDuFnOGAk4CCy8LEwsvTgJTEyEjLTEdCyMhIyl5HwsdHRUZGRkZggIdJRMdDWPHCQ0VIQsXCwsTDw8PCw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbWF0aABtb2R1bGUAcmV0dXJuAG1hdG11bABjb25zdGFudABzdWJmAGRpdmYAc2hhcGVfY2FzdABsb2FkAG11bHRpX3JlZHVjdGlvbgBicm9hZGNhc3QAc3RvcmUAZXhwAC9ob21lL2p3dGFuLy5sb2NhbC9saWIvcHl0aG9uMy4xMC9zaXRlLXBhY2thZ2VzL2pheC9leHBlcmltZW50YWwvcGFsbGFzL29wcy90cHUvZmxhc2hfYXR0ZW50aW9uLnB5AF9mbGFzaF9hdHRlbnRpb25fa2VybmVsX3NpbmdsZV9iYXRjaF9zaW5nbGVfc3RlcAB2YWx1ZQBmdW5jdGlvbl90eXBlAHN5bV9uYW1lAHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL2dldFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAHRyYW5zZm9ybV8wAHRyYW5zZm9ybV8xAHRyYW5zZm9ybV8yAHRyYW5zZm9ybV8zAHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBkaW1lbnNpb25fc2VtYW50aWNzAGl0ZXJhdGlvbl9ib3VuZHMAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAbWFpbgB3aW5kb3dfcGFyYW1zAF9mbGFzaF9hdHRlbnRpb25fa2VybmVsAF9mbGFzaF9hdHRlbnRpb25faW1wbABfZmxhc2hfYXR0ZW50aW9uAGZsYXNoX2F0dGVudGlvbgA8bW9kdWxlPgAvbW50L2Rpc2tzL3NzZC93b3JrL3BhbGxhcy9wYWxsYXNfYWRkLnB5AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPSg8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+LCA8UHJlY2lzaW9uLkhJR0hFU1Q6IDI+KSBwcmVmZXJyZWRfZWxlbWVudF90eXBlPWZsb2F0MzJdAC9yZWR1Y2VfbWF4W2F4ZXM9KDEsKV0AL3N1YgBmYXN0bWF0aAAvZXhwAC9yZWR1Y2Vfc3VtW2F4ZXM9KDEsKV0AL2RpdgAvZG90X2dlbmVyYWxbZGltZW5zaW9uX251bWJlcnM9KCgoMSwpLCAoMCwpKSwgKCgpLCAoKSkpIHByZWNpc2lvbj0oPFByZWNpc2lvbi5ISUdIRVNUOiAyPiwgPFByZWNpc2lvbi5ISUdIRVNUOiAyPikgcHJlZmVycmVkX2VsZW1lbnRfdHlwZT1mbG9hdDMyXQAvc3dhcFt0cmVlPVB5VHJlZURlZigoQ3VzdG9tTm9kZShOREluZGV4ZXJbKFB5VHJlZURlZigoKiwgKiwgQ3VzdG9tTm9kZShTbGljZVsoMCwgMTI4KV0sIFtdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCldLCBbXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"needs_layout_passes\": true}}"

payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMjAuMC4wZ2l0AAEvCwEDBQcJAQMLAxkNDxETFRcZGx0fISMD0gJaAhsB9QcTCwsPExMLDw8TCwsLC5MLCw8TDwsLCwsLDwsLCw8PCwsPCw8PExMXExMTCw9DCxsLxQuTCwsLCxsbCxsLGwsbCxsbGxsPDw8PFw8LFw8PCxcPDwsXDw8LFw8PCxcPCxcPDxcTHwsPDxcjDxMfCw8XGw8PCw8XCw8LBQmNeZFhBwNdCQNZASsXHwsTFx8PCxMTFxMTFxcfCxMXIwsHA0kBGw8HKx8bBxcPIwsbLy8C0g0fAwMPjwUlBScVl50DAw+NHVICVQUpHSORHSPDHSMuAgUrBS0FLwUxIw8JQQEAAAAAAAAAAQAAAAAAAACAAAAAAAAAAAQAAAAAAAAADRkFMxETAAMD7/8RDwEFNQU3BTkFOwU9Hc3PBT8FQQVDHd0/Fd8JBUUFRwED5QVJHelLFesJHQoCTxUOAgkdHgIiAh1CAlUVRgIJAwNZWwVLEQ8JAw9fYRdjZ2lrKW0pGW9xcwVNAQn19fX5DRdhZmZpbmVfbWFwPChkMCwgZDEsIGQyLCBkMykgLT4gKGQwLCBkMSwgZDIsIGQzKT4ABU8jDwlBAwAAAAAAAAACAAAAAAAAAAEAAAAAAAAAAQAAAAAAAAAFUQVTBVUFVwEJdXl9gQMFG3cdHwkrAwUbex0fCS0DBRt/HR8JLwMFG4MdHwkxAwUXIRkrAwUXIRktAwUXIRkvAwUXIRkxEQEBEQMBFZMJHQeVFwUGCAEdmZsFWRcFSgUBFZ+lHaGjBVsXBYYLARWnrR2pqwVdFwViAwEVr7UdsbMFXxcFGgMBFbe9Hbm7BWEXM14DAR2/wQVjFzM2EAEVxQkdB8cXBQoIAQMDD8slBwkAAAAABWUV0QkdB9MXBQ4IAQMHN/c5JTvXERMBAwMP2yUNCQAAgP8FZx0H4RcFoggBAwVB/UNFEQ8FHUc/BWkdB+0XBaYIAQVrHfNLBW0jdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8cGFyYWxsZWw+ACN0cHUuY29udHJhY3RfcHJlY2lzaW9uPGZwMzI+ACN0cHUuZGltZW5zaW9uX3NlbWFudGljczxhcmJpdHJhcnk+ACN0cHUubWVtb3J5X3NwYWNlPHZtZW0+ACN2ZWN0b3Iua2luZDxtYXhpbXVtZj4AI2FyaXRoLmZhc3RtYXRoPG5vbmU+AAMDDwYCJQ0JAAAAAAVvHQcSAhcFqggBAwVBVgJDRR1HTwVxFSYCCR0HKgIXBa4IARUyAgkdBzYCFwXKCAEDAw8+AiUJCQAAAAAFcx0HSgIXBc4IAQMHN/c5JTslBXUjdmVjdG9yLmtpbmQ8YWRkPgABAgIDF/sJBQUCBBELZScFAgQCBAsnBQIEEQsLJwMCBAsBAgQnCQUFAgQRCwEJJwUCBAULBREBAQEBBQUFBQEFCQEBAQEJAQEBAQSuBwUBEQFXBwMBFQcRAV0HA2GrEQEBAQEBAQEBBQEFAQUBBQEDAxEDAwMDAxEDAwMDAxEDAwMDAxEDAwMLBhEDEQsJERMVFwUGEQMJAxkDAxMDAwMDAxMDAwMDAxMDAwMDAxMDAwMLBhMDEQsLHR8hIwUGEwMJAyUDAzXJAwcNBzXVAwcHGycpAwM92QMNDwc94wMNBSstBQbnAxUDLxEGSQMHAzETB0knAwcFKzMVB/EnAwcDNQMDTQICAw0PB00WAgMNBTc5BQYaAgMVAzsRBlEDBwM9FwdRJwMHBTc/AwMVAwMDAwMVAwMDAwMVAwMDAwMVAwMDCwYVAxELDUNFR0kFBhUDCQNLAwNTOgIDCQ0HU04CAwkHQU1PAwMNAwMDAwMNAwMDAwMNAwMDAwMNAwMDCwYNAxELD1NVV1kFBg0DCQNbBQYNAxEDURkEDQ1fD1NVV1kJAAEHEQGFBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBxEBhwcDDQ8JAQEBAQEBAQEDAwELAwEDAwELAwEJBAEJAQMHCQcRAYkHAw0PCQEBAQEBAQEBAwMBCwMBAwMBCwMBCQQBCQEDBwkHEQGLBwMNDwkBAQEBAQEBAQMDAQsDAQMDAQsDAQkEAQkBAwUJBgMBBQEAMhp37gImAgsvCxMLLyYCE2MhIy0xHQsjISMpLXkfCx0dFVkZGRkZ6gIdJRMdDWPvGxcTFyMvFxkZFSUfDw0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQB2ZWN0b3IAYXJpdGgAbW9kdWxlAGFyaXRoLmNvbnN0YW50AHZlY3Rvci5zaGFwZV9jYXN0AGZ1bmMuZnVuYwBmdW5jLnJldHVybgB2ZWN0b3IubG9hZAB0cHUubWF0bXVsAHZlY3Rvci5tdWx0aV9yZWR1Y3Rpb24AdmVjdG9yLmJyb2FkY2FzdABhcml0aC5zdWJmAG1hdGguZXhwAGFyaXRoLmRpdmYAdmVjdG9yLnN0b3JlAC9ob21lL2JiYWhsL21pbmljb25kYTMvZW52cy90b3JjaHNlcDEwL2xpYi9weXRob24zLjEwL3NpdGUtcGFja2FnZXMvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9mbGFzaF9hdHRlbnRpb24ucHkAX2ZsYXNoX2F0dGVudGlvbl9rZXJuZWxfc2luZ2xlX2JhdGNoX3NpbmdsZV9zdGVwAHZhbHVlAGZ1bmN0aW9uX3R5cGUAc3ltX25hbWUAdHJhbnNmb3JtX2luZGljZXMAd2luZG93X2JvdW5kcwAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKCgqLCAqLCBDdXN0b21Ob2RlKFNsaWNlWygwLCAxMjgsIDEpXSwgW05vbmUsIE5vbmVdKSwgQ3VzdG9tTm9kZShTbGljZVsoMCwgNCwgMSldLCBbTm9uZSwgTm9uZV0pKSksICgxLCAxLCAxMjgsIDQpLCAoKSldLCBbKiwgKl0pLCkpXQB0cmFuc2Zvcm1fMAB0cmFuc2Zvcm1fMQB0cmFuc2Zvcm1fMgB0cmFuc2Zvcm1fMwAvaG9tZS9iYmFobC9weXRvcmNoL3hsYS90ZXN0L3Rlc3RfcGFsbGFzLnB5AHByZWNpc2lvbgB0cmFuc3Bvc2VfbGhzAHRyYW5zcG9zZV9yaHMAa2luZAByZWR1Y3Rpb25fZGltcwAvYnJvYWRjYXN0X2luX2RpbVtzaGFwZT0oMTI4LCAxKSBicm9hZGNhc3RfZGltZW5zaW9ucz0oMCwpXQBzdGFibGVfbW9zYWljLnZlcnNpb24AZGltZW5zaW9uX3NlbWFudGljcwBpdGVyYXRpb25fYm91bmRzAHNjYWxhcl9wcmVmZXRjaABzY3JhdGNoX29wZXJhbmRzAG1haW4Ad2luZG93X3BhcmFtcwBfZmxhc2hfYXR0ZW50aW9uX2tlcm5lbABfZmxhc2hfYXR0ZW50aW9uX2ltcGwAX2ZsYXNoX2F0dGVudGlvbgBmbGFzaF9hdHRlbnRpb24AdGVzdF90cHVfY3VzdG9tX2NhbGxfcGFsbGFzX3dyYXBfZmxhc2hfYXR0ZW50aW9uADxtb2R1bGU+AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgxLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3JlZHVjZV9tYXhbYXhlcz0oMSwpXQAvc3ViAGZhc3RtYXRoAC9leHAAL3JlZHVjZV9zdW1bYXhlcz0oMSwpXQAvZGl2AC9kb3RfZ2VuZXJhbFtkaW1lbnNpb25fbnVtYmVycz0oKCgxLCksICgwLCkpLCAoKCksICgpKSkgcHJlY2lzaW9uPShQcmVjaXNpb24uSElHSEVTVCwgUHJlY2lzaW9uLkhJR0hFU1QpIHByZWZlcnJlZF9lbGVtZW50X3R5cGU9ZmxvYXQzMl0AL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKCosICosIEN1c3RvbU5vZGUoU2xpY2VbKDAsIDEyOCwgMSldLCBbTm9uZSwgTm9uZV0pLCBDdXN0b21Ob2RlKFNsaWNlWygwLCA0LCAxKV0sIFtOb25lLCBOb25lXSkpKSwgKDEsIDEsIDEyOCwgNCksICgpKV0sIFsqLCAqXSksKSldAA==\", \"serialization_format\": 1, \"needs_layout_passes\": true}, \"implicit_sharding\": {\"type\": \"MANUAL\"}}"
# The division is to cause potential precision issue on TPU.
q_mini = torch.arange(128 * 4, dtype=torch.float32).reshape(128, 4) / 13
k_mini = torch.arange(
Expand Down Expand Up @@ -209,7 +208,8 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):

o = flash_attention_kernel(q, k, v)
expected_o = self._attention(q, k, v)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
torch.testing.assert_close(o.cpu(), expected_o.cpu())
# self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
Expand Down
8 changes: 8 additions & 0 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def _setup_libtpu_flags():
# improves device memory usage.
flags = _set_missing_flags(
flags, (('xla_tpu_prefer_async_allgather_to_allreduce', 'true'),))

# This flag enables FlashAttention HLO pass that pattern matches attention
# and rewrites it as flash attention. This pattern matching is causing
# issues for our standard dot product attention. Turning it off till
# we fix the issue with pattern matching.
flags = _set_missing_flags(
flags, (('xla_tpu_enable_flash_attention', 'false'),)
)

if tpu.version() == 5:
default_v5_flags = {
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@
from . import xla_model as this_module

xrt_world_size = deprecated(this_module, torch_xla.runtime.world_size,
'xrt_world_size() will be removed in release 2.6.')
'xrt_world_size() will be removed in release 2.7.')
get_ordinal = deprecated(
this_module, torch_xla.runtime.global_ordinal,
'xla_model.get_ordinal() will be removed in release 2.6.')
'xla_model.get_ordinal() will be removed in release 2.7.')
parse_xla_device = deprecated(
this_module, _utils.parse_xla_device,
'xla_model.parse_xla_device() will be removed in release 2.6.')
'xla_model.parse_xla_device() will be removed in release 2.7.')


class DeviceContext(object):
Expand Down
Loading