/*
 * Decompiled with CFR 0.152.
 */
package org.redfx.strange.algorithm;

import java.util.List;
import java.util.function.Function;
import org.redfx.strange.Complex;
import org.redfx.strange.ControlledBlockGate;
import org.redfx.strange.Gate;
import org.redfx.strange.Program;
import org.redfx.strange.QuantumExecutionEnvironment;
import org.redfx.strange.Qubit;
import org.redfx.strange.Result;
import org.redfx.strange.Step;
import org.redfx.strange.gate.Cr;
import org.redfx.strange.gate.Fourier;
import org.redfx.strange.gate.Hadamard;
import org.redfx.strange.gate.InvFourier;
import org.redfx.strange.gate.MulModulus;
import org.redfx.strange.gate.Oracle;
import org.redfx.strange.gate.ProbabilitiesGate;
import org.redfx.strange.gate.X;
import org.redfx.strange.local.Computations;
import org.redfx.strange.local.SimpleQuantumExecutionEnvironment;

public class Classic {
    private static QuantumExecutionEnvironment qee = new SimpleQuantumExecutionEnvironment();

    public static void setQuantumExecutionEnvironment(QuantumExecutionEnvironment val) {
        qee = val;
    }

    public static int randomBit() {
        Program program = new Program(1, new Step(new Hadamard(0)));
        Result result = qee.runProgram(program);
        Qubit[] qubits = result.getQubits();
        int answer = qubits[0].measure();
        return answer;
    }

    public static int qsum(int a, int b) {
        int i;
        int y = a > b ? a : b;
        int x = a > b ? b : a;
        int m = y < 2 ? 1 : 1 + (int)Math.ceil(Math.log(y) / Math.log(2.0));
        int n = x < 2 ? 1 : 1 + (int)Math.ceil(Math.log(x) / Math.log(2.0));
        Program program = new Program(m + n, new Step[0]);
        Step prep = new Step(new Gate[0]);
        int y0 = y;
        for (int i2 = 0; i2 < m; ++i2) {
            int p = 1 << m - i2 - 1;
            if (y0 < p) continue;
            prep.addGate(new X(m - i2 - 1));
            y0 -= p;
        }
        int x0 = x;
        for (i = 0; i < n; ++i) {
            int p = 1 << n - i - 1;
            if (x0 < p) continue;
            prep.addGate(new X(m + n - i - 1));
            x0 -= p;
        }
        program.addStep(prep);
        program.addStep(new Step(new Fourier(m, 0)));
        for (i = 0; i < m; ++i) {
            for (int j = 0; j < m - i; ++j) {
                int cr0 = 2 * m - j - i - 1;
                if (cr0 >= m + n) continue;
                Step s = new Step(new Cr(i, cr0, 2, 1 + j));
                program.addStep(s);
            }
        }
        program.addStep(new Step(new InvFourier(m, 0)));
        Result res = qee.runProgram(program);
        Qubit[] qubits = res.getQubits();
        int answer = 0;
        for (int i3 = 0; i3 < m; ++i3) {
            if (qubits[i3].measure() != 1) continue;
            answer += 1 << i3;
        }
        return answer;
    }

    public static <T> T search(List<T> list, Function<T, Integer> function) {
        int size = list.size();
        int n = (int)Math.ceil(Math.log(size) / Math.log(2.0));
        int N = 1 << n;
        double cnt = Math.PI * Math.sqrt(N) / 4.0;
        Oracle oracle = Classic.createGroverOracle(n, list, function);
        Program p = new Program(n, new Step[0]);
        Step s0 = new Step(new Gate[0]);
        for (int i = 0; i < n; ++i) {
            s0.addGate(new Hadamard(i));
        }
        p.addStep(s0);
        oracle.setCaption("O");
        Complex[][] dif = Classic.createDiffMatrix(n);
        Oracle difOracle = new Oracle(dif);
        difOracle.setCaption("D");
        int i = 1;
        while ((double)i < cnt) {
            Step s1 = new Step("Oracle " + i, new Gate[0]);
            s1.addGate(oracle);
            Step s2 = new Step("Diffusion " + i, new Gate[0]);
            s2.addGate(difOracle);
            Step s3 = new Step("Prob " + i, new Gate[0]);
            s3.addGate(new ProbabilitiesGate(0));
            p.addStep(s1);
            p.addStep(s2);
            p.addStep(s3);
            ++i;
        }
        System.out.println(" n = " + n + ", steps = " + cnt);
        Result res = qee.runProgram(p);
        Complex[] probability = res.getProbability();
        int winner = 0;
        double wv = 0.0;
        for (int i2 = 0; i2 < probability.length; ++i2) {
            double a = probability[i2].abssqr();
            if (!(a > wv)) continue;
            wv = a;
            winner = i2;
        }
        System.err.println("winner = " + winner + " with prob " + wv);
        return list.get(winner);
    }

    public static int findPeriod(int a, int mod) {
        int maxtries = 2;
        int tries = 0;
        int p = 0;
        while (p == 0 && tries < maxtries) {
            p = Classic.measurePeriod(a, mod);
            if (p != 0) continue;
            System.err.println("We measured a periodicity of 0, and have to start over.");
        }
        if (p == 0) {
            return -1;
        }
        int period = Computations.fraction(p, mod);
        return period;
    }

    public static int qfactor(int N) {
        System.out.println("We need to factor " + N);
        int a = 1 + (int)((double)(N - 1) * Math.random());
        System.out.println("Pick a random number a, a < N: " + a);
        int gcdan = Computations.gcd(N, a);
        System.out.println("calculate gcd(a, N):" + gcdan);
        if (gcdan != 1) {
            return gcdan;
        }
        int p = Classic.findPeriod(a, N);
        if (p == -1) {
            System.err.println("After too many tries with " + a + ", we need to pick a new random number.");
            return Classic.qfactor(N);
        }
        System.out.println("period of f = " + p);
        if (p % 2 == 1) {
            System.out.println("bummer, odd period, restart.");
            return Classic.qfactor(N);
        }
        int md = (int)(Math.pow(a, p / 2) + 1.0);
        int m2 = md % N;
        if (m2 == 0) {
            System.out.println("bummer, m^p/2 + 1 = 0 mod N, restart");
            return Classic.qfactor(N);
        }
        int f2 = (int)Math.pow(a, p / 2) - 1;
        int factor = Computations.gcd(N, f2);
        return factor;
    }

    private static int measurePeriod(int a, int mod) {
        int length;
        int offset = length = (int)Math.ceil(Math.log(mod) / Math.log(2.0));
        Program p = new Program(2 * length + 3 + offset, new Step[0]);
        Step prep = new Step(new Gate[0]);
        for (int i = 0; i < offset; ++i) {
            prep.addGate(new Hadamard(i));
        }
        Step prepAnc = new Step(new X(length + 1 + offset));
        p.addStep(prep);
        p.addStep(prepAnc);
        for (int i = length - 1; i > length - 1 - offset; --i) {
            int m = 1;
            for (int j = 0; j < 1 << i; ++j) {
                m = m * a % mod;
            }
            MulModulus mul = new MulModulus(length, 2 * length, m, mod);
            ControlledBlockGate cbg = new ControlledBlockGate(mul, offset, i);
            p.addStep(new Step(cbg));
        }
        p.addStep(new Step(new InvFourier(offset, 0)));
        System.err.println("Calculate periodicity using " + qee);
        Result result = qee.runProgram(p);
        Qubit[] q = result.getQubits();
        int answer = 0;
        for (int i = 0; i < offset; ++i) {
            answer += q[i].measure() * (1 << i);
        }
        return answer;
    }

    private static <T> Oracle createGroverOracle(int n, List<T> list, Function<T, Integer> function) {
        int N = 1 << n;
        int listSize = list.size();
        Complex[][] matrix = new Complex[N][N];
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                matrix[i][j] = i != j ? Complex.ZERO : (i >= listSize || function.apply(list.get(i)) == 0 ? Complex.ONE : Complex.ONE.mul(-1.0));
            }
        }
        return new Oracle(matrix);
    }

    private static Complex[][] createDiffMatrix(int n) {
        Complex[][] matrix;
        int N = 1 << n;
        Hadamard g = new Hadamard(0);
        Complex[][] h2 = matrix = g.getMatrix();
        for (int i = 1; i < n; ++i) {
            h2 = Complex.tensor(h2, matrix);
        }
        Complex[][] I2 = new Complex[N][N];
        for (int i = 0; i < N; ++i) {
            for (int j = 0; j < N; ++j) {
                I2[i][j] = i != j ? Complex.ZERO : Complex.ONE;
            }
        }
        I2[0][0] = Complex.ONE.mul(-1.0);
        int nd = n << 1;
        Complex[][] inter1 = Complex.mmul(h2, I2);
        Complex[][] dif = Complex.mmul(inter1, h2);
        return dif;
    }
}

