package net.covers1624.coffeegrinder.bytecode;

import net.covers1624.coffeegrinder.DecompilerSettings;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.transform.transformers.generics.GenericTransform;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.EnumBitSet;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.FastStream;

import java.util.List;

import static net.covers1624.coffeegrinder.bytecode.InstructionFlag.END_POINT_UNREACHABLE;

/**
 * Created by covers1624 on 18/1/22.
 */
public class InvariantVisitor extends SimpleInsnVisitor<None> {

    private static final EnumBitSet<AccessFlag> LOAD_FLAG_MASK = EnumBitSet.of(AccessFlag.STATIC, AccessFlag.FINAL);
    private static final AType BOOLEAN_CONSTANTS = new IntegerConstantUnion(List.of(new IntegerConstantType(0), new IntegerConstantType(1)));

    public static void checkInvariants(Instruction insn) {
        if (!DecompilerSettings.ASSERTIONS_ENABLED) return;

        insn.accept(new InvariantVisitor());
    }

    @Override
    public None visitDefault(Instruction insn, None ctx) {
        super.visitDefault(insn, ctx);

        // Pass invariant check through to all child slots.
        for (InstructionSlot<?> slot = insn.firstChild; slot != null; slot = slot.nextSibling) {
            slot.checkInvariant();
        }

        return NONE;
    }

    @Override
    public None visitArrayElementReference(ArrayElementReference elemRef, None ctx) {
        assert elemRef.getArray().getResultType() instanceof ArrayType;
        assert TypeSystem.isAssignableTo(elemRef.getIndex().getResultType(), PrimitiveType.INT);

        return super.visitArrayElementReference(elemRef, ctx);
    }

    @Override
    public None visitArrayLen(ArrayLen arrayLen, None ctx) {
        assert arrayLen.getArray().getResultType() instanceof ArrayType;

        return super.visitArrayLen(arrayLen, ctx);
    }

    @Override
    public None visitBlock(Block block, None ctx) {
        if (block.getParent() instanceof BlockContainer) {
            assert block.getIncomingEdgeCount() > 0 : "Unreachable block: " + block.getName();
        } else {
            assert block.getIncomingEdgeCount() == 0 : "Block should not have any edges: " + block.getName();
        }

        for (Instruction insn : block.instructions) {
            assert !insn.hasFlag(InstructionFlag.END_POINT_UNREACHABLE) || insn == block.getLastChild();
        }

        return super.visitBlock(block, ctx);
    }

    @Override
    public None visitBlockContainer(BlockContainer container, None ctx) {
        assert !container.blocks.isEmpty() && container.getEntryPoint() == container.blocks.first();
        assert !container.isConnected() || container.getEntryPoint().getIncomingEdgeCount() >= 1;
        assert FastStream.of(container.blocks).allMatch(e -> e.hasFlag(END_POINT_UNREACHABLE));

        return super.visitBlockContainer(container, ctx);
    }

    @Override
    public None visitBranch(Branch branch, None ctx) {
        assert branch.getTargetBlock().getParentOrNull() instanceof BlockContainer;
        assert branch.isDescendantOf(branch.getTargetBlock().getParentOrNull());
        return super.visitBranch(branch, ctx);
    }

    @Override
    public None visitCheckCast(Cast cast, None ctx) {

        // TODO, boxing checks
        AType argType = cast.getArgument().getResultType();
        if (cast.getType() instanceof ReferenceType && argType instanceof ReferenceType) {
            assert isCastableTo((ReferenceType) argType, (ReferenceType) cast.getType());
        }
        assertRepresentable(cast, cast.getType());

        return super.visitCheckCast(cast, ctx);
    }

    private static void assertRepresentable(Instruction scope, AType type) {
        assert GenericTransform.isRepresentable(scope, type) : "Unable to represent: " + type;
    }

    @Override
    public None visitContinue(Continue cont, None ctx) {
        assert cont.getLoop().isConnected();
        return super.visitContinue(cont, ctx);
    }

    @Override
    public None visitFieldDecl(FieldDecl fieldDecl, None ctx) {
        assert fieldDecl.getParent() instanceof ClassDecl;
        return super.visitFieldDecl(fieldDecl, ctx);
    }

    @Override
    public None visitIfInstruction(IfInstruction ifInsn, None ctx) {
        assert isAssignableTo(ifInsn.getCondition().getResultType(), PrimitiveType.BOOLEAN);

        return super.visitIfInstruction(ifInsn, ctx);
    }

    @Override
    public None visitLocalVariable(LocalVariable localVariable, None ctx) {
        assert !localVariable.isDead();

        if (localVariable.getKind() == LocalVariable.VariableKind.PARAMETER) {
            ParameterVariable p = (ParameterVariable) localVariable;
            assert !p.isImplicit() || p.getReferenceCount() == 0;
        }
        else {
            assert localVariable.getStoreCount() > 0;
        }

        return super.visitLocalVariable(localVariable, ctx);
    }

    @Override
    public None visitInvoke(Invoke invoke, None ctx) {
        List<Parameter> parameters = invoke.getMethod().getParameters();
        assert parameters.size() == invoke.getArguments().size();
        for (int i = 0; i < parameters.size(); i++) {
            Parameter parameter = parameters.get(i);
            Instruction arg = invoke.getArguments().get(i);
            assert arg instanceof Nop || isAssignableTo(arg.getResultType(), parameter.getType());
        }

        if (invoke.explicitTypeArgs) ((ParameterizedMethod) invoke.getMethod()).getTypeArguments().forEach(type -> assertRepresentable(invoke, type));
        assert TypeSystem.isFullyDefined(invoke.getMethod());

        return super.visitInvoke(invoke, ctx);
    }

    @Override
    public None visitLeave(Leave leave, None ctx) {
        assert leave.isDescendantOf(leave.getTargetContainer());

        return super.visitLeave(leave, ctx);
    }

    @Override
    public None visitLocalReference(LocalReference localRef, None ctx) {
        assert localRef.variable.isConnected();

        return super.visitLocalReference(localRef, ctx);
    }

    @Override
    public None visitNew(New newInsn, None ctx) {
        List<Parameter> parameters = newInsn.getMethod().getParameters();
        assert parameters.size() == newInsn.getArguments().size();
        for (int i = 0; i < parameters.size(); i++) {
            Parameter parameter = parameters.get(i);
            Instruction arg = newInsn.getArguments().get(i);
            assert arg instanceof Nop || isAssignableTo(arg.getResultType(), parameter.getType());
        }

        if (newInsn.explicitTypeArgs) ((ParameterizedMethod) newInsn.getMethod()).getTypeArguments().forEach(type -> assertRepresentable(newInsn, type));
        if (newInsn.explicitClassTypeArgs) ((ParameterizedClass) newInsn.getResultType()).getTypeArguments().forEach(type -> assertRepresentable(newInsn, type));

        assert TypeSystem.isFullyDefined(newInsn.getResultType());
        assert TypeSystem.isFullyDefined(newInsn.getMethod());
        return super.visitNew(newInsn, ctx);
    }

    @Override
    public None visitReturn(Return ret, None ctx) {
        assert isAssignableTo(ret.getValue().getResultType(), ret.getMethod().getReturnType());
        assert ret.getMethod().getReturns().contains(ret);
        assert ret.firstAncestorOfType(MethodDecl.class) == ret.getMethod();

        return super.visitReturn(ret, ctx);
    }

    @Override
    public None visitStore(Store store, None ctx) {
        assert isAssignableTo(store.getValue().getResultType(), store.getReference().getType());

        return super.visitStore(store, ctx);
    }

    @Override
    public None visitSwitch(Switch switchInsn, None ctx) {
        switchInsn.getSwitchTable(); // Rednek assertion
        assert switchInsn.getYields().isEmpty() == (switchInsn.getResultType() == PrimitiveType.VOID);

        return super.visitSwitch(switchInsn, ctx);
    }

    @Override
    public None visitSwitchTable(SwitchTable switchTable, None ctx) {
        boolean hasDefault = false;
        for (SwitchTable.SwitchSection section : switchTable.sections) {
            if (section.values.anyMatch(e -> e instanceof Nop)) {
                assert !hasDefault : "Switch must only have one default block.";
                hasDefault = true;
            }
        }

        return super.visitSwitchTable(switchTable, ctx);
    }

    @Override
    public None visitSwitchSection(SwitchTable.SwitchSection switchSection, None ctx) {
        assert !switchSection.values.isEmpty();
        assert switchSection.values.allMatch(e -> e instanceof LdcInsn
                || e instanceof Nop
                // TODO, IField.isConstant()
                || (e instanceof FieldReference ref && ref.getField().getAccessFlags().toSet().containsAll(LOAD_FLAG_MASK.toSet())));

        return super.visitSwitchSection(switchSection, ctx);
    }

    @Override
    public None visitTernary(Ternary ternary, None ctx) {
        assert isAssignableTo(ternary.getCondition().getResultType(), PrimitiveType.BOOLEAN);

        return super.visitTernary(ternary, ctx);
    }

    @Override
    public None visitTryCatch(TryCatch tryCatch, None ctx) {
        assert !tryCatch.resources.isEmpty() || !tryCatch.handlers.isEmpty() || tryCatch.getFinallyBody() != null;
        return super.visitTryCatch(tryCatch, ctx);
    }

    @Override
    public None visitYield(Yield yield, None ctx) {
        assert yield.firstAncestorOfType(Switch.class) == yield.getSwitch();

        return super.visitYield(yield, ctx);
    }

    public boolean isAssignableTo(AType from, AType to) {
        if (to == PrimitiveType.BOOLEAN) {
            return from == PrimitiveType.BOOLEAN || TypeSystem.isAssignableTo(from, BOOLEAN_CONSTANTS);
        }

        return TypeSystem.isAssignableTo(from, to);
    }

    public boolean isCastableTo(ReferenceType from, ReferenceType to) {
        return TypeSystem.isCastableTo(from, to, true);
    }
}
