Skip to content

Commit

Permalink
[Doc] export tutorial, TDM tuto refactoring
Browse files Browse the repository at this point in the history
ghstack-source-id: 695b419d6a7a7b504462152a29b2d5e1b9e60843
Pull Request resolved: #994
  • Loading branch information
vmoens committed Sep 19, 2024
1 parent 63c9982 commit c7fb85b
Show file tree
Hide file tree
Showing 16 changed files with 556 additions and 942 deletions.
15 changes: 9 additions & 6 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ on:
- v[0-9]+.[0-9]+.[0-9]
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
pull_request:
branches:
- "*"
workflow_dispatch:

concurrency:
Expand Down Expand Up @@ -90,13 +88,13 @@ jobs:
# 10. Build doc
cd ./docs
make docs
sphinx-build ./source _local_build
cd ..
cp -r docs/build/html/* "${RUNNER_ARTIFACT_DIR}"
cp -r docs/_local_build/* "${RUNNER_ARTIFACT_DIR}"
echo $(ls "${RUNNER_ARTIFACT_DIR}")
if [[ ${{ github.event_name == 'pull_request' }} ]]; then
cp -r docs/build/html/* "${RUNNER_DOCS_DIR}"
cp -r docs/_local_build/* "${RUNNER_DOCS_DIR}"
fi
upload:
Expand All @@ -118,7 +116,12 @@ jobs:
REF_NAME=${{ github.ref_name }}
if [[ "${REF_TYPE}" == branch ]]; then
TARGET_FOLDER="${REF_NAME}"
if [[ "${REF_NAME}" == main ]]; then
TARGET_FOLDER="${REF_NAME}"
# Bebug:
# else
# TARGET_FOLDER="release-doc"
fi
elif [[ "${REF_TYPE}" == tag ]]; then
case "${REF_NAME}" in
*-rc*)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/fx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ We'll illustrate with an example from the overview. We create a :obj:`TensorDict
We can check that a forward pass with each module results in the same outputs.

>>> tensordict = TensorDict({"input": torch.randn(32, 100)}, [32])
>>> module_out = module(tensordict, tensordict_out=TensorDict({}, []))
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict({}, []))
>>> module_out = module(tensordict, tensordict_out=TensorDict())
>>> graph_module_out = graph_module(tensordict, tensordict_out=TensorDict())
>>> assert (
... module_out["outputs", "logits"] == graph_module_out["outputs", "logits"]
... ).all()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ tensordict.nn
:maxdepth: 1

tutorials/tensordict_module
tutorials/tensordict_module_functional
tutorials/export

Dataloading
-----------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/saving.rst
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ Here is a full example:
>>> snapshot = torchsnapshot.Snapshot.take(app_state=state, path=path)
>>> # later
>>> snapshot = torchsnapshot.Snapshot(path=path)
>>> tensordict2 = TensorDict({}, [])
>>> tensordict2 = TensorDict()
>>> target_state = {
>>> "state": tensordict2
>>> }
Expand Down
2 changes: 1 addition & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10276,7 +10276,7 @@ def is_tensor_collection(datatype: type | Any) -> bool:
Examples:
>>> is_tensor_collection(TensorDictBase) # True
>>> is_tensor_collection(TensorDict({}, [])) # True
>>> is_tensor_collection(TensorDict()) # True
>>> @tensorclass
... class MyClass:
... pass
Expand Down
Loading

0 comments on commit c7fb85b

Please sign in to comment.