Skip to content

WaitFreeQueueSlow

Pslydhh edited this page Oct 5, 2017 · 1 revision

package org.psly.concurrent;

import java.lang.reflect.Field;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import sun.misc.Unsafe;

public class WaitFreeQueueScan {
	public WaitFreeQueueScan() {
		head = tail = new Node(null);
		@SuppressWarnings("unchecked")
		OpDesc[] temp = (OpDesc[]) new OpDesc>[Short.MAX_VALUE];
		state = temp;
	}

	private void help(long phase) {
		int end = atoIntger.get();
		int random = threadLocalRandom.get();
		random = random < 0 ? -random : random;
		int index = random % end;

		if ((index & 1) == 1) {
			for (int i = 0; i < end; ++i) {
				OpDesc opDesc = getArrayAt(i);
				if (opDesc != null) {
					if (opDesc instanceof OpDescAdd) {
						OpDescAdd op = (OpDescAdd) opDesc;
						if (op.pending == 1 && op.phase <= phase) {
							helpEnq(op);
						}
					} else {
						OpDescPoll op = (OpDescPoll) opDesc;
						if (op.pNode.pending == 1 && op.phase <= phase) {
							helpDeq(op);
						}
					}
				}
			}
			for (int i = 0; i < index; ++i) {
				OpDesc opDesc = getArrayAt(i);
				if (opDesc != null) {
					if (opDesc instanceof OpDescAdd) {
						OpDescAdd op = (OpDescAdd) opDesc;
						if (op.pending == 1 && op.phase <= phase) {
							helpEnq(op);
						}
					} else {
						OpDescPoll op = (OpDescPoll) opDesc;
						if (op.pNode.pending == 1 && op.phase <= phase) {
							helpDeq(op);
						}
					}
				}
			}
		} else {
			for (int i = index; i >= 0; --i) {
				OpDesc opDesc = getArrayAt(i);
				if (opDesc != null) {
					if (opDesc instanceof OpDescAdd) {
						OpDescAdd op = (OpDescAdd) opDesc;
						if (op.pending == 1 && op.phase <= phase) {
							helpEnq(op);
						}
					} else {
						OpDescPoll op = (OpDescPoll) opDesc;
						if (op.pNode.pending == 1 && op.phase <= phase) {
							helpDeq(op);
						}
					}
				}
			}
			for (int i = end - 1; i > index; --i) {
				OpDesc opDesc = getArrayAt(i);
				if (opDesc != null) {
					if (opDesc instanceof OpDescAdd) {
						OpDescAdd op = (OpDescAdd) opDesc;
						if (op.pending == 1 && op.phase <= phase) {
							helpEnq(op);
						}
					} else {
						OpDescPoll op = (OpDescPoll) opDesc;
						if (op.pNode.pending == 1 && op.phase <= phase) {
							helpDeq(op);
						}
					}
				}
			}
		}
	}

	short maxPhase() {
		short maxPhase = 0;
		for (int i = 0; i < Short.MAX_VALUE; i++) {
			OpDesc opDesc;
			if ((opDesc = getArrayAt(i)) != null) {
				short phase = opDesc.phase;
				if (phase > maxPhase) {
					maxPhase = phase;
				}
				continue;
			}
			break;
		}
		// System.out.println(maxPhase);
		return maxPhase;
	}

	public boolean add(E item) {
		int tid = threadLocal.get();
		short phase = (short) (maxPhase() + 1);
		setArrayAt(tid, new OpDescAdd(phase, 1, true, new Node(item)));
		help(phase);
		return true;
	}

	public void helpEnq(OpDescAdd opDescAdd) {
		if (opDescAdd.phase > Long.MAX_VALUE / 2) {
			int countHelp = 0;
			for (;;) {
				if (opDescAdd.pending == 1) {
					if (++countHelp > NUMOFHELP) {
						opDescAdd.phase = 0;
						break;
					}
					helpEnqReal(opDescAdd);
					continue;
				}
				return;
			}
		}

		while (opDescAdd.pending == 1)
			helpEnqReal(opDescAdd);
	}

	private void helpEnqReal(OpDescAdd opDescAdd) {
		Node last = tail;
		Node next = last.next;
		if (last == tail) {
			if (next == null) {
				if (opDescAdd.pending == 1) {
					if (last.casNext(null, opDescAdd.node)) {
						enqNext.getAndIncrement();
						helpFinishEnq();
						return;
					}
				}
			} else {
				helpFinishEnq();
			}
		}
	}

	private void helpFinishEnq() {
		Node last = tail;
		Node next = last.next;
		if (next != null) {
			OpDescAdd curDesc = (OpDescAdd) next.opDesc;
			if(curDesc.pending == 1){
				if(curDesc.casPending(1, 0)){
					enqPend.getAndIncrement();
				}
			}
			if (last == tail) {
				if(casTail(last, next)){
					enqTail.getAndIncrement();
				}
			}
		}
	}

	public E poll() {
		int tid = threadLocal.get();
		short phase = (short) (maxPhase() + 1);
		OpDescPoll opDesc;
		setArrayAt(tid, opDesc = new OpDescPoll(phase, 1, false, null));
		help(phase);
		Node node = opDesc.pNode.node;
		if (node == null)
			return null;
		else
			return node.next.value;
	}

	public int size() {
		Node head = this.head, next;
		int count = 0;
		for (;;) {
			next = head.next;
			if (next == null)
				break;
			if (head.opDesc instanceof OpDescAdd)
				++count;
			head = next;
		}
		return count;
	}

	private void helpDeq(OpDescPoll opDescPoll) {
		if (opDescPoll.phase > Short.MAX_VALUE / 2) {
			int countHelp = 0;
			for (;;) {
				if (opDescPoll.pNode.pending == 1) {
					if (++countHelp > NUMOFHELP) {
						opDescPoll.phase = 0;
						break;
					}
					helpDeqReal(opDescPoll);
					continue;
				}
				return;
			}
		}
		while (opDescPoll.pNode.pending == 1)
			helpDeqReal(opDescPoll);
	}

	private void helpDeqReal(OpDescPoll opDescPoll) {
		Node first = head;
		Node last = tail;
		Node next = first.next;
		if (first == head) {
			if (first == last) {
				if (next == null) {
					OpDescPoll.PNode curPNode = opDescPoll.pNode;
					if (last == tail && curPNode.pending == 1) {
						if (opDescPoll.casPNode(curPNode, new OpDescPoll.PNode<>(null, 0))){
							deqNull.getAndIncrement();
						}
					}
				} else {
					helpFinishEnq();
				}
			} else {
				OpDescPoll.PNode curPNode = opDescPoll.pNode;
				if (curPNode.pending == 0)
					return;
				if (first == head) {
					OpDesc op = first.opDesc;
					if (op instanceof OpDescAdd) {
						if (first != curPNode.node) {
							opDescPoll.casPNode(curPNode, new OpDescPoll.PNode<>(first, 1));
							first = opDescPoll.pNode.node;
							if (first == null)
								return;
							op = first.opDesc;
						}
						if (op instanceof OpDescAdd){
							if (first.casOpDesc(op, opDescPoll)){
								deqNode.getAndIncrement();
							}
						}
					}
					helpFinishDeq();
				}
			}
		}
	}

	private void helpFinishDeq() {
		Node first = head;
		Node next = first.next;
		OpDesc opDesc = first.opDesc;
		if (opDesc != null && opDesc instanceof OpDescPoll) {
			OpDescPoll opDescPoll = (OpDescPoll) opDesc;
			OpDescPoll.PNode curPNode = opDescPoll.pNode;

			OpDescPoll.PNode newAgg = new OpDescPoll.PNode<>(curPNode.node, 0);
			if (curPNode.pending == 1 /* && curPNode.node == first */){
				if (opDescPoll.casPNode(curPNode, newAgg)){
					deqPend.getAndIncrement();
				}
			}
			if (first == head && next != null) {
				if(casHead(first, next)){
					deqHead.getAndIncrement();
				}
			}
		}
	}

	boolean casTail(Node cmp, Node val) {
		return UNSAFE.compareAndSwapObject(this, tailOffset, cmp, val);
	}

	boolean casHead(Node cmp, Node val) {
		return UNSAFE.compareAndSwapObject(this, headOffset, cmp, val);
	}

	final void setArrayAt(int i, OpDesc v) {
		UNSAFE.putObjectVolatile(state, ((long) i << ASHIFT) + ABASE, v);
	}

	@SuppressWarnings("unchecked")
	final OpDesc getArrayAt(int i) {
		return (OpDesc) UNSAFE.getObjectVolatile(state, ((long) i << ASHIFT) + ABASE);
	}

	public volatile Node head, tail;
	private static final int NUMOFHELP = 8;

	final OpDesc[] state;
	final ThreadLocal threadLocal = new ThreadLocal() {
		protected Integer initialValue() {
			return atoIntger.getAndIncrement();
		}
	};
	final ThreadLocal threadLocalRandom = new ThreadLocal() {
		protected Integer initialValue() {
			return new Random().nextInt();
		}
	};

	private final AtomicInteger atoIntger = new AtomicInteger();

	public AtomicInteger enqPend = new AtomicInteger();
	public AtomicInteger enqNext = new AtomicInteger();
	public AtomicInteger enqTail = new AtomicInteger();
	
	public AtomicInteger deqPend = new AtomicInteger();
	public AtomicInteger deqNode = new AtomicInteger();
	public AtomicInteger deqHead = new AtomicInteger();
	
	public AtomicInteger deqNull = new AtomicInteger();

	private static class Node {
		final E value;
		volatile Node next;
		volatile OpDesc opDesc;

		Node(E val) {
			value = val;
			next = null;
			opDesc = new OpDescAdd();
		}

		boolean casNext(Node cmp, Node val) {
			return UNSAFE.compareAndSwapObject(this, nextOffset, cmp, val);
		}

		boolean casOpDesc(OpDesc cmp, OpDesc val) {
			return UNSAFE.compareAndSwapObject(this, opDescOffset, cmp, val);
		}

		private static final sun.misc.Unsafe UNSAFE;
		private static final long nextOffset;
		private static final long opDescOffset;
		static {
			try {
				UNSAFE = UtilUnsafe.getUnsafe();
				nextOffset = UNSAFE.objectFieldOffset(Node.class.getDeclaredField("next"));
				opDescOffset = UNSAFE.objectFieldOffset(Node.class.getDeclaredField("opDesc"));
			} catch (Exception e) {
				throw new Error(e);
			}
		}
	}

	static class OpDesc {
		volatile short phase;
		final boolean enqueue;

		OpDesc(short ph, boolean enq) {
			phase = ph;
			enqueue = enq;
		}
	}

	static final class OpDescAdd extends OpDesc {
		final Node node;
		volatile int pending;

		OpDescAdd() {
			this((short) 0, 0, false, null);
		}

		OpDescAdd(short ph, int pend, boolean enq, Node n) {
			super(ph, enq);
			pending = pend;
			node = n;
			if (n != null)
				n.opDesc = this;
		}

		boolean casPending(int cmp, int val) {
			return UNSAFE.compareAndSwapInt(this, pendingOffsetAdd, cmp, val);
		}

		private static final sun.misc.Unsafe UNSAFE;
		private static final long pendingOffsetAdd;
		static {
			try {
				UNSAFE = UtilUnsafe.getUnsafe();
				pendingOffsetAdd = UNSAFE.objectFieldOffset(OpDescAdd.class.getDeclaredField("pending"));
			} catch (Exception e) {
				throw new Error(e);
			}
		}
	}

	static final class OpDescPoll extends OpDesc {
		volatile PNode pNode;

		OpDescPoll(short ph, int pend, boolean enq, Node n) {
			super(ph, enq);
			pNode = new PNode(n, pend);
		}

		static final class PNode {
			public PNode(Node node, int pending) {
				this.node = node;
				this.pending = pending;
			}

			final Node node;
			final int pending;
		}

		boolean casPNode(PNode cmp, PNode val) {
			return UNSAFE.compareAndSwapObject(this, pNodeOffset, cmp, val);
		}

		private static final sun.misc.Unsafe UNSAFE;
		private static final long pNodeOffset;
		static {
			try {
				UNSAFE = UtilUnsafe.getUnsafe();
				pNodeOffset = UNSAFE.objectFieldOffset(OpDescPoll.class.getDeclaredField("pNode"));
			} catch (Exception e) {
				throw new Error(e);
			}
		}
	}

	private static class UtilUnsafe {
		private UtilUnsafe() {
		}

		/** Fetch the Unsafe. Use With Caution. */
		public static Unsafe getUnsafe() {
			if (UtilUnsafe.class.getClassLoader() == null)
				return Unsafe.getUnsafe();
			try {
				final Field fld = Unsafe.class.getDeclaredField("theUnsafe");
				fld.setAccessible(true);
				return (Unsafe) fld.get(UtilUnsafe.class);
			} catch (Exception e) {
				throw new RuntimeException("Could not obtain access to sun.misc.Unsafe", e);
			}
		}
	}

	private static final sun.misc.Unsafe UNSAFE;
	private static final long headOffset;
	private static final long tailOffset;
	private static final int _Obase;
	private static final int _Oscale;

	private static final long ABASE;
	private static final int ASHIFT;

	static {
		try {
			UNSAFE = UtilUnsafe.getUnsafe();
			headOffset = UNSAFE.objectFieldOffset(WaitFreeQueueScan.class.getDeclaredField("head"));
			tailOffset = UNSAFE.objectFieldOffset(WaitFreeQueueScan.class.getDeclaredField("tail"));
			_Obase = UNSAFE.arrayBaseOffset(OpDesc[].class);
			_Oscale = UNSAFE.arrayIndexScale(OpDesc[].class);

			ABASE = _Obase;
			if ((_Oscale & (_Oscale - 1)) != 0)
				throw new Error("data type scale not a power of two");
			ASHIFT = 31 - Integer.numberOfLeadingZeros(_Oscale);
		} catch (Exception e) {
			throw new Error(e);
		}
	}
}
Clone this wiki locally