package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import com.google.common.collect.ImmutableList;
import net.covers1624.coffeegrinder.bytecode.IndexedInstructionCollection;
import net.covers1624.coffeegrinder.bytecode.InsnOpcode;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.insns.Comparison.ComparisonKind;
import net.covers1624.coffeegrinder.bytecode.matching.ComparisonMatching;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.MethodTransformer;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.FastStream;

import java.util.List;

import static java.util.Objects.requireNonNull;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcBoolean;
import static net.covers1624.coffeegrinder.bytecode.matching.LdcMatching.matchLdcInt;
import static net.covers1624.coffeegrinder.type.PrimitiveType.*;
import static net.covers1624.coffeegrinder.type.TypeSystem.isIntegerConstant;

/**
 * Created by covers1624 on 28/7/21.
 */
public class IntegerConstantInference implements MethodTransformer {

    public static final IntegerConstantUnion BOOLEAN_CONSTANTS = new IntegerConstantUnion(ImmutableList.of(new IntegerConstantType(0), new IntegerConstantType(1)));

    @Override
    public void transform(MethodDecl function, MethodTransformContext ctx) {
        function.accept(new Visitor(ctx), VOID);
    }

    private static class Visitor extends SimpleInsnVisitor<AType> {

        private final MethodTransformContext methodCtx;

        public Visitor(MethodTransformContext ctx) {
            this.methodCtx = ctx;
        }

        @Override
        public None visitDefault(Instruction insn, AType ctx) {
            return super.visitDefault(insn, VOID);
        }

        @Override
        public None visitLoad(Load load, AType ctx) {
            if (load.getReference().opcode != InsnOpcode.LOCAL_REFERENCE) {
                load.getReference().accept(this, VOID);
                return NONE;
            }

            LocalVariable variable = load.getVariable();
            if (isIntegerConstant(variable.getType()) && !isIntegerConstant(ctx)) {
                if (ctx == VOID) {
                    ctx = INT;
                }

                variable.setType(ctx);
                FastStream.of(variable.getReferences())
                        .filter(Reference::isWrittenTo)
                        .forEach(e -> e.getParent().accept(this, VOID));
            }
            return NONE;
        }

        @Override
        public None visitLdcNumber(LdcNumber ldcNumber, AType ctx) {
            LdcNumber ldcInt = matchLdcInt(ldcNumber);
            if (ldcInt == null) return NONE;

            int intValue = ldcInt.intValue();
            if (ctx == BOOLEAN) {
                assert intValue == 0 || intValue == 1;
                ldcInt.replaceWith(new LdcBoolean(intValue == 1));
            }
            if (ctx == CHAR) {
                assert intValue >= Character.MIN_VALUE && intValue <= Character.MAX_VALUE;
                ldcInt.replaceWith(new LdcChar((char) intValue));
            }
            return NONE;
        }

        @Override
        public None visitStore(Store store, AType ctx) {
            Reference r = store.getReference();
            if (r.opcode == InsnOpcode.LOCAL_REFERENCE) {
                LocalVariable v = store.getVariable();
                if (v.isSynthetic() && v.getLoadCount() == 0 && isIntegerConstant(v.getType())) {
                    v.setType(isBooleanConstant(v.getType()) ? BOOLEAN : INT);
                }
            }
            store.getValue().accept(this, r.getType());
            store.getReference().accept(this, VOID);
            return NONE;
        }

        @Override
        public None visitInvoke(Invoke invoke, AType ctx) {
            invoke.getTarget().accept(this, VOID);
            List<Parameter> parameters = invoke.getMethod().getParameters();
            IndexedInstructionCollection<Instruction> arguments = invoke.getArguments();
            for (int i = 0; i < arguments.size(); i++) {
                arguments.get(i).accept(this, parameters.get(i).getType());
            }
            return NONE;
        }

        @Override
        public None visitReturn(Return ret, AType ctx) {
            ret.getValue().accept(this, ret.getMethod().getReturnType());

            return NONE;
        }

        @Override
        public None visitComparison(Comparison comparison, AType ctx) {
            AType leftType = comparison.getLeft().getResultType();
            AType rightType = comparison.getRight().getResultType();
            if (isBooleanConstant(leftType) && isBooleanConstant(rightType)) {
                leftType = BOOLEAN;
                rightType = BOOLEAN;
            } else if (isIntegerConstant(leftType) && isIntegerConstant(rightType)) {
                leftType = INT;
                rightType = INT;
            }
            comparison.getLeft().accept(this, rightType);
            comparison.getRight().accept(this, leftType);

            ComparisonKind kind = comparison.getKind();
            if ((kind == ComparisonKind.EQUAL || kind == ComparisonKind.NOT_EQUAL)
                && matchLdcBoolean(comparison.getRight(), false) != null) {
                if (kind == ComparisonKind.EQUAL) {
                    methodCtx.pushStep("Unwrap boolean logic to not.");
                    comparison.replaceWith(new LogicNot(comparison.getLeft()));
                    methodCtx.popStep();
                } else {
                    methodCtx.pushStep("Unwrap boolean conversion.");
                    comparison.replaceWith(comparison.getLeft());
                    methodCtx.popStep();
                }
            }
            return NONE;
        }

        @Override
        public None visitIfInstruction(IfInstruction ifInsn, AType ctx) {
            ifInsn.getCondition().accept(this, BOOLEAN);
            ifInsn.getTrueInsn().accept(this, VOID);
            ifInsn.getFalseInsn().accept(this, VOID);
            return NONE;
        }

        @Override
        public None visitBinary(Binary binary, AType ctx) {
            AType resultType = binary.getResultType();
            if (binary.getOp().isLogic()) {
                resultType = ctx;
            }

            binary.getLeft().accept(this, resultType);
            binary.getRight().accept(this, resultType);
            return NONE;
        }

        @Override
        public None visitSwitchTable(SwitchTable switchTable, AType ctx) {
            super.visitSwitchTable(switchTable, ctx);
            AType type = switchTable.getValue().getResultType();
            for (SwitchTable.SwitchSection section : switchTable.sections) {
                for (Instruction value : section.values) {
                    value.accept(this, type);
                }
            }
            return NONE;
        }
    }

    public static boolean isBooleanConstant(AType type) {
        return TypeSystem.isAssignableTo(type, BOOLEAN_CONSTANTS);
    }
}
