Skip to content

Commit

Permalink
[Doc] export tutorial, TDM tuto refactoring
Browse files Browse the repository at this point in the history
ghstack-source-id: f6e0b2b6779c63948084cb607f45b64f7555c274
Pull Request resolved: #994
  • Loading branch information
vmoens committed Sep 19, 2024
1 parent 63c9982 commit 4f2a099
Show file tree
Hide file tree
Showing 16 changed files with 589 additions and 943 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ on:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
branches:
- "*"
workflow_dispatch:

concurrency:
Expand Down Expand Up @@ -90,13 +88,13 @@ jobs:
# 10. Build doc
cd ./docs
make docs
sphinx-build ./source _local_build
cd ..
cp -r docs/build/html/* "${RUNNER_ARTIFACT_DIR}"
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
echo $(ls "${RUNNER_ARTIFACT_DIR}")
if [[ ${{ github.event_name == 'pull_request' }} ]]; then
cp -r docs/build/html/* "${RUNNER_DOCS_DIR}"
cp -r docs/_local_build/* "${RUNNER_DOCS_DIR}"
fi
upload:
Expand All @@ -118,7 +116,12 @@ jobs:
REF_NAME=${{ github.ref_name }}
if [[ "${REF_TYPE}" == branch ]]; then
TARGET_FOLDER="${REF_NAME}"
if [[ "${REF_NAME}" == main ]]; then
TARGET_FOLDER="${REF_NAME}"
# Bebug:
# else
# TARGET_FOLDER="release-doc"
fi
elif [[ "${REF_TYPE}" == tag ]]; then
case "${REF_NAME}" in
*-rc*)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/fx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ We'll illustrate with an example from the overview. We create a :obj:`TensorDict
We can check that a forward pass with each module results in the same outputs.

>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict({}, []))
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict({}, []))
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ tensordict.nn
:maxdepth: 1

tutorials/tensordict_module
tutorials/tensordict_module_functional
tutorials/export

Dataloading
-----------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/saving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Here is a full example:
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict({}, [])
>>> tensordict2 = TensorDict()
>>> target_state = {
>>> "state": tensordict2
>>> }
Expand Down
6 changes: 5 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5409,6 +5409,10 @@ def keys(
):
"""Returns a generator of tensordict keys.
.. warning:: TensorDict ``keys()`` method returns a lazy view of the keys. If the ``keys``
are queried but not iterated over and then the tensordict is modified, iterating over
the keys later will return the new configuration of the keys.
Args:
include_nested (bool, optional): if ``True``, nested values will be returned.
Defaults to ``False``.
Expand Down Expand Up @@ -10276,7 +10280,7 @@ def is_tensor_collection(datatype: type | Any) -> bool:
Examples:
>>> is_tensor_collection(TensorDictBase) # True
>>> is_tensor_collection(TensorDict({}, [])) # True
>>> is_tensor_collection(TensorDict()) # True
>>> @tensorclass
... class MyClass:
... pass
Expand Down
Loading

1 comment on commit 4f2a099

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 4f2a099 Previous: 63c9982 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 76346.60716505983 iter/sec (stddev: 0.0000010065129023831037) 252820.60869254422 iter/sec (stddev: 3.700220021873688e-7) 3.31
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 76738.28975625627 iter/sec (stddev: 7.828788107185204e-7) 248184.40403706365 iter/sec (stddev: 5.509418481343613e-7) 3.23
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.1945353341693157 iter/sec (stddev: 0.3069413093573667) 2.464808774067019 iter/sec (stddev: 0.049812919910213554) 2.06

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.