from numpy import * from matplotlib.pyplot import * def backsubstitute(A, b): n = len(b) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) x = b.copy() for i in range(n-1,-1,-1): x[i] = b[i] for j in range(n-1, i, -1): x[i] -= A[i,j] * x[j] x[i] /= A[i,i] return x def gauss_eliminate(A_in, b_in): A = A_in.copy() b = b_in.copy() n = len(b) i,j = A.shape if i != n or j != n: raise ValueError("A should be shape (%d, %d), but is %s" % (n,n,A.shape)) # which row is currently handled for r in range(0, n): # which row to eliminate for i in range(r+1, n): l = A[i,r]/A[r,r] A[i,r] = 0 b[i] -= l*b[r] A[i,r+1:] -= l*A[r,r+1:] return A, b def solve(A, b): n,i = A.shape Ares, bres = gauss_eliminate(A, b) # Ares, bres = gauss_eliminate_columnpivot(A, b) # Ares, bres, original_index = gauss_eliminate_totalpivot(A, b) x = backsubstitute(Ares, bres) # x[original_index,:] = x[range(0,n),:] return x def poly4(x): a = array([ 0., 1.86073502, -0.88843553, 0.09426594]) return a[0] + a[1]*x + a[2]*x**2 + a[3]*x**3 xs = linspace(0, 2*pi, 4) ys = sin(xs) plot(xs, ys, 'o') A = zeros((4,4)) A[:, 0] = 1 A[:, 1] = xs A[:, 2] = xs**2 A[:, 3] = xs**3 as_ = solve(A, ys) xs = linspace(0, 2*pi, 100) plot(xs, sin(xs), label='sin(x)') plot(xs, poly4(xs), label='p(x)') legend() show()