Skip to content

Instantly share code, notes, and snippets.

@koush
Created January 7, 2012 22:27
Show Gist options
  • Save koush/1576282 to your computer and use it in GitHub Desktop.
Save koush/1576282 to your computer and use it in GitHub Desktop.
package com.koushikdutta.nio;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import java.util.LinkedList;
import junit.framework.Assert;
public class ByteBufferList {
private static final String LOGTAG = "Tether";
LinkedList<ByteBuffer> mBuffers = new LinkedList<ByteBuffer>();
public ByteBufferList() {
}
public int getAvailable() {
int ret = 0;
for (ByteBuffer bb: mBuffers) {
ret += bb.remaining();
}
return ret;
}
public ByteBuffer read(int count) {
Assert.assertTrue(count <= getAvailable());
ByteBuffer first = mBuffers.peek();
while (first.position() == first.limit()) {
mBuffers.remove();
first = mBuffers.peek();
}
if (first.remaining() >= count) {
return first;
}
else {
// reallocate the count into a single buffer, and return it
byte[] bytes = new byte[count];
int offset = 0;
ByteBuffer bb = null;
while (offset < count) {
bb = mBuffers.remove();
int toRead = Math.min(count - offset, bb.remaining());
bb.get(bytes, offset, toRead);
offset += toRead;
}
Assert.assertNotNull(bb);
// if there was still data left in the last buffer we popped
// toss it back into the head
if (bb.position() < bb.limit())
mBuffers.add(0, bb);
ByteBuffer ret = ByteBuffer.wrap(bytes);
mBuffers.add(0, ret);
return ret;
}
}
private ByteBuffer tail(int alloc) {
ByteBuffer bb = null;
if (mBuffers.size() > 0) {
bb = mBuffers.get(mBuffers.size() - 1);
if (bb.limit() == bb.capacity())
bb = null;
}
if (bb == null) {
bb = ByteBuffer.allocate(alloc);
bb.limit(0);
mBuffers.add(bb);
}
return bb;
}
public int writeFromSocket(ReadableByteChannel socket) {
try {
ByteBuffer bb = tail(2 << 10);
// save the read position
int position = bb.position();
// set write position to the read limit
bb.position(bb.limit());
// allow the write to continue to the end of the buffer
bb.limit(bb.capacity());
// read away
int read = socket.read(bb);
// System.out.println("read: " + read);
// update the new read limit/write position
bb.limit(bb.position());
// restore the old read position
bb.position(position);
return read;
}
catch (Exception ex) {
// ex.printStackTrace();
return -1;
}
}
public ByteBuffer remove() {
return mBuffers.remove();
}
public int size() {
return mBuffers.size();
}
}
package com.koushikdutta.nio;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
public interface ChannelWrapper extends ReadableByteChannel {
public boolean isConnected();
public int write(ByteBuffer src) throws IOException;
public SelectionKey register(Selector sel, int ops) throws ClosedChannelException;
public void connect(SocketAddress remote) throws IOException;
public boolean isChunked();
}
package com.koushikdutta.nio;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
public class DatagramChannelWrapper implements ChannelWrapper {
DatagramChannel mChannel;
public static DatagramChannelWrapper from(DatagramChannel channel) throws IOException {
channel.configureBlocking(false);
return new DatagramChannelWrapper(channel);
}
private DatagramChannelWrapper(DatagramChannel channel) {
mChannel = channel;
}
@Override
public int read(ByteBuffer buffer) throws IOException {
return mChannel.read(buffer);
}
@Override
public void close() throws IOException {
mChannel.close();
}
@Override
public boolean isOpen() {
return mChannel.isOpen();
}
@Override
public boolean isConnected() {
return mChannel.isConnected();
}
@Override
public int write(ByteBuffer src) throws IOException {
return mChannel.write(src);
}
SelectionKey mKey;
@Override
public SelectionKey register(Selector sel, int ops) throws ClosedChannelException {
// can't actually register for anything but read. connect will throw.
return mKey = mChannel.register(sel, SelectionKey.OP_READ);
}
@Override
public void connect(SocketAddress remote) throws IOException {
mChannel.connect(remote);
NonBlockingSocketHandler handler = (NonBlockingSocketHandler)mKey.attachment();
handler.onConnected();
}
@Override
public boolean isChunked() {
return true;
}
}
package com.koushikdutta.nio;
import java.nio.ByteBuffer;
public interface NonBlockingReadCallback {
public void onDataAvailable(ByteBuffer bb);
}
package com.koushikdutta.nio;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.ClosedSelectorException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.Set;
import com.clockworkmod.billing.UpdateTrialCallback;
import junit.framework.Assert;
import android.content.Intent;
import android.sax.StartElementListener;
import android.util.Log;
public class NonBlockingServer {
private static final String LOGTAG = "Tether";
ServerSocketChannel mServer;
Selector mSelector;
public NonBlockingServer() {
}
public Selector getSelector() {
return mSelector;
}
protected NonBlockingSocketHandler onNewClient(SocketChannel sc) throws IOException {
return null;
}
protected SelectionKey handleSocket(NonBlockingSocketHandler handler) {
try {
ChannelWrapper sc = handler.getChannel();
SelectionKey ckey = sc.register(mSelector, SelectionKey.OP_CONNECT);
ckey.attach(handler);
handler.mKey = ckey;
return ckey;
}
catch (ClosedChannelException ex) {
ex.printStackTrace();
return null;
}
}
boolean mRun = false;
public void wakeup() {
try {
mSelector.wakeup();
}
catch (Exception ex) {
}
}
public void stop() {
mRun = false;
mSelector.wakeup();
try {
mServer.close();
}
catch (IOException e) {
}
mServer = null;
}
protected void onDataTransmitted(int transmitted) {
}
// long mWindowLimit = 100000L;
// long mWindowLength = 1000L;
// long mWindowExpiration = System.currentTimeMillis() + mWindowLength;
// int mWindowTransmitted = 0;
public void listen(InetAddress host, int port) {
try {
mRun = true;
mServer = ServerSocketChannel.open();
mServer.configureBlocking(false);
InetSocketAddress isa = new InetSocketAddress(host, port);
mServer.socket().bind(isa);
mSelector = SelectorProvider.provider().openSelector();
mServer.register(mSelector, SelectionKey.OP_ACCEPT);
}
catch (ClosedSelectorException ex) {
}
catch (IOException ex) {
}
while (mRun) {
try {
mSelector.select();
// Log.i(LOGTAG, "Selector keys: " + mSelector.keys().size());
Set<SelectionKey> readyKeys = mSelector.selectedKeys();
for (SelectionKey key : readyKeys) {
// long now = System.currentTimeMillis();
// if (now > mWindowExpiration) {
// mWindowExpiration = now + mWindowLength;
// mWindowTransmitted = 0;
// }
//
// if (mWindowTransmitted > mWindowLimit) {
// // sleep till it can transmit more.
// long overage = mWindowTransmitted - mWindowLimit;
// long delay = overage * mWindowLength / mWindowLimit;
// Log.i(LOGTAG, "RATE LIMITING! " + delay);
// Thread.sleep(delay);
// mWindowExpiration = now + mWindowLength;
// mWindowTransmitted = 0;
// }
// System.out.println(key.isAcceptable());
// System.out.println(key.isWritable());
// System.out.println(key.isConnectable());
// System.out.println(key.isReadable());
if (key.isAcceptable()) {
ServerSocketChannel nextReady = (ServerSocketChannel) key.channel();
SocketChannel sc = nextReady.accept();
if (sc == null)
continue;
sc.configureBlocking(false);
NonBlockingSocketHandler handler = onNewClient(sc);
SelectionKey ckey = sc.register(mSelector, SelectionKey.OP_READ);
ckey.attach(handler);
handler.mKey = ckey;
// System.out.println("acceptable");
}
else if (key.isReadable()) {
// System.out.println("readable");
NonBlockingSocketHandler handler = (NonBlockingSocketHandler) key.attachment();
int transmitted = handler.onReadable();
// report data stats if not purchased
onDataTransmitted(transmitted);
}
else if (key.isWritable()) {
// System.out.println("writable");
NonBlockingSocketHandler handler = (NonBlockingSocketHandler) key.attachment();
handler.onDataWritable();
}
else if (key.isConnectable()) {
// System.out.println("connectable");
NonBlockingSocketHandler handler = (NonBlockingSocketHandler) key.attachment();
SocketChannel sc = (SocketChannel) key.channel();
key.interestOps(SelectionKey.OP_READ);
try {
sc.finishConnect();
if (handler != null)
handler.onConnected();
}
catch (Exception ex) {
key.cancel();
sc.close();
handler.onConnectFailed();
}
}
else {
Log.i(LOGTAG, "wtf");
Assert.fail();
}
}
// if (readyKeys.size() == 0) {
// for (SelectionKey key: mSelector.keys()) {
// System.out.println(key.channel().getClass().getName());
// System.out.println(key.interestOps() &
// SelectionKey.OP_ACCEPT);
// System.out.println(key.interestOps() & SelectionKey.OP_READ);
// System.out.println(key.interestOps() &
// SelectionKey.OP_WRITE);
// System.out.println(key.interestOps() &
// SelectionKey.OP_CONNECT);
// }
// }
readyKeys.clear();
// System.out.println("Keys in selector: " +
// mSelector.keys().size());
}
catch (Exception e) {
Log.i(LOGTAG, "exception?");
e.printStackTrace();
}
}
for (SelectionKey key : mSelector.keys()) {
try {
key.channel().close();
}
catch (IOException e) {
}
}
try {
mSelector.close();
}
catch (IOException e) {
}
mSelector = null;
// throw new Exception("selector terminated?");
}
}
package com.koushikdutta.nio;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.DatagramChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.LinkedList;
import junit.framework.Assert;
import android.graphics.AvoidXfermode;
import android.util.Log;
public class NonBlockingSocketHandler {
private static final String LOGTAG = "Tether";
public NonBlockingSocketHandler(ChannelWrapper channel) {
mChannel = channel;
// try {
// throw new Exception();
// }
// catch (Exception e) {
// mWhere = e;
// }
}
// Exception mWhere;
// public Exception where() {
// return mWhere;
// }
public NonBlockingSocketHandler(SocketChannel channel) throws IOException {
this(SocketChannelWrapper.from(channel));
}
public NonBlockingSocketHandler(DatagramChannel channel) throws IOException {
this(DatagramChannelWrapper.from(channel));
}
public ChannelWrapper getChannel() {
return mChannel;
}
public void onConnected() {
};
public void onClosed() {};
public void onDataWritable() {
write(null);
}
private ChannelWrapper mChannel;
SelectionKey mKey;
LinkedList<ByteBuffer> mPendingWrites = new LinkedList<ByteBuffer>();
public void writeBytes(byte[] bytes) {
write(ByteBuffer.wrap(bytes));
}
public void writeBytes(byte[] bytes, int offset, int count) {
write(ByteBuffer.wrap(bytes, offset, count));
}
public void write(ByteBuffer add) {
try {
if (add != null)
mPendingWrites.add(add);
if (!mChannel.isConnected())
return;
// keep writing until the the socket can't write any more, or the
// data is exhausted.
mKey.interestOps(SelectionKey.OP_READ);
while (mPendingWrites.size() > 0) {
ByteBuffer bb = mPendingWrites.peek();
mChannel.write(bb);
if (bb.position() < bb.limit()) {
// register for a write notification if a write fails
mKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE);
break;
}
mPendingWrites.remove();
}
}
catch (IOException ex) {
ex.printStackTrace();
}
}
boolean hasPendingWrites() {
return mPendingWrites.size() > 0;
}
private OutputStream mStream = new OutputStream() {
@Override
public void write(int oneByte) throws IOException {
NonBlockingSocketHandler.this.writeBytes(new byte[] { (byte)oneByte });
}
public void write(byte[] buffer, int offset, int count) throws IOException {
byte[] bytes = new byte[count];
System.arraycopy(buffer, offset, bytes, 0, count);
NonBlockingSocketHandler.this.writeBytes(bytes);
};
};
public OutputStream getOutputStream() {
return mStream;
}
NonBlockingReadCallback mPendingRead;
int mPendingReadLength;
ByteBufferList mPendingData = new ByteBufferList();
// Exception readStack;
public void read(int count, NonBlockingReadCallback callback) {
Assert.assertNull(mPendingRead);
// Assert.assertTrue(count < 10000);
mPendingReadLength = count;
mPendingRead = callback;
// try {
// throw new Exception();
// }
// catch (Exception e) {
// readStack = e;
// }
}
private boolean handlePendingData() {
// readStack.printStackTrace();
// System.out.println("pending read length: " + mPendingReadLength);
// System.out.println("pending available: " + mPendingData.getAvailable());
if (mPendingRead == null || mPendingReadLength > mPendingData.getAvailable())
return false;
ByteBuffer bb = mPendingData.read(mPendingReadLength);
NonBlockingReadCallback pendingRead = mPendingRead;
mPendingRead = null;
pendingRead.onDataAvailable(bb);
return true;
}
void handleAllPendingData() {
int available;
while ((available = mPendingData.getAvailable()) > 0) {
// System.out.println("handling data");
if (mPendingRead != null) {
if (!handlePendingData())
break;
// if nothing was read, the socket may be waiting for more data?
if (available == mPendingData.getAvailable())
break;
}
else {
onDataAvailable(mPendingData.remove());
}
}
}
int onReadable() {
int total = 0;
Assert.assertTrue(mKey.isReadable());
final boolean chunked = mChannel.isChunked();
boolean closed = false;
while (true) {
int read = mPendingData.writeFromSocket(mChannel);
// handle udp in chunks
if (read < 0) {
close();
closed = true;
// System.out.println("cancelled key: " + mChannel.isOpen());
}
else {
total += read;
}
if (read <= 0)
break;
if (chunked)
break;
}
handleAllPendingData();
if (closed)
onClosed();
return total;
}
public void close() {
mKey.cancel();
try {
// if (mName == null)
// mWhere.printStackTrace();
mChannel.close();
}
catch (IOException e) {
}
}
//
// String mName;
// public void setName(String name) {
// mName = name;
// }
public void onDataAvailable(ByteBuffer byteBuffer) {
// if this isn't handled, the data gets tossed on the floor.
};
public void onConnectFailed() {
}
}
package com.koushikdutta.nio;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.LinkedList;
import junit.framework.Assert;
public class PushParser {
private LinkedList<Object> mWaiting = new LinkedList<Object>();
static class BufferWaiter {
int length;
}
int mNeeded = 0;
public PushParser readInt() {
mNeeded += 4;
mWaiting.add(int.class);
return this;
}
public PushParser readByte() {
mNeeded += 1;
mWaiting.add(byte.class);
return this;
}
public PushParser readShort() {
mNeeded += 2;
mWaiting.add(short.class);
return this;
}
public PushParser readLong() {
mNeeded += 8;
mWaiting.add(long.class);
return this;
}
public PushParser readBuffer(int length) {
if (length != -1)
mNeeded += length;
BufferWaiter bw = new BufferWaiter();
bw.length = length;
mWaiting.add(bw);
return this;
}
public PushParser readLenBuffer() {
readInt();
BufferWaiter bw = new BufferWaiter();
bw.length = -1;
mWaiting.add(bw);
return this;
}
NonBlockingSocketHandler mSocket;
public PushParser(NonBlockingSocketHandler s) {
mSocket = s;
}
private ArrayList<Object> mArgs = new ArrayList<Object>();
private TapCallback mCallback;
Exception stack() {
try {
throw new Exception();
}
catch (Exception e) {
return e;
}
}
public void tap(TapCallback callback) {
Assert.assertNull(mCallback);
Assert.assertTrue(mWaiting.size() > 0);
final Exception e = stack();
mCallback = callback;
new NonBlockingReadCallback() {
{
onDataAvailable(null);
}
@Override
public void onDataAvailable(ByteBuffer bb) {
try {
while (mWaiting.size() > 0) {
Object waiting = mWaiting.peek();
if (waiting == null)
break;
// System.out.println("Remaining: " + bb.remaining());
if (waiting == int.class) {
mArgs.add(bb.getInt());
mNeeded -= 4;
}
else if (waiting == short.class){
mArgs.add(bb.getShort());
mNeeded -= 2;
}
else if (waiting == byte.class){
mArgs.add(bb.get());
mNeeded -= 1;
}
else if (waiting == long.class){
mArgs.add(bb.getLong());
mNeeded -= 8;
}
else {
BufferWaiter bw = (BufferWaiter)waiting;
int length = bw.length;
if (length == -1) {
length = (Integer)mArgs.get(mArgs.size() - 1);
mArgs.remove(mArgs.size() - 1);
bw.length = length;
mNeeded += length;
}
if (bb.remaining() < length) {
// System.out.print("imminient feilure detected");
throw new Exception();
}
// e.printStackTrace();
// System.out.println("Buffer length: " + length);
byte[] bytes = null;
if (length > 0) {
bytes = new byte[length];
bb.get(bytes);
}
mNeeded -= length;
mArgs.add(bytes);
}
// System.out.println("Parsed: " + mArgs.get(0));
mWaiting.remove();
}
}
catch (Exception ex) {
Assert.assertTrue(mNeeded != 0);
// ex.printStackTrace();
mSocket.read(mNeeded, this);
return;
}
try {
Object[] args = mArgs.toArray();
mArgs.clear();
TapCallback callback = mCallback;
mCallback = null;
Method method = callback.getTap();
method.invoke(callback, args);
}
catch (Exception ex) {
ex.printStackTrace();
}
}
};
}
}
package com.koushikdutta.nio;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
public class SocketChannelWrapper implements ChannelWrapper {
SocketChannel mChannel;
public static SocketChannelWrapper from(SocketChannel channel) throws IOException {
channel.configureBlocking(false);
return new SocketChannelWrapper(channel);
}
private SocketChannelWrapper(SocketChannel channel) {
mChannel = channel;
}
@Override
public int read(ByteBuffer buffer) throws IOException {
return mChannel.read(buffer);
}
@Override
public void close() throws IOException {
mChannel.close();
}
@Override
public boolean isOpen() {
return mChannel.isOpen();
}
@Override
public boolean isConnected() {
return mChannel.isConnected();
}
@Override
public int write(ByteBuffer src) throws IOException {
return mChannel.write(src);
}
@Override
public SelectionKey register(Selector sel, int ops) throws ClosedChannelException {
return mChannel.register(sel, ops);
}
@Override
public void connect(SocketAddress remote) throws IOException {
mChannel.connect(remote);
}
@Override
public boolean isChunked() {
// TODO Auto-generated method stub
return false;
}
}
package com.koushikdutta.nio;
import java.lang.reflect.Method;
import java.util.Hashtable;
public class TapCallback {
static Hashtable<Class, Method> mTable = new Hashtable<Class, Method>();
Method getTap() {
Method found = mTable.get(getClass());
if (found != null)
return found;
for (Method method : getClass().getMethods()) {
if ("tap".equals(method.getName())) {
mTable.put(getClass(), method);
return method;
}
}
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment