diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandler.java b/zuul-core/src/main/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandler.java index 725587efc6..157602b7f6 100644 --- a/zuul-core/src/main/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandler.java +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandler.java @@ -22,10 +22,11 @@ import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.LastHttpContent; import io.netty.util.AttributeKey; -import java.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.time.Duration; + /** * Client Timeout Handler * @@ -57,15 +58,21 @@ public static final class OutboundHandler extends ChannelOutboundHandlerAdapter @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { try { + if (!(msg instanceof LastHttpContent)) { + return; + } + final Duration timeout = ctx.channel().attr(ORIGIN_RESPONSE_READ_TIMEOUT).get(); - if (timeout != null && msg instanceof LastHttpContent) { + if (timeout != null) { promise.addListener(e -> { - LOG.debug( - "[{}] Adding read timeout handler: {}", - ctx.channel().id(), - timeout.toMillis()); - PooledConnection.getFromChannel(ctx.channel()).startReadTimeoutHandler(timeout); + if (e.isSuccess()) { + LOG.debug( + "[{}] Adding read timeout handler: {}", + ctx.channel().id(), + timeout.toMillis()); + PooledConnection.getFromChannel(ctx.channel()).startReadTimeoutHandler(timeout); + } }); } } finally { diff --git a/zuul-core/src/test/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandlerTest.java b/zuul-core/src/test/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandlerTest.java new file mode 100644 index 0000000000..bff5a94719 --- /dev/null +++ b/zuul-core/src/test/java/com/netflix/zuul/netty/connectionpool/ClientTimeoutHandlerTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2024 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.zuul.netty.connectionpool; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * @author Justin Guerra + * @since 7/30/24 + */ +@ExtendWith(MockitoExtension.class) +class ClientTimeoutHandlerTest { + + @Mock + private PooledConnection pooledConnection; + + private EmbeddedChannel channel; + private WriteVerifyingHandler verifier; + + @BeforeEach + public void setup() { + channel = new EmbeddedChannel(); + channel.attr(PooledConnection.CHANNEL_ATTR).set(pooledConnection); + verifier = new WriteVerifyingHandler(); + channel.pipeline().addLast(verifier); + channel.pipeline().addLast(new ClientTimeoutHandler.OutboundHandler()); + } + + @AfterEach + public void cleanup() { + channel.finishAndReleaseAll(); + } + + @Test + public void dontStartReadTimeoutHandlerIfNotLastContent() { + addTimeoutToChannel(); + channel.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer("yo".getBytes()))); + verify(pooledConnection, never()).startReadTimeoutHandler(any()); + verifyWrite(); + } + + @Test + public void dontStartReadTimeoutHandlerIfNoTimeout() { + channel.writeOutbound(new DefaultLastHttpContent()); + verify(pooledConnection, never()).startReadTimeoutHandler(any()); + verifyWrite(); + } + + @Test + public void dontStartReadTimeoutHandlerOnFailedPromise() { + addTimeoutToChannel(); + + channel.pipeline().addFirst(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.safeRelease(msg); + promise.setFailure(new RuntimeException()); + } + }); + try { + channel.writeOutbound(new DefaultLastHttpContent()); + } catch (RuntimeException e) { + // expected + } + verify(pooledConnection, never()).startReadTimeoutHandler(any()); + verifyWrite(); + } + + @Test + public void startReadTimeoutHandlerOnSuccessfulPromise() { + Duration timeout = addTimeoutToChannel(); + channel.writeOutbound(new DefaultLastHttpContent()); + verify(pooledConnection).startReadTimeoutHandler(timeout); + verifyWrite(); + } + + private Duration addTimeoutToChannel() { + Duration timeout = Duration.of(5, ChronoUnit.SECONDS); + channel.attr(ClientTimeoutHandler.ORIGIN_RESPONSE_READ_TIMEOUT).set(timeout); + return timeout; + } + + private void verifyWrite() { + Assertions.assertTrue(verifier.seenWrite); + } + + private static class WriteVerifyingHandler extends ChannelDuplexHandler { + boolean seenWrite; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + seenWrite = true; + super.write(ctx, msg, promise); + } + } +}