/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.ReBlock;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.parser.StatementBlock;

public class RewriteAddPrefetchLop
extends LopRewriteRule {
    @Override
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb) {
        if (!ConfigurationManager.isPrefetchEnabled()) {
            return List.of(sb);
        }
        ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
        if (lops == null) {
            return List.of(sb);
        }
        ArrayList<Lop> nodesWithPrefetch = new ArrayList<Lop>();
        for (Lop l : lops) {
            nodesWithPrefetch.add(l);
            if (!this.isPrefetchNeeded(l)) continue;
            ArrayList<Lop> oldOuts = new ArrayList<Lop>(l.getOutputs());
            UnaryCP prefetch = new UnaryCP(l, Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP);
            prefetch.setAsynchronous(true);
            l.setAsynchronous(false);
            for (Lop outCP : oldOuts) {
                prefetch.addOutput(outCP);
                outCP.replaceInput(l, prefetch);
                l.removeOutput(outCP);
            }
            nodesWithPrefetch.add(prefetch);
        }
        return Arrays.asList(sb);
    }

    @Override
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
        return sbs;
    }

    private boolean isPrefetchNeeded(Lop lop) {
        boolean transformOP = lop.getExecType() == Types.ExecType.SPARK && lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK && lop.getDataType() != Types.DataType.SCALAR && !(lop instanceof MapMultChain) && !(lop instanceof PickByCount) && !(lop instanceof MMZip) && !(lop instanceof CentralMoment) && !(lop instanceof CoVariance) && !(lop instanceof Checkpoint) && !(lop instanceof ReBlock) && !(lop instanceof CSVReBlock) && !(lop instanceof DataGen) && !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);
        boolean hasParameterizedOut = lop.getOutputs().stream().anyMatch(out -> out instanceof ParameterizedBuiltin || out instanceof GroupedAggregate || out instanceof GroupedAggregateM);
        return transformOP && !hasParameterizedOut && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) && lop.getDataType() == Types.DataType.MATRIX;
    }
}

