package net.covers1624.coffeegrinder.bytecode.transform.transformers;

import net.covers1624.coffeegrinder.bytecode.AccessFlag;
import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.ClassTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Field;
import net.covers1624.coffeegrinder.type.PrimitiveType;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.util.None;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.function.Function;

import static java.util.Objects.requireNonNull;

/**
 * Created by covers1624 on 27/9/22.
 */
public class NumericConstants extends SimpleInsnVisitor<None> implements ClassTransformer {

    private static final Type MATH = Type.getType(Math.class);
    private static final Type INTEGER = Type.getType(Integer.class);
    private static final Type LONG = Type.getType(Long.class);
    private static final Type FLOAT = Type.getType(Float.class);
    private static final Type DOUBLE = Type.getType(Double.class);

    private static final Map<Object, Function<NumericConstants, Instruction>> REPLACEMENTS = new HashMap<>();

    static {
        double piD = Math.PI;
        float piF = (float) Math.PI;

        // Pi
        REPLACEMENTS.put(piD, e -> loadField(e.piField));
        REPLACEMENTS.put(piF, e -> d2f(loadField(e.piField)));

        // E
        REPLACEMENTS.put(Math.E, e -> loadField(e.eField));

        // PI / 1-10, 1-10 / PI, PI * 1-10
        for (int i = 1; i <= 10; i++) {
            addAngle(i);
        }

        // Common angles
        // PI / ang, ang / PI, PI * ang
        int[] angles = { 30, 45, 60, 90, 180, 270, 360 };
        for (int angle : angles) {
            addAngle(angle);
        }

        // a * pi / b
        for (int a = 1; a <= 10; a++) {
            double aD = a;
            float aF = a;
            for (int b = 1; b <= 10; b++) {
                if (a == 1 && b == 1) continue;
                double bD = b;
                float bF = b;
                REPLACEMENTS.putIfAbsent(aD * piD / bD, e -> div(mul(ldc(aD), loadField(e.piField)), ldc(bD)));
                REPLACEMENTS.putIfAbsent(aF * piF / bF, e -> div(mul(ldc(aF), d2f(loadField(e.piField))), ldc(bF)));
            }
        }

        // n / d
        for (int n = 1; n <= 10; n++) {
            double nD = n;
            float nF = n;
            for (int d : new int[] { 3, 7, 9 }) {
                if (n % d == 0) continue;
                double dD = d;
                float dF = d;
                REPLACEMENTS.putIfAbsent(nD / dD, e -> div(ldc(nD), ldc(dF)));
                REPLACEMENTS.putIfAbsent(nF / dF, e -> div(ldc(nF), ldc(dF)));
            }
        }
        for (Map.Entry<Object, Function<NumericConstants, Instruction>> entry : new LinkedList<>(REPLACEMENTS.entrySet())) {
            Object n = entry.getKey();
            if (n instanceof Double) {
                REPLACEMENTS.putIfAbsent(-(Double) n, e -> negate(entry.getValue().apply(e)));
            } else {
                REPLACEMENTS.putIfAbsent(-(Float) n, e -> negate(entry.getValue().apply(e)));
            }
        }

        // Min/Max
        REPLACEMENTS.put(Integer.MIN_VALUE, e -> loadField(e.intClass, "MIN_VALUE", Type.INT_TYPE));
        REPLACEMENTS.put(Integer.MAX_VALUE, e -> loadField(e.intClass, "MAX_VALUE", Type.INT_TYPE));
        REPLACEMENTS.put(Long.MIN_VALUE, e -> loadField(e.longClass, "MIN_VALUE", Type.LONG_TYPE));
        REPLACEMENTS.put(Long.MAX_VALUE, e -> loadField(e.longClass, "MAX_VALUE", Type.LONG_TYPE));
        REPLACEMENTS.put(Float.MIN_VALUE, e -> loadField(e.floatClass, "MIN_VALUE", Type.FLOAT_TYPE));
        REPLACEMENTS.put(Float.MAX_VALUE, e -> loadField(e.floatClass, "MAX_VALUE", Type.FLOAT_TYPE));
        REPLACEMENTS.put(Double.MIN_VALUE, e -> loadField(e.doubleClass, "MIN_VALUE", Type.DOUBLE_TYPE));
        REPLACEMENTS.put(Double.MAX_VALUE, e -> loadField(e.doubleClass, "MAX_VALUE", Type.DOUBLE_TYPE));

        // Float specials
        REPLACEMENTS.put(Float.POSITIVE_INFINITY, e -> loadField(e.floatClass.resolveField("POSITIVE_INFINITY", Type.FLOAT_TYPE)));
        REPLACEMENTS.put(Float.NEGATIVE_INFINITY, e -> loadField(e.floatClass.resolveField("NEGATIVE_INFINITY", Type.FLOAT_TYPE)));
        REPLACEMENTS.put(Float.MIN_NORMAL, e -> loadField(e.floatClass.resolveField("MIN_NORMAL", Type.FLOAT_TYPE)));
        REPLACEMENTS.put(Float.NaN, e -> loadField(e.floatClass.resolveField("NaN", Type.FLOAT_TYPE)));

        // Double specials:
        REPLACEMENTS.put(Double.POSITIVE_INFINITY, e -> loadField(e.doubleClass.resolveField("POSITIVE_INFINITY", Type.DOUBLE_TYPE)));
        REPLACEMENTS.put(Double.NEGATIVE_INFINITY, e -> loadField(e.doubleClass.resolveField("NEGATIVE_INFINITY", Type.DOUBLE_TYPE)));
        REPLACEMENTS.put(Double.MIN_NORMAL, e -> loadField(e.doubleClass.resolveField("MIN_NORMAL", Type.DOUBLE_TYPE)));
        REPLACEMENTS.put(Double.NaN, e -> loadField(e.doubleClass.resolveField("NaN", Type.DOUBLE_TYPE)));
    }

    private static void addAngle(int angle) {
        double piD = Math.PI;
        float piF = (float) Math.PI;

        double aD = angle;
        float aF = angle;
        REPLACEMENTS.putIfAbsent(piD / aD, e -> div(loadField(e.piField), ldc(aD)));
        REPLACEMENTS.putIfAbsent(aD / piD, e -> div(ldc(aD), loadField(e.piField)));
        REPLACEMENTS.putIfAbsent(piF / aF, e -> div(d2f(loadField(e.piField)), ldc(aF)));
        REPLACEMENTS.putIfAbsent(aF / piF, e -> div(ldc(aF), d2f(loadField(e.piField))));

        REPLACEMENTS.putIfAbsent(aD * piD, e -> mul(ldc(aD), loadField(e.piField)));
        REPLACEMENTS.putIfAbsent(aF * piF, e -> mul(ldc(aF), d2f(loadField(e.piField))));
    }

    private final ClassType intClass;
    private final ClassType longClass;
    private final ClassType floatClass;
    private final ClassType doubleClass;

    private final Field piField;
    private final Field eField;

    public NumericConstants(TypeResolver typeResolver) {
        ClassType mathClass = requireNonNull(typeResolver.resolveClass(MATH));
        intClass = requireNonNull(typeResolver.resolveClass(INTEGER));
        longClass = requireNonNull(typeResolver.resolveClass(LONG));
        floatClass = requireNonNull(typeResolver.resolveClass(FLOAT));
        doubleClass = requireNonNull(typeResolver.resolveClass(DOUBLE));

        piField = requireNonNull(mathClass.resolveField("PI", Type.DOUBLE_TYPE));
        eField = requireNonNull(mathClass.resolveField("E", Type.DOUBLE_TYPE));
    }

    @Override
    public void transform(ClassDecl cInsn, ClassTransformContext ctx) {
        if (cInsn.getClazz().getDeclType() != ClassType.DeclType.TOP_LEVEL) return;

        cInsn.accept(this);
    }

    @Override
    public None visitLdcNumber(LdcNumber ldc, None ctx) {
        Instruction replacement = getReplacement(ldc.getRawValue());
        if (replacement == null) return NONE;

        // Switches have some special rules..
        if (ldc.getParent() instanceof SwitchTable.SwitchSection) {
            FieldReference ref = LoadStoreMatching.matchLoadFieldRef(replacement);
            if (ref == null) return NONE;
            if (!ref.getField().isStatic() || !ref.getField().getAccessFlags().get(AccessFlag.FINAL)) return NONE;
            replacement = ref;
        }
        ldc.replaceWith(replacement);
        return NONE;
    }

    public @Nullable Instruction getReplacement(Number number) {
        Function<NumericConstants, Instruction> repl = REPLACEMENTS.get(number);
        if (repl == null) return null;

        return repl.apply(this);
    }

    private static Instruction loadField(ClassType clazz, String name, Type desc) {
        return loadField(clazz.resolveField(name, desc));
    }

    private static Instruction loadField(@Nullable Field field) {
        assert field != null;
        assert field.isStatic() && field.getAccessFlags().get(AccessFlag.FINAL);
        return new Load(new FieldReference(field));
    }

    private static Instruction negate(Instruction insn) {
        if (insn instanceof Binary binary) {
            binary.setLeft(negate(binary.getLeft()));
            return binary;
        }
        if (insn instanceof Cast cast) {
            // we should only ever have Double -> Flat casts here.
            assert cast.getType() == PrimitiveType.FLOAT;
            assert cast.getArgument().getResultType() == PrimitiveType.DOUBLE;

            return d2f(negate(cast.getArgument()));
        }
        LdcNumber number = insn.getResultType() == PrimitiveType.FLOAT ? ldc(0F) : ldc(0D);
        return new Binary(BinaryOp.SUB, number, insn);
    }

    // @formatter:off
    private static Instruction d2f(Instruction insn) { return new Cast(insn, PrimitiveType.FLOAT); }
    private static Instruction div(Instruction left, Instruction right) { return new Binary(BinaryOp.DIV, left, right); }
    private static Instruction mul(Instruction left, Instruction right) { return new Binary(BinaryOp.MUL, left, right); }
    private static LdcNumber ldc(Number n) { return new LdcNumber(n); }
    // @formatter:on
}
