Skip to content

Commit

Permalink
Merge pull request #3444 from levskaya:rc7.5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577359040
  • Loading branch information
Flax Authors committed Oct 28, 2023
2 parents 9d6e7d8 + 332e723 commit 791e7da
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
35 changes: 22 additions & 13 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,43 @@ Changelog
vNext
------
(Add your change to a random empty line to avoid merge conflicts)
- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.
-
-
-
-
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
-
-
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [#3385](https://github.com/google/flax/pull/3385)
-
-
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
-
-
-
-
-
-
-
-
-
-
-
- NOTE: Remember to bump version number to 0.8.0

0.7.3
0.7.5
-----
- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [#3389](https://github.com/google/flax/discussions/3389).
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [#3385](https://github.com/google/flax/pull/3385)

0.7.4
-----
New features:
- Add QK-normalization to MultiHeadDotProductAttention
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ To cite this repository:
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.7.4},
version = {0.7.5},
year = {2023},
}
```
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dependencies = [
"numpy>=1.22",
"numpy>=1.23.2; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
"jax>=0.4.11",
"jax>=0.4.19",
"msgpack",
"optax",
"orbax-checkpoint",
Expand Down

0 comments on commit 791e7da

Please sign in to comment.