From b9ff208910bcfe7d2857ba393c09e9f3effe710e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Sat, 21 Sep 2024 07:47:46 -0700 Subject: [PATCH] fix tests for numpy 2.0 compatibility PiperOrigin-RevId: 677208314 --- tests/jax_utils_test.py | 20 ++++++++++++-------- tests/serialization_test.py | 19 +++---------------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/tests/jax_utils_test.py b/tests/jax_utils_test.py index c46e91cbf4..c9cd9b3095 100644 --- a/tests/jax_utils_test.py +++ b/tests/jax_utils_test.py @@ -17,12 +17,12 @@ from functools import partial from absl.testing import absltest +from absl.testing import parameterized import chex +from flax import jax_utils import jax +import jax.numpy as jnp import numpy as np -from absl.testing import parameterized - -from flax import jax_utils NDEV = 4 @@ -44,6 +44,7 @@ def test_basics(self, dtype, bs): # Just tests that basic calling works without exploring caveats. @partial(jax_utils.pad_shard_unpad, static_argnums=()) def add(a, b): + b = jnp.asarray(b, dtype=dtype) return a + b x = np.arange(bs, dtype=dtype) @@ -58,7 +59,7 @@ def test_trees(self, dtype, bs): def add(a, b): return a['a'] + b[0] - x = np.arange(bs, dtype=dtype) + x = jnp.arange(bs, dtype=dtype) y = add(dict(a=x), (10 * x,)) chex.assert_type(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @@ -69,12 +70,13 @@ def test_min_device_batch_avoids_recompile(self, dtype): @jax.jit @chex.assert_max_traces(n=1) def add(a, b): + b = jnp.asarray(b, dtype=dtype) return a + b chex.clear_trace_counter() for bs in self.BATCH_SIZES: - x = np.arange(bs, dtype=dtype) + x = jnp.arange(bs, dtype=dtype) y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg chex.assert_type(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x)) @@ -83,9 +85,9 @@ def add(a, b): def test_static_argnum(self, dtype, bs): @partial(jax_utils.pad_shard_unpad, static_argnums=(1,)) def add(a, b): - return a + b + return a + jnp.asarray(b, dtype=dtype) - x = np.arange(bs, dtype=dtype) + x = jnp.arange(bs, dtype=dtype) y = add(x, 10) chex.assert_type(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(x + 10)) @@ -96,9 +98,11 @@ def test_static_argnames(self, dtype, bs): # test the default/most canonical path where `params` are the first arg. @partial(jax_utils.pad_shard_unpad, static_argnames=('b',)) def add(params, a, *, b): + params = jnp.asarray(params, dtype=dtype) + b = jnp.asarray(b, dtype=dtype) return params * a + b - x = np.arange(bs, dtype=dtype) + x = jnp.arange(bs, dtype=dtype) y = add(5, x, b=10) chex.assert_type(y.dtype, x.dtype) np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10)) diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 40e1a1100d..712358f7ed 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -197,8 +197,7 @@ def __call__(self): ) self.assertEqual(variables, deserialized_state) - @parameterized.parameters( - [ + @parameterized.parameters([ 'byte', 'b', 'ubyte', @@ -222,11 +221,9 @@ def __call__(self): 'd', 'longdouble', 'g', - 'cfloat', 'cdouble', 'clongdouble', 'm', - 'bool8', 'b1', 'int64', 'i8', @@ -259,26 +256,16 @@ def __call__(self): 'i1', 'uint8', 'u1', - 'complex_', - 'int0', - 'uint0', + 'uint', 'single', 'csingle', - 'singlecomplex', - 'float_', 'intc', 'uintc', - 'int_', - 'longfloat', - 'clongfloat', - 'longcomplex', - 'bool_', 'int', 'float', 'complex', 'bool', - ] - ) + ]) def test_numpy_serialization(self, dtype): np.random.seed(0) if (