package handwriting.neural;

public class NeuralNet implements Supervised {
    private Perceptron inputToHidden, hiddenToOutput;
    
    // Invariant: 
    //   inputToHidden.numOutputNodes() == hiddenToOutput.numInputNodes()
    
    public int numInputNodes() {return inputToHidden.numInputNodes();}
    public int numHiddenNodes() {return inputToHidden.numOutputNodes();}
    public int numOutputNodes() {return hiddenToOutput.numOutputNodes();}
    
    public Perceptron getHiddenLayer() {return inputToHidden;}
    public Perceptron getOutputLayer() {return hiddenToOutput;}
    
    public NeuralNet(int numIn, int numHid, int numOut) {
        this(new Perceptron(numIn, numHid), new Perceptron(numHid, numOut));
    }
    
    // Pre: hidden.numOutputNodes() == output.numInputNodes()
    public NeuralNet(Perceptron hidden, Perceptron output) {
        inputToHidden = hidden;
        hiddenToOutput = output;
    }
    
    // Pre: inputs.length = numInputNodes()
    // Post: Returns value of output nodes
    public double[] compute(double[] inputs) {
        return hiddenToOutput.compute(inputToHidden.compute(inputs));
    }
    
    // Pre: train() has been called some number of times
    // Post: Weights are updated for one training cycle
    //       The incremental deltas are reset to zero for next cycle
    public void updateWeights() {
        inputToHidden.updateWeights();
        hiddenToOutput.updateWeights();
    }
    
    // Pre: getOutputLayer() has just had its weights and errors changed
    // Post: Backpropagates the weight changes and errors to getHiddenLayer()
    protected void backpropagate(double[] inputs, double rate) {
        /* Calculate the backpropagated error for each hidden/output node pair
           Then call inputToHidden.setError() and inputToHidden.addToWeightDeltas()
           to store the errors.
         */
    }

    // Pre: inputs.length = numInputNodes()
    //      targets.length = numOutputNodes()
    //      0 < rate <= 1.0
    // Post: Accumulates deltas for the given training pair following the
    //       backpropagation learning rule
    public void train(double[] inputs, double[] targets, double rate) {
        hiddenToOutput.train(getHiddenLayer().compute(inputs), targets, rate);
        backpropagate(inputs, rate);
    }
    
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Input to Hidden\n");
        sb.append(inputToHidden);
        sb.append("\nHidden to Output\n");
        sb.append(hiddenToOutput);
        return sb.toString();
    }
}
