Skip to content

Releases: google/flax

v0.9.0

27 Aug 17:51
Compare
Choose a tag to compare

What's Changed

  • Add NNX surgery guide by @IvyZX in #4005
  • Port gemma/transformer to NNX by @copybara-service in #4019
  • upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
  • [nnx] add Using Filters guide by @cgarciae in #4028
  • v0.8.6 by @cgarciae in #4040
  • allow imagenet training profiling to be disabled in config by @copybara-service in #4043
  • [nnx] LoRAParam inherits from Param by @cgarciae in #3988
  • [linen] allows multiple compact methods by @cgarciae in #3808
  • Added support of NANOO fp8. by @wenchenvincent in #3993
  • Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
  • Fix typo in nnx_basics doc by @rajasekharporeddy in #4047
  • [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
  • Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
  • [nnx] stabilize unsafe_pytree by @cgarciae in #4030
  • Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
  • [NVIDIA] Rename fp8 custom dtype to fp32_max_grad by @kaixih in #3984
  • [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
  • [nnx] fix Accuracy on eager mode by @cgarciae in #4065
  • Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
  • Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
  • Rename nnx.compat to nnx.bridge by @IvyZX in #4066
  • [nnx] improve mnist tutorial by @cgarciae in #4070
  • Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
  • Remove some outdated backward-compatibility code. by @copybara-service in #4068
  • [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
  • [nnx] add extract APIs by @cgarciae in #4078
  • [example]: remove lm1b useless parallism rules by @knightXun in #4077
  • [nnx] improve filters guide by @cgarciae in #4059
  • [nnx] add call by @cgarciae in #4004
  • Ignore Orbax warning in deprecated flax.training.checkpoints.py to unbreak head doctest by @IvyZX in #4092
  • fix mypy failures due tu numpy update by @cgarciae in #4098
  • [linen] generalize transform caching by @copybara-service in #4057
  • [linen] fold rngs on jit to improve caching by @copybara-service in #4064
  • Add shape-based lazy init to LinenToNNX (prev LinenWrapper) by @IvyZX in #4081
  • [nnx] add reseed by @cgarciae in #4099
  • [nnx] add split/merge_inputs by @cgarciae in #4084
  • Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
  • fix restore_checkpoint example in docstring by @copybara-service in #4101
  • [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
  • set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
  • [linen] add share_scope by @cgarciae in #4102
  • Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
  • avoid mixing einsum_dot_general and einsum argument by specifying them explicitly in the caller. by @copybara-service in #4115
  • Add logging to track deprecated codepaths. by @copybara-service in #4121
  • [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
  • Add NNXToLinen wrapper to nnx.bridge by @IvyZX in #4126
  • Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
  • Add GroupNorm to NNX normalization layers by @treigerm in #4095
  • [nnx] fix initializing propagation by @cgarciae in #4134
  • add JAX-style NNX Transforms FLIP by @cgarciae in #4108
  • Fix _ParentType annotation by @dcharatan in #4120
  • add uv.lock file by @copybara-service in #4139
  • use uv package manager by @cgarciae in #4136
  • More testing and misc fixes on wrappers by @IvyZX in #4137
  • Fix link to orbax documentation by @cool-RR in #4123
  • [nnx] experimental transforms by @cgarciae in #3963
  • [nnx] improve docs by @cgarciae in #4141
  • remove repeated license headers by @cgarciae in #4148
  • update Flax to version 0.9.0 by @copybara-service in #4147

New Contributors

Full Changelog: v0.8.5...v0.9.0

v0.8.5

26 Jun 09:27
Compare
Choose a tag to compare

What's Changed

  • v0.8.5 by @cgarciae in #3941
  • [nnx] improve vmap axis size detection by @cgarciae in #3947
  • Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
  • [nnx] fix nnx_basics dependencies by @cgarciae in #3942
  • Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
  • updated rng guide by @chiamp in #3912
  • upgraded haiku guide to include NNX by @chiamp in #3923
  • parameterized NNX transforms tests by @chiamp in #3906
  • Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
  • fix HEAD by @chiamp in #3960
  • Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
  • Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
  • Adding Welford metric. by @copybara-service in #3959
  • Modify Welford metric to return mean value. by @copybara-service in #3970
  • [nnx] make State generic by @cgarciae in #3964
  • updated NNX nn docstrings by @chiamp in #3972
  • make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
  • updated nnx.module docstrings by @chiamp in #3966
  • updated nnx.Conv and nnx.ConvTranspose by @chiamp in #3974
  • updated nnx.graph docstrings by @chiamp in #3958
    • Adds pmap and Pmap. static_broadcasted_argnums, donate_argnums, and global_arg_shapes are not yet supported. by @copybara-service in #3978
  • Fixes for batch norm docs by @jkarwowski in #3982
  • fix deprecation warning by @chiamp in #3981
  • updated NNX rnglib docstring by @chiamp in #3980
  • updated nnx.training by @chiamp in #3975
  • updated nnx.variables docstrings by @chiamp in #3986
  • [nnx] vectorize vmap split counts by @cgarciae in #3989
  • added wrt option to nnx.Optimizer by @chiamp in #3983
  • Added nnx.graph.iter_children by @chiamp in #3991
  • [nnx] fix vmap by @copybara-service in #3995
  • Fix head pytest breakage by @IvyZX in #4006
  • Helper function for loading params from a linen module by @copybara-service in #4012
  • Port gemma/layers to NNX by @copybara-service in #4013
  • [nnx] fix grad by @cgarciae in #4007
  • [nnx] add PathContains Filter by @cgarciae in #4011
  • Support Python 3.9 by @copybara-service in #4018
  • Port gemma/modules to NNX by @copybara-service in #4014
  • Internal change to fix current head CI by @copybara-service in #4017
  • Unpin the Orbax pip version. by @copybara-service in #4024
  • Fix Gemma test to unbreak head by @IvyZX in #4025
  • Fix pickling of exceptions by @sanderland in #4002
  • Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
  • CI: add test run against oldest supported jax version by @jakevdp in #3996
  • Make force_fp32_for_softmax arg in MultiHeadDotProductAttention useful. by @copybara-service in #4029

New Contributors

Full Changelog: v0.8.4...v0.8.5

v0.8.4

24 May 17:09
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.3...v0.8.4

v0.8.3

30 Apr 09:56
Compare
Choose a tag to compare

What's Changed

  • Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
  • removed getattr/setattr unboxing magic from nnx.Pytree by @chiamp in #3743
  • added Einsum layer to NNX by @chiamp in #3741
  • Make TrainState's step possibly jax.Array. This makes replicate valid for type checking. by @copybara-service in #3763
  • v0.8.3 by @cgarciae in #3758
  • [nnx] fix demo notebook by @cgarciae in #3744
  • added nnx api reference by @chiamp in #3762
  • updated rng docstring for init, apply and make_rng by @chiamp in #3765
  • use note box in make_rng docstring by @cgarciae in #3767
  • [nnx] improved graph update mechanism by @cgarciae in #3759
  • use note box in docstrings by @chiamp in #3769
  • Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
  • Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
  • Minor doc improvements by @canyon289 in #3588
  • added MGU reset_gate test by @chiamp in #3773
  • [nnx] Pytrees are Trees by @cgarciae in #3768
  • Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
  • fix tabulate on norm wrappers by @chiamp in #3772
  • Add kw_only struct.dataclass test by @chiamp in #3651
  • extended PyTreeNode to take dataclass kwargs by @chiamp in #3785
  • [nnx] Arrays are state by @cgarciae in #3791
  • [nnx] add GraphNode base class by @cgarciae in #3790
  • [nnx] jit accepts many Modules by @cgarciae in #3783
  • Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
  • Expose nnx.GraphNode by @chiamp in #3796
  • [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
  • [nnx] TrainState uses struct by @cgarciae in #3788
  • [nnx] split returns graphdef first by @cgarciae in #3794
  • Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
  • Add nnx.training by @chiamp in #3782
  • [nnx] non-str State keys by @cgarciae in #3802
  • [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
  • [nnx] simplify readme by @cgarciae in #3805
  • [nnx] Fix nnx basics by @cgarciae in #3812
  • [nnx] grad accepts argnums by @cgarciae in #3798
  • [nnx] improve toy examples by @cgarciae in #3813
  • [nnx] expose Sequential by @cgarciae in #3814
  • [nnx] Rng Variable tags by @cgarciae in #3807
  • [nnx] remove copy in graph unflatten by @cgarciae in #3804
  • fixed optax guide links and docstring typos by @chiamp in #3789
  • added dropout broadcast test by @chiamp in #3776
  • relaxed grads kwarg for Optimizer.update by @chiamp in #3818
  • added tree_map deprecation warning filter by @chiamp in #3828
  • updated tree_map by @chiamp in #3823
  • added NNX vs JAX transformations guide by @chiamp in #3819
  • Updated NNX MNIST tutorial by @chiamp in #3810
  • [nnx] add Dropout.rngs by @cgarciae in #3815
  • removed autosummary from linen docs by @chiamp in #3792
  • Fix cloudpickle sentinel cloning by @cgarciae in #3825
  • [nnx] remove pytreelib by @cgarciae in #3816
  • [nnx] fix nnx_basics by @cgarciae in #3839
  • [linen] fix DenseGeneral init by @cgarciae in #3834
  • [nnx] jit constrain object state by @cgarciae in #3817
  • Copybara import of the project: by @copybara-service in #3857
  • Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
  • RNNCellBase refactor FLIP by @cgarciae in #3099
  • [nnx] Some small documentation suggestions. by @gnecula in #3861
  • updated nnx dropout by @chiamp in #3841
  • Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
  • Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
  • added nnx api reference link by @chiamp in #3871
  • option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
  • allow custom dot_general for einsum. by @copybara-service in #3884
  • [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
  • updated robots.txt by @chiamp in #3886
  • fixed autosummary links by @chiamp in #3887
  • Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
  • [nnx] v0.1 by @cgarciae in #3876

Full Changelog: v0.8.2...v0.8.3

v0.8.2

14 Mar 11:34
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.1...v0.8.2

Version 0.8.1

07 Feb 21:52
Compare
Choose a tag to compare

What's Changed

Full Changelog: v0.8.0...v0.8.1

v0.8.0

23 Jan 23:16
Compare
Choose a tag to compare

What's Changed

New Contributors

Read more

v0.7.5

28 Oct 02:07
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.7.4...v0.7.5

v0.7.4

13 Sep 18:37
Compare
Choose a tag to compare

What's Changed

Added python version constraint >=3.9.

  • Bump version no by 1 post-release. by @copybara-service in #3328
  • Update minimum python version info. by @levskaya in #3331

Full Changelog: v0.7.3...v0.7.4

Version 0.7.3

13 Sep 01:12
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.7.2...v0.7.3