From fffe72e40a09a59c98d6975a3e6708cc6515efe7 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 6 Jul 2022 18:14:52 +0530 Subject: [PATCH 1/4] Add method for `chunk` with size of chunks as kwarg --- src/utils.jl | 23 +++++++++++------------ test/utils.jl | 8 ++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index f64cbdd..de5659a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -158,28 +158,27 @@ julia> xs[1] 3 8 13 18 ``` """ -chunk(x, n::Int) = collect(Iterators.partition(x, ceil(Int, length(x) / n))) +chunk(x; size::Int) = collect(Iterators.partition(x, size)) +chunk(x, n::Int) = chunk(x; size = ceil(Int, length(x) / n)) -function chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) - idxs = _partition_idxs(x, n, dims) +function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x)) + idxs = _partition_idxs(x, size, dims) [selectdim(x, dims, i) for i in idxs] end +chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = ceil(Int, size(x, dims) / n), dims) -function _partition_idxs(x, n, dims) - bs = ceil(Int, size(x, dims) / n) - Iterators.partition(axes(x, dims), bs) -end - -function rrule(::typeof(chunk), x::AbstractArray, n::Int; dims::Int=ndims(x)) +function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x)) # this is the implementation of chunk - idxs = _partition_idxs(x, n, dims) + idxs = _partition_idxs(x, size, dims) y = [selectdim(x, dims, i) for i in idxs] valdims = Val(dims) - chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims), NoTangent()) - + chunk_pullback(dy) = (NoTangent(), ∇chunk(unthunk(dy), x, idxs, valdims)) + return y, chunk_pullback end +_partition_idxs(x, size, dims) = Iterators.partition(axes(x, dims), size) + # Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77 function ∇chunk(dys, x::AbstractArray, idxs, vd::Val{dim}) where {dim} i1 = findfirst(dy -> !(dy isa AbstractZero), dys) diff --git a/test/utils.jl b/test/utils.jl index cbc3888..679b46e 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -111,10 +111,10 @@ end n = 2 dims = 2 x = rand(4, 5) - y = chunk(x, 2) - dy = randn!.(collect.(y)) - idxs = MLUtils._partition_idxs(x, n, dims) - test_zygote(MLUtils.∇chunk, dy, x, idxs, Val(dims), check_inferred=false) + l = chunk(x, 2) + dl = randn!.(collect.(l)) + idxs = MLUtils._partition_idxs(x, ceil(Int, size(x, dims) / n), dims) + test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false) end @testset "group_counts" begin From cd96085cc9c7132536a86546dd1fb5fafb32b5f9 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Wed, 6 Jul 2022 18:25:22 +0530 Subject: [PATCH 2/4] Update docs for `chunk` --- src/utils.jl | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index de5659a..1582f64 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -122,9 +122,10 @@ unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims) """ chunk(x, n; [dims]) + chunk(x; [size, dims]) -Split `x` into `n` parts. The parts contain the same number of elements -except possibly for the last one that can be smaller. +Split `x` into `n` parts or alternatively, into equal chunks of size `size`. The parts contain +the same number of elements except possibly for the last one that can be smaller. If `x` is an array, `dims` can be used to specify along which dimension to split (defaults to the last dimension). @@ -138,6 +139,14 @@ julia> chunk(1:10, 3) 5:8 9:10 +julia> chunk(1:10; size = 2) +5-element Vector{UnitRange{Int64}}: + 1:2 + 3:4 + 5:6 + 7:8 + 9:10 + julia> x = reshape(collect(1:20), (5, 4)) 5×4 Matrix{Int64}: 1 6 11 16 @@ -156,6 +165,19 @@ julia> xs[1] 1 6 11 16 2 7 12 17 3 8 13 18 + +julia> xes = chunk(x; size = 2, dims = 2) +2-element Vector{SubArray{Int64, 2, Matrix{Int64}, Tuple{Base.Slice{Base.OneTo{Int64}}, UnitRange{Int64}}, true}}: + [1 6; 2 7; … ; 4 9; 5 10] + [11 16; 12 17; … ; 14 19; 15 20] + +julia> xes[2] +5×2 view(::Matrix{Int64}, :, 3:4) with eltype Int64: + 11 16 + 12 17 + 13 18 + 14 19 + 15 20 ``` """ chunk(x; size::Int) = collect(Iterators.partition(x, size)) From b2e514e1a519d9c8b5f843ac84fb95ef94084874 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 7 Jul 2022 09:20:25 +0530 Subject: [PATCH 3/4] Use `cld` and `fld` Add tests for `chunk` with `size` --- src/batchview.jl | 2 +- src/utils.jl | 4 ++-- test/utils.jl | 15 +++++++++++---- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/batchview.jl b/src/batchview.jl index 4b5f252..d5a1467 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -100,7 +100,7 @@ function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(no throw(ArgumentError("`collate` must be one of `nothing`, `true` or `false`.")) end E = _batchviewelemtype(data, collate) - count = partial ? ceil(Int, n / batchsize) : floor(Int, n / batchsize) + count = partial ? cld(n, batchsize) : fld(n, batchsize) BatchView{E,T,typeof(collate)}(data, batchsize, count, partial) end diff --git a/src/utils.jl b/src/utils.jl index 1582f64..b534292 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -181,13 +181,13 @@ julia> xes[2] ``` """ chunk(x; size::Int) = collect(Iterators.partition(x, size)) -chunk(x, n::Int) = chunk(x; size = ceil(Int, length(x) / n)) +chunk(x, n::Int) = chunk(x; size = cld(length(x), n)) function chunk(x::AbstractArray; size::Int, dims::Int=ndims(x)) idxs = _partition_idxs(x, size, dims) [selectdim(x, dims, i) for i in idxs] end -chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = ceil(Int, size(x, dims) / n), dims) +chunk(x::AbstractArray, n::Int; dims::Int=ndims(x)) = chunk(x; size = cld(size(x, dims), n), dims) function rrule(::typeof(chunk), x::AbstractArray; size::Int, dims::Int=ndims(x)) # this is the implementation of chunk diff --git a/test/utils.jl b/test/utils.jl index 679b46e..85ec41c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -101,9 +101,16 @@ end x = reshape(collect(1:20), (5, 4)) cs = chunk(x, 2) @test length(cs) == 2 - cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10] - cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20] - + @test cs[1] == [1 6; 2 7; 3 8; 4 9; 5 10] + @test cs[2] == [11 16; 12 17; 13 18; 14 19; 15 20] + + x = permutedims(reshape(collect(1:10), (2, 5))) + cs = chunk(x; size = 2, dims = 1) + @test length(cs) == 3 + @test cs[1] == [1 2; 3 4] + @test cs[2] == [5 6; 7 8] + @test cs[3] == [9 10] + # test gradient test_zygote(chunk, rand(10), 3, check_inferred=false) @@ -113,7 +120,7 @@ end x = rand(4, 5) l = chunk(x, 2) dl = randn!.(collect.(l)) - idxs = MLUtils._partition_idxs(x, ceil(Int, size(x, dims) / n), dims) + idxs = MLUtils._partition_idxs(x, cld(size(x, dims), n), dims) test_zygote(MLUtils.∇chunk, dl, x, idxs, Val(dims), check_inferred=false) end From 1128724941aa084ea7b5839a450610f54574cd0d Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Thu, 7 Jul 2022 14:14:21 +0530 Subject: [PATCH 4/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ea76060..8accba0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.2.9" +version = "0.2.10" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"