/*
 * Decompiled with CFR 0.152.
 */
package org.teavm.backend.wasm.transformation.gc;

import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.function.Supplier;
import org.teavm.backend.wasm.BaseWasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.generate.gc.classes.WasmGCClassInfoProvider;
import org.teavm.backend.wasm.model.WasmFunction;
import org.teavm.backend.wasm.model.WasmLocal;
import org.teavm.backend.wasm.model.WasmType;
import org.teavm.backend.wasm.model.expression.WasmBlock;
import org.teavm.backend.wasm.model.expression.WasmCall;
import org.teavm.backend.wasm.model.expression.WasmConditional;
import org.teavm.backend.wasm.model.expression.WasmExpression;
import org.teavm.backend.wasm.model.expression.WasmFloat32Constant;
import org.teavm.backend.wasm.model.expression.WasmFloat64Constant;
import org.teavm.backend.wasm.model.expression.WasmGetLocal;
import org.teavm.backend.wasm.model.expression.WasmInt32Constant;
import org.teavm.backend.wasm.model.expression.WasmInt64Constant;
import org.teavm.backend.wasm.model.expression.WasmNullConstant;
import org.teavm.backend.wasm.model.expression.WasmReturn;
import org.teavm.backend.wasm.model.expression.WasmSetLocal;
import org.teavm.backend.wasm.transformation.SuspensionPointCollector;
import org.teavm.backend.wasm.transformation.gc.CoroutineFunctions;
import org.teavm.backend.wasm.transformation.gc.CoroutineTransformationVisitor;
import org.teavm.backend.wasm.transformation.gc.NonOptimizableBlockCollector;

public class CoroutineTransformation {
    private static final String FIBER = "org.teavm.runtime.Fiber";
    private WasmFunctionTypes functionTypes;
    private WasmGCClassInfoProvider classInfoProvider;
    private CoroutineFunctions coroutineFunctions;

    public CoroutineTransformation(WasmFunctionTypes functionTypes, BaseWasmFunctionRepository functions, WasmGCClassInfoProvider classInfoProvider) {
        this.functionTypes = functionTypes;
        this.classInfoProvider = classInfoProvider;
        this.coroutineFunctions = new CoroutineFunctions(functions);
    }

    public void transform(final WasmFunction function) {
        int lastIndex;
        WasmExpression last;
        SuspensionPointCollector suspensionPoints = new SuspensionPointCollector();
        for (WasmExpression wasmExpression : function.getBody()) {
            wasmExpression.acceptVisitor(suspensionPoints);
        }
        NonOptimizableBlockCollector nonOptimizableBlockCollector = new NonOptimizableBlockCollector();
        nonOptimizableBlockCollector.suspendable = suspensionPoints;
        nonOptimizableBlockCollector.nonOptimizableBlocks = new LinkedHashSet<WasmBlock>();
        for (WasmExpression part : function.getBody()) {
            part.acceptVisitor(nonOptimizableBlockCollector);
        }
        CoroutineTransformationVisitor coroutineTransformationVisitor = new CoroutineTransformationVisitor(this.functionTypes, this.coroutineFunctions);
        coroutineTransformationVisitor.nonOptimizableBlocks = nonOptimizableBlockCollector.nonOptimizableBlocks;
        coroutineTransformationVisitor.collector = suspensionPoints;
        List<WasmLocal> locals = List.copyOf(function.getLocalVariables());
        WasmLocal stateLocal = new WasmLocal(WasmType.INT32, "_teavm_fiberState");
        WasmLocal fiberLocal = new WasmLocal(this.classInfoProvider.getClassInfo(FIBER).getType(), "_teavm_fiber");
        function.add(stateLocal);
        function.add(fiberLocal);
        coroutineTransformationVisitor.stateLocal = stateLocal;
        coroutineTransformationVisitor.fiberLocal = fiberLocal;
        coroutineTransformationVisitor.tmpValueLocalSupplier = new Supplier<WasmLocal>(){
            private WasmLocal local;

            @Override
            public WasmLocal get() {
                if (this.local == null) {
                    this.local = new WasmLocal(WasmType.Reference.FUNC, "_teavm_fiberTmp");
                    function.add(this.local);
                }
                return this.local;
            }
        };
        coroutineTransformationVisitor.init();
        for (WasmExpression part : function.getBody()) {
            part.acceptVisitor(coroutineTransformationVisitor);
        }
        if (!coroutineTransformationVisitor.resultList.isEmpty() && !(last = coroutineTransformationVisitor.resultList.get(lastIndex = coroutineTransformationVisitor.resultList.size() - 1)).isTerminating()) {
            coroutineTransformationVisitor.resultList.set(lastIndex, new WasmReturn(last));
        }
        function.getBody().clear();
        function.getBody().addAll(this.generatePrologue(fiberLocal, stateLocal, locals));
        coroutineTransformationVisitor.mainBlock.getBody().addAll(coroutineTransformationVisitor.resultList);
        function.getBody().add(coroutineTransformationVisitor.mainBlock);
        function.getBody().addAll(this.generateEpilogue(fiberLocal, stateLocal, locals, function.getType().getSingleReturnType()));
        coroutineTransformationVisitor.complete();
    }

    private List<WasmExpression> generatePrologue(WasmLocal fiberLocal, WasmLocal stateLocal, List<WasmLocal> locals) {
        ArrayList<WasmExpression> prologue = new ArrayList<WasmExpression>();
        prologue.add(new WasmSetLocal(fiberLocal, new WasmCall(this.coroutineFunctions.currentFiber())));
        WasmConditional restoreCond = new WasmConditional(new WasmCall(this.coroutineFunctions.isResuming(), new WasmGetLocal(fiberLocal)));
        prologue.add(restoreCond);
        restoreCond.getElseBlock().getBody().add(new WasmSetLocal(stateLocal, new WasmInt32Constant(0)));
        List<WasmExpression> restoreBody = restoreCond.getThenBlock().getBody();
        restoreBody.add(new WasmSetLocal(stateLocal, new WasmCall(this.coroutineFunctions.popInt(), new WasmGetLocal(fiberLocal))));
        for (int i = locals.size() - 1; i >= 0; --i) {
            WasmLocal local = locals.get(i);
            restoreBody.add(new WasmSetLocal(local, this.coroutineFunctions.restoreValue(local.getType(), fiberLocal)));
        }
        return prologue;
    }

    private List<WasmExpression> generateEpilogue(WasmLocal fiberLocal, WasmLocal stateLocal, List<WasmLocal> locals, WasmType returnType) {
        ArrayList<WasmExpression> epilogue = new ArrayList<WasmExpression>();
        for (WasmLocal local : locals) {
            epilogue.add(this.coroutineFunctions.saveValue(local.getType(), fiberLocal, new WasmGetLocal(local)));
        }
        epilogue.add(new WasmCall(this.coroutineFunctions.pushInt(), new WasmGetLocal(stateLocal), new WasmGetLocal(fiberLocal)));
        if (returnType != null) {
            if (returnType instanceof WasmType.Number) {
                switch (((WasmType.Number)returnType).number) {
                    case INT32: {
                        epilogue.add(new WasmInt32Constant(0));
                        break;
                    }
                    case INT64: {
                        epilogue.add(new WasmInt64Constant(0L));
                        break;
                    }
                    case FLOAT32: {
                        epilogue.add(new WasmFloat32Constant(0.0f));
                        break;
                    }
                    case FLOAT64: {
                        epilogue.add(new WasmFloat64Constant(0.0));
                    }
                }
            } else {
                epilogue.add(new WasmNullConstant((WasmType.Reference)returnType));
            }
        }
        return epilogue;
    }
}

