Skip to content

Commit

Permalink
Normalize states for mcsolve solution (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Sep 5, 2024
1 parent 7ba8bbe commit e76d77d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/time_evolution/mcsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end
function _mcsolve_generate_statistics(sol, i, times, states, expvals_all, jump_times, jump_which)
sol_i = sol[:, i]
!isempty(sol_i.prob.kwargs[:saveat]) ?
states[i] = [QuantumObject(sol_i.u[i], dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing
states[i] = [QuantumObject(normalize!(sol_i.u[i]), dims = sol_i.prob.p.Hdims) for i in 1:length(sol_i.u)] : nothing

copyto!(view(expvals_all, i, :, :), sol_i.prob.p.expvals)
times[i] = sol_i.t
Expand Down
7 changes: 7 additions & 0 deletions test/time_evolution_and_partial_trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@
sol_me2 = mesolve(H, psi0, t_l, c_ops, progress_bar = Val(false))
sol_me3 = mesolve(H, psi0, t_l, c_ops, e_ops = e_ops, saveat = t_l, progress_bar = Val(false))
sol_mc = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, e_ops = e_ops, progress_bar = Val(false))
sol_mc_states = mcsolve(H, psi0, t_l, c_ops, n_traj = 500, saveat = t_l, progress_bar = Val(false))

ρt_mc = [ket2dm.(normalize.(states)) for states in sol_mc_states.states]
expect_mc_states = mapreduce(states -> expect.(Ref(e_ops[1]), states), hcat, ρt_mc)
expect_mc_states_mean = sum(expect_mc_states, dims = 2) / size(expect_mc_states, 2)

sol_me_string = sprint((t, s) -> show(t, "text/plain", s), sol_me)
sol_mc_string = sprint((t, s) -> show(t, "text/plain", s), sol_mc)
@test sum(abs.(sol_mc.expect .- sol_me.expect)) / length(t_l) < 0.1
@test sum(abs.(vec(expect_mc_states_mean) .- vec(sol_me.expect))) / length(t_l) < 0.1
@test length(sol_me.states) == 1
@test size(sol_me.expect) == (length(e_ops), length(t_l))
@test length(sol_me2.states) == length(t_l)
Expand Down

0 comments on commit e76d77d

Please sign in to comment.