Skip to content

Commit

Permalink
polish reindex memory optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 24, 2022
1 parent 0b13930 commit 9d899dc
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.3.8'
__version__ = '1.3.3.9'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
1 change: 0 additions & 1 deletion python/jittor/src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,6 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
root = fuse_ops[rr-1];
load_fused_op(fused_op, fuse_ops, ops, ll, rr, tt);
}
LOGvvv << "Run" << op;
for (auto* var : op->outputs()) {
var->alloc(allocator);
}
Expand Down
2 changes: 2 additions & 0 deletions python/jittor/src/ops/reindex_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ ReindexOp::ReindexOp(Var* x, NanoVector shape, vector<string>&& indexes, float64
flags.set(NodeFlags::_cuda);
set_type(OpType::broadcast);
flags.set(NodeFlags::_manual_set_vnbb);
for (auto& v : extras) v->flags.set(NodeFlags::_needed_by_backward);
y = create_output(nullptr, x->dtype());
}

Expand Down Expand Up @@ -64,6 +65,7 @@ ReindexOp::ReindexOp(Var* x, vector<Var*>&& indexes, float64 overflow_value, vec
extras = indexes;
for (uint i = 0; i < indexes.size(); ++i) {
indexes[i]->flags.set(NodeFlags::_force_fuse);
indexes[i]->flags.set(NodeFlags::_needed_by_backward);
}
}

Expand Down
2 changes: 2 additions & 0 deletions python/jittor/src/ops/reindex_reduce_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector
if (e->shape != y->shape) {
e->flags.set(NodeFlags::_stop_fuse);
}
if (op.get(NanoString::_no_need_back_in))
e->flags.set(NodeFlags::_needed_by_backward);
}
}

Expand Down
10 changes: 10 additions & 0 deletions python/jittor/test/test_reindex_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ def test_reindex_wrong_op(self):
b = jt.array([1])
c = a.reindex([8,8], ["@e0(0) // 1", "@e0(0)"], extras=[b, b])
expect_error(lambda: c.sync())

def test_reindex_memopt(self):
a = jt.zeros([10,10])
b = jt.array([1,2,3]).name("b")
c = a.reindex([8,8], ["@e0(0) / 1", "@e0(0)"], extras=[b, b])
del b
c.sync()
da = jt.grad(c, a)
da.sync()



@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
Expand Down

0 comments on commit 9d899dc

Please sign in to comment.