/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shenyu.plugin.ai.proxy.enhanced;

import java.nio.charset.StandardCharsets;
import java.util.Objects;
import java.util.Optional;
import org.apache.shenyu.common.dto.RuleData;
import org.apache.shenyu.common.dto.SelectorData;
import org.apache.shenyu.common.dto.convert.rule.AiProxyHandle;
import org.apache.shenyu.common.enums.AiModelProviderEnum;
import org.apache.shenyu.common.enums.PluginEnum;
import org.apache.shenyu.common.utils.JsonUtils;
import org.apache.shenyu.plugin.ai.common.config.AiCommonConfig;
import org.apache.shenyu.plugin.ai.common.spring.ai.AiModelFactory;
import org.apache.shenyu.plugin.ai.common.spring.ai.registry.AiModelFactoryRegistry;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.AiProxyApiKeyCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.cache.ChatClientCache;
import org.apache.shenyu.plugin.ai.proxy.enhanced.handler.AiProxyPluginHandler;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyConfigService;
import org.apache.shenyu.plugin.ai.proxy.enhanced.service.AiProxyExecutorService;
import org.apache.shenyu.plugin.api.ShenyuPluginChain;
import org.apache.shenyu.plugin.api.utils.WebFluxResultUtils;
import org.apache.shenyu.plugin.base.AbstractShenyuPlugin;
import org.apache.shenyu.plugin.base.utils.CacheKeyUtils;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class AiProxyPlugin
extends AbstractShenyuPlugin {
    private static final Logger LOG = LoggerFactory.getLogger(AiProxyPlugin.class);
    private final AiModelFactoryRegistry aiModelFactoryRegistry;
    private final AiProxyConfigService aiProxyConfigService;
    private final AiProxyExecutorService aiProxyExecutorService;
    private final ChatClientCache chatClientCache;
    private final AiProxyPluginHandler aiProxyPluginHandler;

    public AiProxyPlugin(AiModelFactoryRegistry aiModelFactoryRegistry, AiProxyConfigService aiProxyConfigService, AiProxyExecutorService aiProxyExecutorService, ChatClientCache chatClientCache, AiProxyPluginHandler aiProxyPluginHandler) {
        this.aiModelFactoryRegistry = aiModelFactoryRegistry;
        this.aiProxyConfigService = aiProxyConfigService;
        this.aiProxyExecutorService = aiProxyExecutorService;
        this.chatClientCache = chatClientCache;
        this.aiProxyPluginHandler = aiProxyPluginHandler;
    }

    protected Mono<Void> doExecute(ServerWebExchange exchange, ShenyuPluginChain chain, SelectorData selector, RuleData rule) {
        AiProxyHandle selectorHandle = (AiProxyHandle)this.aiProxyPluginHandler.getSelectorCachedHandle().obtainHandle((Object)CacheKeyUtils.INST.getKey(selector.getId(), "default_rule"));
        return DataBufferUtils.join((Publisher)exchange.getRequest().getBody()).flatMap(dataBuffer -> {
            boolean proxyEnabled;
            String requestBody = dataBuffer.toString(StandardCharsets.UTF_8);
            DataBufferUtils.release((DataBuffer)dataBuffer);
            AiCommonConfig primaryConfig = this.aiProxyConfigService.resolvePrimaryConfig(selectorHandle);
            HttpHeaders headers = exchange.getRequest().getHeaders();
            String proxyApiKey = headers.getFirst("X-API-KEY");
            boolean bl = proxyEnabled = Objects.nonNull(selectorHandle) && "true".equalsIgnoreCase(String.valueOf(selectorHandle.getProxyEnabled()));
            if (proxyEnabled) {
                if (Objects.isNull(proxyApiKey) || proxyApiKey.isEmpty()) {
                    exchange.getResponse().setStatusCode((HttpStatusCode)HttpStatus.UNAUTHORIZED);
                    return exchange.getResponse().setComplete();
                }
                String realKey = AiProxyApiKeyCache.getInstance().getRealApiKey(selector.getId(), proxyApiKey);
                if (Objects.nonNull(realKey)) {
                    primaryConfig.setApiKey(realKey);
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("[AiProxy] proxy key hit, selectorId={}, key={}... (masked)", (Object)selector.getId(), (Object)proxyApiKey.substring(0, Math.min(6, proxyApiKey.length())));
                    }
                    LOG.info("[AiProxy] proxy key hit, cacheSize={}", (Object)AiProxyApiKeyCache.getInstance().size());
                } else {
                    LOG.warn("[AiProxy] proxy key invalid, key={}... (masked), selectorId={}", (Object)proxyApiKey.substring(0, Math.min(6, proxyApiKey.length())), (Object)selector.getId());
                    exchange.getResponse().setStatusCode((HttpStatusCode)HttpStatus.UNAUTHORIZED);
                    return exchange.getResponse().setComplete();
                }
            }
            if (Boolean.TRUE.equals(primaryConfig.getStream())) {
                return this.handleStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
            }
            return this.handleNonStreamRequest(exchange, selector, requestBody, primaryConfig, selectorHandle);
        });
    }

    private Mono<Void> handleStreamRequest(ServerWebExchange exchange, SelectorData selector, String requestBody, AiCommonConfig primaryConfig, AiProxyHandle selectorHandle) {
        ChatClient mainClient = this.createMainChatClient(selector.getId(), primaryConfig);
        Optional<ChatClient> fallbackClient = this.resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), requestBody);
        ServerHttpResponse response = exchange.getResponse();
        response.getHeaders().setContentType(MediaType.TEXT_EVENT_STREAM);
        Flux<ChatResponse> chatResponseFlux = this.aiProxyExecutorService.executeStream(mainClient, fallbackClient, requestBody);
        Flux sseFlux = chatResponseFlux.map(chatResponse -> {
            String json = JsonUtils.toJson((Object)chatResponse);
            String sseData = "data: " + json + "\n\n";
            return response.bufferFactory().wrap(sseData.getBytes(StandardCharsets.UTF_8));
        });
        return response.writeWith((Publisher)sseFlux);
    }

    private Mono<Void> handleNonStreamRequest(ServerWebExchange exchange, SelectorData selector, String requestBody, AiCommonConfig primaryConfig, AiProxyHandle selectorHandle) {
        ChatClient mainClient = this.createMainChatClient(selector.getId(), primaryConfig);
        Optional<ChatClient> fallbackClient = this.resolveFallbackClient(primaryConfig, selectorHandle, selector.getId(), requestBody);
        return this.aiProxyExecutorService.execute(mainClient, fallbackClient, requestBody).flatMap(response -> {
            byte[] jsonBytes = JsonUtils.toJson((Object)response).getBytes(StandardCharsets.UTF_8);
            return WebFluxResultUtils.result((ServerWebExchange)exchange, (Object)jsonBytes);
        });
    }

    private Optional<ChatClient> resolveFallbackClient(AiCommonConfig primaryConfig, AiProxyHandle selectorHandle, String selectorId, String requestBody) {
        return this.aiProxyConfigService.resolveDynamicFallbackConfig(primaryConfig, requestBody).map(cfg -> {
            LOG.info("[AiProxy] use dynamic fallback");
            if (LOG.isDebugEnabled()) {
                LOG.debug("[AiProxy] dynamic fallback config: {}", cfg);
            }
            return this.createDynamicFallbackClient((AiCommonConfig)cfg);
        }).or(() -> this.aiProxyConfigService.resolveAdminFallbackConfig(primaryConfig, selectorHandle).map(adminFallbackConfig -> {
            LOG.info("[AiProxy] use admin fallback");
            if (LOG.isDebugEnabled()) {
                LOG.debug("[AiProxy] admin fallback config: {}", adminFallbackConfig);
            }
            return this.createAdminFallbackClient(selectorId, (AiCommonConfig)adminFallbackConfig);
        }));
    }

    private ChatClient createMainChatClient(String selectorId, AiCommonConfig config) {
        return this.chatClientCache.computeIfAbsent(selectorId, () -> {
            LOG.info("Creating and caching main model for selector: {}", (Object)selectorId);
            return this.createChatModel(config);
        });
    }

    private ChatClient createAdminFallbackClient(String selectorId, AiCommonConfig fallbackConfig) {
        String fallbackCacheKey = selectorId + "|adminFallback";
        return this.chatClientCache.computeIfAbsent(fallbackCacheKey, () -> {
            LOG.info("Creating and caching admin fallback model for selector: {}", (Object)selectorId);
            return this.createChatModel(fallbackConfig);
        });
    }

    private ChatClient createDynamicFallbackClient(AiCommonConfig fallbackConfig) {
        LOG.info("Creating non-cached dynamic fallback model.");
        return ChatClient.builder((ChatModel)this.createChatModel(fallbackConfig)).build();
    }

    private ChatModel createChatModel(AiCommonConfig config) {
        AiModelProviderEnum provider;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Creating chat model with config: {}", (Object)config);
        }
        if (Objects.isNull(provider = AiModelProviderEnum.getByName((String)config.getProvider()))) {
            throw new IllegalArgumentException("Invalid AI model provider in config: " + config.getProvider());
        }
        AiModelFactory factory = this.aiModelFactoryRegistry.getFactory(provider);
        if (Objects.isNull(factory)) {
            throw new IllegalArgumentException("AI model factory not found for provider: " + provider.getName());
        }
        return Objects.requireNonNull(factory.createAiModel(config), "The AI model created by the factory must not be null");
    }

    public int getOrder() {
        return PluginEnum.AI_PROXY.getCode();
    }

    public String named() {
        return PluginEnum.AI_PROXY.getName();
    }
}

