package net.covers1624.coffeegrinder.bytecode;

import net.covers1624.coffeegrinder.bytecode.insns.*;
import org.jetbrains.annotations.Nullable;

import java.util.function.Supplier;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.ScopeVisitor.ScopeVisitRule.*;

/**
 * Created by covers1624 on 11/24/25.
 */
public class ScopeVisitor<R, C> extends InsnVisitor<R, C> {

    private record Scope(@Nullable Scope prev, @Nullable LocalVariable var) { }

    private Scope scope = new Scope(null, null);

    private final InsnVisitor<R, C> inner;

    public ScopeVisitor(InsnVisitor<R, C> inner) {
        this.inner = inner;
    }

    public boolean currentScopeHasDeclarations() {
        return scope.var != null;
    }

    enum ScopeVisitRule {
        NONE,
        COLLECT,
        COLLECT_TRUE,
        COLLECT_FALSE,
        TRUE,
        FALSE
    }

    private class ConditionalScopeHelper {

        public ScopeVisitor.ConditionalVars scope = ConditionalVars.NONE;
        private final ScopeVisitRule[] rules;
        private @Nullable Instruction nextChild;
        private int childIndex;

        private ConditionalScopeHelper(Instruction insn, ScopeVisitRule... rules) {
            this.rules = rules;
            nextChild = insn.getFirstChildOrNull();
        }

        public ScopeVisitRule getRule(Instruction insn) {
            if (insn != nextChild) throw new IllegalStateException("Visited children in the wrong order?");
            nextChild = nextChild.getNextSiblingOrNull();
            return rules[childIndex++];
        }

        private R withScope(@Nullable ScopeVisitor.VarList vars, Supplier<R> func) {
            return ScopeVisitor.this.withScope(() -> {
                declare(vars);
                return func.get();
            });
        }

        private R visitCollect(Supplier<R> func) {
            try {
                return func.get();
            } finally {
                scope = retConditionalScope;
            }
        }

        private R visitCollectTrue(Supplier<R> func) {
            try {
                return visitTrue(func);
            } finally {
                scope = new ConditionalVars(
                        combine(scope.trueVars, retConditionalScope.trueVars),
                        null
                );
            }
        }

        private R visitCollectFalse(Supplier<R> func) {
            try {
                return visitFalse(func);
            } finally {
                scope = new ConditionalVars(
                        null,
                        combine(scope.falseVars, retConditionalScope.falseVars)
                );
            }
        }

        public R visitTrue(Supplier<R> func) {
            return withScope(scope.trueVars, func);
        }

        public R visitFalse(Supplier<R> func) {
            return withScope(scope.falseVars, func);
        }

        private @Nullable VarList combine(@Nullable ScopeVisitor.VarList a, @Nullable ScopeVisitor.VarList b) {
            while (b != null) {
                a = new VarList(b.var, a);
                b = b.next;
            }
            return a;
        }
    }

    private record VarList(LocalVariable var, @Nullable ScopeVisitor.VarList next) { }

    private record ConditionalVars(@Nullable ScopeVisitor.VarList trueVars, @Nullable ScopeVisitor.VarList falseVars) {

        public static final ConditionalVars NONE = new ConditionalVars(null, null);

        public ConditionalVars inverted() {
            return new ConditionalVars(falseVars, trueVars);
        }
    }

    private @Nullable ConditionalScopeHelper conditionalScopeHelper = null;

    private ConditionalVars retConditionalScope = ConditionalVars.NONE;

    public R visit(Instruction insn, C ctx) {
        var h = conditionalScopeHelper;
        conditionalScopeHelper = null;
        try {
            return switch (h != null ? h.getRule(insn) : NONE) {
                case NONE -> insn.accept(this, ctx);
                case COLLECT -> h.visitCollect(() -> insn.accept(this, ctx));
                case COLLECT_TRUE -> h.visitCollectTrue(() -> insn.accept(this, ctx));
                case COLLECT_FALSE -> h.visitCollectFalse(() -> insn.accept(this, ctx));
                case TRUE -> h.visitTrue(() -> insn.accept(this, ctx));
                case FALSE -> h.visitFalse(() -> insn.accept(this, ctx));
            };
        } finally {
            conditionalScopeHelper = h;
            retConditionalScope = ConditionalVars.NONE;
        }
    }

    @Override
    public R visitDefault(Instruction insn, C ctx) {
        return insn.accept(inner, ctx);
    }

    @Override
    public R visitInstanceOf(InstanceOf instanceOf, C ctx) {
        try {
            return super.visitInstanceOf(instanceOf, ctx);
        } finally {
            if (instanceOf.getPattern() instanceof LocalReference ref) {
                retConditionalScope = new ConditionalVars(new VarList(ref.variable, null), null);
            }
        }
    }

    @Override
    public R visitIfInstruction(IfInstruction ifInsn, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(ifInsn, ScopeVisitRule.COLLECT, TRUE, FALSE);

        try {
            return super.visitIfInstruction(ifInsn, ctx);
        } finally {
            boolean trueFallthrough = !ifInsn.getTrueInsn().hasFlag(InstructionFlag.END_POINT_UNREACHABLE);
            boolean falseFallthrough = !ifInsn.getFalseInsn().hasFlag(InstructionFlag.END_POINT_UNREACHABLE);

            if (trueFallthrough && !falseFallthrough) {
                declare(conditionalScopeHelper.scope.trueVars);
            }
            if (!trueFallthrough && falseFallthrough) {
                declare(conditionalScopeHelper.scope.falseVars);
            }
        }
    }

    @Override
    public R visitLogicNot(LogicNot logicNot, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(logicNot, COLLECT);
        try {
            return super.visitLogicNot(logicNot, ctx);
        } finally {
            retConditionalScope = conditionalScopeHelper.scope.inverted();
        }
    }

    @Override
    public R visitLogicAnd(LogicAnd logicAnd, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(logicAnd, COLLECT, COLLECT_TRUE);
        try {
            return super.visitLogicAnd(logicAnd, ctx);
        } finally {
            retConditionalScope = conditionalScopeHelper.scope;
        }
    }

    @Override
    public R visitLogicOr(LogicOr logicOr, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(logicOr, COLLECT, COLLECT_FALSE);
        try {
            return super.visitLogicOr(logicOr, ctx);
        } finally {
            retConditionalScope = conditionalScopeHelper.scope;
        }
    }

    @Override
    public R visitTernary(Ternary ternary, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(ternary, COLLECT, TRUE, FALSE);
        return super.visitTernary(ternary, ctx);
    }

    @Override
    public R visitMethodDecl(MethodDecl methodDecl, C ctx) {
        return withScope(() -> {
            // declare all the params
            methodDecl.parameters.forEach(this::declare);
            return visitDefault(methodDecl, ctx);
        });
    }

    @Override
    public R visitBlockContainer(BlockContainer container, C ctx) {
        return visitWithScope(container, ctx);
    }

    @Override
    public R visitSwitchSection(SwitchTable.SwitchSection switchSection, C ctx) {
        return visitWithScope(switchSection, ctx);
    }

    @Override
    public R visitWhileLoop(WhileLoop whileLoop, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(whileLoop, COLLECT, TRUE);
        try {
            return super.visitWhileLoop(whileLoop, ctx);
        } finally {
            declare(conditionalScopeHelper.scope.falseVars);
        }
    }

    @Override
    public R visitForLoop(ForLoop loop, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(loop, NONE, COLLECT, TRUE, TRUE);
        try {
            return visitWithScope(loop, ctx);
        } finally {
            declare(conditionalScopeHelper.scope.falseVars);
        }
    }

    @Override
    public R visitDoWhileLoop(DoWhileLoop doWhileLoop, C ctx) {
        conditionalScopeHelper = new ConditionalScopeHelper(doWhileLoop, NONE, COLLECT);
        try {
            return super.visitDoWhileLoop(doWhileLoop, ctx);
        } finally {
            declare(conditionalScopeHelper.scope.falseVars);
        }
    }

    @Override
    public R visitForEachLoop(ForEachLoop forEachLoop, C ctx) {
        return visitWithScope(forEachLoop, ctx);
    }

    @Override
    public R visitTryCatchHandler(TryCatch.TryCatchHandler catchHandler, C ctx) {
        return visitWithScope(catchHandler, ctx);
    }

    @Override
    public R visitTryCatch(TryCatch tryCatch, C ctx) {
        return visitWithScope(tryCatch, ctx);
    }

    private R visitWithScope(Instruction insn, C ctx) {
        return withScope(() -> visitDefault(insn, ctx));
    }

    private R withScope(Supplier<R> func) {
        scope = new Scope(scope, null);
        try {
            return func.get();
        } finally {
            while (scope.var != null) {
                scope = requireNonNull(scope.prev);
            }
            scope = requireNonNull(scope.prev);
        }
    }

    private void declare(@Nullable ScopeVisitor.VarList vars) {
        while (vars != null) {
            declare(vars.var);
            vars = vars.next;
        }
    }

    @Override
    public R visitLocalReference(LocalReference local, C ctx) {
        var r = super.visitLocalReference(local, ctx);

        if (local.isConnected() && !local.isReadFrom() && !isDeclared(local.variable) && !(local.getParent() instanceof InstanceOf)) {
            declare(local.variable);
        }

        return r;
    }

    private void declare(LocalVariable var) {
        scope = new Scope(scope, var);
    }

    public boolean isDeclared(LocalVariable var) {
        var s = scope;
        while (s != null) {
            if (s.var == var) {
                return true;
            }

            s = s.prev;
        }

        return false;
    }

    public boolean isDeclared(String name) {
        return getVariableInScope(name) != null;
    }

    public @Nullable LocalVariable getVariableInScope(String name) {
        var s = scope;
        while (s != null) {
            if (s.var != null && s.var.getName().equals(name)) {
                return s.var;
            }

            s = s.prev;
        }

        return null;
    }
}
