Skip to content

Commit

Permalink
[CI] Fix windows wheels
Browse files Browse the repository at this point in the history
ghstack-source-id: e5c9fd8a8534fef623982fe435cadaf0a9c4703a
Pull Request resolved: #1006
  • Loading branch information
vmoens committed Sep 23, 2024
1 parent 85b6b81 commit 6be092f
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/build-wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,22 @@ jobs:
matrix:
include:
- repository: pytorch/tensordict
pre-script: ""
env-script: .github/scripts/version_script.bat
post-script: "python packaging/wheel/relocate.py"
smoke-test-script: test/smoke_test.py
package-name: tensordict
name: pytorch/tensordict
name: ${{ matrix.repository }}
uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main
with:
repository: ${{ matrix.repository }}
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
pre-script: ${{ matrix.pre-script }}
env-script: ${{ matrix.env-script }}
post-script: ${{ matrix.post-script }}
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-script: .github/scripts/version_script.bat
6 changes: 3 additions & 3 deletions .github/workflows/wheels-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ jobs:
python3 -mpip install wheel
TENSORDICT_BUILD_VERSION=0.5.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: tensordict-win-${{ matrix.python_version[0] }}.whl
path: dist/tensordict-*.whl
- name: Upload wheel for download
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: tensordict-batch.whl
path: dist/*.whl
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
run: |
python3 -mpip install numpy pytest pytest-cov codecov unittest-xml-reporting pillow>=4.1.1 scipy av networkx expecttest pyyaml
- name: Download built wheels
uses: actions/download-artifact@v2
uses: actions/download-artifact@v3
with:
name: tensordict-win-${{ matrix.python_version }}.whl
path: wheels
Expand Down
10 changes: 10 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,6 +1495,16 @@ def grad(self):
"""Returns a tensordict containing the .grad attributes of the leaf tensors."""
return self._grad()

@grad.setter
def grad(self, grad):
def set_grad(x, grad):
if x.grad is None:
x.grad = grad
else:
x.grad.copy_(grad)

self._fast_apply(set_grad, grad)

def zero_grad(self, set_to_none: bool = True) -> T:
"""Zeros all the gradients of the TensorDict recursively.
Expand Down

0 comments on commit 6be092f

Please sign in to comment.