/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.kmeans;

import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.io.IOException;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.OutputFactory;
import org.tribuo.clustering.ClusteringFactory;
import org.tribuo.clustering.evaluation.ClusteringEvaluation;
import org.tribuo.clustering.kmeans.KMeansTrainer;
import org.tribuo.data.DataOptions;

public class TrainTest {
    private static final Logger logger = Logger.getLogger(TrainTest.class.getName());

    public static void main(String[] args) throws IOException {
        ConfigurationManager cm;
        LabsLogFormatter.setAllLogFormatters();
        KMeansOptions o = new KMeansOptions();
        try {
            cm = new ConfigurationManager(args, (Options)o);
        }
        catch (UsageException e) {
            logger.info(e.getMessage());
            return;
        }
        if (o.general.trainingPath == null) {
            logger.info(cm.usage());
            return;
        }
        ClusteringFactory factory = new ClusteringFactory();
        Pair data = o.general.load((OutputFactory)factory);
        Dataset train = (Dataset)data.getA();
        KMeansTrainer trainer = new KMeansTrainer(o.centroids, o.iterations, o.distance, o.initialisation, o.numThreads, o.general.seed);
        Model model = trainer.train(train);
        logger.info("Finished training model");
        ClusteringEvaluation evaluation = (ClusteringEvaluation)factory.getEvaluator().evaluate(model, train);
        logger.info("Finished evaluating model");
        System.out.println("Normalized MI = " + evaluation.normalizedMI());
        System.out.println("Adjusted MI = " + evaluation.adjustedMI());
        if (o.general.outputPath != null) {
            o.general.saveModel(model);
        }
    }

    public static class KMeansOptions
    implements Options {
        public DataOptions general;
        @Option(charName=110, longName="num-clusters", usage="Number of clusters to infer.")
        public int centroids = 5;
        @Option(charName=105, longName="iterations", usage="Maximum number of iterations.")
        public int iterations = 10;
        @Option(charName=100, longName="distance", usage="Distance function to use in the e step.")
        public KMeansTrainer.Distance distance = KMeansTrainer.Distance.EUCLIDEAN;
        @Option(charName=115, longName="initialisation", usage="Type of initialisation to use for centroids.")
        public KMeansTrainer.Initialisation initialisation = KMeansTrainer.Initialisation.RANDOM;
        @Option(charName=116, longName="num-threads", usage="Number of threads to use (range (1, num hw threads)).")
        public int numThreads = 4;

        public String getOptionsDescription() {
            return "Trains and evaluates a K-Means model on the specified dataset.";
        }
    }
}

