package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.Comparison.ComparisonKind;
import net.covers1624.coffeegrinder.bytecode.matching.LdcMatching;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.bytecode.transform.TransformContextBase;
import net.covers1624.coffeegrinder.type.AType;
import net.covers1624.coffeegrinder.type.PrimitiveType;
import net.covers1624.coffeegrinder.util.None;

import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcBoolean;
import static net.covers1624.coffeegrinder.bytecode.matching.LogicMatching.matchLogicNot;

/**
 * Created by covers1624 on 20/7/21.
 */
public class ExpressionTransforms extends SimpleInsnVisitor<TransformContextBase> implements StatementTransformer {

    private static final ExpressionTransforms INSTANCE = new ExpressionTransforms();

    public static void runOnExpression(Instruction statement, TransformContextBase ctx) {
        statement.accept(INSTANCE, ctx);
    }

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        statement.accept(this, ctx);
    }

    @Override
    public None visitBlock(Block block, TransformContextBase ctx) {
        return NONE; // Don't visit child blocks.
    }

    @Override
    public None visitCheckCast(Cast cast, TransformContextBase ctx) {
        // sometimes javac will duplicate a cast instruction perfectly
        if (cast.getArgument() instanceof Cast innerCast) {
            if (innerCast.getType() == cast.getType()) {
                ctx.pushStep("Remove duplicate cast");
                cast = cast.replaceWith(innerCast);
                ctx.popStep();
            }
        }
        return super.visitCheckCast(cast, ctx);
    }

    @Override
    public None visitComparison(Comparison comparison, TransformContextBase ctx) {
        // COMPARISON.op (COMPARISON.L/G *(a, b), 0)
        if (comparison.getLeft() instanceof Compare) {
            Compare left = (Compare) comparison.getLeft();
            assert LdcMatching.matchLdcInt(comparison.getRight(), 0) != null;

            ctx.pushStep("Unwrap NaN comparison");
            comparison = comparison.replaceWith(new Comparison(comparison.getKind(), left.getLeft(), left.getRight()));
            if (!isValidNaNComparison(left.getKind(), comparison.getKind())) {
                comparison.setKind(comparison.getKind().negate());
                comparison.replaceWith(new LogicNot(comparison));
            }
            ctx.popStep();
            comparison.accept(this, ctx);
            return NONE;
        }

        return super.visitComparison(comparison, ctx);
    }

    private static boolean isValidNaNComparison(Compare.Kind nanKind, ComparisonKind kind) {
        if (nanKind == Compare.Kind.NAN_L) return kind != ComparisonKind.LESS_THAN && kind != ComparisonKind.LESS_THAN_EQUAL;
        if (nanKind == Compare.Kind.NAN_G) return kind != ComparisonKind.GREATER_THAN && kind != ComparisonKind.GREATER_THAN_EQUAL;
        return true;
    }

    @Override
    public None visitTernary(Ternary ternary, TransformContextBase ctx) {
        ternary.getTrueInsn().accept(this, ctx);
        ternary.getFalseInsn().accept(this, ctx);

        if (unwrapBooleanOperator(ternary, ctx)) {
            return NONE;
        }

        ternary.getCondition().accept(this, ctx);

        return NONE;
    }

    @Override
    public None visitLogicNot(LogicNot logicNot, TransformContextBase ctx) {
        Instruction arg = logicNot.getArgument();
        // If we are visiting a logic not, that has a logic not as a child, fully unwrap the expression.
        LogicNot nestedNot = matchLogicNot(arg);
        if (nestedNot != null) {
            ctx.pushStep("Unwrap not chain");
            Instruction repl = logicNot.replaceWith(nestedNot.getArgument());
            ctx.popStep();
            repl.accept(this, ctx);
            return NONE;
        }
        if (arg instanceof Comparison) {
            Comparison comp = (Comparison) arg;
            AType leftType = comp.getLeft().getResultType();
            AType rightType = comp.getRight().getResultType();
            // We can only push negations into float comparisons prior to NaN unwrapping.
            // Otherwise, we change NaN comparison result semantics.
            boolean isFloatComparison = leftType == PrimitiveType.FLOAT || leftType == PrimitiveType.DOUBLE
                                        || rightType == PrimitiveType.FLOAT || rightType == PrimitiveType.DOUBLE;
            if (!isFloatComparison || comp.getKind() == ComparisonKind.EQUAL || comp.getKind() == ComparisonKind.NOT_EQUAL) {
                ctx.pushStep("Push negation into comparison");
                comp.setKind(comp.getKind().negate());
                comp.setOffsets(logicNot);
                logicNot.replaceWith(comp);
                ctx.popStep();
            }
            comp.accept(this, ctx);
            return NONE;
        }
        if (arg instanceof LogicAnd and) {
            ctx.pushStep("Push inversion into logic and");
            LogicOr or = logicNot.replaceWith(new LogicOr(new LogicNot(and.getLeft()), new LogicNot(and.getRight())));
            or.accept(this, ctx);
            ctx.popStep();
            return NONE;
        }
        if (arg instanceof LogicOr or) {
            ctx.pushStep("Push inversion into logic or");
            LogicAnd and = logicNot.replaceWith(new LogicAnd(new LogicNot(or.getLeft()), new LogicNot(or.getRight())));
            and.accept(this, ctx);
            ctx.popStep();
            return NONE;
        }
        arg.accept(this, ctx);
        return NONE;
    }

    public boolean unwrapBooleanOperator(Ternary ternary, TransformContextBase ctx) {
        if (ternary.getParent() instanceof Block) return false;

        // if cond ? true : false -> cond
        if (matchLdcBoolean(ternary.getTrueInsn(), true) != null && matchLdcBoolean(ternary.getFalseInsn(), false) != null) {
            ctx.pushStep("Unwrap if cond ? true : false");
            ternary.replaceWith(ternary.getCondition())
                    .accept(this, ctx);
            ctx.popStep();
            return true;
        }

        // if cond ? false : true -> !cond
        if (matchLdcBoolean(ternary.getFalseInsn(), false) != null && matchLdcBoolean(ternary.getTrueInsn(), true) != null) {
            ctx.pushStep("Unwrap if cond ? false : true");
            Instruction replacement = new LogicNot(ternary.getCondition());
            ternary.replaceWith(replacement);
            ctx.popStep();
            replacement.accept(this, ctx);
            return true;
        }
        return false;
    }
}
