Skip to content

Commit

Permalink
add slice broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 23, 2022
1 parent ab30a15 commit 0b13930
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.3.7'
__version__ = '1.3.3.8'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
4 changes: 1 addition & 3 deletions python/jittor/src/ops/broadcast_to_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x),
bool BroadcastToOp::need_broadcast(const Var* x, const NanoVector& shape) {
if (x->shape.size() < shape.size()) return true;
for (uint i=shape.size()-1, j=x->shape.size()-1; i<shape.size(); i--,j--)
if (x->shape[j]< 0 || (x->shape[j] != shape[i] && shape[i] != 1)) return true;
if ((x->shape[j] != shape[i] && shape[i] != 1)) return true;
return false;
}

Expand Down Expand Up @@ -154,8 +154,6 @@ void BroadcastToOp::infer_shape() {
int64 zs;
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
zs = xshape * yshape;
} else if (xshape < 0 || yshape < 0) {
zs = std::min(xshape, yshape);
} else {
CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask;
zs = xshape;
Expand Down
6 changes: 3 additions & 3 deletions python/jittor/src/ops/getitem_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ void GetitemOp::infer_slices(
for (int j=0; j<niv; j++) {
auto iv_shape_j = iv_shape[niv-j-1];
auto& out_shape_j = out_shape[first_oid_of_var+var_dim-j-1];
CHECK(out_shape_j == iv_shape_j || out_shape_j == 1 || iv_shape_j == 1) << "Shape not match " >> out_shape_j >> "!="
>> iv_shape_j << "data shape:" << in_shape <<
"slice shape:" << iv_shape;
if (out_shape_j == 1)
out_shape_j = iv_shape_j;
else
ASSERT(out_shape_j == iv_shape_j || out_shape_j < 0 || iv_shape_j < 0)
<< out_shape_j << iv_shape_j << out_shape;
}
} else
if (s.is_ellipsis()) {
Expand Down
4 changes: 1 addition & 3 deletions python/jittor/src/ops/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ void ReshapeOp::infer_shape() {
CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
int64_t x_items = x->num;
auto yshape = shape;
if (x_items < 0) {
// pass if input is uncertain
} else if (uncertain_dim == 0) {
if (uncertain_dim == 0) {
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
} else {
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
Expand Down
1 change: 0 additions & 1 deletion python/jittor/src/ops/ternary_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ void TernaryOp::infer_shape() {
auto shape = std::min(xshape, std::min(yshape, cshape));
auto shape2 = std::max(xshape, std::max(yshape, cshape));
zshape.push_back(shape2);
if (shape < 0) continue;
CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape;
}
z->set_shape(zshape);
Expand Down
34 changes: 33 additions & 1 deletion python/jittor/test/test_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,39 @@ def test_dfs_memopt(self):
jt.get_max_memory_treemap()



def test_setitem_bc(self):
a = jt.random([10,11,12])
b = a[jt.arange(3)[:,None],
jt.arange(4)[None,:]]
b.sync()
assert (a[:3, :4] == b).all()

a = jt.random([10,11,12])
b = a[jt.arange(3)[:,None],
jt.arange(4)[None,:],
jt.arange(4)[None,:]]
nb = a.data[np.arange(3)[:,None],
np.arange(4)[None,:],
np.arange(4)[None,:]]
np.testing.assert_allclose(nb, b.data)

a = jt.random([10,11,12])
b = a[jt.arange(3)[::-1,None],
jt.arange(4)[None,:],
jt.arange(4)[None,:]]
nb = a.data[np.arange(3)[::-1,None],
np.arange(4)[None,:],
np.arange(4)[None,:]]
np.testing.assert_allclose(nb, b.data)

a = jt.random([10,11,12])
b = a[jt.arange(3)[::-1,None],
jt.arange(4)[None,:],
jt.arange(4)[None,::-1]]
nb = a.data[np.arange(3)[::-1,None],
np.arange(4)[None,:],
np.arange(4)[None,::-1]]
np.testing.assert_allclose(nb, b.data)



Expand Down

0 comments on commit 0b13930

Please sign in to comment.