r"""
This file is not free software yet, please use for testing purposes only.
Written by Rubén Muñoz-\-Bertrand (2025).
Joint work with Christophe Levrat.
Tested on SageMath 10.8.beta1 running on Debian 13.1, to run, use:
sage -python $THISFILENAME
"""

from itertools import islice
from sage.all import *
from sage.modules.free_module_element import FreeModuleElement
from sage.modules.module import Module
from sage.modules.with_basis.indexed_element import IndexedFreeModuleElement
from sage.modules.with_basis.morphism import ModuleMorphismByLinearity
from sage.rings.abc import IntegerModRing
from sage.rings.function_field.divisor import divisor, prime_divisor
from sage.rings.function_field.element import FunctionFieldElement
from sage.schemes.curves.projective_curve import (
    IntegralProjectiveCurve,
    IntegralProjectiveCurve_finite_field
)
from sage.schemes.generic.scheme import Scheme
from sage.sets.finite_enumerated_set import FiniteEnumeratedSet
from sage.structure.element import Matrix, ModuleElement
from sage.structure.indexed_generators import parse_indices_names

def jordan_chevalley_frobenius(M):
    if not M.is_square () :
        raise ValueError ("matrix must be square")
    FF = M.base_ring ()
    if not FF.characteristic ().is_prime () :
        raise ValueError ("characterisitic must be a prime number")
    frob = FF.frobenius_endomorphism ()
    s = M.rank ()
    while True :
        M *= M.apply_map (frob)
        r = s
        s = M.rank ()
        if r == s :
            break
    return M

def basis_of_kernel_of_linear_polynomial(v):
    R = v.base_ring()
    p = R.characteristic()
    if p == 0 :
        raise ValueError("characteristic must be positive")
    n = v.degree()
    A = PolynomialRing(R, names=('x',))
    x = A._first_ngens(1)[0]
    j = n
    pow = 1
    P = 0
    for i in range (n) :
        j -= 1
        P += v [j] * x**pow
        pow *= p
    V, phi_Vs, phi_sV = R.free_module()
    return [phi_Vs(k) for k in matrix([phi_sV(P(phi_Vs(V.basis()[i])))
            for i in range(V.rank())]).kernel().basis()]

def diagonalise_frobenius_linear_companion_matrix (M, check = True) :
    R = M.base_ring ()
    n = M.nrows () - 1
    if check :
        if not R.characteristic ().is_prime () :
            raise ValueError ("characterisitic must be a prime number")
        if not M.is_square () :
            raise ValueError ("matrix must be square")
    sol = basis_of_kernel_of_linear_polynomial(
              vector([R.frobenius_endomorphism(n - i)(M [i, n])
                     for i in range(n+1)] + [-1]))
    frob = R.frobenius_endomorphism ()
    res = matrix(R, n + 1, 0)
    if(len(sol) != n + 1):
        raise ValueError("matrix is not diagonalisable")
    for s in sol:
        v = []
        if n != 0:
            b = frob(s)
            v.append(M [0, n] * b)
            for i in range(n - 1) :
                v = v + [frob(v [i]) + M [i + 1, n] * b]
        v.append(s)
        res = res.augment(matrix (v).transpose ())
    return res

def diagonalise_frobenius_linear_invertible_matrix(M, check = True) :
    R = M.base_ring()
    p = R.characteristic()
    if check :
        if M.is_singular():
            raise ValueError ("matrix must be invertible")
        if not p.is_prime():
            raise ValueError ("characterisitic must be a prime number")
    res = matrix(R, M.nrows(), 0)
    frob = R.frobenius_endomorphism()
    for vec in M.parent().identity_matrix().columns():
        family = copy(res)
        n = 1
        s = res.ncols()
        frob = R.frobenius_endomorphism()
        A = PolynomialRing (R, names=('x',))
        x = A._first_ngens(1)[0]
        while True:
            vec = M * vec.apply_map(frob)
            family = family.augment(matrix(vec).transpose())
            r = family.rank ()
            if (r != n + s):
                if n == 1:
                    break
                N = family.echelon_form().delete_columns([i
                                                          for i in range(s+1)])
                combinations = diagonalise_frobenius_linear_companion_matrix(
                                   N.matrix_from_rows(range(s, r)),
                                   check=False)
                n -= 2
                new_combinations = matrix(R, s, n + 1)
                for i in range(n + 1) :
                    b = frob(combinations[n, i])
                    for j in range (s) :
                        P = x**p - x + b * N [j, n]
                        new_combinations [j, i] = P.roots()[0][0]
                combinations = new_combinations.stack (combinations)
                res = res.augment (family.delete_columns ([r]) * combinations)
                if res.is_square () :
                    return res
                break
            n += 1

def diagonalise_frobenius_linear_matrix(M):
    """
    TODO

    OUTPUT:

    A couple nilpotent, fixed points
    """
    jordan_chevalley_M = jordan_chevalley_frobenius(M)
    im_basis = jordan_chevalley_M.column_module().basis()
    if jordan_chevalley_M.column_module().basis() == [] :
        return jordan_chevalley_M.parent().identity_matrix().rows(), []
    im = jordan_chevalley_M * matrix(im_basis).transpose()
    N = im.solve_right(M
                       * im.apply_map(im.base_ring().frobenius_endomorphism()),
                       extend=False)
    sol = diagonalise_frobenius_linear_invertible_matrix(N, check = False)
    R = sol.base_ring()
    return (jordan_chevalley_M.change_ring(R).right_kernel().basis(),
            (im.change_ring(R) * sol).columns())

def solve_inhomogeneous_equation(M, m):
    """
    TODO
    x such that M frob(x)-x=m
    """
    nil, fix = diagonalise_frobenius_linear_matrix(M)
    n = len(nil)
    f = len(fix)
    r = n + f
    N = (matrix(fix).transpose() if n == 0
         else matrix(nil).stack(matrix(fix)).transpose())
    inv = N.inverse()
    R = M.base_ring()
    m = inv * vector(m)
    m_nil = N * vector(m[:n].list() + [R.zero()]*f)
    r_nil = zero_vector(R, r)
    while not m_nil.is_zero():
        r_nil -= m_nil
        m_nil = M * m_nil.apply_map(R.frobenius_endomorphism())
    r_fix = zero_vector(R, r)
    A = PolynomialRing(R, names=('x',))
    (x,) = A._first_ngens(1)
    for i in range(f):
        P = x**R.characteristic() - x - m[n+i]
        try:
            sol = P.roots()[0][0]
        except:
            raise
        r_fix += sol * fix[i]
    return r_nil + r_fix


class ModuleRestrictionOfScalarsElement(ModuleElement):
    """
    TODO
    """
    def __init__(self, parent, x):
        """
        TODO
        """
        self._element = parent.underlying_module()(x)
        super().__init__(parent)

    def _lmul_(self, right):
        """
        TODO
        """
        parent = self.parent()
        return parent(self._element * parent.restriction_map()(right))

    def _repr_(self):
        """
        TODO
        """
        return repr(self._element)

    def _rmul_(self, left):
        """
        TODO
        """
        parent = self.parent()
        return parent(parent.restriction_map()(left) * self._element)


class ModuleRestrictionOfScalars(Module):
    """
    TODO
    """
    Element = ModuleRestrictionOfScalarsElement

    def __init__(self, M, restriction_map):
        """
        TODO
        """
        if not isinstance(M, Module):
            raise TypeError(f"{M} is not a module")
        base = restriction_map.domain()
        if base is not M.base_ring():
            raise ValueError(f"{restriction_map} is not defined on the base "
                             "ring of the module")
        if base is not restriction_map.codomain():
            raise ValueError(f"{restriction_map} is not an endomorphism")
        self._restriction_map = restriction_map
        self._underlying_module = M
        try:
            names = M.variable_names()
        except ValueError:
            names = None
        super().__init__(base, category=M.category(), names=names)

    def _repr_(self):
        """
        TODO
        """
        return (f"Restriction of scalars through {self._restriction_map} of "
                f"{self._underlying_module}")

    def linear_combination(self, iter_of_elements_coeff):
        """
        TODO
        """
        return sum(l * element for element, l in iter_of_elements_coeff)

    def restriction_map(self):
        """
        TODO
        """
        return self._restriction_map

    def underlying_module(self):
        """
        TODO
        """
        return self._underlying_module


class ModuleSemilinearMapFromMatrix(ModuleMorphismByLinearity):
    """
    TODO
    """
    def __init__(self, domain, codomain, restriction_map, matrix,
                 category=None):
        """
        TODO
        """
        C = ModulesWithBasis(domain.base_ring()).FiniteDimensional()
        if domain not in C:
            raise ValueError(f"the domain {domain} is not finite dimensional")
        if codomain not in C:
            raise ValueError(f"the codomain {codomain} is not finite "
                             "dimensional")
        if not isinstance(matrix, Matrix):
            raise TypeError(f"{matrix} is not a matrix")
        import sage.combinat.ranker
        indices = tuple(domain.basis().keys())
        rank_domain = sage.combinat.ranker.rank_from_list(indices)
        matrix = matrix.transpose()
        if matrix.nrows() != len(indices):
            raise ValueError(f"the dimension of the matrix ({matrix.nrows()}) "
                             "does not match with the dimension of the domain "
                             f"({len(indices)})")
        if matrix.ncols() != codomain.dimension():
            raise ValueError(f"the dimension of the matrix ({matrix.nrows()}) "
                             "does not match with the dimension of the "
                             "codomain ({codomain.dimension()})")
        self._matrix = matrix
        codomain_rs = ModuleRestrictionOfScalars(codomain, restriction_map)
        d = {xt: codomain_rs(codomain.from_vector(matrix.row(rank_domain(xt))))
             for xt in indices}
        super().__init__(on_basis=lambda i : d[i], domain=domain,
                         codomain=codomain_rs, category=category)


class FunctionFieldElement_2(FunctionFieldElement):
    def coefficients(self, place, lower_bound=None):
        """
        TODO ITERATOR
        Compute the TODO of ``self`` at ``place``.

        INPUT:

        - ``place`` -- place of the parent of ``self`` at which we compute the
          principal part of ``self``.

        - ``lower_bound`` -- integer (default: ``None``) lower bound of the
          , which must be greater or equal than the valuation of ``self`` at
          ``place``.
        """
        val = self.valuation(place)
        if lower_bound is None:
            lower_bound = val
        elif lower_bound > val:
            raise ValueError("the given lower bound is larger than the "
                             "valuation of self")
        residue_field, from_res, to_res = place.residue_field()
        while lower_bound < val:
            yield residue_field.zero()
            lower_bound += 1
        uniformiser = place.local_uniformizer()
        unit_part = self * uniformiser ** -val
        while True:
            residue = to_res(unit_part)
            yield residue
            unit_part = (unit_part - from_res(residue)) / uniformiser
            lower_bound += 1


class H_et_Scheme(UniqueRepresentation, Module):
    """
    TODO
    """
    def __classcall_private__(cls, X, n, coeff=None):
        """
        TODO
        -  ``coeff`` -- a commutative ring TODO
        """
        if n != 1:
            raise NotImplementedError("only the first étale cohomology group "
                                      "is implemented")
        if not isinstance(X, NormalIntegralProjectiveCurve_finite_field):
            raise NotImplementedError("étale cohomology is only implemented "
                                      "on normal integral projective curves "
                                      "over finite fields")
        if coeff is None:
            cls = H_et_Scheme_free
        elif (not isinstance(coeff, IntegerModRing)
              and coeff not in FiniteFields()):
            raise NotImplementedError("étale cohomology is only implemented "
                                      "with coefficients in either the "
                                      "structure sheaf, or in the constant "
                                      "sheaf associated to the ring of "
                                      "integers modulo a positive power of "
                                      "the characteristic of the base field")
        else:
            data = coeff.characteristic().is_prime_power(get_data=True)
            if data[1] == 0 or data[0] != X.base_ring().characteristic():
                raise NotImplementedError("étale cohomology with coefficients "
                                          "in the constant sheaf associated "
                                          "to the ring of integers modulo an "
                                          "integer is only implemented when "
                                          "that integer is a positive power "
                                          "of the characteristic of the base "
                                          "field")
            cls = (H_et_Scheme_Fp_free if data[1] == 1
                   else H_et_Scheme_ZpnZ_free)
        return cls.__classcall__(cls, X, n, coeff)

    def degree(self):
        """
        TODO
        """
        return self._degree

    def scheme(self):
        """
        TODO
        """
        return self._scheme


class H_et_Scheme_free_Element(IndexedFreeModuleElement):
    """
    TODO
    """
    def apply_frobenius(self):
        """
        TODO
        """
        H1_OX = self.parent()
        F = H1_OX.base_ring().frobenius_endomorphism()
        M = H1_OX.frobenius_action()
        v = self.to_vector()
        return H1_OX(M * v.apply_map(F))

    def extension_function(self, f=None):
        """
        TODO
        """
        if f is None:
            vec = self.parent()(0)
        else:
            pass
        fself_min_self = [[r**p] + [0] * (p - 2) + [-r] for r in self]
        return self.parent().find_function(fself_min_self)


class H_et_Scheme_free(CombinatorialFreeModule, H_et_Scheme):
    """
    TODO
    WARNING never call it on its own
    """
    Element = H_et_Scheme_free_Element

    def __init__(self, C, n, coeff=None):
        """
        TODO
        """
        self._degree = n
        self._scheme = C
        super().__init__(C.base_ring(),
                         basis_keys=self._generate_non_special_family())

    def _element_constructor_(self, data):
        """
        TODO

        INPUT:

        - ``data`` can be list of NONEMPTY vectors
        """
        basis = self.indices()
        dictionary = {}
        if isinstance(data, FreeModuleElement):
            for i in range(self.dimension()):
                dictionary[basis[i]] = data[i]
            return self._from_dict(dictionary)
        if not isinstance(data, list):
            return super()._element_constructor_(data)
        g = self.dimension()
        for i in range(g):
            dictionary[basis[i]] = len(data[i])
        k = self.base_ring()
        size = sum(dictionary.values())
        mat = zero_matrix(k, g, size)
        vec = vector(k, size)
        pos = -1
        for i in range(g):
            for value in data[i]:
                pos += 1
                vec[pos] = value
            mat [i, pos] = k.one()
        big_divisor = divisor(self._scheme.function_field(), dictionary)
        for fun in big_divisor.basis_function_space():
            line = sum([list(islice(FunctionFieldElement_2.coefficients(fun,
                         b, lower_bound=-dictionary[b]), dictionary[b]))
                        for b in basis], [])
            mat = mat.stack(vector(line))
        sol = mat.solve_left(vec, extend=False)
        for i in range(g):
            dictionary[basis[i]] = sol[i]
        return self._from_dict(dictionary)

    def _generate_non_special_family(self):
        """
        TODO
        """
        def primes_for_divisors(F):
            K = F.base_field()
            O = F.maximal_order()
            if K is F:
                R = O._ring
                lm = R._first_ngens(1)[0]
                for k, f in R.base_ring().subfields():
                    for a in k:
                        h = x + f(a)
                        yield O.ideal(h).place().prime_ideal()
            else:
                for prime in primes_for_divisors(K):
                    for p, _, _ in O.decomposition(prime):
                        yield p
        F = self._scheme.function_field()
        g = self._scheme.genus()
        non_special = divisor(F, {})
        for prime in primes_for_divisors(F):
            place = prime.place()
            if place in non_special.dict():
                continue
            test_divisor = non_special + prime_divisor(F, place, 1)
            if test_divisor.dimension() == 1:
                non_special = test_divisor
                g -= 1
            if g == 0:
                break
        else:
            raise ValueError("There is no non special family whose "
                             "cardinality is the genus of the scheme")
        return non_special.support()

    def find_function(self, principal_parts):
        """
        TODO

        INPUT:

        - ``principal_parts`` list of NONEMPTY vectors
        """
        basis = self.indices()
        dictionary = {}
        g = self.dimension()
        for i in range(g):
            dictionary[basis[i]] = len(principal_parts[i])
        k = self.base_ring()
        size = sum(dictionary.values())
        mat = zero_matrix(k, 0, size)
        vec = vector(k, size)
        pos = 0
        for pp in principal_parts:
            for value in pp:
                vec[pos] = value
                pos += 1
        big_divisor = divisor(self._scheme.function_field(), dictionary)
        basis_function_space = big_divisor.basis_function_space()
        for fun in basis_function_space:
            line = sum([list(islice(FunctionFieldElement_2.coefficients(fun,
                         b, lower_bound=-dictionary[b]), dictionary[b]))
                        for b in basis], [])
            mat = mat.stack(vector(line))
        try:
            sol = mat.solve_left(vec, extend=False)
        except ValueError:
            raise ValueError("the list of principal parts does not correspond "
                             "to a rational function")
        return sum(sol[i] * basis_function_space[i]
                   for i in range(len(basis_function_space)))

    @cached_method
    def frobenius_action(self):
        """
        Return a matrix of the action of the Frobenius on étale cohomology.
        OUTPUT:

        A matrix of the action of the Frobenius on the first étale cohomology
        group of ``self`` associated to the basis ``self.indices()``.
        """
        g = self.dimension()
        k = self.base_ring()
        frob = matrix(k, g, 0)
        p = k.characteristic()
        zero_principal_parts = [[k.zero()]] * g
        p_power = [k.zero()] * p
        p_power[0] = k.one()
        for i in range(g):
            uniformiser_power = zero_principal_parts.copy()
            uniformiser_power[i] = p_power
            frob = frob.augment(self(uniformiser_power).to_vector())
        return frob


class H_et_Scheme_Fp_free_Element(IndexedFreeModuleElement):
    """
    TODO
    """
    def extension_function(self):
        """
        TODO
        """
        basis = self.parent().indices()
        p = self.base_ring().characteristic()
        v = sum(self[basis[i]] * basis[i].to_vector()
                for i in range(self.parent().dimension()))
        fself_min_self = [[r**p] + [0] * (p - 2) + [-r] for r in v]
        H1_OX = self.parent().scheme().etale_cohomology(self.parent().degree())
        return H1_OX.find_function(fself_min_self)


class H_et_Scheme_Fp_free(CombinatorialFreeModule, H_et_Scheme):
    """
    TODO
    WARNING never call it on its own
    """
    Element = H_et_Scheme_Fp_free_Element

    def __init__(self, C, n, coeff=None):
        """
        TODO
        """
        self._degree = n
        self._scheme = C
        super().__init__(coeff, basis_keys=self._solve_fixed_points_family())

    def _solve_fixed_points_family(self):
        """
        TODO basis of H_1(self, Z/pZ)
        """
        H1_OX = self._scheme.etale_cohomology(self._degree)
        sols = diagonalise_frobenius_linear_matrix(H1_OX.frobenius_action())[1]
        return [H1_OX(sol) for sol in sols]


class H_et_Scheme_ZpnZ_free(CombinatorialFreeModule, H_et_Scheme):
    """
    TODO
    WARNING never call it on its own
    """
    def __init__(self, C, n, coeff=None):
        """
        TODO
        """
        self._base = coeff
        self._degree = n
        self._scheme = C
        super().__init__(coeff, basis_keys=self._solve_fixed_vectors_family())

    def _solve_fixed_vectors_family(self):
        """
        TODO
        """
        H_et = self._scheme.etale_cohomology
        n = self._degree
        Fp = self._scheme.base_ring()
        H1_Fp = H_et(n, Fp)
        basis_H1_Fp = H1_Fp.indices()
        H1_OX = H_et(n)
        basis_H1_OX = H1_OX.indices()
        d = H1_Fp.dimension()
        F = self._scheme.function_field()
        frob = H1_OX.frobenius_action()
        g = H1_OX.dimension()
        p = Fp.characteristic()
        m = self.base_ring().characteristic().log(p)
        P = []
        R = PolynomialRing(F, ['x%d'%i for i in range(m)]
                              + ['y%d'%i for i in range(m)])
        var = R._first_ngens(2 * m)
        for j in range(1, m):
            W = WittVectorRing(R, prec=j+1, algorithm='finotti')
            w = (W([x**p for x in var[:j]] + [R.zero()])
                 - W([x for x in var[:j]] + [R.zero()])
                 - W([x for x in var[m:m+j]] + [R.zero()]))
            P.append(w[j])
        W = WittVectorRing(F, prec=m, algorithm='finotti')
        w_functions = []
        basis_H1_ZpnZ = []
        for i in range(d):
            w_functions.append([H1_Fp(basis_H1_Fp[i]).extension_function()]
                               + [F.zero()] * (m - 1))
            repartition = basis_H1_Fp[i].to_vector()
            r = [[repartition[e] * basis_H1_OX[e].local_uniformizer()**(-1)
                  for e in range(g)]] + [[F.zero()] * g] * (m - 1)
            basis_H1_ZpnZ.append([sum(r[0][e] for e in range(g))]
                                  + [F.zero()] * (m - 1))
            for j in range(1, m):
                v = [-P[j-1]([r[l][e] for l in range(m)] + w_functions[i])
                     for e in range(g)]
                mm = min(v[e].valuation(basis_H1_OX[e]) for e in range(g))
                vv = [list(islice(FunctionFieldElement_2.coefficients(v[e], basis_H1_OX[e], lower_bound=mm), -mm)) for e in range(g)]
                vv = H1_OX(vv).to_vector()
                r[j] = solve_inhomogeneous_equation(frob, vv)
                f_repartition = [[r[j][e]**p] + [0] * (p - 2) + [-r[j][e]-vv[e]] for e in range(g)]
                w_functions[i][j] = H1_OX.find_function(f_repartition)
                basis_H1_ZpnZ[i][j] = sum(r[j][e] * basis_H1_OX[e].local_uniformizer()**(-1) for e in range(g))
            w_functions[i] = W(w_functions[i])
            basis_H1_ZpnZ[i] = W(basis_H1_ZpnZ[i])
        return basis_H1_ZpnZ


class NormalIntegralProjectiveCurve(IntegralProjectiveCurve):
    """
    TODO
    """
    def etale_cohomology(self, n, coeff=None):
        """
        TODO
        """
        return H_et_Scheme(self, n, coeff)


class NormalIntegralProjectiveCurve_finite_field(
        IntegralProjectiveCurve_finite_field,
        NormalIntegralProjectiveCurve):
    """
    TODO
    """
    pass


class NormalIntegralProjectiveCurve_function_field(
        NormalIntegralProjectiveCurve, UniqueRepresentation):
    """
    TODO
    """
    def __classcall_private__(cls, F):
        """
        TODO
        """
        if F.characteristic() == 0:
            raise NotImplementedError("the construction of the normal "
                                      "integral projective curve associated "
                                      "to a function field of characteristic "
                                      "zero is not implemented")
        child = NormalIntegralProjectiveCurve_function_field_finite_field
        return child.__classcall__(child, F)

    def _repr_(self):
        """
        TODO
        Le copier coller donne:
        Return a string representation of this curve.
        """
        return (f"Normal integral projective curve over {self.base_ring()} "
                f"defined by {self.function_field()}")


class NormalIntegralProjectiveCurve_function_field_finite_field(
        NormalIntegralProjectiveCurve_finite_field,
        NormalIntegralProjectiveCurve_function_field):
    """
    TODO
    """
    def __init__(self, F):
        """
        TODO
        """
        if F not in FunctionFields():
            raise TypeError(f"{F} is not a function field")
        k = F.constant_base_field()
        if k not in FiniteFields():
            raise ValueError(f"the constant base field of {F} is not finite")
        self._genus = F.genus()
        self._function_field = F
        Scheme.__init__(self, k)


k = GF(3**9)

K = FunctionField(k, names=('x',))
x = K._first_ngens(1)[0]
R = K['y']
y = R._first_ngens(1)[0]

pol = y**2 - x**5 - x**2 - 1

F = K.extension(pol, names=('y',))
C = NormalIntegralProjectiveCurve_function_field(F)

print(C.etale_cohomology(1, Integers(3**2)))
