package net.covers1624.coffeegrinder.bytecode.transform.transformers.statement;

import net.covers1624.coffeegrinder.bytecode.Instruction;
import net.covers1624.coffeegrinder.bytecode.SimpleInsnVisitor;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformContext;
import net.covers1624.coffeegrinder.bytecode.transform.StatementTransformer;
import net.covers1624.coffeegrinder.type.ClassType;
import net.covers1624.coffeegrinder.type.Method;
import net.covers1624.coffeegrinder.type.TypeResolver;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.ColUtils;
import net.covers1624.quack.collection.FastStream;
import org.jetbrains.annotations.Nullable;
import org.objectweb.asm.Type;

import java.util.*;

import static net.covers1624.coffeegrinder.bytecode.insns.Invoke.InvokeKind.VIRTUAL;
import static net.covers1624.coffeegrinder.bytecode.matching.InvokeMatching.matchInvokeDynamic;
import static net.covers1624.quack.util.SneakyUtils.notPossible;

/**
 * Created by covers1624 on 16/1/22.
 */
public class StringConcat extends SimpleInsnVisitor<StatementTransformContext> implements StatementTransformer {

    private static final char TAG_ARG = '\u0001';
    private static final char TAG_CONST = '\u0002';

    private final ClassType string;
    private final @Nullable ClassType stringConcatFactory;
    private final ClassType stringBuilder;
    private final Method sb_toString;

    public StringConcat(TypeResolver typeResolver) {
        string = typeResolver.resolveClassDecl(TypeResolver.STRING_TYPE);
        stringConcatFactory = typeResolver.tryResolveClassDecl("java/lang/invoke/StringConcatFactory");
        stringBuilder = typeResolver.resolveClassDecl("java/lang/StringBuilder");
        sb_toString = Objects.requireNonNull(stringBuilder.resolveMethod("toString", Type.getMethodType("()Ljava/lang/String;")));
    }

    @Override
    public void transform(Instruction statement, StatementTransformContext ctx) {
        statement.accept(this, ctx);
    }

    @Override
    public None visitInvokeDynamic(InvokeDynamic indy, StatementTransformContext ctx) {
        super.visitInvokeDynamic(indy, ctx);
        if (stringConcatFactory == null) return NONE;

        // 2 variants of InvokeDynamic string concat.
        // The first is just append the stuff, super easy.
        InvokeDynamic makeConcat = matchInvokeDynamic(indy, stringConcatFactory, "makeConcat");
        if (makeConcat != null) {
            makeStringConcat(makeConcat, makeConcat.arguments.toList(), ctx);
            return NONE;
        }
        // The second uses a 'recipe' and special marker chars to denote where the args go.
        InvokeDynamic makeConcatWithConstants = matchInvokeDynamic(indy, stringConcatFactory, "makeConcatWithConstants");
        if (makeConcatWithConstants != null) {
            List<Instruction> concatArgs = parseConcatWithConstants(makeConcatWithConstants);
            makeStringConcat(makeConcatWithConstants, concatArgs, ctx);
            return NONE;
        }

        return NONE;
    }

    @Override
    public None visitInvoke(Invoke invoke, StatementTransformContext ctx) {
        super.visitInvoke(invoke, ctx); // Depth first.

        // Match: INVOKE.VIRTUAL (...) java/lang/StringBuilder.toString()
        Invoke toString = InvokeMatching.matchInvoke(invoke, VIRTUAL, sb_toString);
        if (toString != null) {
            transformStringConcat(toString, ctx);
        }

        return NONE;
    }

    private void transformStringConcat(Invoke invoke, StatementTransformContext ctx) {
        // Find the elements of the append chain.
        List<Instruction> appendChain = followAppendChain(invoke.getTarget());
        // If the result is empty, we didn't match, and if we found a single element,
        // this is not a pattern generated by the compiler, and should be ignored.
        if (appendChain.size() <= 1) return;

        makeStringConcat(invoke, appendChain, ctx);
    }

    private List<Instruction> followAppendChain(Instruction start) {
        // Continually match the following:
        // INVOKE.VIRTUAL (next) java/lang/StringBuilder.append()
        //
        // Until:
        // NEW java/lang/StringBuilder

        LinkedList<Instruction> appendChain = new LinkedList<>();

        // Follow append chain all the way to the `new StringBuilder()` call.
        Instruction next = start;
        while (!isStringBuilderNew(next)) {
            Instruction arg = getAppendArgument(next);
            if (arg == null) return Collections.emptyList();
            // Re-apply on the target.
            next = ((Invoke) next).getTarget();

            appendChain.addFirst(arg);
        }
        // The elements of this list will be inverse from how they are declared as we walk the append chain backwards.
        return appendChain;
    }

    @Nullable
    private Instruction getAppendArgument(Instruction insn) {
        // Match: INVOKE.VIRTUAL java/lang/StringBuilder.append(...)

        // Match any append method. (Object, int, char, etc)
        Invoke append = InvokeMatching.matchInvoke(insn, VIRTUAL, "append");
        if (append == null) return null;

        // Make sure it's declared on StringBuilder
        Method method = append.getMethod();
        if (method.getDeclaringClass().getDeclaration() != stringBuilder) return null;

        // Only single argument
        if (append.getArguments().size() != 1) return null;

        // Return first argument.
        return append.getArguments().first();
    }

    private boolean isStringBuilderNew(Instruction insn) {
        // Match the void constructor of StringBuilder.

        if (!(insn instanceof New newInsn)) return false;

        Method method = newInsn.getMethod();
        return method.getName().equals("<init>")
               && newInsn.getArguments().isEmpty()
               && method.getDeclaringClass().getDeclaration() == stringBuilder;
    }

    private List<Instruction> parseConcatWithConstants(InvokeDynamic indy) {
        List<Instruction> concatArgs = new ArrayList<>();
        int argIdx = 0;
        String recipe = (String) indy.bootstrapArguments[0];
        StringBuilder builder = new StringBuilder();
        for (char c : recipe.toCharArray()) {
            // We treat tag and const the same here, it doesn't really matter.
            // We don't have to deal with any escaping as javac does that by emitting a const for the entire string.
            if (c == TAG_ARG || c == TAG_CONST) {
                if (!builder.isEmpty()) {
                    // Emit any chars we have processed.
                    concatArgs.add(new LdcString(string, builder.toString()));
                    builder.setLength(0);
                }
                // Emit argument
                concatArgs.add(indy.arguments.get(argIdx++));
            } else {
                // Append char to buffer
                builder.append(c);
            }
        }
        if (!builder.isEmpty()) {
            // Emit any remaining chars
            concatArgs.add(new LdcString(string, builder.toString()));
        }
        return concatArgs;
    }

    private void makeStringConcat(Instruction toReplace, List<Instruction> concatArgs, StatementTransformContext ctx) {
        // String concats are only valid in source if any argument is a string.
        // If there is no string, then we must have come from a `"" + i` style concat and need to add an empty string
        // to coerce the concat.
        if (!ColUtils.anyMatch(concatArgs, e -> e.getResultType().equals(string))) {
            concatArgs.addFirst(new LdcString(string, ""));
        }
        Instruction concat = FastStream.of(concatArgs)
                .fold((a, b) -> new Binary(BinaryOp.ADD, a, b))
                .orElseThrow(notPossible());
        ctx.pushStep("Produce string concat");
        toReplace.replaceWith(concat.withOffsets(toReplace));
        ctx.popStep();
    }
}
