ONLY DO WHAT ONLY YOU CAN DO

こけたら立ちなはれ 立ったら歩きなはれ

Haskell で連立一次方程式を解く(ガウスの消去法)

連立一次方程式
9x+2y+z+u=20
2x+8y-2z+u=16
-x-2y+7z-2u=8
x-y-2z+6u=17
が与えられた場合, まず(1)式を9で割って(1)'とする.
x+\frac{2}{9}y+\frac{1}{9}z+\frac{1}{9}u=\frac{20}{9}

(2)式から(1)'式に2を掛けたものを引くとxの項が消える.
(3)式に(1)'式を足すとxの項が消える.
(4)式から(1)'式を引くとxの項が消える.

同様にして,
2番目の式を利用して, 3番目以降の式のyを消す.
3番目の式を利用して, 4番目の式のzを消す.

このようにして、上三角型方程式を得る.
9x+2y+z+u=20
 7.55555y - 2.22222z + 0.77777u = 11.55555
 6.58823z - 1.70588u = 12.94117
 5.37500u = 21.50000
(前進消去)

(4)式から u の値が求まるので, (3)式に代入すると z の値が求まる.
u, z の値を (2)式に代入すると y の値が求まる.
u, z, y の値を (1)式に代入すると x の値が求まる.
(後退代入)

import Text.Printf
import Debug.Trace
import Control.Monad

n = 4::Int

disp_matrix::[[Double]]->IO()
disp_matrix matrix = do
    forM_ matrix $ \row -> do
        forM_ row $ \elem -> do
            printf "%14.10f\t" elem
        putStrLn ""

disp_vector::[Double]->IO()
disp_vector vector = do
    forM_ vector $ \elem -> do
        printf "%14.10f\t" elem 
    putStrLn ""

-- 各列で 一番値が大きい行を 探す
get_max_row::Int->Int->[[Double]]->Int->Double->(Int, Double)
get_max_row row col a max_row max_val =
    let
        -- 一番値が大きい行
        (max_row2, max_val2) = 
            if max_val < abs(a!!row!!col)
                then (row,     abs(a!!row!!col))
                else (max_row, max_val)
    in
        if row >= length(a) - 1
            then
                (max_row2, max_val2)
            else
                get_max_row (row+1) col a max_row2 max_val2

-- ピボット選択
pivoting::Int->[[Double]]->[Double]->[[Double]]->[Double]->([[Double]],[Double])
pivoting pivot a b a2 b2 = 
    let
        (max_row, max_val) = get_max_row 0 pivot a 0 0.0
        a3 = (a!!max_row):a2
        b3 = (b!!max_row):b2
        a4 = (take (max_row) a) ++ (drop (max_row+1) a)
        b4 = (take (max_row) b) ++ (drop (max_row+1) b)
    in
        if pivot >= (n - 1)
            then
                (reverse a3, reverse b3)
            else
                pivoting (pivot+1) a4 b4 a3 b3
 
-- 前進消去
forward_elim_loop::Int->Int->[[Double]]->[Double]->([[Double]],[Double])
forward_elim_loop pivot row a b = 
    let
        s = a!!row!!pivot / a!!pivot!!pivot

        a1 = map (\(a_pivot, a_row) -> a_row - a_pivot * s) $ zip (a!!pivot) (a!!row)
        a2 = (take row a) ++ (a1:(drop (row+1) a))

        b1 = b!!row - b!!pivot * s
        b2 = (take row b) ++ (b1:(drop (row+1) b))
    in
        if row < (n - 1)
            then
                forward_elim_loop pivot (row+1) a2 b2
            else
                (a2, b2)

forward_elimination::Int->[[Double]]->[Double]->([[Double]],[Double])
forward_elimination pivot a b = 
    let
        (a2, b2) = 
            if pivot < (n - 1)
                then forward_elim_loop pivot (pivot+1) a b
                else (a, b)
    in
        if pivot < (n - 1)
            then
                forward_elimination (pivot+1) a2 b2
            else
                (a2, b2)

-- 後退代入
backward_substitution::Int->[[Double]]->[Double]->[Double]
backward_substitution row a b = 
    let
        a2 = zip (drop (row+1) (a!!row)) (drop (row+1) b)
        a3 = sum $ map(\(a_col, b_col) -> a_col * b_col) $ a2
        x  = (b!!row - a3) / (a!!row!!row)
        b2 = (take row b) ++ (x:(drop (row+1) b))
    in
        if row > 0
            then x:(backward_substitution (row-1) a b2)
            else x:[]

main = do
    let a  = [[-1,-2,7,-2],[1,-1,-2,6],[9,2,1,1],[2,8,-2,1::Double]]
    let b  = [8,17,20,16::Double]

    -- ピボット選択
    let (a1, b1) = pivoting 0 a b [] []

    putStrLn "pivoting"
    putStrLn "A"
    disp_matrix a1
    putStrLn "B"
    disp_vector b1
    putStrLn ""

    -- 前進消去
    let (a2, b2) = forward_elimination 0 a1 b1

    putStrLn "forward_elimination"
    putStrLn "A"
    disp_matrix a2
    putStrLn "B"
    disp_vector b2
    putStrLn ""

    -- 後退代入
    let x = backward_substitution (n-1) a2 b2

    putStrLn "backward_substitution"
    putStrLn "X"
    disp_vector (reverse x)
pivoting
A
  9.0000000000    2.0000000000    1.0000000000    1.0000000000
  2.0000000000    8.0000000000   -2.0000000000    1.0000000000
 -1.0000000000   -2.0000000000    7.0000000000   -2.0000000000
  1.0000000000   -1.0000000000   -2.0000000000    6.0000000000
B
 20.0000000000   16.0000000000    8.0000000000   17.0000000000

forward_elimination
A
  9.0000000000    2.0000000000    1.0000000000    1.0000000000
  0.0000000000    7.5555555556   -2.2222222222    0.7777777778
  0.0000000000    0.0000000000    6.5882352941   -1.7058823529
  0.0000000000    0.0000000000    0.0000000000    5.3750000000
B
 20.0000000000   11.5555555556   12.9411764706   21.5000000000

backward_substitution
X
  1.0000000000    2.0000000000    3.0000000000    4.0000000000
参考文献