import itertools
import copy
import sys

import Mixins

# Objects of the Problem class contain all pertinent information about 
# a possible solution to a planning problem.  These objects should be stored
# in your search nodes.
class Problem(Mixins.Equality):
    def __init__(self, content, name, domain):
        self.name = name
        self.domain = domain
        while not isinstance(content, Null):
            clause = content.car
            tag = clause.car
            if tag == ':domain':
                self.domain_name = clause.cdr.car
            elif tag == ':objects':
                self.objects = cons2list(clause.cdr)
            elif tag == ':init':
                self.start_state = State(clause.cdr)
            elif tag == ':goal':
                self.goal_state = Goal(predify(clause.cdr.car.cdr))
            content = content.cdr
            
    def to_cons(self):
        goal = Cons(':goal', self.goal_state.to_cons())
        lisp = Cons(goal, Null())
        init = Cons(':init', self.start_state.to_cons())
        lisp = Cons(init, lisp)
        objs = Cons(':objects', list2cons(self.objects))
        lisp = Cons(objs, lisp)
        domain = Cons(':domain', Cons(self.domain_name, Null()))
        lisp = Cons(domain, lisp)
        problem = Cons('problem', Cons(self.name, Null()))
        return Cons('define', Cons(problem, lisp))

    def __repr__(self):
        return str(self.to_cons())

    def all_valid_actions(self):
        acts = []
        for act in self.domain.actions:
            acts.extend(act.instantiations_in(self.start_state, self.objects))
        return acts

    def successor(self, action):
        succ = copy.copy(self)
        succ.start_state = self.start_state.successor(action)
        return succ

    def goal_satisfied(self):
        return self.goal_state.satisfied_in(self.start_state)

    def __hash__(self):
        return hash(self.start_state)
        
# Objects of the State class will be useful, and should be accessed through
# the start_state instance variable of the Problem class.
class State(Mixins.Equality):
    def __init__(self, pred_list):
        self.preds = frozenset([p for p in predify(pred_list)])

    def to_cons(self):
        return list2cons(list(self.preds))

    def __repr__(self):
        return str(self.to_cons())

    def is_true(self, pred):
        for p in self.preds:
            if pred == p:
                return True
        return False

    def successor(self, action):
        assert action.applicable_in(self)
        succ = copy.copy(self)
        temp = frozenset([p for p in self.preds if not p in action.neg_post])
        succ.preds = temp.union(action.pos_post)
        return succ

    def __hash__(self):
        return hash(self.preds)

# Objects of the Goal class should be accessed through the goal_state 
# instance variable of the Problem class.
class Goal(Mixins.Equality):
    def __init__(self, pred_list):
        self.goals = pred_list

    def to_cons(self):
        if len(self.goals) == 1:
            return goals[0]
        else:
            return Cons(Cons('and', list2cons(self.goals)), Null())

    def __repr__(self):
        return str(self.to_cons())

    def satisfied_in(self, state):
        for goal in self.goals:
            if not state.is_true(goal):
                return False
        return True

# Contains information about the domain in general, such as available
# actions and predicates.  Each Problem must have a reference to a Domain
# object.
class Domain(Mixins.Equality):
    def __init__(self, content, name):
        self.name = name
        self.predicates = []
        self.actions = []
        while not isinstance(content, Null):
            clause = content.car
            tag = clause.car
            if tag == ':predicates':
                self.predicates = predify(clause.cdr)
            elif tag == ':action':
                self.actions.append(Action(clause))
            content = content.cdr

    def to_cons(self):
        lisp = Null()
        for act in reversed(self.actions):
            lisp = Cons(act.to_cons(), lisp)
        pred = Cons(':predicates', list2cons(self.predicates))
        lisp = Cons(list2cons([':requirements', ':strips']), Cons(pred, lisp))
        domain = Cons('domain', Cons(self.name, Null()))
        return Cons('define', Cons(domain, lisp))

    def __repr__(self):
        return str(self.to_cons())
    
# Represents a fact that may be either true or false.
class Predicate(Mixins.Equality):
    def __init__(self, lisp):
        self.name = lisp.car
        self.args = cons2list(lisp.cdr)

    def substituted_copy(self, mapping):
        result = copy.copy(self)
        result.args = [mapping[arg] for arg in self.args]
        return result

    def to_cons(self):
        return Cons(self.name, list2cons(self.args))

    def __repr__(self):
        return str(self.to_cons())

    def __hash__(self):
        return hash(str(self))
    
# Action lists can be very verbose.  These functions make
# a plan easier to read.
def concise_action_list(action_list):
    return [act.short_form() for act in action_list]     

def concise_action_str(action_list):
    return str(concise_action_list(action_list))

# Given preconditions and postconditions, each Action describes how a State
# can be modified.  Each plan is a list of Actions.
class Action(Mixins.Equality):
    def __init__(self, lisp):
        lisp = lisp.cdr
        self.name = lisp.car
        self.pos_pre = []
        self.neg_pre = []
        self.pos_post = []
        self.neg_post = []
        
        lisp = lisp.cdr
        while not isinstance(lisp, Null):
            tag = lisp.car
            content = lisp.cdr.car
            
            if tag == ':parameters':
                self.params = cons2list(content)
            elif tag == ':precondition':
                self.pos_pre, self.neg_pre = split_conjoined(content)
            elif tag == ':effect':
                self.pos_post, self.neg_post = split_conjoined(content)

            lisp = lisp.cdr.cdr
            
    def __hash__(self):
        return hash(self.name) + sum([hash(p) for p in self.pos_pre]) + sum([hash(p) for p in self.neg_pre])+ sum([hash(p) for p in self.pos_post])+ sum([hash(p) for p in self.neg_post])

    def substituted_copy(self, obj_list):
        assert len(obj_list) == len(self.params)
        mapping = dict(zip(self.params, obj_list))
        result = copy.copy(self)
        result.params = [mapping[param] for param in self.params]
        result.pos_pre = [p.substituted_copy(mapping) for p in self.pos_pre]
        result.neg_pre = [p.substituted_copy(mapping) for p in self.neg_pre]
        result.pos_post = [p.substituted_copy(mapping) for p in self.pos_post]
        result.neg_post = [p.substituted_copy(mapping) for p in self.neg_post]
        return result

    def applicable_in(self, state):
        for p in self.pos_pre:
            if not state.is_true(p):
                return False

        for np in self.neg_pre:
            if state.is_true(np):
                return False

        return True

    def instantiations_in(self, state, objs):
        versions = []
        for obj_perm in itertools.permutations(objs, len(self.params)):
            version = self.substituted_copy(obj_perm)
            if version.applicable_in(state):
                versions.append(version)
        return versions
    
    def short_form(self):
        return Cons(self.name, list2cons(self.params))

    def to_cons(self):
        post = consify_cond(self.pos_post, self.neg_post)
        pre = consify_cond(self.pos_pre, self.neg_pre)
        params = list2cons(self.params)
        return Cons(':action',
                    Cons(self.name,
                          Cons(':parameters',
                                 Cons(params,
                                         Cons(':precondition',
                                                  Cons(pre,
                                                            Cons(':effect', Cons(post, Null()))))))))

    def __repr__(self):
        return str(self.to_cons())

# Some useful parsing and file I/O operations

def file2str(filename):
    with open(filename) as f:
        lines = f.readlines()
        s = ''
        for line in lines:
            line = line.strip()
            if len(line) > 0 and line[0] != ';':
                s = s + line + ' '
        return s

def file2cons(filename):
    return parse_list(file2str(filename))

def parse_domain(pddl):
    name, content = parse_pddl_header(pddl)
    assert deftype(pddl) == 'domain'
    return Domain(content, name)

def parse_problem(pddl, domain):
    name, content = parse_pddl_header(pddl)
    assert deftype(pddl) == 'problem'
    return Problem(content, name, domain)
    
# From here on down is mostly boring parsing stuff.

def consify_cond(pos, neg):
    assert len(pos) + len(neg) >= 1
    if len(pos) + len(neg) >= 2:
        result = add_neg_preds(neg, Null())
        for p in reversed(pos):
            result = Cons(p, result)
        return Cons('and', result)
    elif len(pos) == 1:
        return pos[0]
    else:
        return not_ify(neg[0])
    
def add_neg_preds(pred_list, suffix):
    result = suffix
    for neg in reversed(pred_list):
        neg_entry = not_ify(neg)
        result = Cons(neg_entry, result)
    return result

def not_ify(pred):
    return Cons('not', Cons(pred.to_cons(), Null()))
            
def split_conjoined(lisp):
    if lisp.car == 'and':
        pos = []
        neg = []
        lisp = lisp.cdr
        while not isinstance(lisp, Null):
            pred = lisp.car
            if pred.car == 'not':
                neg.append(Predicate(pred.cdr.car))
            else:
                pos.append(Predicate(pred))
            lisp = lisp.cdr
        return pos, neg
    elif lisp.car == 'not':
        return ([], [Predicate(lisp.cdr.car)])
    else:
        return ([Predicate(lisp)], [])

def predify(cons):
    result = []
    while not isinstance(cons, Null):
        result.append(Predicate(cons.car))
        cons = cons.cdr
    return result
    
class Cons(Mixins.Equality):
    def __init__(self, car, cdr):
        self.car = car
        self.cdr = cdr

    def __repr__(self):
        return '(' + self.list_str()

    def list_str(self):
        prefix = str(self.car) + ' '
        if isinstance(self.cdr, Cons):
            return prefix + self.cdr.list_str()
        else:
            return prefix + '. ' + str(self.cdr)

def cons2list(cons):
    result = []
    while not isinstance(cons, Null):
        result.append(cons.car)
        cons = cons.cdr
    return result

def list2cons(l):
    cons = Null()
    for i in reversed(l):
        if hasattr(i, 'to_cons'):
            cons = Cons(i.to_cons(), cons)
        else:
            cons = Cons(i, cons)
    return cons
        
class Null(Cons):
    def __init__(self):
        Cons.__init__(self, None, None)

    def list_str(self):
        return ')'

def find_match(s, start = 0):
    assert s[start] == '('
    lpar = 1
    i = start + 1
    while lpar > 0 and i < len(s):
        if s[i] == '(':
            lpar += 1
        elif s[i] == ')':
            lpar -= 1
        i += 1
    if lpar > 0:
        return -1
    else:
        return i

def find_space(s, start = 0):
    while start < len(s) and not (s[start] in ' \t\n'):
        start += 1
    return start

def is_list(s):
    return starts_list(s) and s[-1] == ')'

def starts_list(s):
    return s[0] == '('

def parse_list(s):
    s = s.strip()
    assert is_list(s)
    return parse_list_contents(s[1:-1])

def parse_list_contents(s):
    s = s.strip()
    s = s.lower()
    if len(s) == 0:
        return Null()
    elif starts_list(s):
        j = find_match(s)
        if not j > 0:
            sys.exit("Mismatched parentheses in: '" + s + "'")
        return Cons(parse_list(s[:j]), parse_list_contents(s[j:]))
    else:
        j = find_space(s)
        return Cons(s[:j], parse_list_contents(s[j:]))

def parse_pddl_header(pddl):
    assert isinstance(pddl, Cons)
    assert pddl.car == 'define'
    content = pddl.cdr.cdr
    name = name_of(pddl)
    return (name, content)

def deftype(pddl):
    header = pddl.cdr.car
    return header.car

def name_of(pddl):
    header = pddl.cdr.car
    return header.cdr.car



