import random
import copy

import Mixins

class Maze(object):
    def __init__(self, info, treasures = 0, perfection = 1.0):
        if isinstance(info, str):
            self.__maze_from_str__(info)
        else:
            self.__randomized_maze__(info, treasures, perfection)

    def __maze_from_str__(self, info):
        lines = info.strip().split('\n')
        self.height = len(lines)
        self.width = len(lines[0])
        self.cells = {}
        self.treasures = []
        for row in range(len(lines)):
            for col in range(len(lines[row])):
                self.cells[(col,row)] = (lines[row][col] != '*')
                if lines[row][col] == 'T':
                    self.treasures.append((col,row))
                
        for row in range(1, self.height - 1):
            if self.cells[(0,row)]:
                self.entry = (0, row)
            if self.cells[(self.width - 1, row)]:
                self.exit = (self.width - 1, row)

    def __randomized_maze__(self, info, treasures, perfection):
        self.perfection = perfection
        self.width, self.height = info
        self.cells = dict([((x,y),False) for x in range(self.width) for y in range(self.height)])
        self.entry = (0, random.randint(1, self.height - 2))
        self.exit = None
        self.cells[self.entry] = True
        self.__tunnel__()
        self.__place_treasures__(treasures)

    def __repr__(self):
        result = ''
        for row in range(self.height):
            for col in range(self.width):
                if (col,row) in self.treasures:
                    result += 'T'
                elif self.is_passage((col,row)):
                    result += '.'
                else:
                    result += '*'
            result += '\n'
        return result

    def is_passage(self, cell):
        return self.within_maze(cell) and self.cells[cell]

    def has_treasure(self, cell):
        return self.is_passage(cell) and cell in self.treasures

    def all_passages(self):
        return [(col,row) for row in range(self.height) for col in range(self.width) if self.is_passage((col,row))]

    def within_maze(self, cell):
        (x, y) = cell
        return x >= 0 and x < self.width and y >= 0 and y < self.height

    def neighbors_of(self, cell):
        return [(cell[0], cell[1]+1), (cell[0], cell[1]-1), (cell[0]+1, cell[1]), (cell[0]-1, cell[1])]

    def blocked_neighbor_count(self, cell):
        return len([n for n in self.neighbors_of(cell) if not self.is_passage(n)])

    def open_neighbors(self, cell):
        return [n for n in self.neighbors_of(cell) if self.is_passage(n)]

    def is_edge(self, cell):
        (x, y) = cell
        return x == 0 or y == 0 or x == self.width - 1 or y == self.height - 1

    def is_valid_path(self, path):
        current = self.entry
        for move in path:
            if move in self.open_neighbors(current):
                current = move
            else:
                return False
        for t in self.treasures:
            if not t in path:
                return False
        return True

    def __tunnel__(self):
        stack = [(1, self.entry[1])]
        while len(stack) > 0:
            cell = stack.pop()
            if self.__can_tunnel__(cell):
                self.cells[cell] = True
                if self.__carveable_exit__(cell):
                    self.exit = cell
                elif not self.is_edge(cell):
                    stack.extend(self.__random_neighbors__(cell))
                else:
                    self.cells[cell] = False                    

    def __carveable_exit__(self, cell):
        return cell[0] == self.width - 1 and self.exit == None

    def __can_tunnel__(self, cell):
        return not self.is_passage(cell) and (self.blocked_neighbor_count(cell) >= 3 or random.random() > self.perfection)

    def __random_neighbors__(self, cell):
        neigh = self.neighbors_of(cell)
        random.shuffle(neigh)
        return neigh

    def __place_treasures__(self, num):
        self.treasures = []
        candidates = self.all_passages()
        candidates.remove(self.exit)
        candidates.remove(self.entry)
        assert num <= len(candidates)
        for i in range(num):
            t = random.randint(0, len(candidates) - 1)
            self.treasures.append(candidates.pop(t))

class State(Mixins.Equality):
    def __init__(self, maze):
        self.maze = maze
        self.start_state = maze.entry
        self.treasures = frozenset()

    def __repr__(self):
        return str(self.start_state)

    def successor(self, destination):
        assert self.maze.is_passage(destination)
        assert destination in self.maze.neighbors_of(self.start_state)
        result = copy.copy(self)
        result.start_state = destination
        if self.maze.has_treasure(destination) and not destination in self.treasures:
            t_list = list(self.treasures)
            t_list.append(destination)
            result.treasures = frozenset(t_list)
        return result

    def goal_satisfied(self):
        return self.start_state == self.maze.exit and len(self.treasures) == len(self.maze.treasures)

    def all_valid_actions(self):
        return self.maze.open_neighbors(self.start_state)

    def __hash__(self):
        return hash(self.start_state) + hash(self.treasures)


