Skip to content

Commit

Permalink
[Doc] export tutorial, TDM tuto refactoring
Browse files Browse the repository at this point in the history
ghstack-source-id: b464acffd2fed4c483dcf163ef53a0dc4807bb91
Pull Request resolved: #994
  • Loading branch information
vmoens committed Sep 19, 2024
1 parent 63c9982 commit 6cad75a
Show file tree
Hide file tree
Showing 16 changed files with 564 additions and 942 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ on:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
branches:
- "*"
workflow_dispatch:

concurrency:
Expand Down Expand Up @@ -90,13 +88,13 @@ jobs:
# 10. Build doc
cd ./docs
make docs
sphinx-build ./source _local_build
cd ..
cp -r docs/build/html/* "${RUNNER_ARTIFACT_DIR}"
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
echo $(ls "${RUNNER_ARTIFACT_DIR}")
if [[ ${{ github.event_name == 'pull_request' }} ]]; then
cp -r docs/build/html/* "${RUNNER_DOCS_DIR}"
cp -r docs/_local_build/* "${RUNNER_DOCS_DIR}"
fi
upload:
Expand All @@ -118,7 +116,12 @@ jobs:
REF_NAME=${{ github.ref_name }}
if [[ "${REF_TYPE}" == branch ]]; then
TARGET_FOLDER="${REF_NAME}"
if [[ "${REF_NAME}" == main ]]; then
TARGET_FOLDER="${REF_NAME}"
# Bebug:
# else
# TARGET_FOLDER="release-doc"
fi
elif [[ "${REF_TYPE}" == tag ]]; then
case "${REF_NAME}" in
*-rc*)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/fx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ We'll illustrate with an example from the overview. We create a :obj:`TensorDict
We can check that a forward pass with each module results in the same outputs.

>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict({}, []))
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict({}, []))
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ tensordict.nn
:maxdepth: 1

tutorials/tensordict_module
tutorials/tensordict_module_functional
tutorials/export

Dataloading
-----------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/saving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Here is a full example:
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict({}, [])
>>> tensordict2 = TensorDict()
>>> target_state = {
>>> "state": tensordict2
>>> }
Expand Down
6 changes: 5 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5409,6 +5409,10 @@ def keys(
):
"""Returns a generator of tensordict keys.
.. warning:: TensorDict ``keys()`` method returns a lazy view of the keys. If the ``keys``
are queried but not iterated over and then the tensordict is modified, iterating over
the keys later will return the new configuration of the keys.
Args:
include_nested (bool, optional): if ``True``, nested values will be returned.
Defaults to ``False``.
Expand Down Expand Up @@ -10276,7 +10280,7 @@ def is_tensor_collection(datatype: type | Any) -> bool:
Examples:
>>> is_tensor_collection(TensorDictBase) # True
>>> is_tensor_collection(TensorDict({}, [])) # True
>>> is_tensor_collection(TensorDict()) # True
>>> @tensorclass
... class MyClass:
... pass
Expand Down
Loading

0 comments on commit 6cad75a

Please sign in to comment.