Last active
December 22, 2020 10:05
-
-
Save padurean/3c88c6c7054522d16fb7d6683a1be928 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package websocket | |
import ( | |
"fmt" | |
"io" | |
"strings" | |
"time" | |
logger "github.com/rs/zerolog/log" | |
"golang.org/x/net/websocket" | |
) | |
// WebSocket ... | |
type WebSocket struct { | |
Name string | |
KeepAliveInterval time.Duration | |
MaxClientIdleTime time.Duration | |
SendDeadline time.Duration | |
PollValueInterval time.Duration | |
GetValue func() interface{} | |
NotEqual func(a, b interface{}) bool | |
} | |
// SetDefaultDurations sets defaults durations; user code still needs to set the other fields | |
func (ws *WebSocket) SetDefaultDurations() { | |
ws.KeepAliveInterval = 20 * time.Second | |
ws.MaxClientIdleTime = 60 * time.Second | |
ws.SendDeadline = 5 * time.Second | |
ws.PollValueInterval = 3 * time.Second | |
} | |
// SendAndWaitForAcknowledgement ... | |
func (ws *WebSocket) SendAndWaitForAcknowledgement(conn *websocket.Conn) { | |
logPrefix := fmt.Sprintf("WebSocket %s -", ws.Name) | |
logger.Info().Msgf("%s START connection", logPrefix) | |
pollValueChangedTicker := time.NewTicker(ws.PollValueInterval) | |
defer func() { | |
pollValueChangedTicker.Stop() | |
conn.Close() | |
logger.Info().Msgf("%s END connection", logPrefix) | |
}() | |
var lastValue, currValue interface{} | |
lastSentAt := time.Now() | |
clientLastSeenAt := time.Now() | |
send := func() error { | |
now := time.Now() | |
if err := conn.SetWriteDeadline(now.Add(ws.SendDeadline)); err != nil { | |
logger.Err(err).Msgf( | |
"%s ABORT send: error setting write deadline", logPrefix) | |
return err | |
} | |
msg := fmt.Sprintf("%v", currValue) | |
if err := websocket.Message.Send(conn, msg); err != nil { | |
logger.Err(err).Msgf("%s ABORT send: error sending", logPrefix) | |
return err | |
} | |
logger.Info().Msgf("%s SEND: %v", logPrefix, msg) | |
lastValue = currValue | |
lastSentAt = now | |
return nil | |
} | |
receive := func() error { | |
var received string | |
err := websocket.Message.Receive(conn, &received) | |
if err != nil { | |
switch { | |
case err == io.EOF: | |
logger.Info().Msgf( | |
"%s ABORT receive: client closed the connection: %v", | |
logPrefix, err) | |
case strings.Contains(err.Error(), "use of closed network connection"): | |
logger.Info().Msgf( | |
"%s ABORT receive: connection was meanwhile closed: %v", | |
logPrefix, err) | |
default: | |
logger.Err(err).Msgf( | |
"%s ABORT receive: error receiving from client", logPrefix) | |
} | |
return err | |
} | |
logger.Info().Msgf( | |
"%s RECEIVE client response: %s", logPrefix, received) | |
clientLastSeenAt = time.Now() | |
return nil | |
} | |
go func() { | |
currValue = ws.GetValue() | |
// execute 1st send with no delay | |
if err := send(); err != nil { | |
conn.Close() | |
return | |
} | |
for range pollValueChangedTicker.C { | |
// close connection if client did not respond since too long | |
clientIdleTime := time.Now().Sub(clientLastSeenAt) | |
if clientIdleTime >= ws.MaxClientIdleTime { | |
logger.Info().Msgf( | |
"%s ABORT send and close connection: did not hear from client since "+ | |
"%s (more than max idle time %s)", | |
logPrefix, clientIdleTime, ws.MaxClientIdleTime) | |
conn.Close() | |
return | |
} | |
// execute subsequent sends only if the value changed or too much time passed | |
if ws.NotEqual(currValue, lastValue) || time.Now().Sub(lastSentAt) >= ws.KeepAliveInterval { | |
if err := send(); err != nil { | |
conn.Close() | |
return | |
} | |
} | |
} | |
}() | |
for { | |
if err := receive(); err != nil { | |
conn.Close() | |
return | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment