Releases: google/flax
Releases · google/flax
v0.9.0
What's Changed
- Add NNX surgery guide by @IvyZX in #4005
- Port gemma/transformer to NNX by @copybara-service in #4019
- upgrade python to 3.10 + use pyupgrade by @cgarciae in #4038
- [nnx] add Using Filters guide by @cgarciae in #4028
- v0.8.6 by @cgarciae in #4040
- allow imagenet training profiling to be disabled in config by @copybara-service in #4043
- [nnx] LoRAParam inherits from Param by @cgarciae in #3988
- [linen] allows multiple compact methods by @cgarciae in #3808
- Added support of NANOO fp8. by @wenchenvincent in #3993
- Add functool.wraps() annotation to flax.nn.jit. by @copybara-service in #4051
- Fix typo in
nnx_basics
doc by @rajasekharporeddy in #4047 - [nnx] fix Variable overloads and add shape/dtype properties by @cgarciae in #4049
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4039
- [nnx] stabilize unsafe_pytree by @cgarciae in #4030
- Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. by @copybara-service in #4055
- [NVIDIA] Rename fp8 custom dtype to
fp32_max_grad
by @kaixih in #3984 - [nnx] fix mnist_tutorial colab link by @cgarciae in #4063
- [nnx] fix Accuracy on eager mode by @cgarciae in #4065
- Update orbax_upgrade_guide.rst for async checkpointing usage examples by @kaushaladiti-2802 in #4036
- Re-enable some tests after Python 3.9 is dropped by @IvyZX in #4067
- Rename
nnx.compat
tonnx.bridge
by @IvyZX in #4066 - [nnx] improve mnist tutorial by @cgarciae in #4070
- Modify Flax checkpointing in preparation for cl/650338576. by @copybara-service in #4072
- Remove some outdated backward-compatibility code. by @copybara-service in #4068
- [NVIDIA] Add a user guide for fp8 by @kaixih in #4076
- [nnx] add extract APIs by @cgarciae in #4078
- [example]: remove lm1b useless parallism rules by @knightXun in #4077
- [nnx] improve filters guide by @cgarciae in #4059
- [nnx] add call by @cgarciae in #4004
- Ignore Orbax warning in deprecated
flax.training.checkpoints.py
to unbreak head doctest by @IvyZX in #4092 - fix mypy failures due tu numpy update by @cgarciae in #4098
- [linen] generalize transform caching by @copybara-service in #4057
- [linen] fold rngs on jit to improve caching by @copybara-service in #4064
- Add shape-based lazy init to
LinenToNNX
(prevLinenWrapper
) by @IvyZX in #4081 - [nnx] add reseed by @cgarciae in #4099
- [nnx] add split/merge_inputs by @cgarciae in #4084
- Perform shape checks for self.param AFTER unboxing by @danielwatson6 in #4079
- fix restore_checkpoint example in docstring by @copybara-service in #4101
- [numpy] Fix users of NumPy APIs that are removed in NumPy 2.0. by @copybara-service in #4104
- set profile_duration_ms = None as in periodic_actions there's default value for both num_profile_steps and profile_duration_ms, and the profile stopping condition is when both num_profile_steps and profile_duration_ms are satisfied, so setting profile_duration_ms=None so that the passed num_profile_steps value gets used by @copybara-service in #4096
- [linen] add share_scope by @cgarciae in #4102
- Allow metadata pass-through in flax.struct.field by @cool-RR in #4056
- avoid mixing
einsum_dot_general
andeinsum
argument by specifying them explicitly in the caller. by @copybara-service in #4115 - Add logging to track deprecated codepaths. by @copybara-service in #4121
- [pmap no rank reduce cleanup]: When flipping the by @copybara-service in #4125
- Add NNXToLinen wrapper to
nnx.bridge
by @IvyZX in #4126 - Switch NNX to use Treescope instead of Penzai. by @copybara-service in #4132
- Add GroupNorm to NNX normalization layers by @treigerm in #4095
- [nnx] fix initializing propagation by @cgarciae in #4134
- add JAX-style NNX Transforms FLIP by @cgarciae in #4108
- Fix
_ParentType
annotation by @dcharatan in #4120 - add uv.lock file by @copybara-service in #4139
- use uv package manager by @cgarciae in #4136
- More testing and misc fixes on wrappers by @IvyZX in #4137
- Fix link to orbax documentation by @cool-RR in #4123
- [nnx] experimental transforms by @cgarciae in #3963
- [nnx] improve docs by @cgarciae in #4141
- remove repeated license headers by @cgarciae in #4148
- update Flax to version 0.9.0 by @copybara-service in #4147
New Contributors
- @wenchenvincent made their first contribution in #3993
- @rajasekharporeddy made their first contribution in #4047
- @kaushaladiti-2802 made their first contribution in #4036
- @knightXun made their first contribution in #4077
- @danielwatson6 made their first contribution in #4079
- @cool-RR made their first contribution in #4056
- @treigerm made their first contribution in #4095
- @dcharatan made their first contribution in #4120
Full Changelog: v0.8.5...v0.9.0
v0.8.5
What's Changed
- v0.8.5 by @cgarciae in #3941
- [nnx] improve vmap axis size detection by @cgarciae in #3947
- Add direct penzai.treescope support for NNX objects. by @copybara-service in #3948
- [nnx] fix nnx_basics dependencies by @cgarciae in #3942
- Rename all the NNX tests to internal naming & build conventions. by @copybara-service in #3952
- updated rng guide by @chiamp in #3912
- upgraded haiku guide to include NNX by @chiamp in #3923
- parameterized NNX transforms tests by @chiamp in #3906
- Simplify extended dtypes rules part 1. Start by removing sharding specific rules from EDtypes. This is because we always want to replicate the trailing dims introduced by Edtypes. by @copybara-service in #3957
- fix HEAD by @chiamp in #3960
- Minor grammar fixes to NNX documentation. by @mcsmart76 in #3953
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3928
- Adding Welford metric. by @copybara-service in #3959
- Modify Welford metric to return mean value. by @copybara-service in #3970
- [nnx] make State generic by @cgarciae in #3964
- updated NNX nn docstrings by @chiamp in #3972
- make flax work with upcoming JAX change to tree_map (being more careful about by @copybara-service in #3976
- updated
nnx.module
docstrings by @chiamp in #3966 - updated
nnx.Conv
andnnx.ConvTranspose
by @chiamp in #3974 - updated
nnx.graph
docstrings by @chiamp in #3958 -
- Adds
pmap
andPmap
.static_broadcasted_argnums
,donate_argnums
, andglobal_arg_shapes
are not yet supported. by @copybara-service in #3978
- Adds
- Fixes for batch norm docs by @jkarwowski in #3982
- fix deprecation warning by @chiamp in #3981
- updated NNX
rnglib
docstring by @chiamp in #3980 - updated
nnx.training
by @chiamp in #3975 - updated
nnx.variables
docstrings by @chiamp in #3986 - [nnx] vectorize vmap split counts by @cgarciae in #3989
- added
wrt
option tonnx.Optimizer
by @chiamp in #3983 - Added
nnx.graph.iter_children
by @chiamp in #3991 - [nnx] fix vmap by @copybara-service in #3995
- Fix head pytest breakage by @IvyZX in #4006
- Helper function for loading params from a linen module by @copybara-service in #4012
- Port gemma/layers to NNX by @copybara-service in #4013
- [nnx] fix grad by @cgarciae in #4007
- [nnx] add PathContains Filter by @cgarciae in #4011
- Support Python 3.9 by @copybara-service in #4018
- Port gemma/modules to NNX by @copybara-service in #4014
- Internal change to fix current head CI by @copybara-service in #4017
- Unpin the Orbax pip version. by @copybara-service in #4024
- Fix Gemma test to unbreak head by @IvyZX in #4025
- Fix pickling of exceptions by @sanderland in #4002
- Call user-defined variable transforms before determining axis size in nn.vmap. by @copybara-service in #4026
- CI: add test run against oldest supported jax version by @jakevdp in #3996
- Make
force_fp32_for_softmax
arg inMultiHeadDotProductAttention
useful. by @copybara-service in #4029
New Contributors
- @mcsmart76 made their first contribution in #3953
- @jkarwowski made their first contribution in #3982
- @sanderland made their first contribution in #4002
Full Changelog: v0.8.4...v0.8.5
v0.8.4
What's Changed
- fixed codecov by @chiamp in #3895
- Make FlatState a Mapping instead of a dict by @NeilGirdhar in #3880
- Share nnx node registry between threads by @NeilGirdhar in #3901
- fixed
jnp.clip
deprecation by @chiamp in #3905 - Added three tab option to sphinx directive
codediff
and added testing for first tab by @chiamp in #3847 - Add support for
jax.sharding.PartitionSpec.UNCONSTRAINED
in logical specification by @copybara-service in #3902 - [nnx] fix mypy and pytype by @cgarciae in #3894
- [nnx] fix iter_nodes by @cgarciae in #3889
- [nnx] Sequential uses regular list by @cgarciae in #3909
- [nnx] add ConvTranspose by @cgarciae in #3908
- [nnx] add Module pytree_experimental static test by @cgarciae in #3864
- Added docstring for
Module.scope.path
by @chiamp in #3913 - [linen] test jit caching with state updates by @cgarciae in #3900
- v0.8.4 by @cgarciae in #3891
- [linen] enable separate initializers for out layer in MultiHeadDotProductAttention by @cgarciae in #3835
- [nnx] cleanup graph by @cgarciae in #3915
- [nnx] fix bugs by @cgarciae in #3925
- Replace deprecated
jax.tree_*
functions withjax.tree.*
by @copybara-service in #3926 - [nnx] Object refactor by @cgarciae in #3910
- [nnx] add iter_graph by @cgarciae in #3919
- [nnx] add compat by @cgarciae in #3921
- [nnx] transforms refactor by @cgarciae in #3927
- added equivalence test for
nnx.ConvTranspose
by @chiamp in #3934 - added equivalence test for
nnx.Sequential
by @chiamp in #3935 - [NNX] Add
LoRA
andLoRALinear
to NNX by @IvyZX in #3929 - [nnx] fix substate mutability by @cgarciae in #3932
- [nnx] improve update context by @cgarciae in #3933
- [nnx] move out of experimental by @cgarciae in #3936
Full Changelog: v0.8.3...v0.8.4
v0.8.3
What's Changed
- Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
- removed getattr/setattr unboxing magic from
nnx.Pytree
by @chiamp in #3743 - added Einsum layer to NNX by @chiamp in #3741
- Make
TrainState
'sstep
possibly jax.Array. This makesreplicate
valid for type checking. by @copybara-service in #3763 - v0.8.3 by @cgarciae in #3758
- [nnx] fix demo notebook by @cgarciae in #3744
- added nnx api reference by @chiamp in #3762
- updated rng docstring for init, apply and make_rng by @chiamp in #3765
- use note box in make_rng docstring by @cgarciae in #3767
- [nnx] improved graph update mechanism by @cgarciae in #3759
- use note box in docstrings by @chiamp in #3769
- Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
- Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
- Minor doc improvements by @canyon289 in #3588
- added MGU
reset_gate
test by @chiamp in #3773 - [nnx] Pytrees are Trees by @cgarciae in #3768
- Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
- fix tabulate on norm wrappers by @chiamp in #3772
- Add
kw_only
struct.dataclass test by @chiamp in #3651 - extended
PyTreeNode
to take dataclass kwargs by @chiamp in #3785 - [nnx] Arrays are state by @cgarciae in #3791
- [nnx] add GraphNode base class by @cgarciae in #3790
- [nnx] jit accepts many Modules by @cgarciae in #3783
- Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
- Expose
nnx.GraphNode
by @chiamp in #3796 - [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
- [nnx] TrainState uses struct by @cgarciae in #3788
- [nnx] split returns graphdef first by @cgarciae in #3794
- Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
- Add
nnx.training
by @chiamp in #3782 - [nnx] non-str State keys by @cgarciae in #3802
- [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
- [nnx] simplify readme by @cgarciae in #3805
- [nnx] Fix nnx basics by @cgarciae in #3812
- [nnx] grad accepts argnums by @cgarciae in #3798
- [nnx] improve toy examples by @cgarciae in #3813
- [nnx] expose Sequential by @cgarciae in #3814
- [nnx] Rng Variable tags by @cgarciae in #3807
- [nnx] remove copy in graph unflatten by @cgarciae in #3804
- fixed optax guide links and docstring typos by @chiamp in #3789
- added dropout broadcast test by @chiamp in #3776
- relaxed
grads
kwarg forOptimizer.update
by @chiamp in #3818 - added
tree_map
deprecation warning filter by @chiamp in #3828 - updated
tree_map
by @chiamp in #3823 - added NNX vs JAX transformations guide by @chiamp in #3819
- Updated NNX MNIST tutorial by @chiamp in #3810
- [nnx] add Dropout.rngs by @cgarciae in #3815
- removed autosummary from linen docs by @chiamp in #3792
- Fix cloudpickle sentinel cloning by @cgarciae in #3825
- [nnx] remove pytreelib by @cgarciae in #3816
- [nnx] fix nnx_basics by @cgarciae in #3839
- [linen] fix DenseGeneral init by @cgarciae in #3834
- [nnx] jit constrain object state by @cgarciae in #3817
- Copybara import of the project: by @copybara-service in #3857
- Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
- RNNCellBase refactor FLIP by @cgarciae in #3099
- [nnx] Some small documentation suggestions. by @gnecula in #3861
- updated nnx dropout by @chiamp in #3841
- Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
- Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
- added nnx api reference link by @chiamp in #3871
- option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
- allow custom dot_general for einsum. by @copybara-service in #3884
- [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
- updated
robots.txt
by @chiamp in #3886 - fixed autosummary links by @chiamp in #3887
- Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
- [nnx] v0.1 by @cgarciae in #3876
Full Changelog: v0.8.2...v0.8.3
v0.8.2
What's Changed
- Add +1 to version after 0.8.1 release by @IvyZX in #3684
- fixed rng guide outputs by @chiamp in #3685
- enforce mask kwarg in norm layers by @chiamp in #3663
- added kwargs to self.param and self.variable by @chiamp in #3675
- added nnx normalization tests by @chiamp in #3689
- added NNX init_cache docstring example by @chiamp in #3688
- added nnx attention equivalence test by @chiamp in #3687
- Fix bug that assumed frozen-dict keys were strings. by @copybara-service in #3692
- added nnx rmsnorm by @chiamp in #3691
- updated nnx compute_stats by @chiamp in #3693
- fixed intercept_methods docstring by @chiamp in #3694
- [nnx] Add Sphinx Docs by @cgarciae in #3678
- Fix pointless docstring example of nn.checkpoint / nn.remat. by @levskaya in #3703
- added default params rng to .apply by @chiamp in #3698
- [nnx] add partial_init by @cgarciae in #3674
- make make_rng default to 'params' by @chiamp in #3699
- Add SimpleCell. by @carlosgmartin in #3697
- fix Module.module_paths docstring by @cgarciae in #3709
- Guarantee the latest JAX version on CI by @cgarciae in #3705
- Replace deprecated API
jax.tree_map
by @copybara-service in #3715 - Use
jax.tree_util.tree_map
instead of deprecatedjax.tree_map
. by @copybara-service in #3714 - [nnx] simplify readme by @cgarciae in #3707
- [nnx] add demo.ipynb by @cgarciae in #3680
- Fix Tabulate's compute_flops by @cgarciae in #3721
- [nnx] simplify TraceState by @cgarciae in #3724
- Add broadcast of
strides
andkernel_dilation
tonn.ConvTranspose
by @IvyZX in #3731 - [nnx] Fix State.sub by @cgarciae in #3704
- [nnx] always fold_in on fork + new ForkedKeys return type by @cgarciae in #3722
- [nnx] explicit Variables by @cgarciae in #3720
- Improves fingerprint definition for Modules in nn.jit. by @copybara-service in #3736
- Flax: avoid key reuse in tests by @copybara-service in #3740
- added Einsum layer by @chiamp in #3710
- nn.jit: automatic fingerprint definition for dataclass attributes by @cgarciae in #3737
- [NVIDIA] Use custom grad accumulation for FP8 params by @kaixih in #3623
- removed nnx dataclass by @chiamp in #3742
- [nnx] cleanup graph_utils by @cgarciae in #3728
- Fix doctest and unbreak head by @IvyZX in #3753
- [nnx] add pytree support by @cgarciae in #3732
- fixed intercept_methods docstring by @chiamp in #3752
- Add ConvLSTMCell to docs. by @carlosgmartin in #3712
- [nnx] remove flagslib by @cgarciae in #3733
- Fix tests after applying JAX key-reuse checker. See: by @copybara-service in #3748
Full Changelog: v0.8.1...v0.8.2
Version 0.8.1
What's Changed
- bump version number to 0.8.1 by @chiamp in #3649
- Bump pillow from 10.0.1 to 10.2.0 in /examples/vae by @dependabot in #3641
- fixed docstring by @chiamp in #3643
- Add explicit control over frozen/slots setting in flax.struct.dataclass by @copybara-service in #3645
- make Sequential.call compact by @copybara-service in #3647
- add Module.module_paths by @cgarciae in #3654
- added rng_guide by @chiamp in #3497
- Replacing jax.tree_util.tree_map with mapping over leafs. by @copybara-service in #3658
- Copybara import of the project: by @copybara-service in #3659
- added InstanceNorm by @chiamp in #3652
- add Module.module_paths by @copybara-service in #3660
- added norm equivalence tests by @chiamp in #3662
- updated nowrap docstring by @chiamp in #3661
- Add module_paths method to docs by @cgarciae in #3657
- add default make_rng by @chiamp in #3669
- renamed channel_axes to feature_axes in InstanceNorm by @chiamp in #3667
- added flax.typing by @chiamp in #3624
- changed kwargs to actual key-word args by @chiamp in #3562
- updated docs and docstrings by @chiamp in #3670
- re-added linen_intro by @chiamp in #3672
- add compact_name_scope v3 by @cgarciae in #3646
- Release 0.8.1 by @IvyZX in #3682
Full Changelog: v0.8.0...v0.8.1
v0.8.0
What's Changed
- bump version number by @levskaya in #3446
- Add merge / finalize step when using OCDBT driver. Files will be first written to per-process subdirectories, which are later copied by reference to the main directory before the checkpoint is finalized. by @copybara-service in #3426
- fixed quickstart by @chiamp in #3451
- [NVIDIA] Update the algorithm to compute fp8 scales by @kaixih in #3441
- added pre-commit hook that sort imports and formats by @chiamp in #3455
- restructured doc folders by @chiamp in #3434
- Forked a subset of JAX configuration APIs by @superbobry in #3448
- Fix Module.clone in deepclone mode for internal usage. by @levskaya in #3459
- Add user-friendly module copy method. by @levskaya in #3461
- Add simple argument-only lifted nn.grad function. by @levskaya in #3463
- exempt a jax.config deprecation warning by @levskaya in #3465
- Clean up pyproject.toml. by @levskaya in #3468
- Allow for fast accumulation selection for FP8 GEMM by @wenscarl in #3416
- re-added quickstart guide by @chiamp in #3471
- fixed tabulate docstring by @chiamp in #3452
- Add NNX by @cgarciae in #3218
- Bump pillow from 9.5.0 to 10.0.1 in /examples/vae by @dependabot in #3390
- updated attention_test by @chiamp in #3454
- [nnx] Improve docs by @cgarciae in #3478
- added example docstrings by @chiamp in #3453
- fix nn.value_and_grad by implementing directly in core by @levskaya in #3479
- Add dataset loading guide (Issue #2116) by @VictorPrins in #3450
- [nnx] Add support for python container types by @cgarciae in #3486
- remove SelfAttention test and warning filter by @chiamp in #3470
- disabled ruff formatter by @chiamp in #3482
- adding doctest to .rst files by @chiamp in #3481
- changed pip installs to use quotes by @chiamp in #3477
- added enum support for tabulate by @chiamp in #3485
- fix bug in optimizer-api.md by @zhaoyang-0204 in #3462
- removed selfattention from doctest by @chiamp in #3489
- [nnx] Add missing import on why.ipynb by @cgarciae in #3503
- [nnx] switch to nested State representation by @cgarciae in #3502
- Improved Rigor of
PReLU
Test by @Micky774 in #3498 - added geglu activation and tests by @HMUNACHI in #3512
- [nnx] Add LinearGeneral and MultiHeadAttention by @cgarciae in #3487
- Add NNX/Linen consistency test for
Embed
layer by @Micky774 in #3513 - Add NNX/Linen API consistency test for
Conv
layer by @Micky774 in #3511 - Prevent crash in dataclasses with no-init params by @NeilGirdhar in #3514
- [nnx] Variable referece sharing by @cgarciae in #3516
- Added NNX/Linen API consistency test for
Linear/Dense
layer by @Micky774 in #3509 - Add missing mask argument to LayerNorm, RMSNorm, and GroupNorm. by @carlosgmartin in #3510
- [nnx] Fix graph_utils bug by @cgarciae in #3518
- remove deprecated normalize function by @chiamp in #3531
- Reduced number of parameterizations for
Conv
NNX/Linen consistency test by @Micky774 in #3526 - Ensure that
_hashable_filter
does not convert strings to a tuple of letters by @copybara-service in #3533 - added sow attention weights by @chiamp in #3529
- Fix scan out_axes by @cgarciae in #3540
- updated embed docstring by @chiamp in #3539
- add test_scan_negative_axes by @cgarciae in #3542
- add module methods to api docs by @chiamp in #3544
- fixed double backquote code font by @chiamp in #3545
- add nnx conv support for int kernel size by @chiamp in #3537
- added sow attention weights to NNX by @chiamp in #3548
- changed
return_weights
tosow_weights
for attention layer by @chiamp in #3550 - format linen_linear_test.py by @chiamp in #3553
- re-factored features arg by @chiamp in #3554
- updated NNX readme by @chiamp in #3556
- Disable ruff sort imports by @cgarciae in #3560
- Add StateVariablesMapping by @cgarciae in #3523
- add kwargs support for nn.jit by @copybara-service in #3559
- [nnx] Fix readme install instruction by @cgarciae in #3565
- implement Rng.getattr by @cgarciae in #3547
- [nnx] add qkv_features back to MHA by @cgarciae in #3566
- updated readme by @chiamp in #3563
- fixed typo by @chiamp in #3561
- Raise an error for a bad key type by @NeilGirdhar in #3527
- re-factored nnx initializers by @chiamp in #3555
- [nnx] Add complex test with scan + batchnorm + dropout by @cgarciae in #3567
- [nnx] Add interacting with JAX section to README by @cgarciae in #3573
- expose ones and zeros initializers by @chiamp in #3574
- Fix promotion bug in MultiHeadDotProductAttention: by @giovannic in #3571
- fixed error doc formatting by @chiamp in #3587
- [nnx] Improve spmd by @cgarciae in #3580
- [nnx] improve graph_utils._set_key_tuple by @cgarciae in #3592
- [nnx] Fix variable unflatten by @cgarciae in #3578
- [nnx] add open in colab button to why nnx by @cgarciae in #3596
- [nnx] Export missing symbols by @cgarciae in #3583
- [nnx] flaglib add get overloads by @cgarciae in #3582
- Fix type in NNX readme by @shoyer in #3591
- [nnx] add submodule iterator by @cgarciae in #3581
- [nnx] delete flaglib duplicated copyright comment by @cgarciae in #3600
- fixed NNX decode and dynamic slicing by @chiamp in #3576
- [nnx] cleanup CallableProxy by @cgarciae in #3608
- [nnx] improve runtime flags by @cgarciae in #3607
- fixed broken links on quick start guide by @chiamp in #3610
- added multiheadattention alias by @chiamp in #3572
- Rollback of Copybara import of the project: by @copybara-service in #3612
- add missing docs for module functions by @cgarciae in #3619
- fix lm1b data sharding by @cgarciae in #3620
- improve embed by @jianyizh in #3590
- disable ruff linter by @chiamp in #3625
- Add compact_name_scope decorator by @cgarciae in #3621
- Copybara import of the project: by @copybara-service in #3638
- added BatchApply by @chiamp in #3634
- add compact_name_scope v2 by @copybara-service in #3640
- add compact_name_scope v2 by @copybara-service in #3642
- release 0.8.0 by @chiamp in #3644
New Contributors
- @superbobry made their first contribution in #3448
- @VictorPrins made their first contribution in #3450
- @zhaoyang-0204 made their first contribution in #3462
- @Micky774 made their first contribution in #3498
- @HMUNACHI made their first contribution in #3512
- @carlosgmartin made their first contribution in #3510
- @giovannic made their first contribution in https://gith...
v0.7.5
What's Changed
- Add method-to-model section to Haiku migration guide by @IvyZX in #3277
- updated haiku guide with new JAX RNG api by @chiamp in #3343
- changed resnet v1 to v1.5 by @chiamp in #3344
- updated flax_basics by @chiamp in #3342
- Add
find
methods and magic methods for Cursor API by @chiamp in #3306 - fix DeprecationWarnings by @chiamp in #3352
- Make checkpoint guide path absolute by @IvyZX in #3358
- use jax.Array type for rng keys by @chiamp in #3354
- Disable pyink by @cgarciae in #3356
- removed DeprecationWarning filter by @chiamp in #3359
- Don't propagate default args in Tabulate by @cgarciae in #3357
- added spectral norm by @chiamp in #3335
- fixed bind-unbind bug by @chiamp in #3365
- Add MaxText open source LLM to index.rst by @8bitmp3 in #3368
- fixed typo by @chiamp in #3367
- Add fp8 custom op and unit test by @wenscarl in #3322
- fixed scope typing by @chiamp in #3371
- Trailing whitespace fixes. by @levskaya in #3373
- Make fp8 ops use explicit broadcasting. by @levskaya in #3374
- Conditonal contraints for numpy and clu by @cgarciae in #3394
- add truncated_normal to initializers by @levskaya in #3401
- fix HEAD by @chiamp in #3404
- updated imagenet readme by @chiamp in #3383
- Add Flax FAQ - how to search, how to take the derivative w.r.t. a hidden layer, remat_scan vs scan(remat), recommended training libraries/metrics by @8bitmp3 in #3267
- added rmsnorm to api docs by @chiamp in #3406
- updated docs by @chiamp in #3370
- fix HEAD by @chiamp in #3408
- added weightnorm layer by @chiamp in #3405
- added attention refactor to changelog by @chiamp in #3412
- added dropout arg to
MultiHeadDotProductAttention
by @chiamp in #3384 - fix spectralnorm layer by @chiamp in #3403
- remove pdf target by @cgarciae in #3415
- added import precommit hook by @chiamp in #3410
- fixed GRU docstring by @chiamp in #3419
- Replaces pjit with jit in spmd.py by @copybara-service in #3421
- Ignore transient chex deprecationwarning. by @levskaya in #3427
- Remove transformers dependency from docs by @cgarciae in #3431
- updated pytorch upgrade guide by @chiamp in #3432
- Simplify abstract rng creation in param shape check. by @levskaya in #3429
- [NVIDIA] Update the FP8 support by @kaixih in #3435
- fixed mypy errors at HEAD by @chiamp in #3440
- added
has_improved
field to EarlyStopping by @chiamp in #3385 - fully deprecated old RNN api by @chiamp in #3425
- updated lm1b example with jit by @chiamp in #3302
- added MGU class by @chiamp in #3418
- Make Flax Basics visible by @8bitmp3 in #3443
- update for release v0.7.5 by @levskaya in #3444
New Contributors
Full Changelog: v0.7.4...v0.7.5
v0.7.4
Version 0.7.3
What's Changed
- Fix Documentation Typo by @peterdavidfagan in #3265
- bump version number by @chiamp in #3273
- added cursor api by @chiamp in #3246
- Improve Typing support by @cgarciae in #3242
- feat: add configs for vae example by @SauravMaheshkar in #3254
- fix stackoverflow when loading pickled module by @cgarciae in #3286
- Remove string from checkpointing example and unbreak doctest in head by @IvyZX in #3288
- Improve kw_only_dataclass by @cgarciae in #3293
- updated haiku upgrade guide by @chiamp in #3292
- added logical partitioning to pjit guide by @chiamp in #3290
- Rolling back ba9e24a by @copybara-service in #3295
- Add RNN FLIP by @cgarciae in #2585
- [flax] Add module path to nn.module. by @copybara-service in #3300
- Minor fix to jax_utils.prefect_to_device by @anuragarnab in #3308
- add tf-text to doc requirements by @cgarciae in #3317
- Allow apply's method argument to accept submodules by @cgarciae in #3281
- Update handling of typed PRNG keys by @jakevdp in #3314
- 0.7.3 release by @IvyZX in #3326
New Contributors
- @peterdavidfagan made their first contribution in #3265
- @SauravMaheshkar made their first contribution in #3254
Full Changelog: v0.7.2...v0.7.3