/*
 * Decompiled with CFR 0.152.
 */
package io.debezium.ai.embeddings;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import io.debezium.DebeziumException;
import io.debezium.Module;
import io.debezium.ai.embeddings.EmbeddingsModelFactory;
import io.debezium.config.Configuration;
import io.debezium.config.Field;
import io.debezium.data.vector.FloatVector;
import io.debezium.transforms.ConnectRecordUtil;
import io.debezium.transforms.SmtManager;
import io.debezium.util.BoundedConcurrentHashMap;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.ServiceLoader;
import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.connect.components.Versioned;
import org.apache.kafka.connect.connector.ConnectRecord;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.Struct;
import org.apache.kafka.connect.transforms.Transformation;
import org.apache.kafka.connect.transforms.util.Requirements;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FieldToEmbedding<R extends ConnectRecord<R>>
implements Transformation<R>,
Versioned {
    private static final Logger LOGGER = LoggerFactory.getLogger(FieldToEmbedding.class);
    public static final String LEGACY_EMBEDDINGS_PREFIX = "embeddings.";
    private static final Field TEXT_FIELD = Field.create((String)"field.source").withDisplayName("Name of the record field from which embeddings should be created.").withType(ConfigDef.Type.STRING).withWidth(ConfigDef.Width.SHORT).withImportance(ConfigDef.Importance.HIGH).required().withDescription("Name of the field of the record which content will be used as an input for embeddings. Supports also nested fields.").withDeprecatedAliases(new String[]{"embeddings.field.source"});
    private static final Field EMBEDDGINS_FIELD = Field.create((String)"field.embedding").withDisplayName("Name of the field which would contain the embeddings of the input field.").withType(ConfigDef.Type.STRING).withWidth(ConfigDef.Width.SHORT).withImportance(ConfigDef.Importance.HIGH).withDescription("Name of the field which which will be appended to the record and which would contain the embeddings of the content `filed.source` field. Supports also nested fields.").withDeprecatedAliases(new String[]{"embeddings.field.embedding"});
    private static final Schema EMBEDDING_SCHEMA = FloatVector.schema();
    private static final EmbeddingsModelFactory MODEL_FACTORY = EmbeddingsModelFactoryLoader.getModelFactory();
    public static final Field.Set ALL_FIELDS = Field.setOf((Field[])new Field[]{TEXT_FIELD, EMBEDDGINS_FIELD}).with((Iterable)MODEL_FACTORY.getConfigFields());
    private SmtManager<R> smtManager;
    private String sourceField;
    private String embeddingsField;
    private List<String> sourceFieldPath;
    private EmbeddingModel model;
    private static final String NESTING_SPLIT_REG_EXP = "\\.";
    private static final int CACHE_SIZE = 64;
    private final BoundedConcurrentHashMap<Schema, Schema> schemaUpdateCache = new BoundedConcurrentHashMap(64);

    public void configure(Map<String, ?> configs) {
        Configuration config = Configuration.from(configs);
        this.smtManager = new SmtManager(config);
        this.smtManager.validate(config, ALL_FIELDS);
        this.sourceField = config.getString(TEXT_FIELD);
        this.embeddingsField = config.getString(EMBEDDGINS_FIELD);
        MODEL_FACTORY.configure(config);
        this.validateConfiguration();
        this.sourceFieldPath = Arrays.asList(this.sourceField.split(NESTING_SPLIT_REG_EXP));
        this.model = MODEL_FACTORY.getModel();
    }

    public R apply(R record) {
        if (record.value() == null || !this.smtManager.isValidEnvelope(record)) {
            LOGGER.trace("Record {} has null value of invalid envelope and will be skipped.", record.value());
            return record;
        }
        String text = this.getSourceString(record);
        return text == null ? record : this.buildUpdatedRecord(record, text);
    }

    public ConfigDef config() {
        ConfigDef config = new ConfigDef();
        Field.group((ConfigDef)config, null, (Field[])new Field[]{TEXT_FIELD, EMBEDDGINS_FIELD});
        return config;
    }

    public void close() {
    }

    public String version() {
        return Module.version();
    }

    protected void validateConfiguration() {
        if (this.sourceField == null || this.sourceField.isBlank()) {
            throw new ConfigException(String.format("'%s' must be set to non-empty value.", TEXT_FIELD));
        }
        MODEL_FACTORY.validateConfiguration();
    }

    protected String getSourceString(R record) {
        if (record.value() != null && this.smtManager.isValidEnvelope(record) && record.valueSchema().type() == Schema.Type.STRUCT) {
            Struct struct = Requirements.requireStruct((Object)record.value(), (String)"Obtaining source field for embeddings");
            for (int i = 0; i < this.sourceFieldPath.size() - 1; ++i) {
                if (struct.schema().type() == Schema.Type.STRUCT) {
                    if ((struct = struct.getStruct(this.sourceFieldPath.get(i))) != null) continue;
                    LOGGER.debug("Skipping record {}, the structure is not present", record);
                    return null;
                }
                throw new IllegalArgumentException(String.format("Invalid field name %s, %s is not struct.", this.sourceField, struct.schema().name()));
            }
            return struct.getString(this.sourceFieldPath.getLast());
        }
        LOGGER.debug("Skipping record {}, it has either null value or invalid structure", record);
        return null;
    }

    protected R buildUpdatedRecord(R original, String text) {
        List updatedValue;
        Schema updatedSchema;
        Struct value = Requirements.requireStruct((Object)original.value(), (String)"Original value must be struct");
        TextSegment segment = TextSegment.from((String)text);
        Embedding embedding = (Embedding)this.model.embed(segment).content();
        if (this.embeddingsField == null) {
            updatedSchema = EMBEDDING_SCHEMA;
            updatedValue = embedding.vectorAsList();
        } else {
            List<ConnectRecordUtil.NewEntry> newEntries = List.of(new ConnectRecordUtil.NewEntry(this.embeddingsField, EMBEDDING_SCHEMA, (Object)embedding.vectorAsList()));
            updatedSchema = (Schema)this.schemaUpdateCache.computeIfAbsent((Object)value.schema(), valueSchema -> ConnectRecordUtil.makeNewSchema((Schema)valueSchema, (List)newEntries));
            updatedValue = ConnectRecordUtil.makeUpdatedValue((Struct)value, newEntries, (Schema)updatedSchema);
        }
        return (R)original.newRecord(original.topic(), original.kafkaPartition(), original.keySchema(), original.key(), updatedSchema, (Object)updatedValue, original.timestamp(), (Iterable)original.headers());
    }

    public static class EmbeddingsModelFactoryLoader<R extends ConnectRecord<R>> {
        private static final Logger LOGGER = LoggerFactory.getLogger(EmbeddingsModelFactoryLoader.class);

        static EmbeddingsModelFactory getModelFactory() {
            ServiceLoader<EmbeddingsModelFactory> loader = ServiceLoader.load(EmbeddingsModelFactory.class);
            Optional<EmbeddingsModelFactory> factory = loader.findFirst();
            if (factory.isEmpty()) {
                throw new DebeziumException("No implementation of Debezium embeddings model factory found.");
            }
            if (loader.stream().count() > 1L) {
                LOGGER.warn("More then one Debezium embeddings model factory found. Order of loading is not defined and you may load different factory than you intended.");
                LOGGER.warn("Found following factories:");
                loader.stream().forEach(f -> LOGGER.warn(((EmbeddingsModelFactory)f.get()).getClass().getName()));
            }
            return factory.get();
        }
    }
}

