package handwriting.learning;

import handwriting.editor.Drawing;
import handwriting.editor.SampleData;

import java.util.ArrayList;

abstract public class RecognizerAI {
	private SampleData trainingData;
	
	protected RecognizerAI(SampleData samples) {
		trainingData = samples;
	}
	
	public SampleData getSamples() {return trainingData;}
	
	abstract public String classify(Drawing d);
	
	abstract public void trainOnce(double learningRate);
	
	public int trainUntil(double learningRate, int max) {
		int n = 0;
		while (n < max && !allTestsCorrect(trainingData)) {
			trainOnce(learningRate);
			++n;
		}
		return n;
	}
	
	public boolean allTestsCorrect(SampleData testData) {
		return numCorrectTests(testData) == testData.numDrawings();
	}
	
	// Returns the number of samples from testData that were correctly classified by the network.
	public int numCorrectTests(SampleData testData) {
		int passed = 0;
		for (String label: testData.allLabels()) {
			for (int j = 0; j < testData.numDrawingsFor(label); ++j) {
				if (classify(testData.getDrawing(label, j)).equals(label)) {
					passed += 1;
				}
			}
		}
		return passed;
	}
	
	public static double stdDev(ArrayList<Integer> data) {
		double mean = RecognizerAI.mean(data);
		double ssd = 0.0;
		for (int d: data) {
			double diff = mean - d;
			ssd += diff * diff;
		}
		double variance = ssd / data.size();
		return Math.sqrt(variance);
	}

	public static double mean(ArrayList<Integer> data) {
		double sum = 0.0;
		for (int d: data) {sum += d;}
		return sum / data.size();
	}
	
}
