from dataclasses import MISSING
from .sexp import *


SEXP_END = type('SEXP_END', (), {})


class AtomChoice:
    def __init__(self, *choices):
        self.choices = choices

    def __contains__(self, value):
        return value in self.choices

    def __atoms__(self):
        return self.choices

    def __map__(self, obj, parent=None):
        obj, = obj
        if obj not in self:
            raise TypeError(f'Invalid atom {obj} for {type(self)}, valid choices are: {", ".join(map(str, self.choices))}')
        return obj

    def __sexp__(self, value):
        yield value


class Flag:
    def __init__(self, atom=None, invert=None):
        self.atom, self.invert = atom, invert

    def __bind_field__(self, field):
        if self.atom is None:
            self.atom = Atom(field.name)
        if self.invert is None:
            self.invert = bool(field.default)

    def __atoms__(self):
        return [self.atom]

    def __map__(self, obj, parent=None):
        return not self.invert

    def __sexp__(self, value):
        if bool(value) == (not self.invert):
            yield self.atom


def sexp(t, v):
    if v is None:
        return []
    elif t in (int, float, str, Atom):
        return [t(v)]
    elif hasattr(t, '__sexp__'):
        return list(t.__sexp__(v))
    elif isinstance(t, list):
        t, = t
        return [sexp(t, elem) for elem in v]
    else:
        raise TypeError(f'Python type {t} has no defined s-expression serialization')


def map_sexp(t, v, parent=None):
    if t is not Atom and hasattr(t, '__map__'):
        return t.__map__(v, parent=parent)
    elif t in (int, float, str, Atom):
        v, = v
        if not isinstance(v, t):
            types = set({type(v), t})
            if types == {int, float} or types == {str, Atom}:
                v = t(v)
            else:
                raise TypeError(f'Cannot map s-expression value {v} of type {type(v)} to Python type {t}')
        return v
    elif isinstance(t, list):
        t, = t
        return [map_sexp(t, elem, parent=parent) for elem in v]
    else:
        raise TypeError(f'Python type {t} has no defined s-expression deserialization')


class WrapperType:
    def __init__(self, next_type):
        self.next_type = next_type

    def __bind_field__(self, field):
        self.field = field
        getattr(self.next_type, '__bind_field__', lambda x: None)(field)

    def __atoms__(self):
        if hasattr(self, 'name_atom'):
            return [self.name_atom]
        elif self.next_type is Atom:
            return []
        else:
            return getattr(self.next_type, '__atoms__', lambda: [])()

class Named(WrapperType):
    def __init__(self, next_type, name=None):
        super().__init__(next_type)
        self.name_atom = Atom(name) if name else None

    def __bind_field__(self, field):
        if self.next_type is not Atom:
            getattr(self.next_type, '__bind_field__', lambda x: None)(field)
        if self.name_atom is None:
            self.name_atom = Atom(field.name)

    def __map__(self, obj, parent=None):
        k, *obj = obj
        if self.next_type in (int, float, str, Atom) or isinstance(self.next_type, AtomChoice):
            return map_sexp(self.next_type, [*obj], parent=parent)
        else:
            return map_sexp(self.next_type, obj, parent=parent)

    def __sexp__(self, value):
        value = sexp(self.next_type, value)
        if value is not None:
            yield [self.name_atom, *value]


class Rename(WrapperType):
    def __init__(self, next_type, name=None):
        super().__init__(next_type)
        self.name_atom = Atom(name) if name else None

    def __bind_field__(self, field):
        if self.name_atom is None:
            self.name_atom = Atom(field.name)

    def __map__(self, obj, parent=None):
        return map_sexp(self.next_type, obj, parent=parent)

    def __sexp__(self, value):
        value, = sexp(self.next_type, value)
        if self.next_type in (str, float, int, Atom): 
            yield [self.name_atom, *value]
        else:
            key, *rest = value
            yield [self.name_atom, *rest]


class OmitDefault(WrapperType):
    def __bind_field__(self, field):
        getattr(self.next_type, '__bind_field__', lambda x: None)(field)
        if field.default_factory != MISSING:
            self.default = field.default_factory()
        else:
            self.default = field.default

    def __map__(self, obj, parent=None):
        return map_sexp(self.next_type, obj, parent=parent)

    def __sexp__(self, value):
        if value != self.default:
            yield from sexp(self.next_type, value)


class YesNoAtom:
    def __init__(self, yes=Atom.yes, no=Atom.no):
        self.yes, self.no = yes, no

    def __map__(self, value, parent=None):
        value, = value
        return value == self.yes

    def __sexp__(self, value):
        yield self.yes if value else self.no


class Wrap(WrapperType):
    def __map__(self, value, parent=None):
        value, = value
        return map_sexp(self.next_type, value, parent=parent)

    def __sexp__(self, value):
        for inner in sexp(self.next_type, value):
            yield [inner]


class Array(WrapperType):
    def __map__(self, value, parent=None):
        return [map_sexp(self.next_type, [elem], parent=parent) for elem in value]
    
    def __sexp__(self, value):
        for e in value:
            yield from sexp(self.next_type, e)


class List(WrapperType):
    def __bind_field__(self, field):
        self.attr = field.name

    def __map__(self, value, parent):
        l = getattr(parent, self.attr, [])
        mapped = map_sexp(self.next_type, value, parent=parent)
        l.append(mapped)
        setattr(parent, self.attr, l)

    def __sexp__(self, value):
        for elem in value:
            yield from sexp(self.next_type, elem)


class _SexpTemplate:
    @staticmethod
    def __atoms__(kls):
        return [kls.name_atom]

    @staticmethod
    def __map__(kls, value, parent=None):
        positional = iter(kls.positional)
        inst = kls()

        for v in value[1:]: # skip key
            if isinstance(v, Atom) and v in kls.keys:
                name, etype = kls.keys[v]
                mapped = map_sexp(etype, [v], parent=inst)
                if mapped is not None:
                    setattr(inst, name, mapped)

            elif isinstance(v, list):
                name, etype = kls.keys[v[0]]
                mapped = map_sexp(etype, v, parent=inst)
                if mapped is not None:
                    setattr(inst, name, mapped)

            else:
                try:
                    pos_key = next(positional)
                    setattr(inst, pos_key.name, v)
                except StopIteration:
                    raise TypeError(f'Unhandled positional argument {v!r} while parsing {kls}')

        getattr(inst, '__after_parse__', lambda x: None)(parent)
        return inst

    @staticmethod
    def __sexp__(kls, value):
        getattr(value, '__before_sexp__', lambda: None)()

        out = [kls.name_atom]
        for f in fields(kls):
            if f.type is SEXP_END:
                break
            out += sexp(f.type, getattr(value, f.name))
        yield out

    @staticmethod
    def parse(kls, data):
        return kls.__map__(parse_sexp(data))

    @staticmethod
    def sexp(self):
        return next(self.__sexp__(self))


def sexp_type(name=None):
    def register(cls):
        cls = dataclass(cls)
        cls.name_atom = Atom(name) if name is not None else None
        for key in '__sexp__', '__map__', '__atoms__', 'parse':
            if not hasattr(cls, key):
                setattr(cls, key, classmethod(getattr(_SexpTemplate, key)))

        if not hasattr(cls, 'sexp'):
            setattr(cls, 'sexp', getattr(_SexpTemplate, 'sexp'))
        cls.positional = []
        cls.keys = {}
        for f in fields(cls):
            f_type = f.type
            if f_type is SEXP_END:
                break

            if hasattr(f_type, '__bind_field__'):
                f_type.__bind_field__(f)

            atoms = getattr(f_type, '__atoms__', lambda: [])
            atoms = list(atoms())
            for atom in atoms:
                cls.keys[atom] = (f.name, f_type)
            if not atoms:
                cls.positional.append(f)

        return cls
    return register