/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.modelcalllimit;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.modelcalllimit.ModelCallLimitExceededException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.springframework.ai.chat.messages.AssistantMessage;

@HookPositions(value={HookPosition.BEFORE_MODEL, HookPosition.AFTER_MODEL})
public class ModelCallLimitHook
extends ModelHook {
    private static final String THREAD_COUNT_KEY = "__model_call_limit_thread_count__";
    private static final String RUN_COUNT_KEY = "__model_call_limit_run_count__";
    private final Integer threadLimit;
    private final Integer runLimit;
    private final ExitBehavior exitBehavior;

    private ModelCallLimitHook(Builder builder) {
        this.threadLimit = builder.threadLimit;
        this.runLimit = builder.runLimit;
        this.exitBehavior = builder.exitBehavior;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public CompletableFuture<Map<String, Object>> beforeModel(OverAllState state, RunnableConfig config) {
        boolean runLimitExceeded;
        int threadModelCallCount = config.context().containsKey(THREAD_COUNT_KEY) ? (Integer)config.context().get(THREAD_COUNT_KEY) : 0;
        int runModelCallCount = config.context().containsKey(RUN_COUNT_KEY) ? (Integer)config.context().get(RUN_COUNT_KEY) : 0;
        boolean threadLimitExceeded = this.threadLimit != null && threadModelCallCount >= this.threadLimit;
        boolean bl = runLimitExceeded = this.runLimit != null && runModelCallCount >= this.runLimit;
        if (threadLimitExceeded || runLimitExceeded) {
            if (this.exitBehavior == ExitBehavior.ERROR) {
                throw new ModelCallLimitExceededException(threadModelCallCount, runModelCallCount, this.threadLimit, this.runLimit);
            }
            if (this.exitBehavior == ExitBehavior.END) {
                String message = this.buildLimitExceededMessage(threadModelCallCount, runModelCallCount, this.threadLimit, this.runLimit);
                ArrayList<AssistantMessage> messages = new ArrayList<AssistantMessage>();
                messages.add(new AssistantMessage(message));
                HashMap<String, Object> updates = new HashMap<String, Object>();
                updates.put("messages", messages);
                updates.put("jump_to", (Object)JumpTo.end);
                return CompletableFuture.completedFuture(updates);
            }
        }
        return CompletableFuture.completedFuture(Map.of());
    }

    @Override
    public CompletableFuture<Map<String, Object>> afterModel(OverAllState state, RunnableConfig config) {
        int threadModelCallCount = config.context().containsKey(THREAD_COUNT_KEY) ? (Integer)config.context().get(THREAD_COUNT_KEY) : 0;
        int runModelCallCount = config.context().containsKey(RUN_COUNT_KEY) ? (Integer)config.context().get(RUN_COUNT_KEY) : 0;
        config.context().put(THREAD_COUNT_KEY, threadModelCallCount + 1);
        config.context().put(RUN_COUNT_KEY, runModelCallCount + 1);
        return CompletableFuture.completedFuture(Map.of());
    }

    private String buildLimitExceededMessage(int threadCount, int runCount, Integer threadLimit, Integer runLimit) {
        ArrayList<String> exceededLimits = new ArrayList<String>();
        if (threadLimit != null && threadCount >= threadLimit) {
            exceededLimits.add(String.format("thread limit (%d/%d)", threadCount, threadLimit));
        }
        if (runLimit != null && runCount >= runLimit) {
            exceededLimits.add(String.format("run limit (%d/%d)", runCount, runLimit));
        }
        return "Model call limits exceeded: " + String.join((CharSequence)", ", exceededLimits);
    }

    @Override
    public String getName() {
        return "ModelCallLimit";
    }

    @Override
    public List<JumpTo> canJumpTo() {
        if (this.exitBehavior == ExitBehavior.END) {
            return List.of(JumpTo.end);
        }
        return List.of();
    }

    public static class Builder {
        private Integer threadLimit;
        private Integer runLimit;
        private ExitBehavior exitBehavior = ExitBehavior.END;

        public Builder threadLimit(Integer threadLimit) {
            this.threadLimit = threadLimit;
            return this;
        }

        public Builder runLimit(Integer runLimit) {
            this.runLimit = runLimit;
            return this;
        }

        public Builder exitBehavior(ExitBehavior exitBehavior) {
            this.exitBehavior = exitBehavior;
            return this;
        }

        public ModelCallLimitHook build() {
            if (this.threadLimit == null && this.runLimit == null) {
                throw new IllegalArgumentException("At least one limit must be specified (threadLimit or runLimit)");
            }
            return new ModelCallLimitHook(this);
        }
    }

    public static enum ExitBehavior {
        END,
        ERROR;

    }
}

