Last active
November 7, 2023 08:34
-
-
Save jkuipers/ea27406b3bd2b84eab65f197366cfe8e to your computer and use it in GitHub Desktop.
Configuration and code to add tracing support to Spring Cloud AWS's message listeners
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@AutoConfiguration(before = io.awspring.cloud.autoconfigure.sqs.SqsAutoConfiguration.class, | |
afterName = "org.springframework.boot.actuate.autoconfigure.tracing.BraveAutoConfiguration") | |
@ConditionalOnBean(Tracing.class) | |
public class SqsTracingAutoConfiguration { | |
@Bean(name = SqsBeanNames.SQS_LISTENER_ANNOTATION_BEAN_POST_PROCESSOR_BEAN_NAME) | |
TracingSqsListenerAnnotationBeanPostProcessor tracingSLABPP(Tracing tracing) { | |
return new TracingSqsListenerAnnotationBeanPostProcessor(tracing); | |
} | |
static class TracingSqsListenerAnnotationBeanPostProcessor extends SqsListenerAnnotationBeanPostProcessor { | |
private final Tracing tracing; | |
public TracingSqsListenerAnnotationBeanPostProcessor(Tracing tracing) { | |
this.tracing = tracing; | |
} | |
/** | |
* Overrides parent method to ensure that our custom endpoint with tracing support is returned. | |
*/ | |
@Override | |
protected Endpoint createEndpoint(SqsListener sqsListenerAnnotation) { | |
return new TracingSqsEndpoint.TracingSqsEndpointBuilder(tracing) | |
.queueNames(resolveEndpointNames(sqsListenerAnnotation.value())) | |
.factoryBeanName(resolveAsString(sqsListenerAnnotation.factory(), "factory")) | |
.id(getEndpointId(sqsListenerAnnotation.id())) | |
.pollTimeoutSeconds(resolveAsInteger(sqsListenerAnnotation.pollTimeoutSeconds(), "pollTimeoutSeconds")) | |
.maxMessagesPerPoll(resolveAsInteger(sqsListenerAnnotation.maxMessagesPerPoll(), "maxMessagesPerPoll")) | |
.maxConcurrentMessages( | |
resolveAsInteger(sqsListenerAnnotation.maxConcurrentMessages(), "maxConcurrentMessages")) | |
.messageVisibility( | |
resolveAsInteger(sqsListenerAnnotation.messageVisibilitySeconds(), "messageVisibility")) | |
.build(); | |
} | |
} | |
static class TracingSqsEndpoint extends SqsEndpoint { | |
private final Tracing tracing; | |
protected TracingSqsEndpoint(SqsEndpoint.SqsEndpointBuilder builder, Tracing tracing) { | |
super(builder); | |
this.tracing = tracing; | |
} | |
@Override | |
protected <T> MessageListener<T> createMessageListenerInstance(InvocableHandlerMethod handlerMethod) { | |
return new TracingWrappers.MessageListenerWrapper<>(super.createMessageListenerInstance(handlerMethod), tracing); | |
} | |
@Override | |
protected <T> AsyncMessageListener<T> createAsyncMessageListenerInstance(InvocableHandlerMethod handlerMethod) { | |
return new TracingWrappers.AsyncMessageListenerWrapper<>(super.createAsyncMessageListenerInstance(handlerMethod), tracing); | |
} | |
static class TracingSqsEndpointBuilder extends SqsEndpoint.SqsEndpointBuilder { | |
private final Tracing tracing; | |
public TracingSqsEndpointBuilder(Tracing tracing) { | |
this.tracing = tracing; | |
} | |
@Override | |
public SqsEndpoint build() { | |
return new TracingSqsEndpoint(this, tracing); | |
} | |
} | |
} | |
static abstract class TracingWrappers<D> { | |
private static final Propagation.Getter<MessageHeaders, String> GETTER = | |
(headers, key) -> (String) headers.get(key); | |
protected D delegate; | |
private final TraceContext.Extractor<MessageHeaders> extractor; | |
private final Tracer tracer; | |
private final Logger errorLogger = LoggerFactory.getLogger("nl.trifork.sqs.listener"); | |
TracingWrappers(D delegate, Tracing tracing) { | |
this.delegate = delegate; | |
this.extractor = tracing.propagation().extractor(GETTER); | |
this.tracer = tracing.tracer(); | |
} | |
CompletableFuture<Void> doInSpan(Function<Message, CompletableFuture<Void>> caller, Message<?> message) { | |
TraceContextOrSamplingFlags extracted = extractor.extract(message.getHeaders()); | |
Span span = tracer.nextSpan(extracted) | |
.kind(CONSUMER) | |
.name("on-message") | |
.remoteServiceName("sqs") | |
.start(); | |
try (Tracer.SpanInScope ws = tracer.withSpanInScope(span)) { | |
return caller.apply(message); | |
} catch (Throwable t) { | |
span.error(t); | |
logError(message.getHeaders(), t); | |
throw t; | |
} finally { | |
span.finish(); | |
} | |
} | |
private void logError(MessageHeaders headers, Throwable t) { | |
Integer dlqDequeues = (Integer) headers.get("DlqDequeues"); | |
errorLogger.warn("Error processing messageId={} with receiveCount={} and dlqDequeues={} of type {}", | |
headers.getId(), | |
headers.get(SqsHeaders.MessageSystemAttributes.SQS_APPROXIMATE_RECEIVE_COUNT), | |
dlqDequeues != null ? dlqDequeues : 0, | |
headers.get(SqsHeaders.SQS_DEFAULT_TYPE_HEADER), | |
t); | |
} | |
static class AsyncMessageListenerWrapper<T> extends TracingWrappers<AsyncMessageListener<T>> implements AsyncMessageListener<T> { | |
public AsyncMessageListenerWrapper(AsyncMessageListener<T> delegate, Tracing tracing) { | |
super(delegate, tracing); | |
} | |
@Override | |
public CompletableFuture<Void> onMessage(Message<T> message) { | |
return doInSpan(delegate::onMessage, message); | |
} | |
@Override | |
public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) { | |
return delegate.onMessage(messages); | |
} | |
} | |
static class MessageListenerWrapper<T> extends TracingWrappers<MessageListener<T>> implements MessageListener<T> { | |
MessageListenerWrapper(MessageListener<T> delegate, Tracing tracing) { | |
super(delegate, tracing); | |
} | |
@Override | |
public void onMessage(Message<T> message) { | |
doInSpan(msg -> { | |
delegate.onMessage(msg); | |
return null; | |
}, message); | |
} | |
@Override | |
public void onMessage(Collection<Message<T>> messages) { | |
delegate.onMessage(messages); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment