/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.mistralai;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolParameters;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.mistralai.MistralAiChatCompletionResponse;
import dev.langchain4j.model.mistralai.MistralAiChatMessage;
import dev.langchain4j.model.mistralai.MistralAiFunction;
import dev.langchain4j.model.mistralai.MistralAiFunctionCall;
import dev.langchain4j.model.mistralai.MistralAiParameters;
import dev.langchain4j.model.mistralai.MistralAiResponseFormat;
import dev.langchain4j.model.mistralai.MistralAiResponseFormatType;
import dev.langchain4j.model.mistralai.MistralAiRole;
import dev.langchain4j.model.mistralai.MistralAiTool;
import dev.langchain4j.model.mistralai.MistralAiToolCall;
import dev.langchain4j.model.mistralai.MistralAiToolType;
import dev.langchain4j.model.mistralai.MistralAiUsage;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import okhttp3.Headers;

public class DefaultMistralAiHelper {
    static final String MISTRALAI_API_URL = "https://api.mistral.ai/v1";
    static final String MISTRALAI_API_CREATE_EMBEDDINGS_ENCODING_FORMAT = "float";
    private static final Pattern MISTRAI_API_KEY_BEARER_PATTERN = Pattern.compile("^(Bearer\\s*) ([A-Za-z0-9]{1,32})$");

    static List<MistralAiChatMessage> toMistralAiMessages(List<ChatMessage> messages) {
        return messages.stream().map(DefaultMistralAiHelper::toMistralAiMessage).collect(Collectors.toList());
    }

    static MistralAiChatMessage toMistralAiMessage(ChatMessage message) {
        if (message instanceof SystemMessage) {
            return MistralAiChatMessage.builder().role(MistralAiRole.SYSTEM).content(((SystemMessage)message).text()).build();
        }
        if (message instanceof AiMessage) {
            AiMessage aiMessage = (AiMessage)message;
            if (!aiMessage.hasToolExecutionRequests()) {
                return MistralAiChatMessage.builder().role(MistralAiRole.ASSISTANT).content(aiMessage.text()).build();
            }
            List<MistralAiToolCall> toolCalls = aiMessage.toolExecutionRequests().stream().map(DefaultMistralAiHelper::toMistralAiToolCall).collect(Collectors.toList());
            if (Utils.isNullOrBlank((String)aiMessage.text())) {
                return MistralAiChatMessage.builder().role(MistralAiRole.ASSISTANT).content(null).toolCalls(toolCalls).build();
            }
            return MistralAiChatMessage.builder().role(MistralAiRole.ASSISTANT).content(aiMessage.text()).toolCalls(toolCalls).build();
        }
        if (message instanceof UserMessage) {
            return MistralAiChatMessage.builder().role(MistralAiRole.USER).content(message.text()).build();
        }
        if (message instanceof ToolExecutionResultMessage) {
            return MistralAiChatMessage.builder().role(MistralAiRole.TOOL).name(((ToolExecutionResultMessage)message).toolName()).content(((ToolExecutionResultMessage)message).text()).build();
        }
        throw new IllegalArgumentException("Unknown message type: " + message.type());
    }

    static MistralAiToolCall toMistralAiToolCall(ToolExecutionRequest toolExecutionRequest) {
        return MistralAiToolCall.builder().id(toolExecutionRequest.id()).function(MistralAiFunctionCall.builder().name(toolExecutionRequest.name()).arguments(toolExecutionRequest.arguments()).build()).build();
    }

    public static TokenUsage tokenUsageFrom(MistralAiUsage mistralAiUsage) {
        if (mistralAiUsage == null) {
            return null;
        }
        return new TokenUsage(mistralAiUsage.getPromptTokens(), mistralAiUsage.getCompletionTokens(), mistralAiUsage.getTotalTokens());
    }

    public static FinishReason finishReasonFrom(String mistralAiFinishReason) {
        if (mistralAiFinishReason == null) {
            return null;
        }
        switch (mistralAiFinishReason) {
            case "stop": {
                return FinishReason.STOP;
            }
            case "length": {
                return FinishReason.LENGTH;
            }
            case "tool_calls": {
                return FinishReason.TOOL_EXECUTION;
            }
            case "content_filter": {
                return FinishReason.CONTENT_FILTER;
            }
        }
        return null;
    }

    public static AiMessage aiMessageFrom(MistralAiChatCompletionResponse response) {
        MistralAiChatMessage aiMistralMessage = response.getChoices().get(0).getMessage();
        List<MistralAiToolCall> toolCalls = aiMistralMessage.getToolCalls();
        if (!Utils.isNullOrEmpty(toolCalls)) {
            return AiMessage.from(DefaultMistralAiHelper.toToolExecutionRequests(toolCalls));
        }
        return AiMessage.from((String)aiMistralMessage.getContent());
    }

    public static List<ToolExecutionRequest> toToolExecutionRequests(List<MistralAiToolCall> mistralAiToolCalls) {
        return mistralAiToolCalls.stream().filter(toolCall -> toolCall.getType() == MistralAiToolType.FUNCTION).map(DefaultMistralAiHelper::toToolExecutionRequest).collect(Collectors.toList());
    }

    public static ToolExecutionRequest toToolExecutionRequest(MistralAiToolCall mistralAiToolCall) {
        return ToolExecutionRequest.builder().id(mistralAiToolCall.getId()).name(mistralAiToolCall.getFunction().getName()).arguments(mistralAiToolCall.getFunction().getArguments()).build();
    }

    static List<MistralAiTool> toMistralAiTools(List<ToolSpecification> toolSpecifications) {
        return toolSpecifications.stream().map(DefaultMistralAiHelper::toMistralAiTool).collect(Collectors.toList());
    }

    static MistralAiTool toMistralAiTool(ToolSpecification toolSpecification) {
        MistralAiFunction function = MistralAiFunction.builder().name(toolSpecification.name()).description(toolSpecification.description()).parameters(DefaultMistralAiHelper.toMistralAiParameters(toolSpecification.parameters())).build();
        return MistralAiTool.from(function);
    }

    static MistralAiParameters toMistralAiParameters(ToolParameters parameters) {
        if (parameters == null) {
            return MistralAiParameters.builder().build();
        }
        return MistralAiParameters.from(parameters);
    }

    static MistralAiResponseFormat toMistralAiResponseFormat(String responseFormat) {
        if (responseFormat == null) {
            return null;
        }
        switch (responseFormat) {
            case "text": {
                return MistralAiResponseFormat.fromType(MistralAiResponseFormatType.TEXT);
            }
            case "json_object": {
                return MistralAiResponseFormat.fromType(MistralAiResponseFormatType.JSON_OBJECT);
            }
        }
        throw new IllegalArgumentException("Unknown response format: " + responseFormat);
    }

    static String getHeaders(Headers headers) {
        return StreamSupport.stream(headers.spliterator(), false).map(header -> {
            String headerKey = (String)header.component1();
            String headerValue = (String)header.component2();
            if (headerKey.equals("Authorization")) {
                headerValue = DefaultMistralAiHelper.maskAuthorizationHeaderValue(headerValue);
            }
            return String.format("[%s: %s]", headerKey, headerValue);
        }).collect(Collectors.joining(", "));
    }

    private static String maskAuthorizationHeaderValue(String authorizationHeaderValue) {
        try {
            Matcher matcher = MISTRAI_API_KEY_BEARER_PATTERN.matcher(authorizationHeaderValue);
            StringBuffer sb = new StringBuffer();
            while (matcher.find()) {
                String bearer = matcher.group(1);
                String token = matcher.group(2);
                matcher.appendReplacement(sb, bearer + " " + token.substring(0, 2) + "..." + token.substring(token.length() - 2));
            }
            matcher.appendTail(sb);
            return sb.toString();
        }
        catch (Exception e) {
            return "Error while masking Authorization header value";
        }
    }
}

