/**************************************************************************
 *{@C
 *      Copyright:      1988-2025 Paul Obermeier (obermeier@poSoft.de)
 *
 *      Module:         Utilities
 *      Filename:       UT_MathGauss.c
 *
 *      Author:         Paul Obermeier
 *
 *      Description:    Functions to solve sets of linear equations
 *                      using Gauss' algorithm.
 *
 *      Additional documentation:
 *                      None.
 *
 *      Exported functions:
 *                      UT_MathGaussSolve
 *
 **************************************************************************/

#include "UT_Compat.h"
#include "UT_Const.h"
#include "UT_Macros.h"
#include "UT_Error.h"
#include "UT_Memory.h"
#include "UT_Math.h"

#define str_nodecompose "Unable to decompose linear equation"
#define str_nosolve     "Unable to solve linear equation"

/* Computes an LU decomposition of a matrix. The results, "lu" and "perm",
   are used by "solve". */

static UT_Bool decompose
        (Int32 n, const Float64 *mat, Float64 *lu,
         Int32 *perm, Float64 *signdet)
{
    Int32   j0, k, j, m;
    Float64 piv, tmp, rowmax, *d;

    if (!(d = UT_MemTempArray (n, Float64))) {
        return UT_False;
    }
    for (k = 0; k < n; ++k) {
        for (j = 0; j < n; ++j) {
            UT_MATRIX (lu, n, k, j) = UT_MATRIX (mat, n, k, j);
        }
    }
    for (k = 0; k < n; ++k) {
        perm[k] = k;
        rowmax = 0.0;
        for (j = 0; j < n; ++j) {
            rowmax = UT_MAX (rowmax, UT_ABS (UT_MATRIX (lu, n, k, j)));
        }
        if (rowmax == 0.0) {
            UT_ErrSetNum (UT_ErrUnxpInput, str_nodecompose);
            return UT_False;
        }
        d[k] = rowmax;
    }
    *signdet = 1.0;
    for (k = 0; k < n - 1; ++k) {
        piv = UT_ABS (UT_MATRIX (lu, n, k, k)) / d[k];
        j0 = k;
        for (j = k + 1; j < n; ++j) {
            tmp = UT_ABS (UT_MATRIX (lu, n, j, k)) / d[j];
            if (piv < tmp) {
                piv = tmp;
                j0 = j;
            }
        }
        if (piv < Float64Prec) {
            *signdet = 0.0;
            UT_ErrSetNum (UT_ErrUnxpInput, str_nodecompose);
            return UT_False;
        }
        if (j0 != k) {
            *signdet = -*signdet;
            UT_SWAP (perm[j0], perm[k], Int32);
            for (m = 0; m < n; ++m) {
                UT_SWAP (UT_MATRIX (lu, n, j0, m), UT_MATRIX (lu, n, k, m), Float64);
            }
        }
        for (j = k + 1; j < n; ++j) {
            if (UT_MATRIX (lu, n, j, k) != 0.0) {
                tmp = (UT_MATRIX (lu, n, j, k) /= UT_MATRIX (lu, n, k, k));
                for (m = k + 1; m < n; ++m) {
                    UT_MATRIX (lu, n, j, m) -= tmp * UT_MATRIX (lu, n, k, m);
                }
            }
        }
    }
    if (UT_ABS (UT_MATRIX (lu, n, n - 1, n - 1)) < Float64Prec) {
        *signdet = 0.0;
        UT_ErrSetNum (UT_ErrUnxpInput, str_nodecompose);
        return UT_False;
    }
    return UT_True;
}

/* Determines the solution, "x" of a set of linear equations after an LU
decomposition has been performed. */

static UT_Bool solve
        (Int32 n, const Float64 *lu, const Int32 *perm,
         const Float64 *b, Float64 *x)
{
    Int32   j, k;
    Float64 sum;

    for (k = 0; k < n; ++k) {
        x[k] = b[perm[k]];
        for (j = 0; j < k; ++j) {
            x[k] -= UT_MATRIX (lu, n, k, j) * x[j];
        }
    }
    for (k = n - 1; k >= 0; --k) {
        sum = 0.0;
        if (UT_ABS (UT_MATRIX (lu, n, k, k)) < Float64Prec) {
            UT_ErrSetNum (UT_ErrUnxpInput, str_nosolve);
            return UT_False;
        }
        for (j = k + 1; j < n; ++j) {
            sum += UT_MATRIX (lu, n, k, j) * x[j];
        }
        x[k] = (x[k] - sum) / UT_MATRIX (lu, n, k, k);
    }
    return UT_True;
}

/***************************************************************************
 *[@e
 *      Name:           UT_MathGaussSolve
 *
 *      Usage:          Solves a set of linear equations.
 *
 *      Synopsis:       UT_Bool UT_MathGaussSolve
 *                              (UT_Bool rep,
 *                              Int32 n,
 *                              const Float64 *mat,
 *                              Float64 *lu,
 *                              Int32 *perm,
 *                              const Float64 *b,
 *                              Float64 *x,
 *                              Float64 *signdet)
 *
 *      Description:    Size of arrays in parameter list:
 *                      Float64 mat[n*n]
 *                      Float64 lu[n*n]
 *                      Int32   perm[n]
 *                      Float64 b[n]
 *                      Float64 x[n]
 *
 *                      Function "UT_MathGaussSolve" attempts to solve a system
 *                      of linear equations,
 *
 *                              mat * x == b,
 *
 *                      where "mat", is a regular matrix, "b" is the right-hand
 *                      side, and "x" is the solution of the set of equations.
 *                      The solution method is Gauss' algorithm with LU
 *                      decomposition and scaled column-wise pivot search.
 *
 *                      If "rep" is UT_False, an LU decomposition, "lu", of
 *                      "mat" is computed before solving the equations.
 *                      If "rep" is UT_True, "UT_MathGaussSolve" assumes that
 *                      "lu" already contains an LU decomposition of "mat".
 *                      This can be handy when several sets of equations with
 *                      the same coefficients in "mat" but with different right
 *                      hand sides have to be solved.
 *
 *                      "N" is the number of equations to be solved. "B", "x"
 *                      and "perm" must be vectors of length n; "mat" and
 *                      "lu" must be vectors of length (n * n).
 *                      "N" must not be less than 2.
 *
 *                      "Mat" contains the coefficients of the equations to
 *                      be solved.
 *                      "Lu" holds the LU decomposition of "mat". Depending
 *                      on whether "rep" is UT_False or UT_True, 
 *                      "UT_MathGaussSolve" either computes the LU decomposition
 *                      and stores it in "lu" or it reads an existing
 *                      decomposition from "lu".
 *                      "Perm" is used to hold the line permutations in "lu".
 *                      The elements in "mat" and "lu" are stored row by
 *                      row; it is recommended to use the "UT_MATRIX" macro to
 *                      select matrix elements.
 *
 *                      The sign of the determinant of "mat" is computed during
 *                      LU decomposition. It is stored in "signdet".
 *
 *                      "UT_MathGaussSolve" is an adaption of the program
 *                      found in:
 *
 *                              Gisela Engeln-Muellges, Fritz Reutter
 *                              "Formelsammlung zur Numerischen Mathematik
 *                              mit Turbo Pascal-Programmen"
 *                              Bibliographisches Institut, Mannheim, 1991
 *                              pp. 535-538
 *
 *      Return value:   UT_True if a unique solution could be found,
 *                      UT_False if the set of equations has no unique solution
 *                      which can be computed by "UT_MathGaussSolve".
 *
 *      See also:
 *
 ***************************************************************************/

UT_Bool UT_MathGaussSolve
        (UT_Bool rep, Int32 n, const Float64 *mat, Float64 *lu,
         Int32 *perm, const Float64 *b, Float64 *x, Float64 *signdet)
{
    UT_MemState memstate;
    UT_Bool     success;

    memstate = UT_MemRemember ();
    success = UT_True;
    if (!rep) {
        success = decompose (n, mat, lu, perm, signdet);
    }
    if (success) {
        success = solve (n, lu, perm, b, x);
    }
    UT_MemRestore (memstate);
    return success;
}
