/*
 * Decompiled with CFR 0.152.
 */
package it.auties.protobuf.decoder;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import it.auties.protobuf.decoder.ArrayInputStream;
import it.auties.protobuf.decoder.DeserializationException;
import it.auties.protobuf.decoder.ProtobufField;
import it.auties.protobuf.decoder.ProtobufType;
import it.auties.protobuf.decoder.ProtobufTypeDescriptor;
import it.auties.protobuf.util.ProtobufUtils;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.NonNull;

public class ProtobufDecoder<T> {
    private static final Logger log = Logger.getLogger(ProtobufDecoder.class.getName());
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false).registerModule((Module)new Jdk8Module());
    private static final Map<Class<?>, List<ProtobufField>> cache = new HashMap();
    @NonNull
    private final Class<? extends T> modelClass;
    private boolean warnUnknownFields;
    private final LinkedList<Class<?>> classes = new LinkedList();

    public T decode(byte[] input) throws IOException {
        Map<Integer, Object> map = this.decodeAsMap(input);
        try {
            return (T)OBJECT_MAPPER.convertValue(map, this.modelClass);
        }
        catch (Throwable throwable) {
            log.warning("Map value -> %s".formatted(map));
            throw new IOException("An exception occurred while decoding a message", throwable);
        }
    }

    public Map<Integer, Object> decodeAsMap(byte[] input) throws IOException {
        return this.decode(new ArrayInputStream(input));
    }

    public String decodeAsJson(byte[] input) throws IOException {
        return OBJECT_MAPPER.writerWithDefaultPrettyPrinter().writeValueAsString(this.decodeAsMap(input));
    }

    private Map<Integer, Object> decode(ArrayInputStream input) throws IOException {
        Optional<Map.Entry<Integer, Object>> current;
        int tag;
        ArrayList<Map.Entry<Integer, Object>> results = new ArrayList<Map.Entry<Integer, Object>>();
        while ((tag = input.readTag()) != 0 && !(current = this.parseField(input, tag)).isEmpty()) {
            results.add(current.get());
        }
        input.checkLastTagWas(0);
        return results.stream().collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, this::handleDuplicatedFields));
    }

    private <F, S> List<?> handleDuplicatedFields(F first, S second) {
        return Stream.of(first, second).map(entry -> {
            Collection<Object> collection;
            if (entry instanceof Collection) {
                Collection collection2 = (Collection)entry;
                collection = collection2;
            } else {
                collection = List.of(entry);
            }
            return collection;
        }).flatMap(Collection::stream).toList();
    }

    private Optional<Map.Entry<Integer, Object>> parseField(ArrayInputStream input, int tag) throws IOException {
        int number = tag >>> 3;
        if (number == 0) {
            throw DeserializationException.invalidTag();
        }
        Object content = this.readFieldContent(input, tag, number);
        return Optional.ofNullable(content).map(parsed -> Map.entry(number, parsed));
    }

    private Object readFieldContent(ArrayInputStream input, int tag, int number) throws IOException {
        int type = tag & 7;
        return switch (type) {
            case 0 -> input.readInt64();
            case 1 -> input.readFixed64();
            case 2 -> this.readDelimited(input, number);
            case 3 -> this.readGroup(input);
            case 4 -> this.endGroup();
            case 5 -> input.readFixed32();
            default -> throw new DeserializationException("Protocol message(%s) had invalid wire type(%s)".formatted(number, type));
        };
    }

    private Object endGroup() {
        this.classes.poll();
        return null;
    }

    private Object readGroup(ArrayInputStream input) throws IOException {
        byte[] read = input.readBytes();
        ArrayInputStream stream = new ArrayInputStream(read);
        return this.decode(stream);
    }

    private Object readDelimited(ArrayInputStream input, int fieldNumber) throws IOException {
        byte[] read = input.readBytes();
        ProtobufField protobufField = this.getFields().stream().filter(field -> field.index() == fieldNumber).findFirst().orElseGet(() -> this.getFallbackType(fieldNumber));
        return this.convertValueToObject(read, protobufField);
    }

    private ProtobufField getFallbackType(int fieldNumber) {
        if (this.warnUnknownFields) {
            log.warning("Falling back to BYTES for %s in schema %s".formatted(fieldNumber, this.classes.peekFirst()));
        }
        return new ProtobufField(fieldNumber, byte[].class, false);
    }

    private Object convertValueToObject(byte[] read, ProtobufField value) throws IOException {
        if (value.packed()) {
            return this.readPacked(read);
        }
        if (byte[].class.isAssignableFrom(value.type())) {
            return read;
        }
        if (String.class.isAssignableFrom(value.type())) {
            return new String(read, StandardCharsets.UTF_8);
        }
        return this.readDelimited(value.type(), read);
    }

    private ArrayList<Integer> readPacked(byte[] read) throws IOException {
        ArrayInputStream stream = new ArrayInputStream(read);
        int length = stream.readRawVarint32();
        ArrayList<Integer> results = new ArrayList<Integer>();
        while (results.size() * 4 != length) {
            int decoded = stream.readRawVarint32();
            results.add(decoded);
        }
        return results;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private Object readDelimited(Class<?> currentClass, byte[] read) {
        try {
            this.classes.push(currentClass);
            ArrayInputStream stream = new ArrayInputStream(read);
            Map<Integer, Object> map = this.decode(stream);
            return map;
        }
        catch (IOException ex) {
            String string = new String(read, StandardCharsets.UTF_8);
            return string;
        }
        finally {
            this.classes.poll();
        }
    }

    private List<ProtobufField> getFields() {
        return Optional.ofNullable(this.classes.peekFirst()).map(this::getFields).orElseGet(() -> this.getFields(this.modelClass));
    }

    private List<ProtobufField> getFields(Class<?> clazz) {
        if (clazz == null) {
            return List.of();
        }
        if (cache.containsKey(clazz)) {
            return cache.get(clazz);
        }
        ArrayList<ProtobufField> results = new ArrayList<ProtobufField>();
        if (ProtobufTypeDescriptor.hasDescriptor(clazz)) {
            Map<Integer, Class<?>> descriptor = this.invokeDescriptorMethod(clazz);
            descriptor.forEach((index, type) -> results.add(new ProtobufField((int)index, (Class<?>)type, false)));
            results.addAll(this.getFields(clazz.getSuperclass()));
            cache.put(clazz, results);
            return results;
        }
        Stream.of(clazz.getFields(), clazz.getDeclaredFields()).flatMap(Arrays::stream).filter(this::isProperty).map(field -> new ProtobufField(ProtobufUtils.parseIndex(field), this.getPropertyType((Field)field), this.isPacked((Field)field))).forEach(results::add);
        results.addAll(this.getFields(clazz.getSuperclass()));
        cache.put(clazz, results);
        return results;
    }

    private boolean isProperty(Field field) {
        return field.getAnnotation(JsonProperty.class) != null;
    }

    private Map<Integer, Class<?>> invokeDescriptorMethod(Class<?> clazz) {
        Method descriptorMethod = this.getDescriptorMethod(clazz, true);
        try {
            Object temp = clazz.getConstructor(new Class[0]).newInstance(new Object[0]);
            return (Map)descriptorMethod.invoke(temp, new Object[0]);
        }
        catch (Exception anotherException) {
            throw new IllegalArgumentException("Cannot use descriptor to infer type inside class %s: cannot invoke descriptor method using an instance".formatted(clazz.getName()));
        }
    }

    private Method getDescriptorMethod(Class<?> clazz, boolean accessible) {
        try {
            return accessible ? clazz.getMethod("descriptor", new Class[0]) : clazz.getDeclaredMethod("descriptor", new Class[0]);
        }
        catch (NoSuchMethodException exception) {
            if (accessible) {
                return this.getDescriptorMethod(clazz, false);
            }
            throw new IllegalArgumentException("Cannot use descriptor to infer type inside class %s: missing descriptor method".formatted(clazz.getName()));
        }
    }

    private Class<?> getPropertyType(Field field) {
        Class<?> inferredType = this.inferPropertyType(field);
        ProtobufType annotation = inferredType.getAnnotation(ProtobufType.class);
        return annotation != null ? annotation.value() : inferredType;
    }

    private Class<?> inferPropertyType(Field field) {
        ProtobufType explicitType = field.getAnnotation(ProtobufType.class);
        if (explicitType != null) {
            return explicitType.value();
        }
        if (!Collection.class.isAssignableFrom(field.getType())) {
            return field.getType();
        }
        Type genericType = field.getGenericType();
        if (genericType instanceof ParameterizedType) {
            ParameterizedType parameterizedType = (ParameterizedType)genericType;
            return (Class)parameterizedType.getActualTypeArguments()[0];
        }
        Type superClass = field.getType().getGenericSuperclass();
        return this.inferPropertyType(superClass);
    }

    private Class<?> inferPropertyType(Type superClass) {
        Objects.requireNonNull(superClass, "Serialization issue: cannot deduce generic type of field through class hierarchy");
        if (superClass instanceof ParameterizedType) {
            ParameterizedType parameterizedType = (ParameterizedType)superClass;
            return (Class)parameterizedType.getActualTypeArguments()[0];
        }
        Class concreteSuperClass = (Class)superClass;
        return this.inferPropertyType(concreteSuperClass.getGenericSuperclass());
    }

    private boolean isPacked(Field field) {
        return Optional.ofNullable(field.getAnnotation(JsonPropertyDescription.class)).map(JsonPropertyDescription::value).filter(entry -> entry.contains("[packed]")).isPresent();
    }

    private ProtobufDecoder(@NonNull Class<? extends T> modelClass) {
        if (modelClass == null) {
            throw new NullPointerException("modelClass is marked non-null but is null");
        }
        this.modelClass = modelClass;
    }

    public static <T> ProtobufDecoder<T> forType(@NonNull Class<? extends T> modelClass) {
        return new ProtobufDecoder<T>(modelClass);
    }

    public ProtobufDecoder<T> warnUnknownFields(boolean warnUnknownFields) {
        this.warnUnknownFields = warnUnknownFields;
        return this;
    }
}

