from javax.swing import JPanel, JFrame, BorderFactory, JButton
from java.awt import BorderLayout, FlowLayout
import java.io.File
import javax.imageio.ImageIO
import java.awt.Color

import copy

def divBy8(value):    
    if type(value) == type((1,1)):
        return (divBy8(value[0]), divBy8(value[1]))
    return value + (8 - (value  % 8)) % 8

def a2p(a):
    return (ord(a[0]) - ord('a'), 7 - (ord(a[1]) - ord('1')))

def p2a(p):
    return (chr(p[0] + ord('a')) + chr(7 - p[1] + ord('1')))

def readImage(filename):
    return javax.imageio.ImageIO.read(java.io.File(filename))

images = dict([(f, readImage(f)) for f in "BlackKing.png", "BlackQueen.png", "BlackRook.png", "BlackKnight.png", "BlackBishop.png", "BlackPawn.png", "WhiteKing.png", "WhiteQueen.png", "WhiteRook.png", "WhiteKnight.png", "WhiteBishop.png", "WhitePawn.png"])

class Piece(object):
    def __init__(self, image):
        self.image = image
        
    def __deepcopy__(self, memo):
        return self
        
    def draw(self, graphics, position, size):
        x = position[0] * size[0]
        y = position[1] * size[1]
        graphics.drawImage(self.image, x, y, size[0], size[1], None)
        
    def moveSuffix(self, board, end):
        if board.isOccupied(end):
            end = "x" + end
        return end
    
    def movePiece(self, board, start, end):
        board.turn[end] = board.turn[start]
        if end in board.other: del board.other[end]
        del board.turn[start]

    
class NonPawn(Piece):
    def __init__(self, image, symbol):
        Piece.__init__(self, image)
        self.symbol = symbol
        
    def moveStr(self, board, start, end):
        return self.symbol + self.moveSuffix(board, end)
    
class Pawn(Piece):
    def moveStr(self, board, start, end):
        suffix = self.moveSuffix(board, end)
        if suffix[0] == 'x': suffix = start[0] + suffix
        return suffix
        
def BlackKing():
    return NonPawn(images["BlackKing.png"], "K")

def WhiteKing():
    return NonPawn(images["WhiteKing.png"], "K")

def BlackQueen():
    return NonPawn(images["BlackQueen.png"], "Q")

def WhiteQueen():
    return NonPawn(images["WhiteQueen.png"], "Q")

def BlackRook():
    return NonPawn(images["BlackRook.png"], "R")

def WhiteRook():
    return NonPawn(images["WhiteRook.png"], "R")

def BlackBishop():
    return NonPawn(images["BlackBishop.png"], "B")

def WhiteBishop():
    return NonPawn(images["WhiteBishop.png"], "B")

def BlackKnight():
    return NonPawn(images["BlackKnight.png"], "N")

def WhiteKnight():
    return NonPawn(images["WhiteKnight.png"], "N")

def BlackPawn():
    return Pawn(images["BlackPawn.png"])

def WhitePawn():
    return Pawn(images["WhitePawn.png"])

def squareColor(pos):
    if type(pos) != type(""):
        pos = p2a(pos)  
    col = ord(pos[0]) - ord('a')
    rank = ord(pos[1]) - ord('1')
    if (rank + col) % 2 == 0:
        return java.awt.Color.gray
    else:
        return java.awt.Color.white
    
def findOtherKnight(target, squares, side):
    return [sq for sq in squares if sq != target and sq in side and type(side[sq]) == Knight]
        
class Chessboard(JPanel):
    def __init__(self, size):
        self.setBorder(BorderFactory.createRaisedBevelBorder())
        self.setSize(size[0], size[1])
        self.states = [BoardState()]
        self.currentState = 0
        self.size = divBy8(size)
        self.squareSize = (size[0] // 8, size[1] // 8)
        self.squares = [(x, y) for x in range(8) for y in range(8)]
        
        self.prevKey = None
        self.mousePressed = self.storeClick
        
        self.printing = True
        
    def currentBoard(self):
        return self.states[self.currentState]
    
    def isLegalSquare(self, square):
        return len(square) == 2 and ord('a') <= ord(square[0]) <= ord('h') and ord('1') <= ord(square[1]) <= ord('8')
    
    # TODO: Returns true for all legal moves, but also returns true for some illegal moves
    def isLegalMove(self, start, end):
        return self.isLegalSquare(start) and self.isLegalSquare(end) and self.currentBoard().isLegalMove(start, end)
    
    def getMoves(self):
        return self.currentBoard().moves
    
    def log(self, info):
        if self.printing: print info
    
    def storeClick(self, e):
        self.log("Clicked at:" + str((e.point.x, e.point.y)))
        click = self.decodeClick(e.point)
        clickKey = p2a(click)
        if self.prevKey:
            self.movePiece(clickKey)
        elif self.currentBoard().hasMovablePiece(clickKey):
            self.prevKey = clickKey
        self.log(self.currentBoard().moves)
        self.repaint()
        
    def movePiece(self, clickKey):
        if self.isLegalMove(self.prevKey, clickKey):
            self.states = (self.states[:self.currentState+1])
            self.states.append(self.currentBoard().successor(self.prevKey, clickKey))
            self.currentState += 1            
        self.prevKey = None
        
    def decodeClick(self, point):
        return point.x // self.squareSize[0], point.y // self.squareSize[1]
        
    def paintComponent(self, comp):
        for (x, y) in self.squares:
            if p2a((x, y)) == self.prevKey:
                comp.setColor(java.awt.Color.yellow)
            else:
                comp.setColor(squareColor((x, y)))
            comp.fillRect(x * self.squareSize[0], y * self.squareSize[1], self.squareSize[0], self.squareSize[1])
        self.currentBoard().draw(comp, self.squareSize)
        
    def toPrevMove(self, e):
        self.currentState = max(0, self.currentState - 1)
        self.repaint()
        
    def toNextMove(self, e):
        self.currentState = min(len(self.states) - 1, self.currentState + 1)
        self.repaint()
    
class BoardState(object):
    def __init__(self):
        self.white = {"a1":WhiteRook(), "b1":WhiteKnight(), "c1":WhiteBishop(), "d1":WhiteQueen(), 
                      "e1":WhiteKing(), "f1":WhiteBishop(), "g1":WhiteKnight(), "h1":WhiteRook(),
                      "a2":WhitePawn(), "b2":WhitePawn(),   "c2":WhitePawn(),   "d2":WhitePawn(),
                      "e2":WhitePawn(), "f2":WhitePawn(),   "g2":WhitePawn(),   "h2":WhitePawn()}
        self.black = {"a8":BlackRook(), "b8":BlackKnight(), "c8":BlackBishop(), "d8":BlackQueen(),
                      "e8":BlackKing(), "f8":BlackBishop(), "g8":BlackKnight(), "h8":BlackRook(),
                      "a7":BlackPawn(), "b7":BlackPawn(),   "c7":BlackPawn(),   "d7":BlackPawn(),
                      "e7":BlackPawn(), "f7":BlackPawn(),   "g7":BlackPawn(),   "h7":BlackPawn()}
        self.moves = []
        
        self.turn = self.white
        self.other = self.black
        
    def isOccupied(self, pos):
        return pos in self.turn or pos in self.other
        
    def hasMovablePiece(self, pos):
        return pos in self.turn 

    def isLegalMove(self, start, end):
        return start in self.turn and end not in self.turn
        
    def draw(self, surface, squareSize):
        self.paintSide(self.white, surface, squareSize)
        self.paintSide(self.black, surface, squareSize)
            
    def paintSide(self, side, surface, squareSize):
        for k in side:
            side[k].draw(surface, a2p(k), squareSize)
            
    def movePiece(self, start, end):
        if self.isLegalMove(start, end):
            movingPiece = self.turn[start]
            self.moves.append(movingPiece.moveStr(self, start, end))
            movingPiece.movePiece(self, start, end)
            self.switchSides()
        
    def switchSides(self):
        if self.turn == self.white:
            self.turn = self.black
            self.other = self.white
        else:
            self.turn = self.white
            self.other = self.black
            
    def successor(self, start, end):
        result = copy.deepcopy(self)
        result.movePiece(start, end)
        return result
    
class TestPoint(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
class TestClick(object):
    def __init__(self, x, y):
        self.point = TestPoint(x, y)
    
if __name__ == '__main__':
    frame = JFrame("View Chessboard", defaultCloseOperation=JFrame.EXIT_ON_CLOSE, size=(600, 600))
    board = Chessboard((450, 450))
    frame.contentPane.setLayout(BorderLayout())
    frame.contentPane.add(board, BorderLayout.CENTER)
    
    controls = JPanel()
    controls.setLayout(FlowLayout())
    prev = JButton("Last move", actionPerformed=board.toPrevMove)
    next = JButton("Next move", actionPerformed=board.toNextMove)
    controls.add(prev)
    controls.add(next)
    frame.contentPane.add(controls, BorderLayout.NORTH)
    
    frame.visible = True
    
