/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.cloud.gateway.filter;

import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.cloud.gateway.filter.headers.HttpHeadersFilter;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.core.Ordered;
import org.springframework.http.HttpHeaders;
import org.springframework.util.StringUtils;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import org.springframework.web.reactive.socket.client.WebSocketClient;
import org.springframework.web.reactive.socket.server.WebSocketService;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.UriComponentsBuilder;
import reactor.core.publisher.Mono;

public class WebsocketRoutingFilter
implements GlobalFilter,
Ordered {
    public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol";
    private static final Log log = LogFactory.getLog(WebsocketRoutingFilter.class);
    private final WebSocketClient webSocketClient;
    private final WebSocketService webSocketService;
    private final ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider;
    private volatile List<HttpHeadersFilter> headersFilters;

    public WebsocketRoutingFilter(WebSocketClient webSocketClient, WebSocketService webSocketService, ObjectProvider<List<HttpHeadersFilter>> headersFiltersProvider) {
        this.webSocketClient = webSocketClient;
        this.webSocketService = webSocketService;
        this.headersFiltersProvider = headersFiltersProvider;
    }

    static String convertHttpToWs(String scheme) {
        return "http".equals(scheme = scheme.toLowerCase()) ? "ws" : ("https".equals(scheme) ? "wss" : scheme);
    }

    @Override
    public int getOrder() {
        return 0x7FFFFFFE;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange2, GatewayFilterChain chain) {
        WebsocketRoutingFilter.changeSchemeIfIsWebSocketUpgrade(exchange2);
        URI requestUrl = (URI)exchange2.getRequiredAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme();
        if (ServerWebExchangeUtils.isAlreadyRouted(exchange2) || !"ws".equals(scheme) && !"wss".equals(scheme)) {
            return chain.filter(exchange2);
        }
        ServerWebExchangeUtils.setAlreadyRouted(exchange2);
        HttpHeaders headers = exchange2.getRequest().getHeaders();
        HttpHeaders filtered = HttpHeadersFilter.filterRequest(this.getHeadersFilters(), exchange2);
        List<String> protocols = this.getProtocols(headers);
        return this.webSocketService.handleRequest(exchange2, new ProxyWebSocketHandler(requestUrl, this.webSocketClient, filtered, protocols));
    }

    List<String> getProtocols(HttpHeaders headers) {
        ArrayList<String> protocols = headers.get(SEC_WEBSOCKET_PROTOCOL);
        if (protocols != null) {
            ArrayList<String> updatedProtocols = new ArrayList<String>();
            for (int i2 = 0; i2 < protocols.size(); ++i2) {
                String protocol = (String)protocols.get(i2);
                updatedProtocols.addAll(Arrays.asList(StringUtils.tokenizeToStringArray(protocol, ",")));
            }
            protocols = updatedProtocols;
        }
        return protocols;
    }

    List<HttpHeadersFilter> getHeadersFilters() {
        if (this.headersFilters == null) {
            this.headersFilters = this.headersFiltersProvider.getIfAvailable(ArrayList::new);
            this.headersFilters.add((headers, exchange2) -> {
                HttpHeaders filtered = new HttpHeaders();
                filtered.addAll(headers);
                filtered.remove("Host");
                boolean preserveHost = exchange2.getAttributeOrDefault(ServerWebExchangeUtils.PRESERVE_HOST_HEADER_ATTRIBUTE, false);
                if (preserveHost) {
                    String host = exchange2.getRequest().getHeaders().getFirst("Host");
                    filtered.add("Host", host);
                }
                return filtered;
            });
            this.headersFilters.add((headers, exchange2) -> {
                HttpHeaders filtered = new HttpHeaders();
                for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
                    if (entry.getKey().toLowerCase().startsWith("sec-websocket")) continue;
                    filtered.addAll(entry.getKey(), entry.getValue());
                }
                return filtered;
            });
        }
        return this.headersFilters;
    }

    static void changeSchemeIfIsWebSocketUpgrade(ServerWebExchange exchange2) {
        URI requestUrl = (URI)exchange2.getRequiredAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme().toLowerCase();
        String upgrade = exchange2.getRequest().getHeaders().getUpgrade();
        if ("WebSocket".equalsIgnoreCase(upgrade) && ("http".equals(scheme) || "https".equals(scheme))) {
            String wsScheme = WebsocketRoutingFilter.convertHttpToWs(scheme);
            boolean encoded = ServerWebExchangeUtils.containsEncodedParts(requestUrl);
            URI wsRequestUrl = UriComponentsBuilder.fromUri(requestUrl).scheme(wsScheme).build(encoded).toUri();
            exchange2.getAttributes().put(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR, wsRequestUrl);
            if (log.isTraceEnabled()) {
                log.trace("changeSchemeTo:[" + wsRequestUrl + "]");
            }
        }
    }

    private static class ProxyWebSocketHandler
    implements WebSocketHandler {
        private final WebSocketClient client;
        private final URI url;
        private final HttpHeaders headers;
        private final List<String> subProtocols;

        ProxyWebSocketHandler(URI url, WebSocketClient client, HttpHeaders headers, List<String> protocols) {
            this.client = client;
            this.url = url;
            this.headers = headers;
            this.subProtocols = protocols != null ? protocols : Collections.emptyList();
        }

        @Override
        public List<String> getSubProtocols() {
            return this.subProtocols;
        }

        @Override
        public Mono<Void> handle(final WebSocketSession session) {
            return this.client.execute(this.url, this.headers, new WebSocketHandler(){

                private CloseStatus adaptCloseStatus(CloseStatus closeStatus) {
                    int code = closeStatus.getCode();
                    if (code > 2999 && code < 5000) {
                        return closeStatus;
                    }
                    switch (code) {
                        case 1000: 
                        case 1001: 
                        case 1002: 
                        case 1003: 
                        case 1007: 
                        case 1008: 
                        case 1009: 
                        case 1010: 
                        case 1011: {
                            return closeStatus;
                        }
                    }
                    return CloseStatus.PROTOCOL_ERROR;
                }

                @Override
                public Mono<Void> handle(WebSocketSession proxySession) {
                    Mono serverClose = proxySession.closeStatus().filter(__ -> session.isOpen()).map(this::adaptCloseStatus).flatMap(session::close);
                    Mono proxyClose = session.closeStatus().filter(__ -> proxySession.isOpen()).map(this::adaptCloseStatus).flatMap(proxySession::close);
                    Mono<Void> proxySessionSend = proxySession.send(session.receive().doOnNext(WebSocketMessage::retain).doOnNext(webSocketMessage -> {
                        if (log.isTraceEnabled()) {
                            log.trace("proxySession(send from client): " + proxySession.getId() + ", corresponding session:" + session.getId() + ", packet: " + webSocketMessage.getPayloadAsText());
                        }
                    }));
                    Mono<Void> serverSessionSend = session.send(proxySession.receive().doOnNext(WebSocketMessage::retain).doOnNext(webSocketMessage -> {
                        if (log.isTraceEnabled()) {
                            log.trace("session(send from backend): " + session.getId() + ", corresponding proxySession:" + proxySession.getId() + " packet: " + webSocketMessage.getPayloadAsText());
                        }
                    }));
                    Mono.when(serverClose, proxyClose).subscribe();
                    return Mono.zip(proxySessionSend, serverSessionSend).then();
                }

                @Override
                public List<String> getSubProtocols() {
                    return subProtocols;
                }
            });
        }
    }
}

