/*
 * Decompiled with CFR 0.152.
 */
package org.redfx.strange.local;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.redfx.strange.Block;
import org.redfx.strange.BlockGate;
import org.redfx.strange.Complex;
import org.redfx.strange.ControlledBlockGate;
import org.redfx.strange.Gate;
import org.redfx.strange.QuantumExecutionEnvironment;
import org.redfx.strange.Step;
import org.redfx.strange.gate.Identity;
import org.redfx.strange.gate.Oracle;
import org.redfx.strange.gate.PermutationGate;
import org.redfx.strange.gate.ProbabilitiesGate;
import org.redfx.strange.gate.SingleQubitGate;
import org.redfx.strange.gate.ThreeQubitGate;
import org.redfx.strange.gate.TwoQubitGate;
import org.redfx.strange.local.SimpleQuantumExecutionEnvironment;

public class Computations {
    private static final boolean debug = false;
    static int nested = 0;

    static void dbg(String s) {
        SimpleQuantumExecutionEnvironment.dbg(s);
    }

    public static Complex[][] calculateStepMatrix(List<Gate> gates, int nQubits, QuantumExecutionEnvironment qee) {
        long l0 = System.currentTimeMillis();
        Complex[][] a = new Complex[1][1];
        a[0][0] = Complex.ONE;
        for (int idx = nQubits - 1; idx >= 0; --idx) {
            Gate tqg;
            Gate sqg;
            int cnt = idx;
            Gate myGate = gates.stream().filter(gate -> gate.getHighestAffectedQubitIndex() == cnt).findFirst().orElse(new Identity(idx));
            Computations.dbg("stepmatrix, cnt = " + cnt + ", idx = " + idx + ", myGate = " + myGate);
            if (myGate instanceof BlockGate) {
                Computations.dbg("calculateStepMatrix for blockgate " + myGate + " of class " + myGate.getClass());
                sqg = (BlockGate)myGate;
                a = Complex.tensor(a, ((BlockGate)sqg).getMatrix(qee));
                Computations.dbg("calculateStepMatrix for blockgate calculated " + myGate);
                idx = idx - ((BlockGate)sqg).getSize() + 1;
            }
            if (myGate instanceof SingleQubitGate) {
                sqg = (SingleQubitGate)myGate;
                a = Complex.tensor(a, ((SingleQubitGate)sqg).getMatrix());
            }
            if (myGate instanceof TwoQubitGate) {
                tqg = (TwoQubitGate)myGate;
                a = Complex.tensor(a, tqg.getMatrix());
                --idx;
            }
            if (myGate instanceof ThreeQubitGate) {
                tqg = (ThreeQubitGate)myGate;
                a = Complex.tensor(a, tqg.getMatrix());
                idx -= 2;
            }
            if (myGate instanceof PermutationGate) {
                throw new RuntimeException("No perm allowed ");
            }
            if (!(myGate instanceof Oracle)) continue;
            a = myGate.getMatrix();
            idx = 0;
        }
        long l1 = System.currentTimeMillis();
        return a;
    }

    public static List<Step> decomposeStep(Step s, int nqubit) {
        ArrayList<Step> answer = new ArrayList<Step>();
        answer.add(s);
        if (s.getType() == Step.Type.PSEUDO) {
            s.setComplexStep(s.getIndex());
            return answer;
        }
        List<Gate> gates = s.getGates();
        if (gates.isEmpty()) {
            return answer;
        }
        boolean simple = gates.stream().allMatch(g -> g instanceof SingleQubitGate);
        if (simple) {
            return answer;
        }
        if (gates.size() == 1 && gates.get(0) instanceof Oracle) {
            return answer;
        }
        ArrayList<Gate> firstGates = new ArrayList<Gate>();
        for (Gate gate : gates) {
            int second;
            int first;
            Gate tqg;
            if (gate.getHighestAffectedQubitIndex() > nqubit) {
                throw new IllegalArgumentException("Only " + nqubit + " qubits available while Gate " + gate + " requires qubit " + gate.getHighestAffectedQubitIndex());
            }
            if (gate instanceof ProbabilitiesGate) {
                s.setInformalStep(true);
                return answer;
            }
            if (gate instanceof BlockGate) {
                if (gate instanceof ControlledBlockGate) {
                    Computations.processBlockGate((ControlledBlockGate)gate, answer);
                }
                firstGates.add(gate);
                continue;
            }
            if (gate instanceof SingleQubitGate) {
                firstGates.add(gate);
                continue;
            }
            if (gate instanceof TwoQubitGate) {
                tqg = (TwoQubitGate)gate;
                first = ((TwoQubitGate)tqg).getMainQubitIndex();
                second = ((TwoQubitGate)tqg).getSecondQubitIndex();
                if (first >= nqubit || second >= nqubit) {
                    throw new IllegalArgumentException("Step " + s + " uses a gate with invalid index " + first + " or " + second);
                }
                if (first == second + 1) {
                    firstGates.add(gate);
                    continue;
                }
                if (first == second) {
                    throw new RuntimeException("Wrong gate, first == second for " + gate);
                }
                if (first > second) {
                    PermutationGate pg = new PermutationGate(first - 1, second, nqubit);
                    Step prePermutation = new Step(pg);
                    Step postPermutation = new Step(pg);
                    answer.add(0, prePermutation);
                    answer.add(postPermutation);
                    postPermutation.setComplexStep(s.getIndex());
                    s.setComplexStep(-1);
                    continue;
                }
                PermutationGate pg = new PermutationGate(first, second, nqubit);
                Step prePermutation = new Step(pg);
                Step prePermutationInv = new Step(pg);
                int realStep = s.getIndex();
                s.setComplexStep(-1);
                answer.add(0, prePermutation);
                answer.add(prePermutationInv);
                Step postPermutation = new Step(new Gate[0]);
                Step postPermutationInv = new Step(new Gate[0]);
                if (first != second - 1) {
                    PermutationGate pg2 = new PermutationGate(second - 1, first, nqubit);
                    postPermutation.addGate(pg2);
                    postPermutationInv.addGate(pg2);
                    answer.add(1, postPermutation);
                    answer.add(3, postPermutationInv);
                }
                prePermutationInv.setComplexStep(realStep);
                continue;
            }
            if (gate instanceof ThreeQubitGate) {
                Step postPermutation;
                Step prePermutation;
                PermutationGate pg;
                tqg = (ThreeQubitGate)gate;
                first = ((ThreeQubitGate)tqg).getMainQubit();
                second = ((ThreeQubitGate)tqg).getSecondQubit();
                int third = ((ThreeQubitGate)tqg).getThirdQubit();
                int sFirst = first;
                int sSecond = second;
                int sThird = third;
                if (first == second + 1 && second == third + 1) {
                    firstGates.add(gate);
                    continue;
                }
                int p0idx = 0;
                int maxs = Math.max(second, third);
                if (first < maxs) {
                    pg = new PermutationGate(first, maxs, nqubit);
                    prePermutation = new Step(pg);
                    postPermutation = new Step(pg);
                    answer.add(p0idx, prePermutation);
                    answer.add(answer.size() - p0idx, postPermutation);
                    ++p0idx;
                    postPermutation.setComplexStep(s.getIndex());
                    s.setComplexStep(-1);
                    sFirst = maxs;
                    if (second > third) {
                        sSecond = first;
                    } else {
                        sThird = first;
                    }
                }
                if (sSecond != sFirst - 1) {
                    pg = new PermutationGate(sFirst - 1, sSecond, nqubit);
                    prePermutation = new Step(pg);
                    postPermutation = new Step(pg);
                    answer.add(p0idx, prePermutation);
                    answer.add(answer.size() - p0idx, postPermutation);
                    ++p0idx;
                    postPermutation.setComplexStep(s.getIndex());
                    s.setComplexStep(-1);
                    sSecond = sFirst - 1;
                }
                if (sThird == sFirst - 2) continue;
                pg = new PermutationGate(sFirst - 2, sThird, nqubit);
                prePermutation = new Step(pg);
                postPermutation = new Step(pg);
                answer.add(p0idx, prePermutation);
                answer.add(answer.size() - p0idx, postPermutation);
                ++p0idx;
                postPermutation.setComplexStep(s.getIndex());
                s.setComplexStep(-1);
                sThird = sFirst - 2;
                continue;
            }
            throw new RuntimeException("Gate must be SingleQubit or TwoQubit");
        }
        return answer;
    }

    public static void printMatrix(Complex[][] a) {
        for (int i = 0; i < a.length; ++i) {
            StringBuilder sb = new StringBuilder();
            for (int j = 0; j < a[i].length; ++j) {
                sb.append(a[i][j]).append("    ");
            }
            System.out.println("m[" + i + "]: " + sb);
        }
    }

    public static int getInverseModulus(int a, int b) {
        int r0 = a;
        int r1 = b;
        int r2 = 0;
        int s0 = 1;
        int s1 = 0;
        int s2 = 0;
        while (r1 != 1) {
            int q = r0 / r1;
            r2 = r0 % r1;
            s2 = s0 - q * s1;
            r0 = r1;
            r1 = r2;
            s0 = s1;
            s1 = s2;
        }
        return s1 > 0 ? s1 : s1 + b;
    }

    public static int gcd(int a, int b) {
        int x = a > b ? a : b;
        int y = x == a ? b : a;
        int z = 0;
        while (y != 0) {
            z = x % y;
            x = y;
            y = z;
        }
        return x;
    }

    public static int fraction(int p, int max) {
        int length;
        int offset = length = (int)Math.ceil(Math.log(max) / Math.log(2.0));
        int dim = 1 << offset;
        double r = (double)p / (double)dim + 1.0E-6;
        int period = Computations.fraction(r, max);
        return period;
    }

    public static int fraction(double d, int max) {
        double EPS = 1.0E-15;
        int answer = -1;
        int h = 0;
        int k = -1;
        int a = (int)d;
        double r = d - (double)a;
        int h_2 = 0;
        int h_1 = 1;
        int k_2 = 1;
        int k_1 = 0;
        while (k < max && r > EPS) {
            h = a * h_1 + h_2;
            k = a * k_1 + k_2;
            h_2 = h_1;
            h_1 = h;
            k_2 = k_1;
            k_1 = k;
            double rec = 1.0 / r;
            a = (int)rec;
            r = rec - (double)a;
        }
        return k_2;
    }

    public static Complex[][] createIdentity(int dim) {
        Complex[][] matrix = new Complex[dim][dim];
        for (int i = 0; i < dim; ++i) {
            for (int j = 0; j < dim; ++j) {
                matrix[i][j] = i == j ? Complex.ONE : Complex.ZERO;
            }
        }
        return matrix;
    }

    public static void printMemory() {
    }

    public static Complex[] permutateVector(Complex[] vector, int a, int b) {
        int amask = 1 << a;
        int bmask = 1 << b;
        if (amask >= vector.length || bmask >= vector.length) {
            throw new IllegalArgumentException("Can not permutate element " + a + " and " + b + " of vector sized " + vector.length);
        }
        int dim = vector.length;
        Complex[] answer = new Complex[dim];
        for (int i = 0; i < dim; ++i) {
            int j = i;
            int x = (amask & i) / amask;
            int y = (bmask & i) / bmask;
            if (x != y) {
                j ^= amask;
                j ^= bmask;
            }
            answer[i] = vector[j];
        }
        return answer;
    }

    public static Complex[] calculateNewState(List<Gate> gates, Complex[] vector, int length) {
        ++nested;
        Complex[] answer = Computations.getNextProbability(Computations.getAllGates(gates, length), vector);
        --nested;
        return answer;
    }

    private static Complex[] getNextProbability(List<Gate> gates, Complex[] v) {
        Gate gate = gates.get(0);
        int nqubits = gate.getSize();
        int gatedim = 1 << nqubits;
        int size = v.length;
        Computations.dbg("GETNEXTPROBABILITY asked for size = " + size + " and gates = " + gates);
        if (gates.size() > 1) {
            int partdim = size / gatedim;
            Complex[] answer = new Complex[size];
            List<Gate> nextGates = gates.subList(1, gates.size());
            boolean id = true;
            for (Gate g : nextGates) {
                id = id && g instanceof Identity;
            }
            if (id) {
                long s0;
                Computations.dbg("ONLY IDENTITY!! partdim = " + partdim);
                long s1 = s0 = System.currentTimeMillis();
                for (int j = 0; j < partdim; ++j) {
                    int i;
                    Computations.dbg("do part " + j + " from " + partdim);
                    Complex[] oldv = new Complex[gatedim];
                    Complex[] newv = new Complex[gatedim];
                    for (i = 0; i < gatedim; ++i) {
                        oldv[i] = v[i * partdim + j];
                        newv[i] = Complex.ZERO;
                    }
                    if (gate.hasOptimization()) {
                        Computations.dbg("OPTPART!");
                        newv = gate.applyOptimize(oldv);
                    } else {
                        Computations.dbg("GET MATRIX for  " + gate);
                        Complex[][] matrix = gate.getMatrix();
                        s1 = System.currentTimeMillis();
                        for (int i2 = 0; i2 < gatedim; ++i2) {
                            for (int k = 0; k < gatedim; ++k) {
                                newv[i2] = newv[i2].add(matrix[i2][k].mul(oldv[k]));
                            }
                        }
                    }
                    for (i = 0; i < gatedim; ++i) {
                        answer[i * partdim + j] = newv[i];
                    }
                    Computations.dbg("done part");
                }
                long s2 = System.currentTimeMillis();
                return answer;
            }
            long sm0 = System.currentTimeMillis();
            Complex[][] vsub = new Complex[gatedim][partdim];
            for (int i = 0; i < gatedim; ++i) {
                Complex[] vorig = new Complex[partdim];
                for (int j = 0; j < partdim; ++j) {
                    vorig[j] = v[j + i * partdim];
                }
                vsub[i] = Computations.getNextProbability(nextGates, vorig);
            }
            long s0 = System.currentTimeMillis();
            Complex[][] matrix = gate.getMatrix();
            long s1 = System.currentTimeMillis();
            for (int i = 0; i < gatedim; ++i) {
                for (int j = 0; j < partdim; ++j) {
                    answer[j + i * partdim] = Complex.ZERO;
                    for (int k = 0; k < gatedim; ++k) {
                        answer[j + i * partdim] = answer[j + i * partdim].add(matrix[i][k].mul(vsub[k][j]));
                    }
                }
            }
            long s2 = System.currentTimeMillis();
            return answer;
        }
        if (gatedim != size) {
            System.err.println("problem with matrix for gate " + gate);
            throw new IllegalArgumentException("wrong matrix size " + gatedim + " vs vector size " + v.length);
        }
        if (gate.hasOptimization()) {
            return gate.applyOptimize(v);
        }
        Complex[][] matrix = gate.getMatrix();
        Complex[] answer = new Complex[size];
        for (int i = 0; i < size; ++i) {
            answer[i] = Complex.ZERO;
            for (int j = 0; j < size; ++j) {
                answer[i] = answer[i].add(matrix[i][j].mul(v[j]));
            }
        }
        return answer;
    }

    private static void validateGates(List<Gate> gates, int nQubits) {
        for (Gate gate : gates) {
            if (gate.getHighestAffectedQubitIndex() < nQubits) continue;
            throw new IllegalArgumentException("Gate " + gate + " operates on qubit " + gate.getHighestAffectedQubitIndex() + " but we have only " + nQubits + " qubits.");
        }
    }

    private static List<Gate> getAllGates(List<Gate> gates, int nQubits) {
        Computations.validateGates(gates, nQubits);
        Computations.dbg("getAllGates, orig = " + gates);
        ArrayList<Gate> answer = new ArrayList<Gate>();
        for (int idx = nQubits - 1; idx >= 0; --idx) {
            int cnt = idx;
            Gate myGate = gates.stream().filter(gate -> gate.getHighestAffectedQubitIndex() == cnt).findFirst().orElse(new Identity(idx));
            Computations.dbg("stepmatrix, cnt = " + cnt + ", idx = " + idx + ", myGate = " + myGate);
            answer.add(myGate);
            if (myGate instanceof BlockGate) {
                BlockGate sqg = (BlockGate)myGate;
                idx = idx - sqg.getSize() + 1;
                Computations.dbg("processed blockgate, size = " + sqg.getSize() + ", idx = " + idx);
            }
            if (myGate instanceof TwoQubitGate) {
                --idx;
            }
            if (myGate instanceof ThreeQubitGate) {
                idx -= 2;
            }
            if (myGate instanceof PermutationGate) {
                throw new RuntimeException("No perm allowed ");
            }
            if (!(myGate instanceof Oracle)) continue;
            idx = 0;
        }
        return answer;
    }

    private static void processBlockGate(ControlledBlockGate gate, ArrayList<Step> answer) {
        PermutationGate pg;
        Step master = answer.get(answer.size() - 1);
        gate.calculateHighLow();
        int low = gate.getLow();
        int control = gate.getControlQubit();
        int idx = gate.getMainQubitIndex();
        int high = control;
        int size = gate.getSize();
        int gap = control - idx;
        LinkedList<PermutationGate> perm = new LinkedList<PermutationGate>();
        Block block = gate.getBlock();
        int bs = block.getNQubits();
        if (control > idx) {
            if (gap < bs) {
                throw new IllegalArgumentException("Can't have control at " + control + " for gate with size " + bs + " starting at " + idx);
            }
            low = idx;
            if (gap > bs) {
                high = control;
                size = high - low + 1;
                PermutationGate pg2 = new PermutationGate(control, control - gap + bs, low + size);
                perm.add(pg2);
            }
        } else {
            low = control;
            high = idx + bs - 1;
            size = high - low + 1;
            for (int i = low; i < low + size - 1; ++i) {
                pg = new PermutationGate(i, i + 1, low + size);
                perm.add(0, pg);
            }
        }
        for (int i = 0; i < perm.size(); ++i) {
            pg = (PermutationGate)perm.get(i);
            Step lpg = new Step(pg);
            if (i < perm.size() - 1) {
                lpg.setComplexStep(-1);
            } else {
                lpg.setComplexStep(master.getComplexStep());
                master.setComplexStep(-1);
            }
            answer.add(lpg);
            answer.add(0, new Step(pg));
        }
    }

    public static double[] calculateQubitStatesFromVector(Complex[] vectorresult) {
        int nq = (int)Math.round(Math.log(vectorresult.length) / Math.log(2.0));
        double[] answer = new double[nq];
        int ressize = 1 << nq;
        for (int i = 0; i < nq; ++i) {
            int pw = i;
            int div = 1 << pw;
            for (int j = 0; j < ressize; ++j) {
                int p1 = j / div;
                if (p1 % 2 != 1) continue;
                answer[i] = answer[i] + vectorresult[j].abssqr();
            }
        }
        return answer;
    }
}

