/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.vertexai;

import com.google.cloud.aiplatform.util.ValueConverter;
import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictResponse;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Message;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Json;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.model.vertexai.VertexAiEmbeddingInstance;
import dev.langchain4j.model.vertexai.spi.VertexAiEmbeddingModelBuilderFactory;
import dev.langchain4j.spi.ServiceHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

public class VertexAiEmbeddingModel
implements EmbeddingModel {
    private static final int BATCH_SIZE = 250;
    private final PredictionServiceSettings settings;
    private final EndpointName endpointName;
    private final Integer maxRetries;

    public VertexAiEmbeddingModel(String endpoint, String project, String location, String publisher, String modelName, Integer maxRetries) {
        try {
            this.settings = ((PredictionServiceSettings.Builder)PredictionServiceSettings.newBuilder().setEndpoint(ValidationUtils.ensureNotBlank((String)endpoint, (String)"endpoint"))).build();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
        this.endpointName = EndpointName.ofProjectLocationPublisherModelName((String)ValidationUtils.ensureNotBlank((String)project, (String)"project"), (String)ValidationUtils.ensureNotBlank((String)location, (String)"location"), (String)ValidationUtils.ensureNotBlank((String)publisher, (String)"publisher"), (String)ValidationUtils.ensureNotBlank((String)modelName, (String)"modelName"));
        this.maxRetries = (Integer)Utils.getOrDefault((Object)maxRetries, (Object)3);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public Response<List<Embedding>> embedAll(List<TextSegment> segments) {
        try (PredictionServiceClient client = PredictionServiceClient.create((PredictionServiceSettings)this.settings);){
            ArrayList embeddings = new ArrayList();
            int inputTokenCount = 0;
            for (int i = 0; i < segments.size(); i += 250) {
                List<TextSegment> batch = segments.subList(i, Math.min(i + 250, segments.size()));
                ArrayList<Value> instances = new ArrayList<Value>();
                for (TextSegment segment : batch) {
                    Value.Builder instanceBuilder = Value.newBuilder();
                    JsonFormat.parser().merge(Json.toJson((Object)new VertexAiEmbeddingInstance(segment.text())), (Message.Builder)instanceBuilder);
                    instances.add(instanceBuilder.build());
                }
                PredictResponse response = (PredictResponse)RetryUtils.withRetry(() -> client.predict(this.endpointName, instances, ValueConverter.EMPTY_VALUE), (int)this.maxRetries);
                embeddings.addAll(response.getPredictionsList().stream().map(VertexAiEmbeddingModel::toEmbedding).collect(Collectors.toList()));
                for (Value prediction : response.getPredictionsList()) {
                    inputTokenCount += VertexAiEmbeddingModel.extractTokenCount(prediction);
                }
            }
            Response response = Response.from(embeddings, (TokenUsage)new TokenUsage(Integer.valueOf(inputTokenCount)));
            return response;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static Embedding toEmbedding(Value prediction) {
        List vector = ((Value)prediction.getStructValue().getFieldsMap().get("embeddings")).getStructValue().getFieldsOrThrow("values").getListValue().getValuesList().stream().map(v -> Float.valueOf((float)v.getNumberValue())).collect(Collectors.toList());
        return Embedding.from(vector);
    }

    private static int extractTokenCount(Value prediction) {
        return (int)((Value)((Value)((Value)prediction.getStructValue().getFieldsMap().get("embeddings")).getStructValue().getFieldsMap().get("statistics")).getStructValue().getFieldsMap().get("token_count")).getNumberValue();
    }

    public static Builder builder() {
        return (Builder)ServiceHelper.loadFactoryService(VertexAiEmbeddingModelBuilderFactory.class, Builder::new);
    }

    public static class Builder {
        private String endpoint;
        private String project;
        private String location;
        private String publisher;
        private String modelName;
        private Integer maxRetries;

        public Builder endpoint(String endpoint) {
            this.endpoint = endpoint;
            return this;
        }

        public Builder project(String project) {
            this.project = project;
            return this;
        }

        public Builder location(String location) {
            this.location = location;
            return this;
        }

        public Builder publisher(String publisher) {
            this.publisher = publisher;
            return this;
        }

        public Builder modelName(String modelName) {
            this.modelName = modelName;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public VertexAiEmbeddingModel build() {
            return new VertexAiEmbeddingModel(this.endpoint, this.project, this.location, this.publisher, this.modelName, this.maxRetries);
        }
    }
}

