package net.covers1624.coffeegrinder.bytecode;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import net.covers1624.coffeegrinder.bytecode.insns.*;
import net.covers1624.coffeegrinder.type.*;
import net.covers1624.coffeegrinder.util.None;
import net.covers1624.quack.collection.ColUtils;
import org.jetbrains.annotations.Nullable;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

import static net.covers1624.coffeegrinder.bytecode.matching.LoadStoreMatching.matchLocalRef;
import static net.covers1624.coffeegrinder.type.TypeSystem.isIntegerConstant;

/**
 * Created by covers1624 on 13/10/21.
 */
public class VariableLivenessGraph {

    public final int maxLocals;
    public final int firstLocalIndex;

    private final Map<LocalVariable, LocalVariable> equivalentLocals = new HashMap<>();

    private final HashMap<Block, List<LocalVariable>> branchStacks = new HashMap<>();
    public final AtomicInteger nodeCounter = new AtomicInteger();
    private final Map<Block, CFNode> blockNodeMap = new HashMap<>();
    private final List<LVLoad> lvLoads = new LinkedList<>();
    private final List<CFNode> allNodes = new ArrayList<>();

    private CFNode currentNode;

    public VariableLivenessGraph(int maxLocals, int firstLocalIndex, Block start) {
        this.maxLocals = maxLocals;
        this.firstLocalIndex = firstLocalIndex;
        currentNode = markBlockStack(start, Collections.emptyList());
    }

    public List<CFNode> getAllNodes() {
        return Collections.unmodifiableList(allNodes);
    }

    public List<LVLoad> getLVLoads() {
        return Collections.unmodifiableList(lvLoads);
    }

    public void addExceptionHandler(Block handler, LocalVariable var) {
        markBlockStack(handler, Collections.singletonList(var)).addInfo("eh");
    }

    public LocalReference readLocal(int index) {
        LocalVariable v = getEquivalentVar(currentNode.readLocal(index));
        lvLoads.add(new LVLoad(v, currentNode));
        return new LocalReference(v);
    }

    public void visitStore(Store store) {
        updateCFNode(new CFNode(store));
    }

    public void addExceptionLink(Block handler) {
        blockNodeMap.get(handler).addLink(currentNode);
    }

    public void addHandlerLink(Block handler) {
        blockNodeMap.get(handler).addHandlerLink(currentNode);
    }

    public void addCFEdge(Block target, List<LocalVariable> currentStack) {
        markBlockStack(target, currentStack).addLink(currentNode);
    }

    public void markNode(String name) {
        updateCFNode(new CFNode(name));
    }

    public List<LocalVariable> visitBlock(Block block) {
        currentNode = Objects.requireNonNull(blockNodeMap.get(block));
        return branchStacks.get(block);
    }

    public void applyLVInfo(LocalVariable from) {
        applyLVInfo(from, Objects.requireNonNull(getEquivalentVar(currentNode.readLocal(from.getIndex()))));
    }

    public boolean isDead(Block block) {
        return !blockNodeMap.containsKey(block);
    }

    public void applyAllReplacements(BlockContainer mainContainer) {
        mainContainer.accept(new SimpleInsnVisitor<None>() {
            @Override
            public None visitStore(Store store, None ctx) {
                super.visitStore(store, ctx);

                // with synthetic locals, the type information is updated and widened as they're assigned.
                // It's possible to load a synthetic local into a stack slot, and then the local variable changes type, and is no longer assignable to the stack slot
                // A single pass should be sufficient due to bytecode ordering and dependence
                // Note that makeVariableCompatibleWithType doesn't change non-synthetic variables
                LocalReference stLocVar = matchLocalRef(store.getReference());
                if (stLocVar != null) {
                    makeVariableCompatibleWithType(stLocVar.variable, store.getValue().getResultType());
                }
                return NONE;
            }
        });
    }

    // Apply LV info from one variable to another.
    private void applyLVInfo(LocalVariable from, LocalVariable to) {
        assert from.getIndex() == to.getIndex();
        if (!from.isSynthetic()) {
            if (!to.isSynthetic()) {
                // If this variable is not synthetic, the type and name must match.
                assert to.getType().equals(from.getType());
                assert to.getName().equals(from.getName());
            } else {
                to.setGenericSignature(from.getGenericSignature());
                to.setType(from.getType());
                to.setName(from.getName());
                to.setSynthetic(false);
            }

            // javac has a bad habit of only annotating one of the split ranges of an LV, so all merges need to be checked for useful incoming annotation info
            if (!from.getAnnotationSupplier().isEmpty()) {
                assert to.getAnnotationSupplier().isEmpty() || to.getAnnotationSupplier().equals(from.getAnnotationSupplier());
                to.setAnnotationSupplier(from.getAnnotationSupplier());
            }
        }
    }

    // TODO: this could disappear completely if cf node held LocalReferences and the variables were mutable
    public LocalVariable getEquivalentVar(LocalVariable var) {
        LocalVariable equiv = equivalentLocals.get(var);
        if (equiv == null) return var;
        assert equiv == var || !var.isConnected();

        LocalVariable equiv2 = getEquivalentVar(equiv);
        if (equiv2 != equiv) {
            assert !equiv.isConnected();
            equivalentLocals.put(var, equiv2);
        }

        return equiv2;
    }

    // Marks a Local as equivalent to another.
    private void makeEquivalent(LocalVariable var1, LocalVariable var2) {
        assert var1.getIndex() == var2.getIndex();
        var1 = getEquivalentVar(var1);
        var2 = getEquivalentVar(var2);
        if (var1 == var2) return;
        assert var1.isConnected() && var2.isConnected();

        equivalentLocals.put(var2, var1);
        makeVariableCompatibleWithType(var1, var2.getType());
        if (var2.getKind() == LocalVariable.VariableKind.LOCAL) {
            applyLVInfo(var2, var1);
        }

        LocalVariable finalVar = var1;
        List.copyOf(var2.getReferences()).forEach(e -> e.replaceWith(new LocalReference(finalVar)));
    }

    private CFNode markBlockStack(Block target, List<LocalVariable> currentStack) {
        List<LocalVariable> previousStack = branchStacks.get(target);
        if (previousStack != null) {
            assert currentStack.size() == previousStack.size();
            for (int i = 0; i < currentStack.size(); i++) {
                LocalVariable currVariable = currentStack.get(i);
                LocalVariable prevVariable = previousStack.get(i);
                makeEquivalent(prevVariable, currVariable);
            }
        } else {
            branchStacks.put(target, new ArrayList<>(currentStack));
        }

        // Add a new CFNode for the target block and add a link to the current node.
        return blockNodeMap.computeIfAbsent(target, l -> new VariableLivenessGraph.CFNode(target));
    }

    private void makeVariableCompatibleWithType(LocalVariable var, AType type) {
        if (var.isSynthetic()) {
            var.setType(combineStackTypes(var.getType(), type));
        } else {
            // unchecked casts aren't added immediately when applying LV info, so we can't check assignability yet.
            // assert isAssignableTo(type, varType);
        }
    }

    private static AType combineStackTypes(AType a, AType b) {
        if (a == b) return a;

        if (a instanceof ReferenceUnionType) {
            assert ColUtils.anyMatch(((ReferenceUnionType) a).getTypes(), t -> t.equals(b));
            return a;
        }

        if (TypeSystem.isAssignableTo(a, b)) return b;
        if (TypeSystem.isAssignableTo(b, a)) return a;

        if (a instanceof ReferenceType) {
            return TypeSystem.lub((ReferenceType) a, (ReferenceType) b);
        }

        if (a == PrimitiveType.BOOLEAN || b == PrimitiveType.BOOLEAN) {
            return PrimitiveType.BOOLEAN;
        }

        if (!isIntegerConstant(a) || !isIntegerConstant(b)) {
            assert TypeSystem.isAssignableTo(a, PrimitiveType.INT);
            assert TypeSystem.isAssignableTo(b, PrimitiveType.INT);
            return PrimitiveType.INT;
        }

        // flatten them!
        LinkedList<IntegerConstantType> types = new LinkedList<>();
        if (a instanceof IntegerConstantUnion) {
            types.addAll(((IntegerConstantUnion) a).getTypes());
        } else {
            types.add((IntegerConstantType) a);
        }
        types.add((IntegerConstantType) b);

        return new IntegerConstantUnion(types);
    }

    private void updateCFNode(VariableLivenessGraph.CFNode newNode) {
        newNode.addLink(currentNode);
        currentNode = newNode;
    }

    /**
     * Represents a node in a ControlFlow graph.
     * <p>
     * This graph is specific to Local Variable reads and writes,
     * it is intended for this graph to only ever be iterated 'backwards'
     * relative to a given Local Variable load.
     */
    public class CFNode {

        private String id;
        private final Set<CFNode> parents = new LinkedHashSet<>();
        private final Set<CFNode> isHandlerLink = new LinkedHashSet<>();
        private final LocalVariable[] memoTable = new LocalVariable[maxLocals];

        @Nullable
        private Store store;

        @Nullable
        private Block block;

        private CFNode(Store store) {
            allNodes.add(this);
            this.store = store;
            id = "w" + nodeCounter.getAndIncrement();
            memoTable[store.getVariable().getIndex()] = store.getVariable();
        }

        private CFNode(Block block) {
            allNodes.add(this);
            this.block = block;
            id = block.getName();
        }

        private CFNode(String cfInfo) {
            allNodes.add(this);
            id = cfInfo + " " + nodeCounter.getAndIncrement();
        }

        private void addInfo(String info) {
            id += " " + info;
        }

        public LocalVariable getWriteVar() {
            assert store != null;
            return store.getVariable();
        }

        private void addHandlerLink(CFNode other) {
            addLink(other);
            isHandlerLink.add(other);
        }

        private void addLink(CFNode pred) {
            parents.add(pred);

            for (int i = 0; i < memoTable.length; i++) {
                LocalVariable v = memoTable[i];
                if (v == null || store != null && store.getVariable().getIndex() == i) continue;
                makeEquivalent(v, pred.readLocal(i));
            }
        }

        private LocalVariable readLocal(int index) {
            LocalVariable v = memoTable[index];
            if (v != null) return v;

            HashMultimap<CFNode, CFNode> pendingReads = HashMultimap.create();
            v = readLocal(new LinkedList<>(), pendingReads, index);
            assert v != null;
            assert pendingReads.isEmpty();
            return v;
        }

        @Nullable
        private LocalVariable readLocal(LinkedList<CFNode> visiting, Multimap<CFNode, CFNode> pendingReads, int index) {
            LocalVariable v = memoTable[index];
            if (v != null || visiting.contains(this)) {
                return v;
            }

            visiting.push(this);
            // ask all our parents to read
            for (CFNode parent : parents) {
                LocalVariable v2 = parent.readLocal(visiting, pendingReads, index);
                if (v2 != null) {
                    memoLocal(v2, pendingReads);
                } else { // when cfNode actually gets a memoTable entry, we need to update again (either to inherit it or merge it)
                    // TODO: pendingReads is never hit currently, parents always have at least one store, and the lvtable ensures loops are read as they are built
                    //  it's likely that parsing a method without an lv table, where a loop is fully built before a read has to pass through it, will require this code
                    pendingReads.put(parent, this);
                }
            }
            visiting.pop();
            return memoTable[index];
        }

        private void memoLocal(LocalVariable f, Multimap<CFNode, CFNode> pendingReads) {
            LocalVariable v = memoTable[f.getIndex()];
            if (v == null) {
                memoTable[f.getIndex()] = f;
                for (CFNode cfNode : pendingReads.removeAll(this)) {
                    cfNode.memoLocal(f, pendingReads);
                }
            } else {
                assert !pendingReads.containsKey(this);
                makeEquivalent(v, f);
            }
        }

        public String getId() {
            return id;
        }

        public @Nullable Store getStore() {
            return store;
        }

        public Set<CFNode> getParents() {
            return Collections.unmodifiableSet(parents);
        }

        public @Nullable LocalVariable getMemoEntry(int idx) {
            return memoTable[idx];
        }

        public boolean isHandlerLink(CFNode node) {
            return isHandlerLink.contains(node);
        }
    }

    public record LVLoad(LocalVariable variable, CFNode node) { }
}
