Created
February 5, 2018 03:47
-
-
Save swankjesse/d7427c1ff2891ff00c4661bedcb1f424 to your computer and use it in GitHub Desktop.
Hooking up Cipher to Okio’s new UnsafeCursor.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2018 Square Inc. | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package okio; | |
import java.io.IOException; | |
import java.security.GeneralSecurityException; | |
import javax.crypto.Cipher; | |
public final class OkioCipher { | |
static final byte[] EMPTY_ARRAY = new byte[0]; | |
final Cipher cipher; | |
final Buffer.UnsafeCursor sourceCursor = new Buffer.UnsafeCursor(); | |
final Buffer.UnsafeCursor sinkCursor = new Buffer.UnsafeCursor(); | |
OkioCipher(Cipher cipher) { | |
this.cipher = cipher; | |
} | |
/** | |
* Encrypts or decrypts {@code byteCount} bytes from {@code source} to {@code sink}, ending with | |
* {@link Cipher#doFinal} or its overloads. Returns the total number of bytes emitted to {@code | |
* sink}. | |
*/ | |
public int cipherUntilDoFinal(BufferedSource source, BufferedSink sink, long byteCount) | |
throws IOException, GeneralSecurityException { | |
int result = 0; | |
boolean callDoFinal = false; | |
while (byteCount > 0) { | |
// Encrypt up to byteCount bytes from the source. | |
Buffer sourceBuffer = source.buffer(); | |
sourceBuffer.require(1L); | |
try (Buffer.UnsafeCursor ic = sourceBuffer.readUnsafe(sourceCursor)) { | |
ic.seek(0L); | |
int segmentByteCount = ic.end - ic.start; | |
int inputByteCount = (int) Math.min(byteCount, segmentByteCount); | |
callDoFinal = (inputByteCount == byteCount); | |
int outputByteCount = updateOrDoFinal(ic, sink, inputByteCount, callDoFinal); | |
source.skip(inputByteCount); | |
result += outputByteCount; | |
byteCount -= inputByteCount; | |
} | |
} | |
if (!callDoFinal) { | |
result += updateOrDoFinal(null, sink, 0, true); | |
} | |
return result; | |
} | |
/** | |
* Encrypts/decrypts {@code byteCount} bytes from {@code source} to {@code sink}. Returns the | |
* number of bytes written to {@code sink}. | |
*/ | |
private int updateOrDoFinal(Buffer.UnsafeCursor source, BufferedSink sink, | |
int byteCount, boolean doFinal) throws GeneralSecurityException, IOException { | |
int outputByteCount = cipher.getOutputSize(byteCount); | |
if (outputByteCount == 0) { | |
// When we're writing 0 bytes emit into an empty array. | |
int resultByteCount = updateOrDoFinalNoOutput(source, byteCount, doFinal); | |
if (resultByteCount != 0) throw new AssertionError(); | |
return 0; | |
} else if (outputByteCount <= Segment.SIZE) { | |
// When what we're writing fits in a segment write directly do the sink cursor. | |
Buffer sinkBuffer = sink.buffer(); | |
long sinkBufferSize = sinkBuffer.size(); | |
try (Buffer.UnsafeCursor oc = sinkBuffer.readAndWriteUnsafe(sinkCursor)) { | |
oc.expandBuffer(outputByteCount); | |
oc.seek(sinkBufferSize); | |
int resultByteCount = updateOrDoFinalToCursor(source, byteCount, doFinal, oc); | |
oc.resizeBuffer(sinkBufferSize + resultByteCount); | |
sink.emitCompleteSegments(); | |
return resultByteCount; | |
} | |
} else { | |
// Write to a temporary array and then write that to the sink. | |
byte[] bytes = updateOrDoFinalToByteArray(source, byteCount, doFinal); | |
sink.write(bytes); | |
return bytes.length; | |
} | |
} | |
/** Crypts {@code byteCount} bytes of {@code source} and produces no output. */ | |
private int updateOrDoFinalNoOutput(Buffer.UnsafeCursor source, int byteCount, boolean doFinal) | |
throws GeneralSecurityException { | |
if (doFinal) { | |
return byteCount > 0 | |
? cipher.doFinal(source.data, source.start, byteCount, EMPTY_ARRAY, 0) | |
: cipher.doFinal(EMPTY_ARRAY, 0); | |
} else { | |
return cipher.update(source.data, source.start, byteCount, EMPTY_ARRAY, 0); | |
} | |
} | |
/** Crypts {@code byteCount} bytes of {@code source} and writes to {@code sink}. */ | |
private int updateOrDoFinalToCursor(Buffer.UnsafeCursor source, int byteCount, boolean doFinal, | |
Buffer.UnsafeCursor sink) throws GeneralSecurityException { | |
if (doFinal) { | |
return byteCount > 0 | |
? cipher.doFinal(source.data, source.start, byteCount, sink.data, sink.start) | |
: cipher.doFinal(sink.data, sink.start); | |
} else { | |
return cipher.update(source.data, source.start, byteCount, sink.data, sink.start); | |
} | |
} | |
/** Crypts {@code byteCount} bytes of {@code source} and returns them. */ | |
private byte[] updateOrDoFinalToByteArray(Buffer.UnsafeCursor source, int byteCount, | |
boolean doFinal) throws GeneralSecurityException { | |
if (doFinal) { | |
return byteCount > 0 | |
? cipher.doFinal(source.data, source.start, byteCount) | |
: cipher.doFinal(); | |
} else { | |
return cipher.update(source.data, source.start, byteCount); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment