Created
November 9, 2023 15:33
-
-
Save miguelff/0e49db62c80780f0a22d77f86f7d70fc to your computer and use it in GitHub Desktop.
Channels based linearization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/main.go b/main.go | |
index fdb5ccf..8a68f58 100644 | |
--- a/main.go | |
+++ b/main.go | |
@@ -34,9 +34,57 @@ type mysqlConnKey struct { | |
username, pass, session string | |
} | |
+type request struct { | |
+ query string | |
+ session string | |
+} | |
+ | |
type timedConn struct { | |
*mysql.Conn | |
lastUsed time.Time | |
+ reqs chan (*request) | |
+ res chan (*psdbv1alpha1.ExecuteResponse) | |
+} | |
+ | |
+func (c *timedConn) close() { | |
+ close(c.reqs) | |
+ close(c.res) | |
+ c.Conn.Close() | |
+} | |
+ | |
+func newTimedConn(my *mysql.Conn) *timedConn { | |
+ conn := timedConn{ | |
+ Conn: my, | |
+ lastUsed: time.Now(), | |
+ reqs: make(chan (*request)), | |
+ res: make(chan (*psdbv1alpha1.ExecuteResponse)), | |
+ } | |
+ | |
+ go func() { | |
+ for { | |
+ select { | |
+ case request, ok := <-conn.reqs: | |
+ if !ok { | |
+ return | |
+ } | |
+ conn.lastUsed = time.Now() | |
+ qr, err := conn.ExecuteFetch(request.query, int(*flagMySQLMaxRows), true) | |
+ conn.res <- &psdbv1alpha1.ExecuteResponse{ | |
+ Session: request.session, | |
+ Result: sqltypes.ResultToProto3(qr), | |
+ Error: vterrors.ToVTRPC(err), | |
+ } | |
+ case <-time.After(*flagMySQLIdleTimeout): | |
+ expiration := time.Now().Add(-*flagMySQLIdleTimeout) | |
+ if conn.lastUsed.Before(expiration) { | |
+ conn.close() | |
+ } | |
+ return | |
+ } | |
+ } | |
+ }() | |
+ | |
+ return &conn | |
} | |
var ( | |
@@ -57,7 +105,7 @@ var ( | |
// since this isn't meant to truly represent reality, it's possible you | |
// can do things with connections locally by munging session ids or auth | |
// that aren't allowed on PlanetScale. This is meant to just mimic the public API. | |
-func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, error) { | |
+func getConn(ctx context.Context, uname, pass, session string) (*timedConn, error) { | |
key := mysqlConnKey{uname, pass, session} | |
// check first if there's already a connection | |
@@ -65,7 +113,7 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err | |
if conn, ok := connPool[key]; ok { | |
connMu.RUnlock() | |
conn.lastUsed = time.Now() | |
- return conn.Conn, nil | |
+ return conn, nil | |
} | |
connMu.RUnlock() | |
@@ -79,15 +127,12 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err | |
// lock to write to map | |
connMu.Lock() | |
- connPool[key] = &timedConn{rawConn, time.Now()} | |
+ connPool[key] = newTimedConn(rawConn) | |
connMu.Unlock() | |
// since it was parallel, the last one would have won and been written | |
// so re-read back so we use the conn that was actually stored in the pool | |
- connMu.RLock() | |
- conn := connPool[key] | |
- connMu.RUnlock() | |
- return conn.Conn, nil | |
+ return getConn(ctx, uname, pass, session) | |
} | |
// dial connects to the underlying MySQL server, and switches to the underlying | |
@@ -187,13 +232,9 @@ func (s *server) Execute( | |
return nil, err | |
} | |
- // This is a gross simplificiation, but is likely sufficient | |
- qr, err := conn.ExecuteFetch(query, int(*flagMySQLMaxRows), true) | |
- return connect.NewResponse(&psdbv1alpha1.ExecuteResponse{ | |
- Session: session, | |
- Result: sqltypes.ResultToProto3(qr), | |
- Error: vterrors.ToVTRPC(err), | |
- }), nil | |
+ conn.reqs <- &request{query, session} | |
+ res := <-conn.res | |
+ return connect.NewResponse(res), nil | |
} | |
func initConnPool() { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment