/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class RewriteAddChkpointInLoop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isCheckpointEnabled()) {
            return List.of(sb);
        }
        if (sb == null || !HopRewriteUtils.isLastLevelLoopStatementBlock(sb)) {
            return List.of(sb);
        }
        Set<String> readUpdatedVars = sb.variablesRead().getVariableNames().stream().filter(v -> sb.variablesUpdated().containsVariable((String)v)).collect(Collectors.toSet());
        if (readUpdatedVars.isEmpty()) {
            return List.of(sb);
        }
        StatementBlock csb = sb instanceof WhileStatementBlock ? ((WhileStatement)sb.getStatement(0)).getBody().get(0) : ((ForStatement)sb.getStatement(0)).getBody().get(0);
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(csb);
        List<Lop> roots = lops.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
        HashSet<Lop> sparkRoots = new HashSet<Lop>();
        roots.forEach(r -> OperatorOrderingUtils.collectSparkRoots(r, new HashMap<Long, Integer>(), sparkRoots));
        if (sparkRoots.isEmpty()) {
            return List.of(sb);
        }
        HashMap<Long, Integer> operatorJobCount = new HashMap<Long, Integer>();
        this.findOverlappingJobs(sparkRoots, readUpdatedVars, operatorJobCount);
        if (operatorJobCount.isEmpty()) {
            return List.of(sb);
        }
        this.addChkpointLop(lops, operatorJobCount, csb);
        return List.of(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private void addChkpointLop(List<Lop> nodes, Map<Long, Integer> operatorJobCount, StatementBlock sb) {
        for (Lop l : nodes) {
            if (!operatorJobCount.containsKey(l.getID()) || operatorJobCount.get(l.getID()) <= 1) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            Checkpoint checkpoint = new Checkpoint(l, l.getDataType(), l.getValueType(), Checkpoint.getDefaultStorageLevelString(), false);
            for (Lop out : oldOuts) {
                checkpoint.addOutput(out);
                out.replaceInput(l, checkpoint);
                l.removeOutput(out);
            }
            sb.setCheckpointPosition(l, oldOuts);
        }
    }

    private void findOverlappingJobs(HashSet<Lop> sparkRoots, Set<String> ruVars, Map<Long, Integer> operatorJobCount) {
        HashSet<Lop> sharedRoots = new HashSet<Lop>();
        for (String var : ruVars) {
            for (Lop root : sparkRoots) {
                if (this.ifJobContains(root, var)) {
                    sharedRoots.add(root);
                }
                root.resetVisitStatus();
            }
            if (!sharedRoots.isEmpty()) {
                OperatorOrderingUtils.markSharedSparkOps(sharedRoots, operatorJobCount);
            }
            sharedRoots.clear();
        }
    }

    private boolean ifJobContains(Lop root, String var) {
        if (root.isVisited()) {
            return false;
        }
        for (Lop input : root.getInputs()) {
            if (!(input instanceof Data) && (!input.isExecSpark() || root.getBroadcastInput() == input) || !this.ifJobContains(input, var)) continue;
            root.setVisited();
            return true;
        }
        if (root instanceof Data && ((Data)root).isTransientRead() && root.getOutputParameters().getLabel().equalsIgnoreCase(var)) {
            root.setVisited();
            return true;
        }
        root.setVisited();
        return false;
    }
}

