/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.transport;

import java.io.EOFException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.io.stream.ByteBufferStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.telemetry.tracing.Span;
import org.opensearch.telemetry.tracing.SpanBuilder;
import org.opensearch.telemetry.tracing.SpanScope;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.Header;
import org.opensearch.transport.InboundMessage;
import org.opensearch.transport.OutboundHandler;
import org.opensearch.transport.ProtocolInboundMessage;
import org.opensearch.transport.ProtocolMessageHandler;
import org.opensearch.transport.ProtocolOutboundHandler;
import org.opensearch.transport.RemoteTransportException;
import org.opensearch.transport.RequestHandlerRegistry;
import org.opensearch.transport.ResponseHandlerFailureTransportException;
import org.opensearch.transport.StatsTracker;
import org.opensearch.transport.TcpChannel;
import org.opensearch.transport.TcpTransportChannel;
import org.opensearch.transport.Transport;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportHandshaker;
import org.opensearch.transport.TransportKeepAlive;
import org.opensearch.transport.TransportLogger;
import org.opensearch.transport.TransportMessageListener;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportSerializationException;
import org.opensearch.transport.nativeprotocol.NativeOutboundHandler;

@ExperimentalApi
public class NativeMessageHandler
implements ProtocolMessageHandler {
    private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class);
    private final ThreadPool threadPool;
    private final ProtocolOutboundHandler outboundHandler;
    private final NamedWriteableRegistry namedWriteableRegistry;
    private final TransportHandshaker handshaker;
    private final TransportKeepAlive keepAlive;
    private final Transport.ResponseHandlers responseHandlers;
    private final Transport.RequestHandlers requestHandlers;
    private final Tracer tracer;
    private static final StreamInput EMPTY_STREAM_INPUT = new ByteBufferStreamInput(ByteBuffer.wrap(BytesRef.EMPTY_BYTES));

    public NativeMessageHandler(String nodeName, Version version, String[] features, StatsTracker statsTracker, ThreadPool threadPool, BigArrays bigArrays, OutboundHandler outboundHandler, NamedWriteableRegistry namedWriteableRegistry, TransportHandshaker handshaker, Transport.RequestHandlers requestHandlers, Transport.ResponseHandlers responseHandlers, Tracer tracer, TransportKeepAlive keepAlive) {
        this.threadPool = threadPool;
        this.outboundHandler = this.createNativeOutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, outboundHandler);
        this.namedWriteableRegistry = namedWriteableRegistry;
        this.handshaker = handshaker;
        this.requestHandlers = requestHandlers;
        this.responseHandlers = responseHandlers;
        this.tracer = tracer;
        this.keepAlive = keepAlive;
    }

    protected ProtocolOutboundHandler createNativeOutboundHandler(String nodeName, Version version, String[] features, StatsTracker statsTracker, ThreadPool threadPool, BigArrays bigArrays, OutboundHandler outboundHandler) {
        return new NativeOutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, outboundHandler);
    }

    @Override
    public void messageReceived(TcpChannel channel, ProtocolInboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener) throws IOException {
        InboundMessage inboundMessage = (InboundMessage)message;
        TransportLogger.logInboundMessage(channel, inboundMessage);
        if (inboundMessage.isPing()) {
            this.keepAlive.receiveKeepAlive(channel);
        } else {
            this.handleMessage(channel, inboundMessage, startTime, slowLogThresholdMs, messageListener);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void handleMessage(TcpChannel channel, InboundMessage message, long startTime, long slowLogThresholdMs, TransportMessageListener messageListener) throws IOException {
        InetSocketAddress remoteAddress = channel.getRemoteAddress();
        Header header = message.getHeader();
        assert (!header.needsToReadVariableHeader());
        ThreadContext threadContext = this.threadPool.getThreadContext();
        try (ThreadContext.StoredContext existing = threadContext.stashContext();){
            threadContext.setHeaders(header.getHeaders());
            threadContext.putTransient("_remote_address", remoteAddress);
            if (header.isRequest()) {
                this.handleRequest(channel, header, message, messageListener);
            } else {
                TransportResponseHandler<? extends TransportResponse> theHandler;
                assert (!message.isShortCircuit());
                long requestId = header.getRequestId();
                TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = header.isHandshake() ? this.handshaker.removeHandlerForHandshake(requestId) : ((theHandler = this.responseHandlers.onResponseReceived(requestId, messageListener)) == null && header.isError() ? this.handshaker.removeHandlerForHandshake(requestId) : theHandler);
                if (handler != null) {
                    if (message.getContentLength() > 0 || !header.getVersion().equals((Object)Version.CURRENT)) {
                        StreamInput streamInput = this.namedWriteableStream(message.openOrGetStreamInput());
                        NativeMessageHandler.assertRemoteVersion(streamInput, header.getVersion());
                        if (header.isError()) {
                            this.handlerResponseError(requestId, streamInput, handler);
                        } else {
                            this.handleResponse(requestId, remoteAddress, streamInput, handler);
                        }
                    } else {
                        assert (!header.isError());
                        this.handleResponse(requestId, remoteAddress, EMPTY_STREAM_INPUT, handler);
                    }
                }
            }
        }
        finally {
            long took = this.threadPool.relativeTimeInMillis() - startTime;
            long logThreshold = slowLogThresholdMs;
            if (logThreshold > 0L && took > logThreshold) {
                logger.warn("handling inbound transport message [{}] took [{}ms] which is above the warn threshold of [{}ms]", (Object)message, (Object)took, (Object)logThreshold);
            }
        }
    }

    private Map<String, Collection<String>> extractHeaders(Map<String, String> headers) {
        return headers.entrySet().stream().collect(Collectors.toMap(e -> (String)e.getKey(), e -> Collections.singleton((String)e.getValue())));
    }

    private <T extends TransportRequest> void handleRequest(TcpChannel channel, Header header, InboundMessage message, TransportMessageListener messageListener) throws IOException {
        block20: {
            String action = header.getActionName();
            long requestId = header.getRequestId();
            Version version = header.getVersion();
            Map<String, Collection<String>> headers = this.extractHeaders((Map)header.getHeaders().v1());
            Span span = this.tracer.startSpan(SpanBuilder.from(action, channel), headers);
            try (SpanScope spanScope = this.tracer.withSpanInScope(span);){
                if (header.isHandshake()) {
                    messageListener.onRequestReceived(requestId, action);
                    assert (!message.isShortCircuit());
                    StreamInput stream = this.namedWriteableStream(message.openOrGetStreamInput());
                    NativeMessageHandler.assertRemoteVersion(stream, header.getVersion());
                    TcpTransportChannel transportChannel = this.createTcpTransportChannel(this.outboundHandler, channel, action, requestId, version, header, message.takeBreakerReleaseControl());
                    TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, this.tracer);
                    try {
                        this.handshaker.handleHandshake(traceableTransportChannel, requestId, stream);
                    }
                    catch (Exception e) {
                        if (Version.CURRENT.isCompatible(header.getVersion())) {
                            NativeMessageHandler.sendErrorResponse(action, traceableTransportChannel, e);
                            break block20;
                        }
                        logger.warn((Message)new ParameterizedMessage("could not send error response to handshake received on [{}] using wire format version [{}], closing channel", (Object)channel, (Object)header.getVersion()), (Throwable)e);
                        channel.close();
                    }
                    break block20;
                }
                TcpTransportChannel transportChannel = this.createTcpTransportChannel(this.outboundHandler, channel, action, requestId, version, header, message.takeBreakerReleaseControl());
                TransportChannel traceableTransportChannel = TraceableTcpTransportChannel.create(transportChannel, span, this.tracer);
                try {
                    messageListener.onRequestReceived(requestId, action);
                    if (message.isShortCircuit()) {
                        NativeMessageHandler.sendErrorResponse(action, traceableTransportChannel, message.getException());
                        break block20;
                    }
                    StreamInput stream = this.namedWriteableStream(message.openOrGetStreamInput());
                    NativeMessageHandler.assertRemoteVersion(stream, header.getVersion());
                    RequestHandlerRegistry reg = this.requestHandlers.getHandler(action);
                    assert (reg != null);
                    Object request = this.newRequest(requestId, action, stream, reg);
                    request.remoteAddress(new TransportAddress(channel.getRemoteAddress()));
                    this.checkStreamIsFullyConsumed(requestId, action, stream);
                    String executor = reg.getExecutor();
                    if ("same".equals(executor)) {
                        try {
                            reg.processMessageReceived(request, traceableTransportChannel);
                        }
                        catch (Exception e) {
                            NativeMessageHandler.sendErrorResponse(reg.getAction(), traceableTransportChannel, e);
                        }
                        break block20;
                    }
                    this.threadPool.executor(executor).execute(new RequestHandler(reg, request, traceableTransportChannel));
                }
                catch (Exception e) {
                    NativeMessageHandler.sendErrorResponse(action, traceableTransportChannel, e);
                }
            }
            catch (Exception e) {
                span.setError(e);
                span.endSpan();
                throw e;
            }
        }
    }

    protected TcpTransportChannel createTcpTransportChannel(ProtocolOutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, Header header, Releasable breakerRelease) {
        return new TcpTransportChannel(outboundHandler, channel, action, requestId, version, header.getFeatures(), header.isCompressed(), header.isHandshake(), breakerRelease);
    }

    private <T extends TransportRequest> T newRequest(long requestId, String action, StreamInput stream, RequestHandlerRegistry<T> reg) throws IOException {
        try {
            return reg.newRequest(stream);
        }
        catch (EOFException e) {
            throw new IllegalStateException("Message fully read (request) but more data is expected for requestId [" + requestId + "], action [" + action + "]; resetting", e);
        }
    }

    private void checkStreamIsFullyConsumed(long requestId, String action, StreamInput stream) throws IOException {
        int nextByte = stream.read();
        if (nextByte != -1) {
            throw new IllegalStateException("Message not fully read (request) for requestId [" + requestId + "], action [" + action + "], available [" + stream.available() + "]; resetting");
        }
    }

    private void checkStreamIsFullyConsumed(long requestId, TransportResponseHandler<?> handler, StreamInput stream, boolean error) throws IOException {
        int nextByte;
        if (stream != EMPTY_STREAM_INPUT && (nextByte = stream.read()) != -1) {
            throw new IllegalStateException("Message not fully read (response) for requestId [" + requestId + "], handler [" + String.valueOf(handler) + "], error [" + error + "]; resetting");
        }
    }

    private static void sendErrorResponse(String actionName, TransportChannel transportChannel, Exception e) {
        try {
            transportChannel.sendResponse(e);
        }
        catch (Exception inner) {
            inner.addSuppressed(e);
            logger.warn(() -> new ParameterizedMessage("Failed to send error message back to client for action [{}]", (Object)actionName), (Throwable)inner);
        }
    }

    private <T extends TransportResponse> void handleResponse(long requestId, InetSocketAddress remoteAddress, StreamInput stream, TransportResponseHandler<T> handler) {
        TransportResponse response;
        try {
            response = (TransportResponse)handler.read(stream);
            response.remoteAddress(new TransportAddress(remoteAddress));
            this.checkStreamIsFullyConsumed(requestId, handler, stream, false);
        }
        catch (Exception e) {
            TransportSerializationException serializationException = new TransportSerializationException("Failed to deserialize response from handler [" + String.valueOf(handler) + "]", e);
            logger.warn((Message)new ParameterizedMessage("Failed to deserialize response from [{}]", (Object)remoteAddress), (Throwable)((Object)serializationException));
            this.handleException(handler, (Throwable)((Object)serializationException));
            return;
        }
        String executor = handler.executor();
        if ("same".equals(executor)) {
            this.doHandleResponse(handler, response);
        } else {
            this.threadPool.executor(executor).execute(() -> this.doHandleResponse(handler, response));
        }
    }

    private <T extends TransportResponse> void doHandleResponse(TransportResponseHandler<T> handler, T response) {
        try {
            handler.handleResponse(response);
        }
        catch (Exception e) {
            this.handleException(handler, (Throwable)((Object)new ResponseHandlerFailureTransportException(e)));
        }
    }

    private void handlerResponseError(long requestId, StreamInput stream, TransportResponseHandler<?> handler) {
        Object error;
        try {
            error = stream.readException();
            this.checkStreamIsFullyConsumed(requestId, handler, stream, true);
        }
        catch (Exception e) {
            error = new TransportSerializationException("Failed to deserialize exception response from stream for handler [" + String.valueOf(handler) + "]", e);
        }
        this.handleException(handler, (Throwable)error);
    }

    private void handleException(TransportResponseHandler<?> handler, Throwable error) {
        if (!(error instanceof RemoteTransportException)) {
            error = new RemoteTransportException(error.getMessage(), (Throwable)error);
        }
        RemoteTransportException rtx = (RemoteTransportException)((Object)error);
        this.threadPool.executor(handler.executor()).execute(() -> {
            try {
                handler.handleException(rtx);
            }
            catch (Exception e) {
                logger.error(() -> new ParameterizedMessage("failed to handle exception response [{}]", (Object)handler), (Throwable)e);
            }
        });
    }

    private StreamInput namedWriteableStream(StreamInput delegate) {
        return new NamedWriteableAwareStreamInput(delegate, this.namedWriteableRegistry);
    }

    static void assertRemoteVersion(StreamInput in, Version version) {
        assert (version.equals((Object)in.getVersion())) : "Stream version [" + String.valueOf(in.getVersion()) + "] does not match version [" + String.valueOf(version) + "]";
    }

    @Override
    public void setMessageListener(TransportMessageListener listener) {
        this.outboundHandler.setMessageListener(listener);
    }

    private static class RequestHandler<T extends TransportRequest>
    extends AbstractRunnable {
        private final RequestHandlerRegistry<T> reg;
        private final T request;
        private final TransportChannel transportChannel;

        RequestHandler(RequestHandlerRegistry<T> reg, T request, TransportChannel transportChannel) {
            this.reg = reg;
            this.request = request;
            this.transportChannel = transportChannel;
        }

        @Override
        protected void doRun() throws Exception {
            this.reg.processMessageReceived(this.request, this.transportChannel);
        }

        @Override
        public boolean isForceExecution() {
            return this.reg.isForceExecution();
        }

        @Override
        public void onFailure(Exception e) {
            NativeMessageHandler.sendErrorResponse(this.reg.getAction(), this.transportChannel, e);
        }
    }
}

