/*
 * Decompiled with CFR 0.152.
 */
package com.mongodb.internal.connection.tlschannel;

import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.internal.connection.tlschannel.BufferAllocator;
import com.mongodb.internal.connection.tlschannel.SniSslContextFactory;
import com.mongodb.internal.connection.tlschannel.TlsChannel;
import com.mongodb.internal.connection.tlschannel.TlsChannelBuilder;
import com.mongodb.internal.connection.tlschannel.TlsChannelCallbackException;
import com.mongodb.internal.connection.tlschannel.TrackingAllocator;
import com.mongodb.internal.connection.tlschannel.impl.BufferHolder;
import com.mongodb.internal.connection.tlschannel.impl.ByteBufferSet;
import com.mongodb.internal.connection.tlschannel.impl.TlsChannelImpl;
import com.mongodb.internal.connection.tlschannel.impl.TlsExplorer;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ClosedChannelException;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;

public class ServerTlsChannel
implements TlsChannel {
    private static final Logger LOGGER = Loggers.getLogger("connection.tls");
    private final ByteChannel underlying;
    private final SslContextStrategy sslContextStrategy;
    private final Function<SSLContext, SSLEngine> engineFactory;
    private final Consumer<SSLSession> sessionInitCallback;
    private final boolean runTasks;
    private final TrackingAllocator plainBufAllocator;
    private final TrackingAllocator encryptedBufAllocator;
    private final boolean releaseBuffers;
    private final boolean waitForCloseConfirmation;
    private final Lock initLock = new ReentrantLock();
    private BufferHolder inEncrypted;
    private volatile boolean sniRead = false;
    private SSLContext sslContext = null;
    private TlsChannelImpl impl = null;

    private static SSLEngine defaultSSLEngineFactory(SSLContext sslContext) {
        SSLEngine engine = sslContext.createSSLEngine();
        engine.setUseClientMode(false);
        return engine;
    }

    public static Builder newBuilder(ByteChannel underlying, SSLContext sslContext) {
        return new Builder(underlying, sslContext);
    }

    public static Builder newBuilder(ByteChannel underlying, SniSslContextFactory sslContextFactory) {
        return new Builder(underlying, sslContextFactory);
    }

    private ServerTlsChannel(ByteChannel underlying, SslContextStrategy internalSslContextFactory, Function<SSLContext, SSLEngine> engineFactory, Consumer<SSLSession> sessionInitCallback, boolean runTasks, BufferAllocator plainBufAllocator, BufferAllocator encryptedBufAllocator, boolean releaseBuffers, boolean waitForCloseConfirmation) {
        this.underlying = underlying;
        this.sslContextStrategy = internalSslContextFactory;
        this.engineFactory = engineFactory;
        this.sessionInitCallback = sessionInitCallback;
        this.runTasks = runTasks;
        this.plainBufAllocator = new TrackingAllocator(plainBufAllocator);
        this.encryptedBufAllocator = new TrackingAllocator(encryptedBufAllocator);
        this.releaseBuffers = releaseBuffers;
        this.waitForCloseConfirmation = waitForCloseConfirmation;
        this.inEncrypted = new BufferHolder("inEncrypted", Optional.empty(), encryptedBufAllocator, 4096, 17408, false, releaseBuffers);
    }

    @Override
    public ByteChannel getUnderlying() {
        return this.underlying;
    }

    public SSLContext getSslContext() {
        return this.sslContext;
    }

    @Override
    public SSLEngine getSslEngine() {
        return this.impl == null ? null : this.impl.engine();
    }

    @Override
    public Consumer<SSLSession> getSessionInitCallback() {
        return this.sessionInitCallback;
    }

    @Override
    public boolean getRunTasks() {
        return this.impl.getRunTasks();
    }

    @Override
    public TrackingAllocator getPlainBufferAllocator() {
        return this.plainBufAllocator;
    }

    @Override
    public TrackingAllocator getEncryptedBufferAllocator() {
        return this.encryptedBufAllocator;
    }

    @Override
    public long read(ByteBuffer[] dstBuffers, int offset, int length) throws IOException {
        ByteBufferSet dest = new ByteBufferSet(dstBuffers, offset, length);
        TlsChannelImpl.checkReadBuffer(dest);
        if (!this.sniRead) {
            try {
                this.initEngine();
            }
            catch (TlsChannelImpl.EofException e) {
                return -1L;
            }
        }
        return this.impl.read(dest);
    }

    @Override
    public long read(ByteBuffer[] dstBuffers) throws IOException {
        return this.read(dstBuffers, 0, dstBuffers.length);
    }

    @Override
    public int read(ByteBuffer dstBuffer) throws IOException {
        return (int)this.read(new ByteBuffer[]{dstBuffer});
    }

    @Override
    public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
        ByteBufferSet source = new ByteBufferSet(srcs, offset, length);
        if (!this.sniRead) {
            try {
                this.initEngine();
            }
            catch (TlsChannelImpl.EofException e) {
                throw new ClosedChannelException();
            }
        }
        return this.impl.write(source);
    }

    @Override
    public long write(ByteBuffer[] srcs) throws IOException {
        return this.write(srcs, 0, srcs.length);
    }

    @Override
    public int write(ByteBuffer srcBuffer) throws IOException {
        return (int)this.write(new ByteBuffer[]{srcBuffer});
    }

    @Override
    public void renegotiate() throws IOException {
        if (!this.sniRead) {
            try {
                this.initEngine();
            }
            catch (TlsChannelImpl.EofException e) {
                throw new ClosedChannelException();
            }
        }
        this.impl.renegotiate();
    }

    @Override
    public void handshake() throws IOException {
        if (!this.sniRead) {
            try {
                this.initEngine();
            }
            catch (TlsChannelImpl.EofException e) {
                throw new ClosedChannelException();
            }
        }
        this.impl.handshake();
    }

    @Override
    public void close() throws IOException {
        if (this.impl != null) {
            this.impl.close();
        }
        if (this.inEncrypted != null) {
            this.inEncrypted.dispose();
        }
        this.underlying.close();
    }

    @Override
    public boolean isOpen() {
        return this.underlying.isOpen();
    }

    private void initEngine() throws IOException, TlsChannelImpl.EofException {
        block5: {
            this.initLock.lock();
            try {
                SSLEngine engine;
                if (this.sniRead) break block5;
                this.sslContext = this.sslContextStrategy.getSslContext(this::getServerNameIndication);
                try {
                    engine = this.engineFactory.apply(this.sslContext);
                }
                catch (Exception e) {
                    LOGGER.trace("client threw exception in SSLEngine factory", e);
                    throw new TlsChannelCallbackException("SSLEngine creation callback failed", e);
                }
                this.impl = new TlsChannelImpl(this.underlying, this.underlying, engine, Optional.of(this.inEncrypted), this.sessionInitCallback, this.runTasks, this.plainBufAllocator, this.encryptedBufAllocator, this.releaseBuffers, this.waitForCloseConfirmation);
                this.inEncrypted = null;
                this.sniRead = true;
            }
            finally {
                this.initLock.unlock();
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Optional<SNIServerName> getServerNameIndication() throws IOException, TlsChannelImpl.EofException {
        this.inEncrypted.prepare();
        try {
            int recordHeaderSize = this.readRecordHeaderSize();
            while (this.inEncrypted.buffer.position() < recordHeaderSize) {
                if (!this.inEncrypted.buffer.hasRemaining()) {
                    this.inEncrypted.enlarge();
                }
                TlsChannelImpl.readFromChannel(this.underlying, this.inEncrypted.buffer);
            }
            ((Buffer)this.inEncrypted.buffer).flip();
            Map<Integer, SNIServerName> serverNames = TlsExplorer.explore(this.inEncrypted.buffer);
            this.inEncrypted.buffer.compact();
            SNIServerName hostName = serverNames.get(0);
            if (hostName != null && hostName instanceof SNIHostName) {
                SNIHostName sniHostName = (SNIHostName)hostName;
                Optional<SNIServerName> optional = Optional.of(sniHostName);
                return optional;
            }
            Optional<SNIServerName> optional = Optional.empty();
            return optional;
        }
        finally {
            this.inEncrypted.release();
        }
    }

    private int readRecordHeaderSize() throws IOException, TlsChannelImpl.EofException {
        while (this.inEncrypted.buffer.position() < 5) {
            if (!this.inEncrypted.buffer.hasRemaining()) {
                throw new IllegalStateException("inEncrypted too small");
            }
            TlsChannelImpl.readFromChannel(this.underlying, this.inEncrypted.buffer);
        }
        ((Buffer)this.inEncrypted.buffer).flip();
        int recordHeaderSize = TlsExplorer.getRequiredSize(this.inEncrypted.buffer);
        this.inEncrypted.buffer.compact();
        return recordHeaderSize;
    }

    @Override
    public boolean shutdown() throws IOException {
        return this.impl != null && this.impl.shutdown();
    }

    @Override
    public boolean shutdownReceived() {
        return this.impl != null && this.impl.shutdownReceived();
    }

    @Override
    public boolean shutdownSent() {
        return this.impl != null && this.impl.shutdownSent();
    }

    private static interface SslContextStrategy {
        public SSLContext getSslContext(SniReader var1) throws IOException, TlsChannelImpl.EofException;

        @FunctionalInterface
        public static interface SniReader {
            public Optional<SNIServerName> readSni() throws IOException, TlsChannelImpl.EofException;
        }
    }

    public static class Builder
    extends TlsChannelBuilder<Builder> {
        private final SslContextStrategy internalSslContextFactory;
        private Function<SSLContext, SSLEngine> sslEngineFactory = x$0 -> ServerTlsChannel.access$200(x$0);

        private Builder(ByteChannel underlying, SSLContext sslContext) {
            super(underlying);
            this.internalSslContextFactory = new FixedSslContextStrategy(sslContext);
        }

        private Builder(ByteChannel wrapped, SniSslContextFactory sslContextFactory) {
            super(wrapped);
            this.internalSslContextFactory = new SniSslContextStrategy(sslContextFactory);
        }

        @Override
        Builder getThis() {
            return this;
        }

        public Builder withEngineFactory(Function<SSLContext, SSLEngine> sslEngineFactory) {
            this.sslEngineFactory = sslEngineFactory;
            return this;
        }

        public ServerTlsChannel build() {
            return new ServerTlsChannel(this.underlying, this.internalSslContextFactory, this.sslEngineFactory, this.sessionInitCallback, this.runTasks, this.plainBufferAllocator, this.encryptedBufferAllocator, this.releaseBuffers, this.waitForCloseConfirmation);
        }
    }

    private static class FixedSslContextStrategy
    implements SslContextStrategy {
        private final SSLContext sslContext;

        public FixedSslContextStrategy(SSLContext sslContext) {
            this.sslContext = sslContext;
        }

        @Override
        public SSLContext getSslContext(SslContextStrategy.SniReader sniReader) {
            return this.sslContext;
        }
    }

    private static class SniSslContextStrategy
    implements SslContextStrategy {
        private SniSslContextFactory sniSslContextFactory;

        public SniSslContextStrategy(SniSslContextFactory sniSslContextFactory) {
            this.sniSslContextFactory = sniSslContextFactory;
        }

        @Override
        public SSLContext getSslContext(SslContextStrategy.SniReader sniReader) throws IOException, TlsChannelImpl.EofException {
            Optional<SSLContext> chosenContext;
            Optional<SNIServerName> nameOpt = sniReader.readSni();
            try {
                chosenContext = this.sniSslContextFactory.getSslContext(nameOpt);
            }
            catch (Exception e) {
                LOGGER.trace("client code threw exception during evaluation of server name indication", e);
                throw new TlsChannelCallbackException("SNI callback failed", e);
            }
            return chosenContext.orElseThrow(() -> new SSLHandshakeException("No ssl context available for received SNI: " + nameOpt));
        }
    }
}

