/*
 * Decompiled with CFR 0.152.
 */
package com.regnosys.rosetta.utils;

import com.regnosys.rosetta.interpreter.RosettaInterpreter;
import com.regnosys.rosetta.interpreter.RosettaInterpreterContext;
import com.regnosys.rosetta.interpreter.RosettaValue;
import com.regnosys.rosetta.rosetta.RosettaCallableWithArgs;
import com.regnosys.rosetta.rosetta.RosettaSymbol;
import com.regnosys.rosetta.rosetta.expression.RosettaExpression;
import com.regnosys.rosetta.rosetta.expression.RosettaLiteral;
import com.regnosys.rosetta.rosetta.expression.RosettaSymbolReference;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.inject.Inject;

public class RosettaSimpleSystemSolver {
    @Inject
    private RosettaInterpreter interpreter;

    public Optional<SolutionSet> solve(Collection<Equation> equations, Set<? extends RosettaSymbol> variablesToSolve) {
        SolutionSet solution = new SolutionSet();
        for (Equation eq : equations) {
            RosettaSymbol leftSymbol;
            if (!this.isSimple(eq, variablesToSolve)) {
                return Optional.empty();
            }
            RosettaExpression left = eq.getLeft();
            RosettaExpression right = eq.getRight();
            if (this.isVariable(left) && this.isVariable(right)) {
                leftSymbol = ((RosettaSymbolReference)left).getSymbol();
                RosettaSymbol rightSymbol = ((RosettaSymbolReference)right).getSymbol();
                if (variablesToSolve.contains(leftSymbol)) {
                    solution.addMapping(leftSymbol, right);
                    continue;
                }
                if (variablesToSolve.contains(rightSymbol)) {
                    solution.addMapping(rightSymbol, left);
                    continue;
                }
                solution.addCondition(eq);
                continue;
            }
            if (this.isVariable(left)) {
                leftSymbol = ((RosettaSymbolReference)left).getSymbol();
                if (variablesToSolve.contains(leftSymbol)) {
                    solution.addMapping(leftSymbol, right);
                    continue;
                }
                solution.addCondition(eq);
                continue;
            }
            if (this.isVariable(right)) {
                RosettaSymbol rightSymbol = ((RosettaSymbolReference)right).getSymbol();
                if (variablesToSolve.contains(rightSymbol)) {
                    solution.addMapping(rightSymbol, left);
                    continue;
                }
                solution.addCondition(eq);
                continue;
            }
            solution.addCondition(eq);
        }
        if (!solution.getSolvedVariables().equals(variablesToSolve)) {
            return Optional.empty();
        }
        return Optional.of(solution);
    }

    public boolean isSimple(Equation equation, Collection<? extends RosettaSymbol> variablesToSolve) {
        RosettaExpression left = equation.getLeft();
        RosettaExpression right = equation.getRight();
        if (!this.isVariable(left) && !this.isLiteral(left)) {
            return false;
        }
        if (!this.isVariable(right) && !this.isLiteral(right)) {
            return false;
        }
        if (this.isVariable(left) && this.isVariable(right)) {
            RosettaSymbol leftSymbol = ((RosettaSymbolReference)left).getSymbol();
            RosettaSymbol rightSymbol = ((RosettaSymbolReference)right).getSymbol();
            if (variablesToSolve.contains(leftSymbol) && variablesToSolve.contains(rightSymbol)) {
                return false;
            }
        }
        return true;
    }

    private boolean isVariable(RosettaExpression expr) {
        return expr instanceof RosettaSymbolReference && !(((RosettaSymbolReference)expr).getSymbol() instanceof RosettaCallableWithArgs);
    }

    private boolean isLiteral(RosettaExpression expr) {
        return expr instanceof RosettaLiteral;
    }

    public class SolutionSet {
        private Map<RosettaSymbol, RosettaExpression> solutionMap = new HashMap<RosettaSymbol, RosettaExpression>();
        private Set<Equation> conditions = new HashSet<Equation>();

        public void addMapping(RosettaSymbol var, RosettaExpression solution) {
            RosettaExpression existingSolution = this.solutionMap.get(var);
            if (existingSolution == null) {
                this.solutionMap.put(var, solution);
            } else {
                this.conditions.add(new Equation(solution, existingSolution));
            }
        }

        public void addCondition(Equation eq) {
            this.conditions.add(eq);
        }

        public Set<RosettaSymbol> getSolvedVariables() {
            return this.solutionMap.keySet();
        }

        public Optional<Map<RosettaSymbol, RosettaValue>> getSolution(RosettaInterpreterContext context) {
            for (Equation condition : this.conditions) {
                RosettaValue evalRight;
                RosettaValue evalLeft = RosettaSimpleSystemSolver.this.interpreter.interpret(condition.getLeft(), context);
                if (evalLeft.equals(evalRight = RosettaSimpleSystemSolver.this.interpreter.interpret(condition.getRight(), context))) continue;
                return Optional.empty();
            }
            HashMap solution = new HashMap();
            this.solutionMap.entrySet().forEach(e -> solution.put((RosettaSymbol)e.getKey(), RosettaSimpleSystemSolver.this.interpreter.interpret((RosettaExpression)e.getValue(), context)));
            return Optional.of(solution);
        }
    }

    public static class Equation {
        private final RosettaExpression left;
        private final RosettaExpression right;

        public Equation(RosettaExpression left, RosettaExpression right) {
            this.left = left;
            this.right = right;
        }

        public RosettaExpression getLeft() {
            return this.left;
        }

        public RosettaExpression getRight() {
            return this.right;
        }
    }
}

