package org.allenai.ml.sequences;

import com.gs.collections.api.list.MutableList;
import com.gs.collections.api.map.primitive.MutableIntObjectMap;
import com.gs.collections.api.tuple.Pair;
import com.gs.collections.impl.list.mutable.FastList;
import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap;
import com.gs.collections.impl.tuple.Tuples;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.allenai.ml.util.IOUtils;

/* loaded from: input_file:org/allenai/ml/sequences/StateSpace.class */
public class StateSpace<S> {
    private final List<S> states;
    private final List<Transition> transitions;
    private final MutableIntObjectMap<MutableList<Transition>> fromTransitions;
    private final MutableIntObjectMap<MutableList<Transition>> toTransitions;
    private static final String DATA_VERSION = "1.0";

    public StateSpace(List<S> list, List<Pair<S, S>> list2) {
        if (list.size() != new HashSet(list).size()) {
            throw new IllegalArgumentException("Passed in duplicate states: " + list);
        }
        if (list2.size() != new HashSet(list2).size()) {
            throw new IllegalArgumentException("Passed in transition pairs");
        }
        Stream<Integer> boxed = IntStream.range(0, list.size()).boxed();
        list.getClass();
        Map map = (Map) boxed.collect(Collectors.toMap((v1) -> {
            return r1.get(v1);
        }, Function.identity()));
        this.states = new ArrayList(list);
        this.transitions = new ArrayList();
        for (Pair<S, S> pair : list2) {
            this.transitions.add(new Transition(((Integer) map.get(pair.getOne())).intValue(), ((Integer) map.get(pair.getTwo())).intValue(), this.transitions.size()));
        }
        this.fromTransitions = new IntObjectHashMap(this.states.size());
        this.toTransitions = new IntObjectHashMap(this.states.size());
        for (Transition transition : this.transitions) {
            ((MutableList) this.fromTransitions.getIfAbsentPut(transition.fromState, FastList::new)).add(transition);
            ((MutableList) this.toTransitions.getIfAbsentPut(transition.toState, FastList::new)).add(transition);
        }
    }

    public Optional<Transition> transitionFor(S s, S s2) {
        int indexOf = this.states.indexOf(s);
        int indexOf2 = this.states.indexOf(s2);
        return (indexOf < 0 || indexOf2 < 0) ? Optional.empty() : transitionFor(indexOf, indexOf2);
    }

    public Optional<Transition> transitionFor(int i, int i2) {
        return transitionsFrom(i).stream().filter(transition -> {
            return transition.toState == i2;
        }).findFirst();
    }

    public List<Transition> transitionsFrom(int i) {
        return Collections.unmodifiableList((List) this.fromTransitions.getIfAbsent(i, FastList::new));
    }

    public List<Transition> transitionsTo(int i) {
        return Collections.unmodifiableList((List) this.toTransitions.getIfAbsent(i, FastList::new));
    }

    public Pair<S, S> transition(int i) {
        Transition transition = this.transitions.get(i);
        return Tuples.pair(this.states.get(transition.fromState), this.states.get(transition.toState));
    }

    public S startState() {
        return this.states.get(0);
    }

    public S stopState() {
        return this.states.get(1);
    }

    public int stateIndex(S s) {
        return this.states.indexOf(s);
    }

    public List<S> states() {
        return Collections.unmodifiableList(this.states);
    }

    public List<Transition> transitions() {
        return Collections.unmodifiableList(this.transitions);
    }

    public int startStateIndex() {
        return 0;
    }

    public int stopStateIndex() {
        return 1;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static <S> List<S> ensureStartStopPadded(List<S> list, S s, S s2) {
        ArrayList arrayList = new ArrayList(list);
        if (arrayList.size() == 0 || !arrayList.get(0).equals(s)) {
            arrayList.add(0, s);
        }
        int size = arrayList.size() - 1;
        if (arrayList.size() < 2 || !arrayList.get(size).equals(s2)) {
            arrayList.add(s2);
        }
        return arrayList;
    }

    private static <S> List<Pair<S, S>> transitions(List<S> list) {
        ArrayList arrayList = new ArrayList(list.size() - 1);
        for (int i = 0; i < list.size() - 1; i++) {
            arrayList.add(Tuples.pair(list.get(i), list.get(i + 1)));
        }
        return arrayList;
    }

    public static <S> StateSpace<S> buildFullStateSpace(Set<S> set, S s, S s2) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(s);
        arrayList.add(s2);
        HashSet hashSet = new HashSet(set);
        hashSet.remove(s);
        hashSet.remove(s2);
        arrayList.addAll(hashSet);
        ArrayList arrayList2 = new ArrayList(arrayList.size() * arrayList.size());
        for (Object obj : arrayList) {
            for (Object obj2 : arrayList) {
                if (obj != s2 && obj2 != s) {
                    arrayList2.add(Tuples.pair(obj, obj2));
                }
            }
        }
        return new StateSpace<>(arrayList, arrayList2);
    }

    public static <S> StateSpace<S> buildFromSequences(Collection<List<S>> collection, S s, S s2) {
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(Arrays.asList(s, s2));
        arrayList.addAll((Set) collection.stream().flatMap((v0) -> {
            return v0.stream();
        }).filter(obj -> {
            return (obj.equals(s) || obj.equals(s2)) ? false : true;
        }).collect(Collectors.toSet()));
        return new StateSpace<>(arrayList, (List) collection.stream().map(list -> {
            return ensureStartStopPadded(list, s, s2);
        }).flatMap(list2 -> {
            return transitions(list2).stream();
        }).distinct().collect(Collectors.toList()));
    }

    public static StateSpace<String> load(DataInputStream dataInputStream) {
        IOUtils.ensureVersionMatch(dataInputStream, DATA_VERSION);
        List<String> loadList = IOUtils.loadList(dataInputStream);
        int readInt = dataInputStream.readInt();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < readInt; i++) {
            arrayList.add(Tuples.pair(loadList.get(dataInputStream.readInt()), loadList.get(dataInputStream.readInt())));
        }
        return new StateSpace<>(loadList, arrayList);
    }

    public void save(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeUTF(DATA_VERSION);
        IOUtils.saveList(dataOutputStream, (List) this.states.stream().map((v0) -> {
            return v0.toString();
        }).collect(Collectors.toList()));
        dataOutputStream.writeInt(transitions().size());
        for (Transition transition : this.transitions) {
            dataOutputStream.writeInt(transition.fromState);
            dataOutputStream.writeInt(transition.toState);
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1818100338:
                if (implMethodName.equals("<init>")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/Function0") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gs/collections/impl/list/mutable/FastList") && serializedLambda.getImplMethodSignature().equals("()V")) {
                    return FastList::new;
                }
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/Function0") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gs/collections/impl/list/mutable/FastList") && serializedLambda.getImplMethodSignature().equals("()V")) {
                    return FastList::new;
                }
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/Function0") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gs/collections/impl/list/mutable/FastList") && serializedLambda.getImplMethodSignature().equals("()V")) {
                    return FastList::new;
                }
                if (serializedLambda.getImplMethodKind() == 8 && serializedLambda.getFunctionalInterfaceClass().equals("com/gs/collections/api/block/function/Function0") && serializedLambda.getFunctionalInterfaceMethodName().equals("value") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("()Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/gs/collections/impl/list/mutable/FastList") && serializedLambda.getImplMethodSignature().equals("()V")) {
                    return FastList::new;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
