Created
February 26, 2022 10:38
-
-
Save shenfeng/b08b69b705ac0dca44bcabb09b8460c7 to your computer and use it in GitHub Desktop.
lock free concurrent queue java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.invoke.MethodHandles; | |
import java.lang.invoke.VarHandle; | |
public class LockFreeQueue<E> { | |
private static final VarHandle NEXT; | |
private static final VarHandle HEAD; | |
private static final VarHandle TAIL; | |
static { | |
try { | |
MethodHandles.Lookup l = MethodHandles.lookup(); | |
NEXT = l.findVarHandle(Node.class, "next", Node.class); | |
HEAD = l.findVarHandle(LockFreeQueue.class, "head", Node.class); | |
TAIL = l.findVarHandle(LockFreeQueue.class, "tail", Node.class); | |
} catch (ReflectiveOperationException e) { | |
throw new ExceptionInInitializerError(e); | |
} | |
} | |
static final class Node<E> { | |
E item; | |
volatile Node<E> next; | |
Node() { | |
} | |
Node(E item) { | |
this.item = item; | |
} | |
boolean casNext(Node<E> o, Node<E> n) { | |
return NEXT.compareAndSet(this, o, n); | |
} | |
} | |
private volatile Node<E> head; | |
private volatile Node<E> tail; | |
public LockFreeQueue() { | |
head = tail = new Node<>(); | |
} | |
public void offer(E e) { | |
Node<E> node = new Node<>(e); | |
Node<E> t; | |
while (true) { | |
t = tail; | |
Node<E> n = t.next; | |
if (n == null) { | |
if (t.casNext(null, node)) { | |
TAIL.compareAndSet(this, t, node); | |
break; | |
} | |
} else { | |
TAIL.compareAndSet(this, t, n); | |
} | |
} | |
} | |
public E poll() { | |
while (true) { | |
Node<E> h = head; | |
Node<E> t = tail; | |
if (h == t) { | |
return null; | |
} else { | |
Node<E> n = h.next; | |
E e = n.item; | |
if (HEAD.compareAndSet(this, h, n)) { | |
return e; | |
} | |
} | |
} | |
} | |
} | |
import org.junit.Assert; | |
import org.junit.Test; | |
import java.util.ArrayList; | |
import java.util.List; | |
import java.util.concurrent.ConcurrentLinkedQueue; | |
import java.util.concurrent.atomic.AtomicIntegerArray; | |
public class LockFreeQueueTest { | |
final static int N = 200000; | |
@Test | |
public void testSingleThread() { | |
LockFreeQueue<Integer> q = new LockFreeQueue<>(); | |
for (int i = 0; i < N; i++) { | |
q.offer(i); | |
Assert.assertEquals(i, q.poll().intValue()); | |
Assert.assertNull(q.poll()); | |
} | |
} | |
@Test | |
public void TestJDK() { | |
ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>(); | |
q.offer(1); | |
q.offer(2); | |
q.offer(3); | |
q.offer(4); | |
} | |
@Test | |
public void testTwoThread() throws InterruptedException { | |
LockFreeQueue<Integer> q = new LockFreeQueue<>(); | |
// ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>(); | |
int N = 1000000; | |
for (int w = 0; w < 10; w++) { | |
Thread t = new Thread(() -> { | |
for (int i = 0; i < N; i++) { | |
while (true) { | |
Integer v = q.poll(); | |
if (v != null) { | |
Assert.assertEquals(v.intValue(), i); | |
break; | |
} | |
} | |
} | |
}); | |
t.start(); | |
new Thread(() -> { | |
for (int i = 0; i < N; i++) { | |
q.offer(i); | |
} | |
}).start(); | |
t.join(); | |
System.out.println(); | |
} | |
} | |
@Test | |
public void testNThread() throws InterruptedException { | |
AtomicIntegerArray array = new AtomicIntegerArray(N); | |
for (int i = 0; i < 8; i++) { | |
testMPMC(i + 1, array); | |
} | |
} | |
private void testMPMC(int n, AtomicIntegerArray array) throws InterruptedException { | |
for (int i = 0; i < N; i++) { | |
array.set(i, 0); | |
} | |
List<Thread> threads = new ArrayList<>(); | |
LockFreeQueue<Integer> q = new LockFreeQueue<>(); | |
// ConcurrentLinkedQueue<Integer> q = new ConcurrentLinkedQueue<>(); | |
for (int i = 0; i < n; i++) { | |
Thread c = new Thread(() -> { // consumer | |
for (int j = 0; j < N; j++) { | |
while (true) { | |
Integer v = q.poll(); | |
if (v != null) { | |
array.incrementAndGet(v); | |
break; | |
} | |
} | |
} | |
}); | |
threads.add(c); | |
Thread t = new Thread(() -> { // producer | |
for (int j = 0; j < N; j++) { | |
q.offer(j); | |
} | |
}); | |
threads.add(t); | |
t.start(); | |
c.start(); | |
} | |
for (Thread t : threads) { | |
t.join(); | |
} | |
for (int i = 0; i < N; i++) { | |
if (n != array.getPlain(i)) { | |
Assert.fail("idx: " + i + ", expect: " + n + ", get: " + array.getPlain(i)); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment