/*
 * Decompiled with CFR 0.152.
 */
package com.microsoft.semantickernel.planner.stepwiseplanner;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.SKBuilders;
import com.microsoft.semantickernel.Verify;
import com.microsoft.semantickernel.orchestration.ContextVariables;
import com.microsoft.semantickernel.orchestration.SKContext;
import com.microsoft.semantickernel.orchestration.SKFunction;
import com.microsoft.semantickernel.planner.PlanningException;
import com.microsoft.semantickernel.planner.actionplanner.Plan;
import com.microsoft.semantickernel.planner.stepwiseplanner.StepwisePlanner;
import com.microsoft.semantickernel.planner.stepwiseplanner.StepwisePlannerConfig;
import com.microsoft.semantickernel.planner.stepwiseplanner.SystemStep;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplate;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
import com.microsoft.semantickernel.semanticfunctions.SemanticFunctionConfig;
import com.microsoft.semantickernel.skilldefinition.FunctionView;
import com.microsoft.semantickernel.skilldefinition.ReadOnlyFunctionCollection;
import com.microsoft.semantickernel.skilldefinition.annotations.DefineSKFunction;
import com.microsoft.semantickernel.skilldefinition.annotations.SKFunctionParameters;
import com.microsoft.semantickernel.textcompletion.CompletionRequestSettings;
import com.microsoft.semantickernel.util.EmbeddedResourceLoader;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;

public class DefaultStepwisePlanner
implements StepwisePlanner {
    private static final Logger LOGGER = LoggerFactory.getLogger(DefaultStepwisePlanner.class);
    private static final String RESTRICTED_SKILL_NAME = "StepwisePlanner_Excluded";
    private static final String SCRATCH_PAD_PREFIX = "This was my previous work (but they haven't seen any of it! They only see what I return as final answer):";
    private static final String THOUGHT = "[THOUGHT]";
    private static final String OBSERVATION = "[OBSERVATION]";
    private static final String Action = "[ACTION]";
    private static final Pattern S_FINAL_ANSWER_REGEX = Pattern.compile("\\[FINAL ANSWER\\](?<finalanswer>.+)", 40);
    private static final Pattern S_THOUGHT_REGEX = Pattern.compile("(\\[THOUGHT\\])?(?<thought>.+)", 40);
    private static final Pattern s_thoughtActionRemoveRegex = Pattern.compile("(.*)^\\[ACTION\\].+", 40);
    private static final Pattern S_ACTION_REGEX = Pattern.compile("\\[ACTION\\][^{}]*((?:\\{|\\{\\{)(?:[^{}]*\\{[^{}]*\\})*[^{}]*(?:\\}|\\}\\}))$", 40);
    private final Kernel kernel;
    private final StepwisePlannerConfig config;
    private final SKFunction<?> systemStepFunction;
    private final SKContext context;
    private final ReadOnlyFunctionCollection nativeFunctions;

    public DefaultStepwisePlanner(Kernel kernel, @Nullable StepwisePlannerConfig config, @Nullable String prompt, @Nullable PromptTemplateConfig promptUserConfig) {
        PromptTemplateConfig promptConfig;
        Verify.notNull((Object)kernel);
        this.kernel = kernel;
        if (config == null) {
            config = new StepwisePlannerConfig();
        }
        this.config = config;
        this.config.addExcludedSkills(RESTRICTED_SKILL_NAME);
        if (promptUserConfig == null) {
            String promptConfigString = null;
            try {
                promptConfigString = EmbeddedResourceLoader.readFile((String)"config.json", DefaultStepwisePlanner.class);
                if (!Verify.isNullOrEmpty((String)promptConfigString)) {
                    promptConfig = (PromptTemplateConfig)new ObjectMapper().readValue(promptConfigString, PromptTemplateConfig.class);
                }
                promptConfig = new PromptTemplateConfig();
            }
            catch (JsonProcessingException | FileNotFoundException e) {
                throw new PlanningException(PlanningException.ErrorCodes.INVALID_CONFIGURATION, "Could not find or parse config.json", e);
            }
        } else {
            promptConfig = promptUserConfig;
        }
        if (prompt == null) {
            try {
                prompt = EmbeddedResourceLoader.readFile((String)"skprompt.txt", DefaultStepwisePlanner.class);
            }
            catch (FileNotFoundException e) {
                throw new PlanningException(PlanningException.ErrorCodes.INVALID_CONFIGURATION, "Could not find skprompt.txt", (Throwable)e);
            }
        }
        promptConfig = new PromptTemplateConfig(promptConfig.getSchema(), promptConfig.getDescription(), promptConfig.getType(), new PromptTemplateConfig.CompletionConfigBuilder(promptConfig.getCompletionConfig()).maxTokens(this.config.getMaxTokens()).build(), promptConfig.getInput());
        this.systemStepFunction = this.importStepwiseFunction(this.kernel, prompt, promptConfig);
        this.nativeFunctions = this.kernel.importSkill((Object)this, RESTRICTED_SKILL_NAME);
        this.context = (SKContext)SKBuilders.context().withKernel(kernel).build();
    }

    @Override
    public Plan createPlan(String goal) {
        if (Verify.isNullOrEmpty((String)goal)) {
            throw new PlanningException(PlanningException.ErrorCodes.INVALID_GOAL, "The goal specified is empty");
        }
        String functionDescriptions = this.getFunctionDescriptions();
        SKFunction planStep = this.nativeFunctions.getFunction("ExecutePlan", SKFunction.class);
        ContextVariables parameters = (ContextVariables)SKBuilders.variables().withVariable("functionDescriptions", functionDescriptions).withVariable("question", goal).build();
        return new Plan(goal, parameters, () -> ((Kernel)this.kernel).getSkills(), planStep);
    }

    @DefineSKFunction(description="Execute a plan", name="ExecutePlan")
    public Mono<SKContext> executePlanAsync(@SKFunctionParameters(name="question", description="The question to answer") String question, @SKFunctionParameters(name="functionDescriptions", description="List of tool descriptions") String functionDescriptions, SKContext context) {
        if (!Verify.isNullOrEmpty((String)question)) {
            return Flux.generate(() -> Mono.just(new ArrayList()), this.loopThroughStepExecutions(context)).concatMap(it -> it).filter(stepsTaken -> !stepsTaken.isEmpty() && ((SystemStep)stepsTaken.get(stepsTaken.size() - 1)).getFinalAnswer() != null).take(1L).single().map(stepsTaken -> this.returnFinalAnswer(context, (SystemStep)stepsTaken.get(stepsTaken.size() - 1), (List<SystemStep>)stepsTaken));
        }
        return Mono.just((Object)context.update("Question not found."));
    }

    private BiFunction<Mono<ArrayList<SystemStep>>, SynchronousSink<Mono<ArrayList<SystemStep>>>, Mono<ArrayList<SystemStep>>> loopThroughStepExecutions(SKContext context) {
        AtomicInteger stepIndex = new AtomicInteger(0);
        return (stepsTakenFlux, sink) -> {
            sink.next(stepsTakenFlux);
            return stepsTakenFlux.flatMap(stepsTaken -> {
                if (stepIndex.get() > this.config.getMaxIterations()) {
                    return Mono.error((Throwable)new PlanningException(PlanningException.ErrorCodes.PLAN_EXECUTION_PRODUCED_NO_RESULTS, "Max iterations exceeded"));
                }
                if (stepsTaken.size() > 1 && !Verify.isNullOrEmpty((String)((SystemStep)stepsTaken.get(stepsTaken.size() - 1)).getFinalAnswer())) {
                    return Mono.just((Object)stepsTaken);
                }
                return this.executeNextStep(context, stepIndex.incrementAndGet(), (ArrayList<SystemStep>)stepsTaken);
            });
        };
    }

    private Mono<ArrayList<SystemStep>> executeNextStep(SKContext context, Integer stepIndex, ArrayList<SystemStep> stepsTaken) {
        String scratchPad = this.createScratchPad(stepsTaken);
        context.setVariable("agentScratchPad", scratchPad);
        return this.systemStepFunction.invokeAsync(context).flatMap(llmResponse -> {
            String actionText = Objects.requireNonNull(llmResponse.getResult()).trim();
            LOGGER.trace("Response: " + actionText);
            SystemStep nextStep = this.parseResult(actionText);
            stepsTaken.add(nextStep);
            if (!Verify.isNullOrEmpty((String)nextStep.getFinalAnswer())) {
                return Mono.just((Object)stepsTaken);
            }
            LOGGER.trace("Thought: {}", (Object)nextStep.getThought());
            if (!Verify.isNullOrEmpty((String)nextStep.getAction())) {
                LOGGER.info("Action: {}. Iteration: {}.", (Object)nextStep.getAction(), (Object)(stepIndex + 1));
                try {
                    LOGGER.trace("Action: {}({}). Iteration: {}.", new Object[]{nextStep.getAction(), new ObjectMapper().writeValueAsString(nextStep.getActionVariables()), stepIndex + 1});
                }
                catch (JsonProcessingException e) {
                    return Mono.error((Throwable)new PlanningException(PlanningException.ErrorCodes.UNKNOWN_ERROR, "Could not serialize action variables", (Throwable)e));
                }
                return this.invokeActionAsync(nextStep.getAction(), nextStep.getActionVariables()).flatMap(result -> {
                    if (Verify.isNullOrEmpty((String)result)) {
                        nextStep.setObservation("Got no result from action");
                    } else {
                        nextStep.setObservation((String)result);
                    }
                    LOGGER.trace("Observation: {}", (Object)nextStep.getObservation());
                    return Mono.just((Object)stepsTaken);
                });
            }
            LOGGER.info("Action: No action to take");
            return Mono.just((Object)stepsTaken);
        });
    }

    private SKContext returnFinalAnswer(SKContext context, SystemStep nextStep, List<SystemStep> stepsTaken) {
        LOGGER.trace("Final Answer: {}", (Object)nextStep.getFinalAnswer());
        context = context.update(nextStep.getFinalAnswer());
        String updatedScratchPlan = this.createScratchPad(stepsTaken);
        context.setVariable("agentScratchPad", updatedScratchPlan);
        try {
            this.addExecutionStatsToContext(stepsTaken, context);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
        return context;
    }

    private SystemStep parseResult(String input) {
        SystemStep result = new SystemStep();
        result.setOriginalResponse(input);
        Matcher finalAnswerMatch = S_FINAL_ANSWER_REGEX.matcher(input);
        if (finalAnswerMatch.find()) {
            result.setFinalAnswer(finalAnswerMatch.group(1).trim());
            return result;
        }
        Matcher thoughtMatch = S_THOUGHT_REGEX.matcher(input);
        if (thoughtMatch.find()) {
            Matcher actionRemove = s_thoughtActionRemoveRegex.matcher(thoughtMatch.group("thought"));
            if (actionRemove.find()) {
                result.setThought(actionRemove.group(1).trim());
            } else {
                result.setThought(thoughtMatch.group("thought").trim());
            }
        } else if (!input.contains(Action)) {
            result.setThought(input);
        } else {
            throw new IllegalStateException("Unexpected input format");
        }
        result.setThought(Objects.requireNonNull(result.getThought()).replace(THOUGHT, "").trim());
        Matcher actionMatch = S_ACTION_REGEX.matcher(input);
        if (actionMatch.find()) {
            String json = actionMatch.group(1).trim().replace("`", "");
            if (json.startsWith("{{")) {
                json = json.replaceAll("^\\{\\{", "{");
                json = json.replaceAll("\\}\\}$", "}");
            }
            try {
                SystemStep systemStepResults = (SystemStep)new ObjectMapper().readValue(json, SystemStep.class);
                if (systemStepResults == null) {
                    result.setObservation("System step parsing error, empty JSON: {json}");
                } else {
                    result.setAction(systemStepResults.getAction());
                    result.setActionVariables(systemStepResults.getActionVariables());
                }
            }
            catch (JsonProcessingException e) {
                result.setObservation("System step parsing error, invalid JSON: " + json);
            }
        }
        if (Verify.isNullOrEmpty((String)result.getThought()) && Verify.isNullOrEmpty((String)result.getAction())) {
            result.setObservation("System step error, no thought or action found. Please give a valid thought and/or action.");
        }
        return result;
    }

    private void addExecutionStatsToContext(List<SystemStep> stepsTaken, SKContext context) throws JsonProcessingException {
        context.setVariable("stepCount", Integer.toString(stepsTaken.size()));
        context.setVariable("stepsTaken", new ObjectMapper().writeValueAsString(stepsTaken));
        HashMap actionCounts = new HashMap();
        stepsTaken.stream().filter(step -> !Verify.isNullOrEmpty((String)step.getAction())).forEach(step -> {
            if (!actionCounts.containsKey(step.getAction())) {
                actionCounts.put(step.getAction(), 0);
            }
            actionCounts.put(step.getAction(), (Integer)actionCounts.get(step.getAction()) + 1);
        });
        String skillCallListWithCounts = actionCounts.keySet().stream().map(skill -> skill + "(" + actionCounts.get(skill) + ")").collect(Collectors.joining(", "));
        String skillCallCountStr = actionCounts.values().stream().reduce(0, Integer::sum).toString();
        context.setVariable("skillCount", skillCallCountStr + " (" + skillCallListWithCounts + ")");
    }

    private String createScratchPad(List<SystemStep> stepsTaken) {
        String scratchPad;
        if (stepsTaken.isEmpty()) {
            return "";
        }
        ArrayList<String> scratchPadLines = new ArrayList<String>();
        scratchPadLines.add(SCRATCH_PAD_PREFIX);
        scratchPadLines.add("[THOUGHT] " + stepsTaken.get(0).getThought());
        int insertPoint = scratchPadLines.size();
        for (int i = stepsTaken.size() - 1; i >= 0; --i) {
            if ((double)scratchPadLines.size() / 4.0 > (double)this.config.getMaxTokens() * 0.75) {
                LOGGER.debug("Scratchpad is too long, truncating. Skipping " + (i + 1) + " steps.");
                break;
            }
            SystemStep s = stepsTaken.get(i);
            if (!Verify.isNullOrEmpty((String)s.getObservation())) {
                scratchPadLines.add(insertPoint, "[OBSERVATION] " + s.getObservation());
            }
            if (!Verify.isNullOrEmpty((String)s.getAction())) {
                try {
                    scratchPadLines.add(insertPoint, "[ACTION] {{\"action\": \"" + s.getAction() + "\",\"action_variables\": " + new ObjectMapper().writeValueAsString(s.getActionVariables()) + "}}");
                }
                catch (JsonProcessingException e) {
                    throw new RuntimeException(e);
                }
            }
            if (i == 0) continue;
            scratchPadLines.add(insertPoint, "[THOUGHT] " + s.getThought());
        }
        if (!Verify.isNullOrWhiteSpace((String)(scratchPad = String.join((CharSequence)"\n", scratchPadLines).trim()))) {
            LOGGER.trace("Scratchpad: " + scratchPad);
        }
        return scratchPad;
    }

    private Mono<String> invokeActionAsync(String actionName, Map<String, String> actionVariables) {
        List<SKFunction<?>> availableFunctions = this.getAvailableFunctions();
        Optional<SKFunction> targetFunction = availableFunctions.stream().filter(f -> f.toFullyQualifiedName().equals(actionName)).findFirst();
        if (!targetFunction.isPresent()) {
            throw new PlanningException(PlanningException.ErrorCodes.UNKNOWN_ERROR, "The function '" + actionName + "' was not found.");
        }
        SKFunction function = this.kernel.getFunction(targetFunction.get().getSkillName(), targetFunction.get().getName());
        SKContext actionContext = this.createActionContext(actionVariables);
        return function.invokeAsync(actionContext).mapNotNull(result -> {
            LOGGER.trace("Invoked {}. Result: {}", (Object)((SKFunction)targetFunction.get()).getName(), (Object)result.getResult());
            return result.getResult();
        });
    }

    private SKContext createActionContext(Map<String, String> actionVariables) {
        SKContext actionContext = (SKContext)SKBuilders.context().withKernel(this.kernel).build();
        if (actionVariables != null) {
            actionVariables.forEach((arg_0, arg_1) -> ((SKContext)actionContext).setVariable(arg_0, arg_1));
        }
        return actionContext;
    }

    private List<SKFunction<?>> getAvailableFunctions() {
        return this.context.getSkills().getAllFunctions().getAll().stream().filter(fun -> !this.config.getExcludedSkills().contains(fun.getSkillName()) && !this.config.getExcludedFunctions().contains(fun.getName())).sorted((a, b) -> Comparator.comparing(SKFunction::getSkillName).thenComparing(SKFunction::getName).compare((SKFunction)a, (SKFunction)b)).collect(Collectors.toList());
    }

    private String getFunctionDescriptions() {
        List<SKFunction<?>> availableFunctions = this.getAvailableFunctions();
        return availableFunctions.stream().map(x -> DefaultStepwisePlanner.toManualString(Objects.requireNonNull(x.describe()))).collect(Collectors.joining("\n"));
    }

    private SKFunction<CompletionRequestSettings> importStepwiseFunction(Kernel kernel, String promptTemplate, PromptTemplateConfig config) {
        PromptTemplate template = (PromptTemplate)SKBuilders.promptTemplate().withPromptTemplate(promptTemplate).withPromptTemplateConfig(config).withPromptTemplateEngine(kernel.getPromptTemplateEngine()).build();
        SemanticFunctionConfig functionConfig = new SemanticFunctionConfig(config, template);
        return kernel.registerSemanticFunction(RESTRICTED_SKILL_NAME, "StepwiseStep", functionConfig);
    }

    private static String toManualString(FunctionView function) {
        String inputs = function.getParameters().stream().map(parameter -> {
            String defaultValueString = "";
            if (!Verify.isNullOrEmpty((String)parameter.getDefaultValue()) && !parameter.getDefaultValue().equals("SKFunctionParameters__NO_INPUT_PROVIDED")) {
                defaultValueString = "(default='" + parameter.getDefaultValue() + "')";
            }
            return "  - " + parameter.getName() + ": " + parameter.getDescription() + " " + defaultValueString;
        }).collect(Collectors.joining("\n"));
        String functionDescription = function.getDescription().trim();
        if (Verify.isNullOrEmpty((String)inputs)) {
            return DefaultStepwisePlanner.toFullyQualifiedName(function) + ": " + functionDescription + "\n";
        }
        return DefaultStepwisePlanner.toFullyQualifiedName(function) + ": " + functionDescription + "\n" + inputs + "\n";
    }

    private static String toFullyQualifiedName(FunctionView function) {
        return function.getSkillName() + "." + function.getName();
    }
}

