package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.LdcMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.AType;
import net.covers1624.coffeegrinder.type.TypeSystem;
import net.covers1624.coffeegrinder.util.None;

import java.util.LinkedList;
import java.util.List;

import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchStoreLocal;
import static net.covers1624.coffeegrinder.type.PrimitiveType.INT;

/**
 * Created by covers1624 on 13/9/21.
 */
public class TernaryExpressions extends SimpleInsnVisitor<None> implements StatementTransformer {

    @SuppressWarnings ("NotNullFieldNotInitialized")
    private StatementTransformContext ctx;

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        this.ctx = ctx;
        statement.accept(this);
    }

    @Override
    public None visitBlock(Block block, None ctx) {
        return NONE; // Don't visit child blocks.
    }

    @Override
    public None visitIfInstruction(IfInstruction ifInsn, None ctx) {
        ifInsn.getTrueInsn().accept(this);
        ifInsn.getFalseInsn().accept(this);

        Instruction cond = ifInsn.getCondition();
        handleConditionalOperator(ifInsn);

        cond.accept(this);

        return NONE;
    }

    private Instruction handleConditionalOperator(IfInstruction insn) {
        // if (cond) STORE (A, V1) else STORE (A, V2) -> STORE(A, if (cond) V1 else V2)
        if (!(insn.getTrueInsn() instanceof Block trueInsn)) return insn;

        if (!(insn.getFalseInsn() instanceof Block falseInsn)) return insn;

        List<Runnable> extraTransforms = new LinkedList<>();
        Store store1 = Inlining.matchWithPotentialInline(trueInsn.getFirstChildOrNull(), extraTransforms, ctx,
                e -> e.getNextSiblingOrNull() != null ? null : LoadStoreMatching.matchStoreLocal(e));

        if (store1 == null) return insn;

        LocalVariable var = store1.getVariable();
        if (var.getKind() != LocalVariable.VariableKind.STACK_SLOT) return insn;

        Store store2 = Inlining.matchWithPotentialInline(falseInsn.getFirstChild(), extraTransforms, ctx,
                e -> e.getNextSiblingOrNull() != null ? null : matchStoreLocal(e, var));
        if (store2 == null) return insn;

        ctx.pushStep("Produce ternary for store");
        extraTransforms.forEach(Runnable::run);
        Ternary ternary = new Ternary(insn.getCondition(), store1.getValue(), store2.getValue()).withOffsets(insn);
        insn.replaceWith(new Store(new LocalReference(var), ternary).withOffsets(insn));
        insertCastForConstantTernaryIfNecessary(ternary, var.getType());
        ctx.popStep();
        return ternary;
    }

    private void insertCastForConstantTernaryIfNecessary(Ternary ternary, AType type) {
        // IF(...) LDC_INT a ELSE LDC_INT b
        // ->
        // IF(...) CAST.X(LDC_INT a) ELSE LDC_INT b

        if (ternary.getResultType() != INT) return;
        if (TypeSystem.isAssignableTo(ternary.getResultType(), type)) return;
        assert isAssignableIntegerConstant(ternary, type);

        // todo, nesting order?
        ctx.pushStep("Add cast");
        ternary.getTrueInsn().replaceWith(new Cast(ternary.getTrueInsn(), type));
        ctx.popStep();
    }

    private static boolean isAssignableIntegerConstant(Instruction insn, AType expectedType) {
        if (insn instanceof Ternary ternary) {
            return isAssignableIntegerConstant(ternary.getTrueInsn(), expectedType) && isAssignableIntegerConstant(ternary.getFalseInsn(), expectedType);
        }

        LdcNumber ldc = LdcMatching.matchLdcInt(insn);
        if (ldc != null) {
            return TypeSystem.isAssignableTo(ldc.getResultType(), expectedType);
        }

        return false;
    }
}
