Skip to content

Commit

Permalink
Return of the threaded detection (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcMush committed Aug 4, 2024
1 parent e60297b commit cd0414f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 15 deletions.
17 changes: 15 additions & 2 deletions src/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ Base.@kwdef mutable struct ProgressCore
numprintedvalues::Int = 0 # num values printed below progress in last iteration
prev_update_count::Int = 1 # counter at last update
printed::Bool = false # true if we have issued at least one status update
safe_lock::Bool = Threads.nthreads() > 1 # set to false for non-threaded tight loops
safe_lock::Int = 2*(Threads.nthreads()>1) # 0: no lock, 1: lock, 2: detect
thread_id::Int = Threads.threadid() # id of the thread that created the progressmeter
tinit::Float64 = time() # time meter was initialized
tlast::Float64 = time() # time of last update
tsecond::Float64 = time() # ignore the first loop given usually uncharacteristically slow
Expand Down Expand Up @@ -448,8 +449,20 @@ end

predicted_updates_per_dt_have_passed(p::AbstractProgress) = p.counter - p.prev_update_count >= p.check_iterations

function is_threading(p::AbstractProgress)
p.safe_lock == 0 && return false
p.safe_lock == 1 && return true
if p.thread_id != Threads.threadid()
lock(p.lock) do
p.safe_lock = 1
end
return true
end
return false
end

function lock_if_threading(f::Function, p::AbstractProgress)
if p.safe_lock
if is_threading(p)
lock(p.lock) do
f()
end
Expand Down
90 changes: 77 additions & 13 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ for ns in [1, 9, 10, 99, 100, 999, 1_000, 9_999, 10_000, 99_000, 100_000, 999_99
end

# Performance test (from #171, #323)
function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=false)
function prog_perf(n; dt=0.1, enabled=true, force=false, safe_lock=0)
prog = Progress(n; dt, enabled, safe_lock)
x = 0.0
for i in 1:n
Expand All @@ -43,38 +43,85 @@ function noprog_perf(n)
return x
end

function prog_threaded(n; dt=0.1, enabled=true, force=false, safe_lock=2)
prog = Progress(n; dt, enabled, safe_lock)
x = Threads.Atomic{Float64}(0.0)
Threads.@threads for i in 1:n
Threads.atomic_add!(x, rand())
next!(prog; force)
end
return x
end

function noprog_threaded(n)
x = Threads.Atomic{Float64}(0.0)
Threads.@threads for i in 1:n
Threads.atomic_add!(x, rand())
end
return x
end

println("Performance tests...")

#precompile
noprog_perf(10)
prog_perf(10)
prog_perf(10; safe_lock=true)
prog_perf(10; dt=9999)
prog_perf(10; safe_lock=1)
prog_perf(10; dt=9999.9)
prog_perf(10; enabled=false)
prog_perf(10; enabled=false, safe_lock=true)
prog_perf(10; enabled=false, safe_lock=1)
prog_perf(10; force=true)

t_noprog = (@elapsed noprog_perf(10^8))/10^8
t_prog = (@elapsed prog_perf(10^8))/10^8
t_lock = (@elapsed prog_perf(10^8; safe_lock=true))/10^8
t_noprint = (@elapsed prog_perf(10^8; dt=9999))/10^8
t_disabled = (@elapsed prog_perf(10^8; enabled=false))/10^8
t_disabled_lock = (@elapsed prog_perf(10^8; enabled=false, safe_lock=true))/10^8
t_force = (@elapsed prog_perf(10^2; force=true))/10^2
noprog_threaded(2*Threads.nthreads())
prog_threaded(2*Threads.nthreads())
prog_threaded(2*Threads.nthreads(); safe_lock=1)
prog_threaded(2*Threads.nthreads(); dt=9999)
prog_threaded(2*Threads.nthreads(); enabled=false)
prog_threaded(2*Threads.nthreads(); force=true)

N = 10^8
N_force = 1000
t_noprog = (@elapsed noprog_perf(N))/N
t_prog = (@elapsed prog_perf(N))/N
t_lock = (@elapsed prog_perf(N; safe_lock=1))/N
t_detect = (@elapsed prog_perf(N; safe_lock=2))/N
t_noprint = (@elapsed prog_perf(N; dt=9999.9))/N
t_disabled = (@elapsed prog_perf(N; enabled=false))/N
t_disabled_lock = (@elapsed prog_perf(N; enabled=false, safe_lock=1))/N
t_force = (@elapsed prog_perf(N_force; force=true))/N_force

Nth = Threads.nthreads() * 10^6
Nth_force = Threads.nthreads() * 100
th_noprog = (@elapsed noprog_threaded(Nth))/Nth
th_detect = (@elapsed prog_threaded(Nth))/Nth
th_lock = (@elapsed prog_threaded(Nth; safe_lock=1))/Nth
th_noprint = (@elapsed prog_threaded(Nth; dt=9999.9))/Nth
th_disabled = (@elapsed prog_threaded(Nth; enabled=false))/Nth
th_force = (@elapsed prog_threaded(Nth_force; force=true))/Nth_force

println("Performance results:")
println("without progress: ", ProgressMeter.speedstring(t_noprog))
println("with defaults: ", ProgressMeter.speedstring(t_prog))
println("with no lock: ", ProgressMeter.speedstring(t_prog))
println("with no printing: ", ProgressMeter.speedstring(t_noprint))
println("with disabled: ", ProgressMeter.speedstring(t_disabled))
println("with lock: ", ProgressMeter.speedstring(t_lock))
println("with automatic lock: ", ProgressMeter.speedstring(t_detect))
println("with lock, disabled: ", ProgressMeter.speedstring(t_disabled_lock))
println("with force: ", ProgressMeter.speedstring(t_force))
println()
println("Threaded performance results: ($(Threads.nthreads()) threads)")
println("without progress: ", ProgressMeter.speedstring(th_noprog))
println("with automatic lock: ", ProgressMeter.speedstring(th_detect))
println("with forced lock: ", ProgressMeter.speedstring(th_lock))
println("with no printing: ", ProgressMeter.speedstring(th_noprint))
println("with disabled: ", ProgressMeter.speedstring(th_disabled))
println("with force: ", ProgressMeter.speedstring(th_force))

if get(ENV, "CI", "false") == "false" # CI environment is too unreliable for performance tests
@test t_prog < 9*t_noprog
end


# Avoid a NaN due to the estimated print time compensation
# https://github.com/timholy/ProgressMeter.jl/issues/209
prog = Progress(10)
Expand Down Expand Up @@ -116,7 +163,24 @@ function simple_sum(n; safe_lock = true)
return s
end
p = Progress(10)
@test p.safe_lock == (Threads.nthreads() > 1)
@test (p.safe_lock) == 2*(Threads.nthreads() > 1)
p = Progress(10; safe_lock = false)
@test p.safe_lock == false
@test simple_sum(10; safe_lock = true) simple_sum(10; safe_lock = false)


# Brute-force thread safety

function test_thread(N)
p = Progress(N)
Threads.@threads for _ in 1:N
next!(p)
end
end

println("Brute-forcing thread safety... ($(Threads.nthreads()) threads)")
@time for i in 1:10^5
test_thread(Threads.nthreads())
end


0 comments on commit cd0414f

Please sign in to comment.