diff --git a/lib/apartment/tenant.rb b/lib/apartment/tenant.rb index abbe87d5..0124704b 100644 --- a/lib/apartment/tenant.rb +++ b/lib/apartment/tenant.rb @@ -19,36 +19,37 @@ module Tenant # @return {subclass of Apartment::AbstractAdapter} # def adapter - Thread.current[:apartment_adapter] ||= begin - adapter_method = "#{config[:adapter]}_adapter" - - if defined?(JRUBY_VERSION) - case config[:adapter] - when /mysql/ - adapter_method = 'jdbc_mysql_adapter' - when /postgresql/ - adapter_method = 'jdbc_postgresql_adapter' - end - end + current_adapter = Thread.current.thread_variable_get(:apartment_adapter) + return current_adapter if current_adapter - begin - require "apartment/adapters/#{adapter_method}" - rescue LoadError - raise "The adapter `#{adapter_method}` is not yet supported" - end + adapter_method = "#{config[:adapter]}_adapter" - unless respond_to?(adapter_method) - raise AdapterNotFound, "database configuration specifies nonexistent #{config[:adapter]} adapter" + if defined?(JRUBY_VERSION) + case config[:adapter] + when /mysql/ + adapter_method = 'jdbc_mysql_adapter' + when /postgresql/ + adapter_method = 'jdbc_postgresql_adapter' end + end - send(adapter_method, config) + begin + require "apartment/adapters/#{adapter_method}" + rescue LoadError + raise "The adapter `#{adapter_method}` is not yet supported" end + + unless respond_to?(adapter_method) + raise AdapterNotFound, "database configuration specifies nonexistent #{config[:adapter]} adapter" + end + + Thread.current.thread_variable_set(:apartment_adapter, send(adapter_method, config)) end # Reset config and adapter so they are regenerated # def reload!(config = nil) - Thread.current[:apartment_adapter] = nil + Thread.current.thread_variable_set(:apartment_adapter, nil) @config = config end diff --git a/spec/tenant_spec.rb b/spec/tenant_spec.rb index 13363b9a..467c494c 100644 --- a/spec/tenant_spec.rb +++ b/spec/tenant_spec.rb @@ -68,6 +68,14 @@ thread.join expect(subject.current).to eq(db1) end + + it 'maintains the current tenant across fibers within a thread' do + subject.switch!(db1) + expect(subject.current).to eq(db1) + fiber = Fiber.new { expect(subject.current).to eq(db1) } + fiber.resume + expect(subject.current).to eq(db1) + end end end