package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.tags.SwitchRecordPatternTag;
import net.covers1624.coffeegrinder.bytecode.matching.BranchLeaveMatching;
import net.covers1624.coffeegrinder.bytecode.transform.*;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.statement.ExpressionTransforms;

import static net.covers1624.coffeegrinder.bytecode.matching.IfMatching.matchNopFalseIf;

/**
 * Created by covers1624 on 11/8/21.
 */
public class SwitchInlining implements BlockTransformer {

    @Override
    public void transform(Block block, BlockTransformContext ctx) {
        if (block.getFirstChild() instanceof SwitchTable switchTable) {
            BlockContainer switchContainer = (BlockContainer) block.getParent();
            ctx.pushStep("Inline switch sections " + switchContainer.getEntryPoint().getName());

            for (SwitchTable.SwitchSection section : switchTable.sections) {
                if (!(section.getBody() instanceof Branch branch)) continue;

                Block target = branch.getTargetBlock();
                if (target.getIncomingEdgeCount() != 1) continue;

                ctx.pushStep("Inline section");
                branch.replaceWith(target);

                if (section.values.onlyOrDefault() instanceof SwitchTable.SwitchPattern pattern) {
                    captureRecordPattern(pattern, target, ctx);
                    capturePatternGuard(pattern, target, ctx);
                }
                ctx.popStep();

                if (section.getNextSiblingOrNull() instanceof SwitchTable.SwitchSection nextSection
                    && target.instructions.count() > 1
                    && BranchLeaveMatching.compatibleExitInstruction(target.getLastChild(), nextSection.getBody())) {
                    ctx.pushStep("Create fallthrough");
                    target.getLastChild().remove();
                    ctx.popStep();
                    continue;
                }

                ConditionDetection.tryUnwrap(target, ctx);
            }

            ctx.popStep();
        }
    }

    private void captureRecordPattern(SwitchTable.SwitchPattern pattern, Block body, BlockTransformContext ctx) {
        // Capture record pattern, match the following:
        // case ... case$?[SwitchRecordPatternTag]: br L_Body;
        // L_Body: {
        //     IF (!(patternVar INSTANCEOF (record_pattern))) {
        //         SWITCH_GUARD
        //     }
        //     ... body
        // }
        // ->
        // case ... record_pattern: br L_Body;
        // L_Body: {
        //    ... body
        // }
        if (!(pattern.getPattern() instanceof LocalReference patternRef)) return;
        if (!(patternRef.variable.getTag() instanceof SwitchRecordPatternTag)) return;

        ctx.pushStep("Capture record pattern");
        var patternIf = (IfInstruction) body.getFirstChild();
        ConditionDetection.invertIf(patternIf, ctx);
        var inversion = (LogicNot) patternIf.getCondition();
        var patternInstanceof = (InstanceOf) inversion.getArgument();
        assert patternInstanceof.getPattern() instanceof RecordPattern;
        assert patternIf.getTrueInsn() instanceof Switch.SwitchGuard;
        assert patternIf.getFalseInsn() instanceof Nop;
        pattern.getPattern().replaceWith(patternInstanceof.getPattern());
        patternIf.remove();
        ctx.popStep();
    }

    private void capturePatternGuard(SwitchTable.SwitchPattern pattern, Block body, BlockTransformContext ctx) {
        // Capture switch guard, match the following:
        // case ... pattern: br L_Body;
        // L_Body: {
        //     IF (!guard_expression) {
        //         SWITCH_GUARD
        //     }
        //     ... body
        // }
        // ->
        // case ... pattern when guard_expression: br L_Body;
        // L_Body: {
        //    ... body
        // }
        if (!(matchNopFalseIf(body.getFirstChild()) instanceof IfInstruction ifInsn)) return;
        if (!(ifInsn.getTrueInsn() instanceof Switch.SwitchGuard)) {
            if (!(ifInsn.getNextSiblingOrNull() instanceof Switch.SwitchGuard)) return;
            ConditionDetection.invertIf(ifInsn, ctx);
        }

        ctx.pushStep("Capture pattern guard");
        pattern.getCondition().replaceWith(new LogicNot(ifInsn.getCondition()));
        ifInsn.remove();
        ExpressionTransforms.runOnExpression(pattern.getCondition(), ctx);
        ctx.popStep();
    }
}
