Created
March 15, 2024 13:41
-
-
Save igrishaev/caf60cca70a507c43766f14b30e0387a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.pg; | |
import clojure.lang.Agent; | |
import clojure.lang.IFn; | |
import clojure.lang.IPersistentMap; | |
import clojure.lang.PersistentHashMap; | |
import org.pg.auth.MD5; | |
import org.pg.auth.ScramSha256; | |
import org.pg.clojure.LazyMap; | |
import org.pg.codec.EncoderBin; | |
import org.pg.codec.CodecParams; | |
import org.pg.codec.EncoderTxt; | |
import org.pg.copy.Copy; | |
import org.pg.enums.*; | |
import org.pg.error.PGError; | |
import org.pg.msg.*; | |
import org.pg.type.OIDHint; | |
import org.pg.util.*; | |
import javax.net.ssl.SSLContext; | |
import javax.net.ssl.SSLSocket; | |
import java.io.*; | |
import java.net.InetSocketAddress; | |
import java.nio.channels.CompletionHandler; | |
import java.nio.charset.Charset; | |
import java.nio.charset.StandardCharsets; | |
import java.security.NoSuchAlgorithmException; | |
import java.time.ZoneId; | |
import java.util.*; | |
import java.net.Socket; | |
import java.nio.ByteBuffer; | |
import java.nio.channels.AsynchronousSocketChannel; | |
import java.util.concurrent.CompletableFuture; | |
import java.util.concurrent.CompletionStage; | |
import java.util.concurrent.ExecutionException; | |
import java.util.function.Function; | |
public final class Connection implements AutoCloseable { | |
private static final boolean isDebug = | |
System.getenv() | |
.getOrDefault("PG_DEBUG", "") | |
.equals("1"); | |
private final ConnConfig config; | |
private final UUID id; | |
private final long createdAt; | |
private int counter = 0; | |
private int pid; | |
private int secretKey; | |
private TXStatus txStatus; | |
private Socket socket; | |
private InputStream inStream; | |
private OutputStream outStream; | |
private final Map<String, String> params; | |
private final CodecParams codecParams; | |
private boolean isSSL = false; | |
private final System.Logger logger = System.getLogger(Connection.class.getCanonicalName()); | |
private final TryLock lock = new TryLock(); | |
private boolean isClosed = false; | |
private AsynchronousSocketChannel channel; | |
public Connection(final String host, | |
final int port, | |
final String user, | |
final String password, | |
final String database | |
) { | |
this(ConnConfig.builder(user, database) | |
.host(host) | |
.port(port) | |
.password(password) | |
.build()); | |
} | |
public Connection(final ConnConfig config, final boolean sendStartup) { | |
this.config = config; | |
this.params = new HashMap<>(); | |
this.codecParams = CodecParams.standard(); | |
this.id = UUID.randomUUID(); | |
this.createdAt = System.currentTimeMillis(); | |
connect(); | |
// setSocketOptions(); | |
// preSSLStage(); | |
// if (sendStartup) { | |
// authenticate(); | |
// } | |
} | |
public Connection(final ConnConfig config) { | |
this(config, true); | |
} | |
public void close () { | |
try (TryLock ignored = lock.get()) { | |
if (!isClosed) { | |
sendTerminate(); | |
flush(); | |
IOTool.close(socket); | |
isClosed = true; | |
} | |
} | |
} | |
private void setSocketOptions () { | |
try { | |
socket.setTcpNoDelay(config.SOTCPnoDelay()); | |
socket.setSoTimeout(config.SOTimeout()); | |
socket.setKeepAlive(config.SOKeepAlive()); | |
socket.setReceiveBufferSize(config.SOReceiveBufSize()); | |
socket.setSendBufferSize(config.SOSendBufSize()); | |
} | |
catch (IOException e) { | |
throw new PGError(e, "couldn't set socket options"); | |
} | |
} | |
private int nextInt() { | |
try (TryLock ignored = lock.get()) { | |
return ++counter; | |
} | |
} | |
public int getPid () { | |
try (TryLock ignored = lock.get()) { | |
return pid; | |
} | |
} | |
public UUID getId() { | |
return id; | |
} | |
@SuppressWarnings("unused") | |
public long getCreatedAt() { | |
return createdAt; | |
} | |
public Boolean isClosed () { | |
try (TryLock ignored = lock.get()) { | |
return isClosed; | |
} | |
} | |
@SuppressWarnings("unused") | |
public TXStatus getTxStatus () { | |
try (TryLock ignored = lock.get()) { | |
return txStatus; | |
} | |
} | |
@SuppressWarnings("unused") | |
public boolean isSSL () { | |
try (TryLock ignored = lock.get()) { | |
return isSSL; | |
} | |
} | |
@SuppressWarnings("unused") | |
public String getParam (final String param) { | |
try (TryLock ignored = lock.get()) { | |
return params.get(param); | |
} | |
} | |
@SuppressWarnings("unused") | |
public IPersistentMap getParams () { | |
try (TryLock ignored = lock.get()) { | |
return PersistentHashMap.create(params); | |
} | |
} | |
private void setParam (final String param, final String value) { | |
params.put(param, value); | |
switch (param) { | |
case "client_encoding" -> | |
codecParams.clientCharset = Charset.forName(value); | |
case "server_encoding" -> | |
codecParams.serverCharset = Charset.forName(value); | |
case "DateStyle" -> | |
codecParams.dateStyle = value; | |
case "TimeZone" -> | |
codecParams.timeZone = ZoneId.of(value); | |
case "integer_datetimes" -> | |
codecParams.integerDatetime = value.equals("on"); | |
} | |
} | |
public ConnConfig getConfig () { | |
return config; | |
} | |
public Integer getPort () { | |
return config.port(); | |
} | |
public String getHost () { | |
return config.host(); | |
} | |
public String getUser () { | |
return config.user(); | |
} | |
@SuppressWarnings("unused") | |
private Map<String, String> getPgParams () { | |
return config.pgParams(); | |
} | |
public String getDatabase () { | |
return config.database(); | |
} | |
public String toString () { | |
return String.format("<PG connection %s@%s:%s/%s>", | |
getUser(), | |
getHost(), | |
getPort(), | |
getDatabase()); | |
} | |
public void authenticate () { | |
sendStartupMessage(); | |
interact(Phase.AUTH); | |
} | |
private boolean readSSLResponse () { | |
final char c = (char) IOTool.read(inStream); | |
return switch (c) { | |
case 'N' -> false; | |
case 'S' -> true; | |
default -> throw new PGError("wrong SSL response: %s", c); | |
}; | |
} | |
private static final String[] SSLProtocols = new String[] { | |
"TLSv1.2", | |
"TLSv1.1", | |
"TLSv1" | |
}; | |
private SSLContext getSSLContext () throws NoSuchAlgorithmException { | |
final SSLContext configContext = config.sslContext(); | |
if (configContext == null) { | |
return SSLContext.getDefault(); | |
} | |
else { | |
return configContext; | |
} | |
} | |
private void upgradeToSSL () throws NoSuchAlgorithmException, IOException { | |
final SSLContext sslContext = getSSLContext(); | |
final SSLSocket sslSocket = (SSLSocket) sslContext.getSocketFactory().createSocket( | |
socket, | |
config.host(), | |
config.port(), | |
true | |
); | |
final InputStream sslInStream = new BufferedInputStream( | |
IOTool.getInputStream(sslSocket), | |
config.inStreamBufSize() | |
); | |
final OutputStream sslOutStream = new BufferedOutputStream( | |
IOTool.getOutputStream(sslSocket), | |
config.outStreamBufSize() | |
); | |
sslSocket.setUseClientMode(true); | |
sslSocket.setEnabledProtocols(SSLProtocols); | |
sslSocket.startHandshake(); | |
socket = sslSocket; | |
inStream = sslInStream; | |
outStream = sslOutStream; | |
isSSL = true; | |
} | |
private void preSSLStage () { | |
if (config.useSSL()) { | |
final SSLRequest msg = new SSLRequest(Const.SSL_CODE); | |
sendMessage(msg); | |
flush(); | |
final boolean ssl = readSSLResponse(); | |
if (ssl) { | |
try { | |
upgradeToSSL(); | |
} | |
catch (Throwable e) { | |
close(); | |
throw new PGError( | |
e, | |
"could not upgrade to SSL due to an exception: %s", | |
e.getMessage() | |
); | |
} | |
} | |
else { | |
close(); | |
throw new PGError("the server is configured to not use SSL"); | |
} | |
} | |
} | |
private void connect () { | |
try (TryLock ignored = lock.get()) { | |
_connect_unlocked(); | |
} | |
} | |
private void _connect_unlocked () { | |
final int port = getPort(); | |
final String host = getHost(); | |
try { | |
channel = AsynchronousSocketChannel.open(); | |
channel.connect(new InetSocketAddress(host, port)).get(); | |
} | |
catch (IOException e) { | |
throw new Error("aaa"); | |
} catch (ExecutionException | InterruptedException e) { | |
throw new RuntimeException(e); | |
} | |
// socket = IOTool.socket(host, port); | |
// inStream = new BufferedInputStream( | |
// IOTool.getInputStream(socket), | |
// config.inStreamBufSize() | |
// ); | |
// outStream = new BufferedOutputStream( | |
// IOTool.getOutputStream(socket), | |
// config.outStreamBufSize() | |
// ); | |
} | |
// Send bytes into the output stream. Do not flush the buffer, | |
// must be done manually. | |
private void sendBytes (final byte[] buf) { | |
if (isDebug) { | |
logger.log(config.logLevel()," <- {0}", Arrays.toString(buf)); | |
} | |
IOTool.write(outStream, buf); | |
} | |
// Like sendBytes above but taking boundaries into account. | |
private void sendBytes (final byte[] buf, final int offset, final int len) { | |
IOTool.write(outStream, buf, offset, len); | |
} | |
private void sendBytesCopy(final byte[] bytes) { | |
final ByteBuffer bb = ByteBuffer.allocate(5); | |
bb.put((byte)'d'); | |
bb.putInt(4 + bytes.length); | |
sendBytes(bb.array()); | |
sendBytes(bytes); | |
} | |
public CompletableFuture<Object> sendMessageAsync (final IMessage msg) { | |
if (isDebug) { | |
logger.log(config.logLevel(), " <- {0}", msg); | |
} | |
final ByteBuffer buf = msg.encode(codecParams.clientCharset); | |
return sendByteBufferAsync(buf); | |
} | |
private void sendMessage (final IMessage msg) { | |
if (isDebug) { | |
logger.log(config.logLevel(), " <- {0}", msg); | |
} | |
final ByteBuffer buf = msg.encode(codecParams.clientCharset); | |
IOTool.write(outStream, buf.array()); | |
} | |
private String generateStatement () { | |
return String.format("s%d", nextInt()); | |
} | |
private String generatePortal () { | |
return String.format("p%d", nextInt()); | |
} | |
public CompletableFuture<Object> sendStartupMessageAsync () { | |
final StartupMessage msg = | |
new StartupMessage( | |
config.protocolVersion(), | |
config.user(), | |
config.database(), | |
config.pgParams() | |
); | |
return sendMessageAsync(msg); | |
} | |
private void sendStartupMessage () { | |
final StartupMessage msg = | |
new StartupMessage( | |
config.protocolVersion(), | |
config.user(), | |
config.database(), | |
config.pgParams() | |
); | |
sendMessage(msg); | |
} | |
private void sendCopyData (final byte[] buf) { | |
sendMessage(new CopyData(ByteBuffer.wrap(buf))); | |
} | |
private void sendCopyDone () { | |
sendMessage(CopyDone.INSTANCE); | |
} | |
private void sendCopyFail (final String errorMessage) { | |
sendMessage(new CopyFail(errorMessage)); | |
} | |
private void sendQuery (final String query) { | |
sendMessage(new Query(query)); | |
} | |
private void sendPassword (final String password) { | |
sendMessage(new PasswordMessage(password)); | |
} | |
private void sendSync () { | |
sendMessage(Sync.INSTANCE); | |
} | |
private void sendFlush () { | |
sendMessage(Flush.INSTANCE); | |
} | |
private void sendTerminate () { | |
sendMessage(Terminate.INSTANCE); | |
} | |
@SuppressWarnings("unused") | |
private void sendSSLRequest () { | |
sendMessage(new SSLRequest(Const.SSL_CODE)); | |
} | |
private record AsyncAttachment (ByteBuffer bb, AsynchronousSocketChannel channel) {} | |
private CompletableFuture<ByteBuffer> readByteBufferAsync (final int size) { | |
final CompletableFuture<ByteBuffer> fut = new CompletableFuture<>(); | |
final ByteBuffer bb = ByteBuffer.allocate(size); | |
final AsyncAttachment attachment = new AsyncAttachment(bb, channel); | |
final CompletionHandler<Integer, AsyncAttachment> handler = new CompletionHandler<>() { | |
@Override | |
public void completed(Integer result, AsyncAttachment attachment) { | |
if (attachment.bb.remaining() == 0) { | |
fut.complete(attachment.bb.rewind()); | |
} | |
else { | |
attachment.channel.read(attachment.bb, attachment, this); | |
} | |
} | |
@Override | |
public void failed(Throwable exc, AsyncAttachment attachment) { | |
fut.completeExceptionally(exc); | |
} | |
}; | |
channel.read(bb, attachment, handler); | |
return fut; | |
} | |
private CompletableFuture<Object> sendByteBufferAsync (ByteBuffer bb) { | |
bb.rewind(); | |
final CompletableFuture<Object> fut = new CompletableFuture<>(); | |
final AsyncAttachment attachment = new AsyncAttachment(bb, channel); | |
final CompletionHandler<Integer, AsyncAttachment> handler = new CompletionHandler<>() { | |
@Override | |
public void completed(Integer result, AsyncAttachment attachment) { | |
if (attachment.bb.remaining() == 0) { | |
fut.complete(true); | |
} | |
else { | |
attachment.channel.write(attachment.bb, attachment, this); | |
} | |
} | |
@Override | |
public void failed(Throwable exc, AsyncAttachment attachment) { | |
fut.completeExceptionally(exc); | |
} | |
}; | |
channel.write(bb, attachment, handler); | |
return fut; | |
} | |
public CompletableFuture<Object> readMessageAsync() { | |
return readByteBufferAsync(5).thenComposeAsync((ByteBuffer bbHeader) -> { | |
final char tag = (char) bbHeader.get(); | |
final int bodySize = bbHeader.getInt() - 4; | |
return readByteBufferAsync(bodySize).thenApplyAsync((ByteBuffer bbBody) -> switch (tag) { | |
case 'R' -> AuthenticationResponse.fromByteBuffer(bbBody).parseResponse(bbBody, codecParams.serverCharset); | |
case 'S' -> ParameterStatus.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'Z' -> ReadyForQuery.fromByteBuffer(bbBody); | |
case 'C' -> CommandComplete.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'T' -> RowDescription.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'D' -> DataRow.fromByteBuffer(bbBody); | |
case 'E' -> ErrorResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'K' -> BackendKeyData.fromByteBuffer(bbBody); | |
case '1' -> ParseComplete.INSTANCE; | |
case '2' -> BindComplete.INSTANCE; | |
case '3' -> CloseComplete.INSTANCE; | |
case 't' -> ParameterDescription.fromByteBuffer(bbBody); | |
case 'H' -> CopyOutResponse.fromByteBuffer(bbBody); | |
case 'd' -> CopyData.fromByteBuffer(bbBody); | |
case 'c' -> CopyDone.INSTANCE; | |
case 'I' -> EmptyQueryResponse.INSTANCE; | |
case 'n' -> NoData.INSTANCE; | |
case 'v' -> NegotiateProtocolVersion.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'A' -> NotificationResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'N' -> NoticeResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 's' -> PortalSuspended.INSTANCE; | |
case 'G' -> CopyInResponse.fromByteBuffer(bbBody); | |
default -> throw new PGError("Unknown message: %s", tag); | |
}); | |
}); | |
} | |
private Object readMessage (final boolean skipMode) { | |
final byte[] bufHeader = IOTool.readNBytes(inStream, 5); | |
final ByteBuffer bbHeader = ByteBuffer.wrap(bufHeader); | |
final char tag = (char) bbHeader.get(); | |
final int bodySize = bbHeader.getInt() - 4; | |
// skipMode means there has been an exception before. There is no need | |
// to parse data-heavy messages as we're going to throw an exception | |
// at the end anyway. If there is a DataRow or a CopyData message, | |
// just skip it. | |
if (skipMode) { | |
if (tag == 'D' || tag == 'd') { | |
IOTool.skip(inStream, bodySize); | |
return SkippedMessage.INSTANCE; | |
} | |
} | |
byte[] bufBody = IOTool.readNBytes(inStream, bodySize); | |
ByteBuffer bbBody = ByteBuffer.wrap(bufBody); | |
return switch (tag) { | |
case 'R' -> AuthenticationResponse.fromByteBuffer(bbBody).parseResponse(bbBody, codecParams.serverCharset); | |
case 'S' -> ParameterStatus.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'Z' -> ReadyForQuery.fromByteBuffer(bbBody); | |
case 'C' -> CommandComplete.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'T' -> RowDescription.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'D' -> DataRow.fromByteBuffer(bbBody); | |
case 'E' -> ErrorResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'K' -> BackendKeyData.fromByteBuffer(bbBody); | |
case '1' -> ParseComplete.INSTANCE; | |
case '2' -> BindComplete.INSTANCE; | |
case '3' -> CloseComplete.INSTANCE; | |
case 't' -> ParameterDescription.fromByteBuffer(bbBody); | |
case 'H' -> CopyOutResponse.fromByteBuffer(bbBody); | |
case 'd' -> CopyData.fromByteBuffer(bbBody); | |
case 'c' -> CopyDone.INSTANCE; | |
case 'I' -> EmptyQueryResponse.INSTANCE; | |
case 'n' -> NoData.INSTANCE; | |
case 'v' -> NegotiateProtocolVersion.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'A' -> NotificationResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 'N' -> NoticeResponse.fromByteBuffer(bbBody, codecParams.serverCharset); | |
case 's' -> PortalSuspended.INSTANCE; | |
case 'G' -> CopyInResponse.fromByteBuffer(bbBody); | |
default -> throw new PGError("Unknown message: %s", tag); | |
}; | |
} | |
private void sendDescribeStatement (final String statement) { | |
final Describe msg = new Describe(SourceType.STATEMENT, statement); | |
sendMessage(msg); | |
} | |
private void sendDescribePortal (final String portal) { | |
final Describe msg = new Describe(SourceType.PORTAL, portal); | |
sendMessage(msg); | |
} | |
private void sendExecute (final String portal, final long maxRows) { | |
final Execute msg = new Execute(portal, maxRows); | |
sendMessage(msg); | |
} | |
public Object query(final String sql) { | |
return query(sql, ExecuteParams.INSTANCE); | |
} | |
public Object query(final String sql, final ExecuteParams executeParams) { | |
try (TryLock ignored = lock.get()) { | |
sendQuery(sql); | |
return interact(Phase.QUERY, executeParams).getResult(); | |
} | |
} | |
public PreparedStatement prepare (final String sql) { | |
return prepare(sql, ExecuteParams.INSTANCE); | |
} | |
public PreparedStatement prepare (final String sql, final ExecuteParams executeParams) { | |
try (TryLock ignored = lock.get()) { | |
return _prepare_unlocked(sql, executeParams); | |
} | |
} | |
private PreparedStatement _prepare_unlocked ( | |
final String sql, | |
final ExecuteParams executeParams | |
) { | |
final String statement = generateStatement(); | |
final List<OID> OIDsProvided = executeParams.OIDs(); | |
final int OIDsProvidedCount = OIDsProvided.size(); | |
final List<Object> params = executeParams.params(); | |
final int paramCount = params.size(); | |
final int maxIndex = Math.max(OIDsProvidedCount, paramCount); | |
final OID[] OIDs = new OID[maxIndex]; | |
for (int i = 0; i < maxIndex; i++) { | |
if (i < OIDsProvidedCount) { | |
OIDs[i] = OIDsProvided.get(i); | |
} | |
else { | |
Object param = params.get(i); | |
OIDs[i] = OIDHint.guessOID(param); | |
} | |
} | |
final Parse parse = new Parse(statement, sql, OIDs); | |
sendMessage(parse); | |
sendDescribeStatement(statement); | |
sendSync(); | |
sendFlush(); | |
final Accum acc = interact(Phase.PREPARE); | |
final ParameterDescription paramDesc = acc.getParameterDescription(); | |
final RowDescription rowDescription = acc.getRowDescription(); | |
return new PreparedStatement(parse, paramDesc, rowDescription); | |
} | |
private void sendBind (final String portal, | |
final PreparedStatement stmt, | |
final ExecuteParams executeParams | |
) { | |
final List<Object> params = executeParams.params(); | |
final OID[] OIDs = stmt.parameterDescription().OIDs(); | |
final int size = params.size(); | |
if (size != OIDs.length) { | |
throw new PGError( | |
"Wrong parameters count: %s (must be %s)", | |
size, OIDs.length | |
); | |
} | |
final Format paramsFormat = (executeParams.binaryEncode() || config.binaryEncode()) ? Format.BIN : Format.TXT; | |
final Format columnFormat = (executeParams.binaryDecode() || config.binaryDecode()) ? Format.BIN : Format.TXT; | |
final byte[][] bytes = new byte[size][]; | |
String statement = stmt.parse().statement(); | |
int i = -1; | |
for (final Object param: params) { | |
i++; | |
if (param == null) { | |
bytes[i] = null; | |
continue; | |
} | |
OID oid = OIDs[i]; | |
switch (paramsFormat) { | |
case BIN -> { | |
ByteBuffer buf = EncoderBin.encode(param, oid, codecParams); | |
bytes[i] = buf.array(); | |
} | |
case TXT -> { | |
String value = EncoderTxt.encode(param, oid, codecParams); | |
bytes[i] = value.getBytes(codecParams.clientCharset); | |
} | |
default -> | |
throw new PGError("unknown format: %s", paramsFormat); | |
} | |
} | |
final Bind msg = new Bind( | |
portal, | |
statement, | |
bytes, | |
paramsFormat, | |
columnFormat | |
); | |
for (byte[] buf: msg.toByteArrays()) { | |
sendBytes(buf); | |
} | |
} | |
private void flush () { | |
IOTool.flush(outStream); | |
} | |
public Object executeStatement(final PreparedStatement stmt) { | |
return executeStatement(stmt, ExecuteParams.INSTANCE); | |
} | |
public Object executeStatement ( | |
final PreparedStatement stmt, | |
final ExecuteParams executeParams | |
) { | |
try (TryLock ignored = lock.get()) { | |
final String portal = generatePortal(); | |
sendBind(portal, stmt, executeParams); | |
sendDescribePortal(portal); | |
sendExecute(portal, executeParams.maxRows()); | |
sendClosePortal(portal); | |
sendSync(); | |
sendFlush(); | |
return interact(Phase.EXECUTE, executeParams).getResult(); | |
} | |
} | |
public Object execute (final String sql) { | |
return execute(sql, ExecuteParams.INSTANCE); | |
} | |
public Object execute (final String sql, final List<Object> params) { | |
return execute(sql, ExecuteParams.builder().params(params).build()); | |
} | |
public Object execute (final String sql, final ExecuteParams executeParams) { | |
try (final TryLock ignored = lock.get()) { | |
final PreparedStatement stmt = prepare(sql, executeParams); | |
final String portal = generatePortal(); | |
sendBind(portal, stmt, executeParams); | |
sendDescribePortal(portal); | |
sendExecute(portal, executeParams.maxRows()); | |
sendClosePortal(portal); | |
sendCloseStatement(stmt); | |
sendSync(); | |
sendFlush(); | |
return interact(Phase.EXECUTE, executeParams).getResult(); | |
} | |
} | |
private void sendCloseStatement (final PreparedStatement stmt) { | |
final Close msg = new Close(SourceType.STATEMENT, stmt.parse().statement()); | |
sendMessage(msg); | |
} | |
private void sendCloseStatement (final String statement) { | |
final Close msg = new Close(SourceType.STATEMENT, statement); | |
sendMessage(msg); | |
} | |
private void sendClosePortal (final String portal) { | |
final Close msg = new Close(SourceType.PORTAL, portal); | |
sendMessage(msg); | |
} | |
public void closeStatement (final PreparedStatement statement) { | |
closeStatement(statement.parse().statement()); | |
} | |
public void closeStatement (final String statement) { | |
try (TryLock ignored = lock.get()) { | |
sendCloseStatement(statement); | |
sendSync(); | |
sendFlush(); | |
interact(Phase.CLOSE); | |
} | |
} | |
public CompletableFuture<Accum> handleAuthenticationSASLAsync(final AuthenticationSASL msg, final Accum acc) { | |
acc.scramPipeline = ScramSha256.pipeline(); | |
// if (msg.isScramSha256()) { | |
final ScramSha256.Step1 step1 = ScramSha256.step1_clientFirstMessage( | |
config.user(), config.password() | |
); | |
final SASLInitialResponse msgSASL = new SASLInitialResponse( | |
SASL.SCRAM_SHA_256, | |
step1.clientFirstMessage() | |
); | |
acc.scramPipeline.step1 = step1; | |
return sendMessageAsync(msgSASL).thenComposeAsync( | |
(Object ignored) -> CompletableFuture.completedFuture(acc) | |
); | |
// } | |
// if (msg.isScramSha256Plus()) { | |
// throw new PGError("SASL SCRAM SHA 256 PLUS method is not implemented yet"); | |
// } | |
} | |
public CompletableFuture<Accum> handleMessageAsync(final Object msg, final Accum acc) { | |
return switch (msg.getClass().getName()) { | |
case "AuthenticationSASL" -> handleAuthenticationSASLAsync((AuthenticationSASL)msg, acc); | |
default -> throw new RuntimeException("aaa"); | |
}; | |
// return CompletableFuture.completedFuture(acc); | |
} | |
public CompletableFuture<Accum> interactAsyncNext (final Accum acc) { | |
return readMessageAsync().thenComposeAsync((Object msg) -> { | |
if (isEnough(msg, acc.phase)) { | |
return CompletableFuture.completedFuture(acc); | |
} | |
else { | |
return handleMessageAsync(msg, acc).thenComposeAsync(this::interactAsyncNext); | |
} | |
}); | |
} | |
public CompletableFuture<Accum> interactAsync (final Phase phase, final ExecuteParams executeParams) { | |
final Accum acc = new Accum(phase, executeParams); | |
return interactAsyncNext(acc); | |
} | |
private Accum interact(final Phase phase, final ExecuteParams executeParams) { | |
flush(); | |
final Accum acc = new Accum(phase, executeParams); | |
while (true) { | |
final Object msg = readMessage(acc.hasException()); | |
if (isDebug) { | |
logger.log(config.logLevel(), " -> {0}", msg); | |
} | |
handleMessage(msg, acc); | |
if (isEnough(msg, phase)) { | |
break; | |
} | |
} | |
acc.maybeThrowError(); | |
return acc; | |
} | |
private Accum interact(final Phase phase) { | |
return interact(phase, ExecuteParams.INSTANCE); | |
} | |
private void handleMessage(final Object msg, final Accum acc) { | |
switch (msg.getClass().getSimpleName()) { | |
case | |
"NotificationResponse" -> | |
handleNotificationResponse((NotificationResponse)msg); | |
case | |
"NoData", | |
"EmptyQueryResponse", | |
"CloseComplete", | |
"BindComplete", | |
"AuthenticationOk", | |
"CopyDone", | |
"SkippedMessage"-> {} | |
case | |
"AuthenticationCleartextPassword" -> | |
handleAuthenticationCleartextPassword(); | |
case | |
"AuthenticationSASL" -> | |
handleAuthenticationSASL((AuthenticationSASL)msg, acc); | |
case | |
"AuthenticationSASLContinue" -> | |
handleAuthenticationSASLContinue((AuthenticationSASLContinue)msg, acc); | |
case | |
"AuthenticationSASLFinal" -> | |
handleAuthenticationSASLFinal((AuthenticationSASLFinal)msg, acc); | |
case | |
"NoticeResponse" -> | |
handleNoticeResponse((NoticeResponse)msg); | |
case | |
"ParameterStatus" -> | |
handleParameterStatus((ParameterStatus)msg); | |
case | |
"RowDescription" -> | |
handleRowDescription((RowDescription)msg, acc); | |
case | |
"DataRow" -> | |
handleDataRow((DataRow)msg, acc); | |
case | |
"ReadyForQuery" -> | |
handleReadyForQuery((ReadyForQuery)msg); | |
case | |
"PortalSuspended" -> | |
handlePortalSuspended((PortalSuspended)msg, acc); | |
case | |
"AuthenticationMD5Password" -> | |
handleAuthenticationMD5Password((AuthenticationMD5Password)msg); | |
case | |
"NegotiateProtocolVersion" -> | |
handleNegotiateProtocolVersion((NegotiateProtocolVersion)msg); | |
case | |
"CommandComplete" -> | |
handleCommandComplete((CommandComplete)msg, acc); | |
case | |
"ErrorResponse" -> | |
handleErrorResponse((ErrorResponse)msg, acc); | |
case | |
"BackendKeyData" -> | |
handleBackendKeyData((BackendKeyData)msg); | |
case | |
"ParameterDescription" -> | |
handleParameterDescription((ParameterDescription)msg, acc); | |
case | |
"ParseComplete" -> | |
handleParseComplete((ParseComplete)msg, acc); | |
case | |
"CopyOutResponse" -> | |
handleCopyOutResponse((CopyOutResponse)msg, acc); | |
case | |
"CopyData" -> | |
handleCopyData((CopyData)msg, acc); | |
case | |
"CopyInResponse" -> | |
handleCopyInResponse(acc); | |
default -> | |
throw new PGError("Cannot handle this message: %s", msg); | |
} | |
} | |
private void handleAuthenticationSASL(final AuthenticationSASL msg, final Accum acc) { | |
acc.scramPipeline = ScramSha256.pipeline(); | |
if (msg.isScramSha256()) { | |
final ScramSha256.Step1 step1 = ScramSha256.step1_clientFirstMessage( | |
config.user(), config.password() | |
); | |
final SASLInitialResponse msgSASL = new SASLInitialResponse( | |
SASL.SCRAM_SHA_256, | |
step1.clientFirstMessage() | |
); | |
acc.scramPipeline.step1 = step1; | |
sendMessage(msgSASL); | |
flush(); | |
} | |
if (msg.isScramSha256Plus()) { | |
throw new PGError("SASL SCRAM SHA 256 PLUS method is not implemented yet"); | |
} | |
} | |
private void handleAuthenticationSASLContinue(final AuthenticationSASLContinue msg, final Accum acc) { | |
final ScramSha256.Step1 step1 = acc.scramPipeline.step1; | |
final String serverFirstMessage = msg.serverFirstMessage(); | |
final ScramSha256.Step2 step2 = ScramSha256.step2_serverFirstMessage(serverFirstMessage); | |
final ScramSha256.Step3 step3 = ScramSha256.step3_clientFinalMessage(step1, step2); | |
acc.scramPipeline.step2 = step2; | |
acc.scramPipeline.step3 = step3; | |
final SASLResponse msgSASL = new SASLResponse(step3.clientFinalMessage()); | |
sendMessage(msgSASL); | |
flush(); | |
} | |
private void handleAuthenticationSASLFinal(final AuthenticationSASLFinal msg, final Accum acc) { | |
final String serverFinalMessage = msg.serverFinalMessage(); | |
final ScramSha256.Step4 step4 = ScramSha256.step4_serverFinalMessage(serverFinalMessage); | |
acc.scramPipeline.step4 = step4; | |
final ScramSha256.Step3 step3 = acc.scramPipeline.step3; | |
ScramSha256.step5_verifyServerSignature(step3, step4); | |
} | |
private void handleCopyInResponseStream(Accum acc) { | |
final int bufSize = acc.executeParams.copyBufSize(); | |
final byte[] buf = new byte[bufSize]; | |
final ByteBuffer bbLead = ByteBuffer.allocate(5); | |
bbLead.put((byte)'d'); | |
InputStream inputStream = acc.executeParams.inputStream(); | |
Throwable e = null; | |
int read; | |
while (true) { | |
try { | |
read = inputStream.read(buf); | |
} | |
catch (Throwable caught) { | |
e = caught; | |
break; | |
} | |
if (read == -1) { | |
break; | |
} | |
bbLead.position(1); | |
bbLead.putInt(4 + read); | |
sendBytes(bbLead.array()); | |
sendBytes(buf, 0, read); | |
} | |
if (e == null) { | |
sendCopyDone(); | |
} | |
else { | |
acc.setException(e); | |
sendCopyFail(Const.COPY_FAIL_EXCEPTION_MSG); | |
} | |
} | |
private void handleCopyInResponseData (final Accum acc, final Iterator<List<Object>> rows) { | |
final ExecuteParams executeParams = acc.executeParams; | |
final CopyFormat format = executeParams.copyFormat(); | |
Throwable e = null; | |
switch (format) { | |
case CSV: | |
String line; | |
while (rows.hasNext()) { | |
try { | |
line = Copy.encodeRowCSV(rows.next(), executeParams, codecParams); | |
} | |
catch (Throwable caught) { | |
e = caught; | |
break; | |
} | |
final byte[] bytes = line.getBytes(StandardCharsets.UTF_8); | |
sendBytesCopy(bytes); | |
} | |
break; | |
case BIN: | |
ByteBuffer buf; | |
// TODO: use sendBytes | |
sendCopyData(Copy.COPY_BIN_HEADER); | |
while (rows.hasNext()) { | |
try { | |
buf = Copy.encodeRowBin(rows.next(), executeParams, codecParams); | |
} | |
catch (Throwable caught) { | |
e = caught; | |
break; | |
} | |
sendBytesCopy(buf.array()); | |
} | |
if (e == null) { | |
sendBytes(Copy.MSG_COPY_BIN_TERM); | |
} | |
break; | |
case TAB: | |
e = new PGError("TAB COPY format is not implemented"); | |
break; | |
} | |
if (e == null) { | |
sendCopyDone(); | |
} | |
else { | |
acc.setException(e); | |
sendCopyFail(Const.COPY_FAIL_EXCEPTION_MSG); | |
} | |
} | |
private void handleCopyInResponseRows (final Accum acc) { | |
final Iterator<List<Object>> iterator = acc.executeParams.copyInRows() | |
.stream() | |
.filter(Objects::nonNull) | |
.iterator(); | |
handleCopyInResponseData(acc, iterator); | |
} | |
private void handleCopyInResponseMaps(final Accum acc) { | |
final List<Object> keys = acc.executeParams.copyInKeys(); | |
final Iterator<List<Object>> iterator = acc.executeParams.copyInMaps() | |
.stream() | |
.filter(Objects::nonNull) | |
.map(map -> mapToRow(map, keys)) | |
.iterator(); | |
handleCopyInResponseData(acc, iterator); | |
} | |
private void handleCopyInResponse(Accum acc) { | |
// These three methods only send data but do not read. | |
// Thus, we rely on sendBytes which doesn't trigger flushing | |
// the output stream. Flushing is expensive and thus must be called | |
// manually when all the data has been sent. | |
if (acc.executeParams.isCopyInRows()) { | |
handleCopyInResponseRows(acc); | |
} | |
else if (acc.executeParams.isCopyInMaps()) { | |
handleCopyInResponseMaps(acc); | |
} else { | |
handleCopyInResponseStream(acc); | |
} | |
// Finally, we flush the output stream so all unsent bytes get sent. | |
flush(); | |
} | |
private void handlePortalSuspended(final PortalSuspended msg, final Accum acc) { | |
acc.handlePortalSuspended(msg); | |
} | |
private void handlerCall(final IFn f, final Object arg) { | |
if (f == null) { | |
logger.log(config.logLevel(), arg); | |
} | |
else { | |
Agent.soloExecutor.submit(() -> { | |
f.invoke(arg); | |
}); | |
} | |
} | |
private void handleNotificationResponse(final NotificationResponse msg) { | |
handlerCall(config.fnNotification(), msg.toClojure()); | |
} | |
private void handleNoticeResponse(final NoticeResponse msg) { | |
handlerCall(config.fnNotice(), msg.toClojure()); | |
} | |
private void handleNegotiateProtocolVersion(final NegotiateProtocolVersion msg) { | |
handlerCall(config.fnProtocolVersion(), msg.toClojure()); | |
} | |
private void handleAuthenticationMD5Password(final AuthenticationMD5Password msg) { | |
final String hashed = MD5.hashPassword( | |
config.user(), | |
config.password(), | |
msg.salt() | |
); | |
sendPassword(hashed); | |
flush(); | |
} | |
private void handleCopyOutResponse(final CopyOutResponse msg, final Accum acc) { | |
acc.handleCopyOutResponse(msg); | |
} | |
private void handleCopyData(final CopyData msg, final Accum acc) { | |
try { | |
handleCopyDataUnsafe(msg, acc); | |
} catch (Throwable e) { | |
acc.setException(e); | |
} | |
} | |
private void handleCopyDataUnsafe(final CopyData msg, final Accum acc) throws IOException { | |
final OutputStream outputStream = acc.executeParams.outputStream(); | |
final byte[] bytes = msg.buf().array(); | |
outputStream.write(bytes); | |
} | |
@SuppressWarnings("unused") | |
public AutoCloseable getLock() { | |
return lock.get(); | |
} | |
@SuppressWarnings("unused") | |
public Object copy (final String sql, final ExecuteParams executeParams) { | |
try (TryLock ignored = lock.get()) { | |
sendQuery(sql); | |
final Accum acc = interact(Phase.COPY, executeParams); | |
return acc.getResult(); | |
} | |
} | |
private static List<Object> mapToRow(final Map<?,?> map, final List<Object> keys) { | |
final List<Object> row = new ArrayList<>(keys.size()); | |
for (final Object key: keys) { | |
row.add(map.get(key)); | |
} | |
return row; | |
} | |
private void handleParseComplete(final ParseComplete msg, final Accum acc) { | |
acc.handleParseComplete(msg); | |
} | |
private void handleParameterDescription (final ParameterDescription msg, final Accum acc) { | |
acc.handleParameterDescription(msg); | |
} | |
private void handleAuthenticationCleartextPassword() { | |
sendPassword(config.password()); | |
flush(); | |
} | |
private void handleParameterStatus(final ParameterStatus msg) { | |
setParam(msg.param(), msg.value()); | |
} | |
private static void handleRowDescription(final RowDescription msg, final Accum acc) { | |
acc.handleRowDescription(msg); | |
} | |
private void handleDataRowUnsafe(final DataRow msg, final Accum acc) { | |
final RowDescription rowDescription = acc.getRowDescription(); | |
final Map<Object, Short> keysIndex = acc.getCurrentKeysIndex(); | |
final LazyMap lazyMap = new LazyMap(msg, rowDescription, keysIndex, codecParams); | |
acc.addClojureRow(lazyMap); | |
} | |
private void handleDataRow(final DataRow msg, final Accum acc) { | |
try { | |
handleDataRowUnsafe(msg, acc); | |
} | |
catch (Throwable e) { | |
acc.setException(e); | |
} | |
} | |
private void handleReadyForQuery(final ReadyForQuery msg) { | |
txStatus = msg.txStatus(); | |
} | |
private static void handleCommandComplete(final CommandComplete msg, final Accum acc) { | |
acc.handleCommandComplete(msg); | |
} | |
private static void handleErrorResponse(final ErrorResponse msg, final Accum acc) { | |
acc.addErrorResponse(msg); | |
} | |
private void handleBackendKeyData(final BackendKeyData msg) { | |
pid = msg.pid(); | |
secretKey = msg.secretKey(); | |
} | |
private static Boolean isEnough (final Object msg, final Phase phase) { | |
return switch (msg.getClass().getSimpleName()) { | |
case "ReadyForQuery" -> true; | |
case "ErrorResponse" -> phase == Phase.AUTH; | |
default -> false; | |
}; | |
} | |
@SuppressWarnings("unused") | |
public static Connection clone (final Connection conn) { | |
return new Connection(conn.config); | |
} | |
@SuppressWarnings("unused") | |
public static void cancelRequest(final Connection conn) { | |
final CancelRequest msg = new CancelRequest(Const.CANCEL_CODE, conn.pid, conn.secretKey); | |
final Connection temp = new Connection(conn.config, false); | |
temp.sendMessage(msg); | |
temp.close(); | |
} | |
@SuppressWarnings("unused") | |
public void begin () { | |
try (TryLock ignored = lock.get()) { | |
sendQuery("BEGIN"); | |
interact(Phase.QUERY); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void commit () { | |
try (TryLock ignored = lock.get()) { | |
sendQuery("COMMIT"); | |
interact(Phase.QUERY); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void rollback () { | |
try (TryLock ignored = lock.get()) { | |
sendQuery("ROLLBACK"); | |
interact(Phase.QUERY); | |
} | |
} | |
@SuppressWarnings("unused") | |
public boolean isIdle () { | |
try (TryLock ignored = lock.get()) { | |
return txStatus == TXStatus.IDLE; | |
} | |
} | |
@SuppressWarnings("unused") | |
public boolean isTxError () { | |
try (TryLock ignored = lock.get()) { | |
return txStatus == TXStatus.ERROR; | |
} | |
} | |
@SuppressWarnings("unused") | |
public boolean isTransaction () { | |
try (TryLock ignored = lock.get()) { | |
return txStatus == TXStatus.TRANSACTION; | |
} | |
} | |
@SuppressWarnings("unused") | |
public void setTxLevel (final TxLevel level) { | |
try (TryLock ignored = lock.get()) { | |
sendQuery(SQL.SQLSetTxLevel(level)); | |
interact(Phase.QUERY); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void setTxReadOnly () { | |
try (TryLock ignored = lock.get()) { | |
sendQuery(SQL.SQLSetTxReadOnly); | |
interact(Phase.QUERY); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void listen (final String channel) { | |
try (TryLock ignored = lock.get()) { | |
query(String.format("listen %s", SQL.quoteChannel(channel))); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void unlisten (final String channel) { | |
try (TryLock ignored = lock.get()) { | |
query(String.format("unlisten %s", SQL.quoteChannel(channel))); | |
} | |
} | |
@SuppressWarnings("unused") | |
public void notify (final String channel, final String message) { | |
try (TryLock ignored = lock.get()) { | |
final List<Object> params = List.of(channel, message); | |
execute("select pg_notify($1, $2)", params); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment