diff --git a/pub.go b/pub.go index a65ac4d..7dffb59 100644 --- a/pub.go +++ b/pub.go @@ -9,6 +9,11 @@ import ( "fmt" "net" "sync" + "sync/atomic" +) + +const ( + DefaultSendHwm = 1000 ) // Topics is an interface that wraps the basic Topics method. @@ -103,9 +108,7 @@ func (pub *pubSocket) SetOption(name string, value interface{}) error { } w := pub.sck.w.(*pubMWriter) - w.qmu.Lock() - w.hwm = hwm - w.qmu.Unlock() + w.hwm.Store(int64(hwm)) return nil } @@ -217,108 +220,78 @@ func (q *pubQReader) topic(msg Msg) bool { } type pubMWriter struct { - ctx context.Context - mu sync.Mutex - ws []*Conn - - qmu sync.Mutex - qcond *sync.Cond - q *Queue - hwm int - closed bool + ctx context.Context + mu sync.RWMutex + subscribers map[*Conn]chan Msg + + hwm atomic.Int64 } func newPubMWriter(ctx context.Context) *pubMWriter { p := &pubMWriter{ - ctx: ctx, - q: NewQueue(), + ctx: ctx, + subscribers: map[*Conn]chan Msg{}, } - p.qcond = sync.NewCond(&p.qmu) - go p.run() + p.hwm.Store(DefaultSendHwm) return p } -func (w *pubMWriter) run() { - for { - w.qmu.Lock() - for w.q.Len() == 0 { - w.qcond.Wait() - if w.closed { - return - } - } - msg, _ := w.q.Peek() - w.q.Pop() - w.qmu.Unlock() - w.sendMsg(msg) - } -} - func (w *pubMWriter) Close() error { - w.qmu.Lock() - w.closed = true - w.qcond.Signal() - w.qmu.Unlock() - w.mu.Lock() - var err error - for _, ww := range w.ws { - e := ww.Close() - if e != nil && err == nil { - err = e - } + defer w.mu.Unlock() + + for conn, channel := range w.subscribers { + _ = conn.Close() + close(channel) } - w.ws = nil - w.mu.Unlock() - return err + w.subscribers = nil + return nil } func (mw *pubMWriter) addConn(w *Conn) { mw.mu.Lock() - mw.ws = append(mw.ws, w) - mw.mu.Unlock() + defer mw.mu.Unlock() + + c := make(chan Msg, mw.hwm.Load()) + mw.subscribers[w] = c + go func() { + for { + msg, ok := <-c + if !ok { + break + } + topic := string(msg.Frames[0]) + if w.subscribed(topic) { + _ = w.SendMsg(msg) + } + } + }() } func (mw *pubMWriter) rmConn(w *Conn) { mw.mu.Lock() defer mw.mu.Unlock() - cur := -1 - for i := range mw.ws { - if mw.ws[i] == w { - cur = i - mw.ws[i].Close() - break - } - } - if cur >= 0 { - mw.ws = append(mw.ws[:cur], mw.ws[cur+1:]...) + if channel, ok := mw.subscribers[w]; ok { + _ = w.Close() + delete(mw.subscribers, w) + close(channel) } } func (w *pubMWriter) write(ctx context.Context, msg Msg) error { - w.qmu.Lock() - defer w.qmu.Unlock() - if w.hwm != 0 && w.q.Len() >= w.hwm { - //TODO(inphi): per subscriber hwm - return nil - } - w.q.Push(msg) - w.qcond.Signal() - return nil -} + w.mu.RLock() + defer w.mu.RUnlock() -func (w *pubMWriter) sendMsg(msg Msg) { - topic := string(msg.Frames[0]) - w.mu.Lock() - defer w.mu.Unlock() - // TODO(inphi): distribute messages across subscribers at once - for i := range w.ws { - ww := w.ws[i] - if ww.subscribed(topic) { - _ = ww.SendMsg(msg) + for _, channel := range w.subscribers { + select { + case <-ctx.Done(): + return ctx.Err() + case channel <- msg: // proceeds to default case if the channel is full (msg will be discarded) + default: } } + return nil } var (