/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hertzbeat.ai.service.impl;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import org.apache.hertzbeat.ai.dao.ChatConversationDao;
import org.apache.hertzbeat.ai.dao.ChatMessageDao;
import org.apache.hertzbeat.ai.pojo.dto.ChatRequestContext;
import org.apache.hertzbeat.ai.pojo.dto.ChatResponseChunk;
import org.apache.hertzbeat.ai.service.ChatClientProviderService;
import org.apache.hertzbeat.ai.service.ConversationService;
import org.apache.hertzbeat.common.entity.ai.ChatConversation;
import org.apache.hertzbeat.common.entity.ai.ChatMessage;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.Sort;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

@Service
public class ConversationServiceImpl
implements ConversationService {
    private static final Logger log = LoggerFactory.getLogger(ConversationServiceImpl.class);
    @Autowired
    private ChatConversationDao conversationDao;
    @Autowired
    private ChatMessageDao messageDao;
    @Autowired
    private ChatClientProviderService chatClientProviderService;

    @Override
    public Flux<ServerSentEvent<ChatResponseChunk>> streamChat(String message, Long conversationId) {
        if (!this.chatClientProviderService.isConfigured()) {
            ChatResponseChunk errorResponse = ChatResponseChunk.builder().conversationId(conversationId).response("Provider is not configured. Please configure your AI Provider.").build();
            return Flux.just((Object)ServerSentEvent.builder((Object)errorResponse).event("error").build());
        }
        log.info("Starting streaming conversation: {}", (Object)conversationId);
        ChatConversation conversation = (ChatConversation)this.conversationDao.findById(conversationId).orElseThrow(() -> new IllegalArgumentException("Conversation not found: " + conversationId));
        List<ChatMessage> messages = this.messageDao.findByConversationIdOrderByGmtCreateAsc(conversationId);
        conversation.setMessages(messages);
        if (conversation.getTitle().startsWith("conversation")) {
            String title = message.length() > 30 ? message.substring(0, 27) + "..." : message;
            conversation.setTitle(title);
            this.conversationDao.save(conversation);
        }
        ChatMessage chatMessage = ChatMessage.builder().conversationId(conversationId).content(message).role("user").build();
        chatMessage = (ChatMessage)this.messageDao.save(chatMessage);
        ChatRequestContext context = ChatRequestContext.builder().message(message).conversationId(conversationId).conversationHistory(CollectionUtils.isEmpty((Collection)conversation.getMessages()) ? null : conversation.getMessages().subList(0, conversation.getMessages().size() - 1)).build();
        StringBuilder fullResponse = new StringBuilder();
        ChatMessage finalChatMessage = chatMessage;
        return this.chatClientProviderService.streamChat(context).map(chunk -> {
            fullResponse.append((String)chunk);
            ChatResponseChunk responseChunk = ChatResponseChunk.builder().conversationId(conversationId).userMessageId(finalChatMessage.getId()).response((String)chunk).build();
            return ServerSentEvent.builder((Object)responseChunk).event("message").build();
        }).concatWith((Publisher)Flux.defer(() -> {
            ChatMessage assistantMessage = ChatMessage.builder().conversationId(conversationId).content(fullResponse.toString()).role("assistant").build();
            assistantMessage = (ChatMessage)this.messageDao.save(assistantMessage);
            ChatResponseChunk finalResponse = ChatResponseChunk.builder().conversationId(conversationId).response("").assistantMessageId(assistantMessage.getId()).build();
            return Flux.just((Object)ServerSentEvent.builder((Object)finalResponse).event("complete").build());
        })).doOnComplete(() -> log.info("Streaming completed for conversation: {}", (Object)conversationId)).doOnError(error -> log.error("Error in streaming chat for conversation {}: {}", new Object[]{conversationId, error.getMessage(), error})).onErrorResume(error -> {
            ChatResponseChunk errorResponse = ChatResponseChunk.builder().conversationId(conversationId).response("An error occurred: " + error.getMessage()).userMessageId(finalChatMessage.getId()).build();
            return Flux.just((Object)ServerSentEvent.builder((Object)errorResponse).event("error").build());
        });
    }

    @Override
    public ChatConversation createConversation() {
        ChatConversation conversation = new ChatConversation();
        conversation.setTitle("conversation-" + UUID.randomUUID().toString().substring(0, 4));
        return (ChatConversation)this.conversationDao.save(conversation);
    }

    @Override
    public ChatConversation getConversation(Long conversationId) {
        if (conversationId == null) {
            return null;
        }
        ChatConversation conversation = this.conversationDao.findById(conversationId).orElse(null);
        if (conversation != null) {
            List<ChatMessage> messages = this.messageDao.findByConversationIdOrderByGmtCreateAsc(conversationId);
            conversation.setMessages(messages);
        }
        return conversation;
    }

    @Override
    public List<ChatConversation> getAllConversations() {
        List conversations = this.conversationDao.findAll(Sort.by((Sort.Direction)Sort.Direction.DESC, (String[])new String[]{"id"}));
        if (conversations.isEmpty()) {
            return conversations;
        }
        List<Long> conversationIds = conversations.stream().map(ChatConversation::getId).toList();
        List<ChatMessage> allMessages = this.messageDao.findByConversationIdInOrderByGmtCreateAsc(conversationIds);
        Map<Long, List<ChatMessage>> messagesByConversationId = allMessages.stream().collect(Collectors.groupingBy(ChatMessage::getConversationId));
        for (ChatConversation conversation : conversations) {
            List messages = messagesByConversationId.getOrDefault(conversation.getId(), Collections.emptyList());
            conversation.setMessages(messages);
        }
        return conversations;
    }

    @Override
    @Transactional(rollbackFor={Exception.class})
    public void deleteConversation(Long conversationId) {
        List<ChatMessage> messages = this.messageDao.findByConversationIdOrderByGmtCreateAsc(conversationId);
        if (!messages.isEmpty()) {
            this.messageDao.deleteAll(messages);
        }
        this.conversationDao.deleteById(conversationId);
    }
}

