package sorters;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;

public class TimeTest {

	public static void main(String[] args) {
		int size = args.length == 0 ? 10000 : Integer.parseInt(args[0]);
		int reps = args.length < 2 ? 1000 : Integer.parseInt(args[1]);
		HashMap<String,ArrayList<Long>> totals = makeTotals();
		
		for (int i = 0; i < reps; i++) {
			ArrayList<Integer> target = makeTarget(size);
			for (Sorter sorter: sorters) {
				runTest(SorterTest.copy(target), (Sorter<Integer>)sorter, totals);
			}
		}
		
		for (String sorter: totals.keySet()) {
			printResultFor(sorter, totals);	
		}
	}
	
	public static long durationFor(Sorter<Integer> sorter, ArrayList<Integer> target) {
		long start = System.currentTimeMillis();
		sorter.sort(target);
		return System.currentTimeMillis() - start;
	}
	
	public static void displayDuration(String sorterName, long nano) {
		System.out.println(String.format("%s: %8.2f\n", sorterName, nano / Math.pow(10.0, 9)));
	}
	
	public static ArrayList<Integer> makeTarget(int size) {
		ArrayList<Integer> target = new ArrayList<>();
		for (int i = 0; i < size; i++) {
			target.add(i);
		}
		Collections.shuffle(target);
		return target;
	}
	
	public static HashMap<String,ArrayList<Long>> makeTotals() {
		HashMap<String,ArrayList<Long>> totals = new HashMap<>();
		for (Sorter sorter: sorters) {
			totals.put(sorter.getClass().getName(), new ArrayList<>());
		}
		return totals;
	}

	static Sorter[] sorters = new Sorter[]{new HeapSorter(), new MergeSorter(), new QuickSorter()};
	
	public static void runTest(ArrayList<Integer> target, Sorter<Integer> sorter, HashMap<String,ArrayList<Long>> totals) {
		long duration = durationFor(sorter, target);
		totals.get(sorter.getClass().getName()).add(duration);
	}
	
	public static void printResultFor(String sorter, HashMap<String,ArrayList<Long>> totals) {
		ArrayList<Long> durations = totals.get(sorter);
		long total = durations.stream().reduce(0L, (x, y) -> x + y);
		double trials = durations.size();
		double mean = total / trials;
		double ssd = durations.stream().mapToDouble(x -> Math.pow(x - mean, 2.0)).sum();
		double stdDev = Math.sqrt(ssd / trials);
		double interval = 2 * stdDev / Math.sqrt(trials);
		double lo = mean - interval;
		double hi = mean + interval;
		System.out.printf("%s:\t[%6.2f, %6.2f] ms\n", sorter, lo, hi);		
	}
}
