package deepimagej.stamp;

import deepimagej.BuildDialog;
import deepimagej.Constants;
import deepimagej.DeepLearningModel;
import deepimagej.Parameters;
import deepimagej.components.HTMLPane;
import deepimagej.tools.DijTensor;
import deepimagej.tools.Index;
import ij.IJ;
import java.awt.Dimension;
import java.awt.GridBagConstraints;
import java.awt.GridBagLayout;
import java.awt.Insets;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import javax.swing.BoxLayout;
import javax.swing.JComboBox;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JScrollPane;
import org.jfree.chart.axis.ValueAxis;

/* loaded from: input_file:deepimagej/stamp/TensorStamp.class */
public class TensorStamp extends AbstractStamp implements ActionListener {
    private List<JComboBox<String>> inputs;
    private static List<JComboBox<String>> outputs;
    private List<JComboBox<String>> inTags;
    private static List<JComboBox<String>> outTags;
    private String[] in;
    private String[] outPyramidal;
    private String[] inputOptions;
    private static String[] outOptions = {"image", "list", "ignore"};
    private HTMLPane pnDim;
    private JPanel pn;
    private JPanel pnInOut;
    private int iterateOverComboBox;
    private String model;
    private boolean pyramidal;

    public TensorStamp(BuildDialog buildDialog) {
        super(buildDialog);
        this.in = new String[]{"B", "Y", "X", "C", "Z"};
        this.outPyramidal = new String[]{"B", "Y", "X", "C", "N/i/z"};
        this.inputOptions = new String[]{"image", "parameter"};
        this.pn = new JPanel();
        this.pnInOut = new JPanel();
        this.model = "";
        this.pyramidal = false;
    }

    @Override // deepimagej.stamp.AbstractStamp
    public void buildPanel() {
        HTMLPane hTMLPane = new HTMLPane(Constants.width, 100);
        hTMLPane.append("h2", "Tensor Organization");
        hTMLPane.append("p", "Each dimension of input and output tensors must be specified to process  the image correctly, i.e. the first dimension of the input tensor corresponds to the batch size, the second dimension to the width and so on.<br><b>Note that for the moment DeepImageJ only supports BATCH_SIZE = 1.</b>");
        hTMLPane.setMaximumSize(new Dimension(Constants.width, 100));
        Parameters parameters = this.parent.getDeepPlugin().params;
        this.pnInOut.removeAll();
        List<DijTensor> list = parameters.totalInputList;
        List<DijTensor> list2 = parameters.totalOutputList;
        this.pnDim = new HTMLPane(Constants.width, 250);
        File file = new File(this.parent.getDeepPlugin().params.path2Model);
        this.pnDim.append("h2", "Input/Output tensor dimensions of " + (file.exists() ? file.getName() : "untitled"));
        this.model = parameters.path2Model;
        this.pyramidal = parameters.pyramidalNetwork;
        for (DijTensor dijTensor : list) {
            this.pnDim.append("<p><b>Input tensor --> </b>" + dijTensor.name + " : " + Arrays.toString(dijTensor.tensor_shape) + "</p>");
        }
        for (DijTensor dijTensor2 : list2) {
            this.pnDim.append("<p><b>Output tensor --> </b>" + dijTensor2.name + " : " + Arrays.toString(dijTensor2.tensor_shape) + "</p>");
        }
        this.pnDim.append("<br>");
        this.pnDim.append("<p><b>Input tensor types:</b></p>");
        this.pnDim.append("<ul><li><p><b>Image:</b> <b>X</b> for width (axis X), <b>Y</b> for height (axis Y), <b>Z</b> for depth (axis Z), <b>B</b> for batch, <b>C</b> for channel. DeepImageJ can only process one image at a time. If the input is not an image, you should include a pre-processing written in Java.</p></li>");
        this.pnDim.append("<li><p><b>Parameter:</b> If the input is a parameter, the corresponding tensor must be created using Java pre-processing, thus no dimension specification is needed. The tensor is fed directly to the model from pre-processing.</p></li></ul>");
        this.pnDim.append("<p><b>Output tensor types:</b></p>");
        if (parameters.pyramidalNetwork) {
            this.pnDim.append("<ul><li><p><b>Image:</b> <b>X</b> for width (axis X), <b>Y</b> for height (axis Y), <b>N/i/z</b> for number of components/patches/objects or depth (axis Z), <b>B</b> for the batch, <b>C</b> for channel or class</p></li>");
        } else {
            this.pnDim.append("<ul><li><p><b>Image:</b> <b>X</b> for width (axis X), <b>Y</b> for height (axis Y), <b>Z</b> for depth (axis Z), <b>B</b> for the batch, <b>C</b> for channel</p></li>");
        }
        this.pnDim.append("<li><p><b>List:</b> the tensor corresponds to a batch of matrices.<b>R</b> for rows, <b>C</b> for columns, <b>B</b> for the batch. This type can be used for tensorswith 3 dimensions at most (being one of them the batch).</p></li>");
        this.pnDim.append("<li><p><b>Ignore:</b> DeepImageJ will not retrieve the tensor from the model.</p></li></ul>");
        this.pnDim.setMaximumSize(new Dimension(Constants.width, 250));
        GridBagConstraints gridBagConstraints = new GridBagConstraints();
        gridBagConstraints.gridwidth = 3;
        gridBagConstraints.gridx = 0;
        gridBagConstraints.insets = new Insets(3, 5, 3, 5);
        GridBagConstraints gridBagConstraints2 = new GridBagConstraints();
        gridBagConstraints2.gridwidth = 3;
        gridBagConstraints2.gridx = 3;
        gridBagConstraints2.insets = new Insets(3, 5, 3, 5);
        int i = 0;
        this.inTags = new ArrayList();
        outTags = new ArrayList();
        this.inputs = new ArrayList();
        outputs = new ArrayList();
        for (DijTensor dijTensor3 : list) {
            JPanel jPanel = new JPanel(new GridBagLayout());
            JComboBox<String> jComboBox = new JComboBox<>(this.inputOptions);
            jComboBox.addActionListener(this);
            jPanel.add(jComboBox, gridBagConstraints);
            this.inTags.add(jComboBox);
            jPanel.add(new JLabel(dijTensor3.name), gridBagConstraints2);
            for (int i2 = 0; i2 < dijTensor3.tensor_shape.length; i2++) {
                JComboBox<String> jComboBox2 = new JComboBox<>(this.in);
                jComboBox2.setPreferredSize(new Dimension(50, 50));
                jPanel.add(jComboBox2);
                this.inputs.add(jComboBox2);
            }
            this.pnInOut.add(jPanel);
            i++;
        }
        for (DijTensor dijTensor4 : list2) {
            JPanel jPanel2 = new JPanel(new GridBagLayout());
            JComboBox<String> jComboBox3 = new JComboBox<>(outOptions);
            jComboBox3.addActionListener(this);
            jPanel2.add(jComboBox3, gridBagConstraints);
            outTags.add(jComboBox3);
            jPanel2.add(new JLabel(dijTensor4.name), gridBagConstraints2);
            for (int i3 = 0; i3 < dijTensor4.tensor_shape.length; i3++) {
                JComboBox<String> jComboBox4 = new JComboBox<>(parameters.pyramidalNetwork ? this.outPyramidal : this.in);
                jComboBox4.setPreferredSize(new Dimension(50, 50));
                jPanel2.add(jComboBox4);
                outputs.add(jComboBox4);
            }
            this.pnInOut.add(jPanel2);
            i++;
        }
        JScrollPane jScrollPane = new JScrollPane();
        this.pnInOut.setPreferredSize(new Dimension(ValueAxis.MAXIMUM_TICK_COUNT, i * 60));
        jScrollPane.setPreferredSize(new Dimension(600, (i * 70) + 50));
        jScrollPane.setViewportView(this.pnInOut);
        this.pn.removeAll();
        this.pn.setLayout(new BoxLayout(this.pn, 3));
        this.pn.add(hTMLPane.getPane());
        this.pn.add(this.pnDim.getPane());
        this.pn.add(jScrollPane);
        this.panel.add(this.pn);
    }

    @Override // deepimagej.stamp.AbstractStamp
    public void init() {
        if (this.parent.getDeepPlugin().params.path2Model.equals(this.model) && this.pyramidal == this.parent.getDeepPlugin().params.pyramidalNetwork) {
            return;
        }
        buildPanel();
    }

    @Override // deepimagej.stamp.AbstractStamp
    public boolean finish() {
        Parameters parameters = this.parent.getDeepPlugin().params;
        boolean z = false;
        if (!parameters.pyramidalNetwork) {
            parameters.allowPatching = true;
        }
        parameters.inputList = new ArrayList();
        List<DijTensor> list = parameters.totalInputList;
        this.iterateOverComboBox = 0;
        int i = 0;
        for (DijTensor dijTensor : list) {
            dijTensor.form = "";
            for (int i2 = this.iterateOverComboBox; i2 < this.iterateOverComboBox + dijTensor.tensor_shape.length; i2++) {
                dijTensor.form = String.valueOf(dijTensor.form) + ((String) this.inputs.get(i2).getSelectedItem());
            }
            int i3 = i;
            i++;
            dijTensor.tensorType = (String) this.inTags.get(i3).getSelectedItem();
            if (!z && dijTensor.tensorType.contains("image")) {
                z = true;
            } else if (dijTensor.tensorType.contains("image")) {
                IJ.error("The current DeepImageJ version only admits on input image tensor.");
                return false;
            }
            this.iterateOverComboBox += dijTensor.tensor_shape.length;
            if (!checkRepeated(dijTensor.form) && !dijTensor.tensorType.equals("parameter")) {
                IJ.error("Dimension repetition is not allowed");
                return false;
            }
            if (!DeepLearningModel.nBatch(dijTensor.tensor_shape, dijTensor.form).equals("1") && !dijTensor.tensorType.equals("ignore")) {
                IJ.error("The plugin only supports models with batch size (N) = 1");
                return false;
            }
            parameters.inputList.add(dijTensor);
        }
        parameters.outputList = new ArrayList();
        List<DijTensor> list2 = parameters.totalOutputList;
        int i4 = 0;
        this.iterateOverComboBox = 0;
        for (DijTensor dijTensor2 : list2) {
            dijTensor2.form = "";
            for (int i5 = this.iterateOverComboBox; i5 < this.iterateOverComboBox + dijTensor2.tensor_shape.length; i5++) {
                String str = (String) outputs.get(i5).getSelectedItem();
                dijTensor2.form = String.valueOf(dijTensor2.form) + (str.toLowerCase().contains("z") ? "Z" : str);
            }
            dijTensor2.auxForm = dijTensor2.form;
            this.iterateOverComboBox += dijTensor2.tensor_shape.length;
            int i6 = i4;
            i4++;
            dijTensor2.tensorType = (String) outTags.get(i6).getSelectedItem();
            if (dijTensor2.tensorType.contains("list")) {
                parameters.allowPatching = false;
            }
            if (!checkRepeated(dijTensor2.form) && !dijTensor2.tensorType.equals("ignore")) {
                IJ.error("Dimension repetition is not allowed");
                return false;
            }
            if (!DeepLearningModel.nBatch(dijTensor2.tensor_shape, dijTensor2.form).equals("1") && !dijTensor2.tensorType.equals("ignore")) {
                IJ.error("The plugin only supports models with batch size (B) = 1");
                return false;
            }
        }
        ListIterator<DijTensor> listIterator = list2.listIterator();
        while (listIterator.hasNext()) {
            DijTensor next = listIterator.next();
            if (!next.tensorType.contains("ignore")) {
                parameters.outputList.add(next);
            }
        }
        if (!z) {
            IJ.error("The model must have at least 1 input image.");
            return false;
        }
        if (parameters.outputList.size() >= 1) {
            return true;
        }
        IJ.error("The model must have at least 1 output.");
        return false;
    }

    public void updateTensorDisplay(Parameters parameters) {
        List<DijTensor> list = parameters.totalInputList;
        int i = 0;
        int i2 = 0;
        Iterator<JComboBox<String>> it = this.inTags.iterator();
        while (it.hasNext()) {
            String str = this.inputOptions[it.next().getSelectedIndex()];
            for (int i3 = i2; i3 < i2 + list.get(i).tensor_shape.length; i3++) {
                if (str.contains("parameter") && ((String) this.inputs.get(i3).getItemAt(0)).equals("B")) {
                    this.inputs.get(i3).removeAllItems();
                    this.inputs.get(i3).addItem("-");
                    this.inputs.get(i3).setEnabled(false);
                    String str2 = list.get(i).form;
                    if (str2 != null && !str2.contentEquals("")) {
                        list.get(i).form = "";
                    }
                } else if (str.contains("image") && ((String) this.inputs.get(i3).getItemAt(0)).equals("-")) {
                    this.inputs.get(i3).removeAllItems();
                    this.inputs.get(i3).addItem("B");
                    this.inputs.get(i3).addItem("Y");
                    this.inputs.get(i3).addItem("X");
                    this.inputs.get(i3).addItem("C");
                    this.inputs.get(i3).addItem("Z");
                    this.inputs.get(i3).setEnabled(true);
                    String str3 = list.get(i).form;
                    if (str3 != null && !str3.contentEquals("")) {
                        list.get(i).form = "";
                    }
                }
            }
            i2 += list.get(i).tensor_shape.length;
            i++;
        }
        List<DijTensor> list2 = parameters.totalOutputList;
        int i4 = 0;
        int i5 = 0;
        Iterator<JComboBox<String>> it2 = outTags.iterator();
        while (it2.hasNext()) {
            String str4 = outOptions[it2.next().getSelectedIndex()];
            for (int i6 = i5; i6 < i5 + list2.get(i4).tensor_shape.length; i6++) {
                if (str4.contains("ignore")) {
                    outputs.get(i6).setEnabled(!str4.equals("ignore"));
                } else if (str4.contains("list") && ((String) outputs.get(i6).getItemAt(1)).equals("Y")) {
                    outputs.get(i6).removeAllItems();
                    outputs.get(i6).addItem("B");
                    outputs.get(i6).addItem("R");
                    outputs.get(i6).addItem("C");
                    outputs.get(i6).setEnabled(true);
                    String str5 = list2.get(i4).form;
                    if (str5 != null && !str5.contentEquals("")) {
                        list2.get(i4).form = "";
                    }
                } else if (str4.contains("list") && !outputs.get(i6).isEnabled()) {
                    outputs.get(i6).setEnabled(true);
                } else if (str4.contains("image") && ((String) outputs.get(i6).getItemAt(1)).equals("R") && !parameters.pyramidalNetwork) {
                    outputs.get(i6).removeAllItems();
                    outputs.get(i6).addItem("B");
                    outputs.get(i6).addItem("Y");
                    outputs.get(i6).addItem("X");
                    outputs.get(i6).addItem("C");
                    outputs.get(i6).addItem("Z");
                    outputs.get(i6).setEnabled(true);
                    String str6 = list2.get(i4).form;
                    if (str6 != null && !str6.contentEquals("")) {
                        list2.get(i4).form = "";
                    }
                } else if (str4.contains("image") && ((String) outputs.get(i6).getItemAt(1)).equals("R") && parameters.pyramidalNetwork) {
                    outputs.get(i6).removeAllItems();
                    outputs.get(i6).addItem("B");
                    outputs.get(i6).addItem("Y");
                    outputs.get(i6).addItem("X");
                    outputs.get(i6).addItem("C");
                    outputs.get(i6).addItem("N/i/z");
                    outputs.get(i6).setEnabled(true);
                    String str7 = list2.get(i4).form;
                    if (str7 != null && !str7.contentEquals("")) {
                        list2.get(i4).form = "";
                    }
                } else if (str4.contains("image") && !outputs.get(i6).isEnabled()) {
                    outputs.get(i6).setEnabled(true);
                }
            }
            i5 += list2.get(i4).tensor_shape.length;
            i4++;
        }
    }

    private boolean checkRepeated(String str) {
        for (int i = 0; i < str.length(); i++) {
            if (Index.lastIndexOf(str.split(""), str.split("")[i]) != i) {
                return false;
            }
        }
        return true;
    }

    public void actionPerformed(ActionEvent actionEvent) {
        updateTensorDisplay(this.parent.getDeepPlugin().params);
    }
}
