Skip to content

Instantly share code, notes, and snippets.

@swankjesse
Created February 5, 2018 03:47
Show Gist options
  • Save swankjesse/d7427c1ff2891ff00c4661bedcb1f424 to your computer and use it in GitHub Desktop.
Save swankjesse/d7427c1ff2891ff00c4661bedcb1f424 to your computer and use it in GitHub Desktop.
Hooking up Cipher to Okio’s new UnsafeCursor.
/*
* Copyright 2018 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package okio;
import java.io.IOException;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
public final class OkioCipher {
static final byte[] EMPTY_ARRAY = new byte[0];
final Cipher cipher;
final Buffer.UnsafeCursor sourceCursor = new Buffer.UnsafeCursor();
final Buffer.UnsafeCursor sinkCursor = new Buffer.UnsafeCursor();
OkioCipher(Cipher cipher) {
this.cipher = cipher;
}
/**
* Encrypts or decrypts {@code byteCount} bytes from {@code source} to {@code sink}, ending with
* {@link Cipher#doFinal} or its overloads. Returns the total number of bytes emitted to {@code
* sink}.
*/
public int cipherUntilDoFinal(BufferedSource source, BufferedSink sink, long byteCount)
throws IOException, GeneralSecurityException {
int result = 0;
boolean callDoFinal = false;
while (byteCount > 0) {
// Encrypt up to byteCount bytes from the source.
Buffer sourceBuffer = source.buffer();
sourceBuffer.require(1L);
try (Buffer.UnsafeCursor ic = sourceBuffer.readUnsafe(sourceCursor)) {
ic.seek(0L);
int segmentByteCount = ic.end - ic.start;
int inputByteCount = (int) Math.min(byteCount, segmentByteCount);
callDoFinal = (inputByteCount == byteCount);
int outputByteCount = updateOrDoFinal(ic, sink, inputByteCount, callDoFinal);
source.skip(inputByteCount);
result += outputByteCount;
byteCount -= inputByteCount;
}
}
if (!callDoFinal) {
result += updateOrDoFinal(null, sink, 0, true);
}
return result;
}
/**
* Encrypts/decrypts {@code byteCount} bytes from {@code source} to {@code sink}. Returns the
* number of bytes written to {@code sink}.
*/
private int updateOrDoFinal(Buffer.UnsafeCursor source, BufferedSink sink,
int byteCount, boolean doFinal) throws GeneralSecurityException, IOException {
int outputByteCount = cipher.getOutputSize(byteCount);
if (outputByteCount == 0) {
// When we're writing 0 bytes emit into an empty array.
int resultByteCount = updateOrDoFinalNoOutput(source, byteCount, doFinal);
if (resultByteCount != 0) throw new AssertionError();
return 0;
} else if (outputByteCount <= Segment.SIZE) {
// When what we're writing fits in a segment write directly do the sink cursor.
Buffer sinkBuffer = sink.buffer();
long sinkBufferSize = sinkBuffer.size();
try (Buffer.UnsafeCursor oc = sinkBuffer.readAndWriteUnsafe(sinkCursor)) {
oc.expandBuffer(outputByteCount);
oc.seek(sinkBufferSize);
int resultByteCount = updateOrDoFinalToCursor(source, byteCount, doFinal, oc);
oc.resizeBuffer(sinkBufferSize + resultByteCount);
sink.emitCompleteSegments();
return resultByteCount;
}
} else {
// Write to a temporary array and then write that to the sink.
byte[] bytes = updateOrDoFinalToByteArray(source, byteCount, doFinal);
sink.write(bytes);
return bytes.length;
}
}
/** Crypts {@code byteCount} bytes of {@code source} and produces no output. */
private int updateOrDoFinalNoOutput(Buffer.UnsafeCursor source, int byteCount, boolean doFinal)
throws GeneralSecurityException {
if (doFinal) {
return byteCount > 0
? cipher.doFinal(source.data, source.start, byteCount, EMPTY_ARRAY, 0)
: cipher.doFinal(EMPTY_ARRAY, 0);
} else {
return cipher.update(source.data, source.start, byteCount, EMPTY_ARRAY, 0);
}
}
/** Crypts {@code byteCount} bytes of {@code source} and writes to {@code sink}. */
private int updateOrDoFinalToCursor(Buffer.UnsafeCursor source, int byteCount, boolean doFinal,
Buffer.UnsafeCursor sink) throws GeneralSecurityException {
if (doFinal) {
return byteCount > 0
? cipher.doFinal(source.data, source.start, byteCount, sink.data, sink.start)
: cipher.doFinal(sink.data, sink.start);
} else {
return cipher.update(source.data, source.start, byteCount, sink.data, sink.start);
}
}
/** Crypts {@code byteCount} bytes of {@code source} and returns them. */
private byte[] updateOrDoFinalToByteArray(Buffer.UnsafeCursor source, int byteCount,
boolean doFinal) throws GeneralSecurityException {
if (doFinal) {
return byteCount > 0
? cipher.doFinal(source.data, source.start, byteCount)
: cipher.doFinal();
} else {
return cipher.update(source.data, source.start, byteCount);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment