徐霁的博客 | 矩阵类计算B=(XTX)−1XTY多元回归分析
徐霁

矩阵类计算B=(XTX)−1XTY多元回归分析

阿里面试题目,根据参数进行回归分析,这里是放假,重点是公式,矩阵类

import java.io.*;
import java.util.Scanner;
import java.math.BigInteger;
import java.util.*;
class Matrix
{
    private int row;
    private int column;
    private double [][] matrix;
    Matrix(int r, int c)
    {
        row = r;
        column = c;
        matrix = new double[r][c];
    }
    Matrix(double[][] m, int r, int c)
    {
        row = r;
        column = c;
        matrix = m;
    }
    Matrix(double[][] m)
    {
        row = m.length;
        column = m[0].length;
        matrix = m;
    }
    public int getrow()
    {
        return row;
    }
    public int getcolumn()
    {
        return column;
    }
    public Matrix ADD(Matrix m)
    {
        Matrix tmp = new Matrix(row, column);
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < column; j++)
            {
                tmp.matrix[i][j] = matrix[i][j] + m.matrix[i][j];
            }
        }
        return tmp;
    }
    public Matrix SUB(Matrix m)
    {
        Matrix tmp = new Matrix(row, column);
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < column; j++)
            {
                tmp.matrix[i][j] = matrix[i][j] - m.matrix[i][j];
            }
        }
        return tmp;
    }
    public Matrix MUL(Matrix m)
    {
        Matrix tmp = new Matrix(row, m.column);
        for (int i = 0; i < row; i++)
            for (int j = 0; j < m.column; j++)
            {
                tmp.matrix[i][j] = 0;
                for (int k = 0; k < column; k++)
                    tmp.matrix[i][j] += matrix[i][k] * m.matrix[k][j];
            }
        return tmp;
    }
    public Matrix T()
    {
        Matrix tmp = new Matrix(column, row);
        for (int i = 0; i < row; i++)
            for (int j = 0; j < column; j++)
            {
                tmp.matrix[j][i] = matrix[i][j];
            }
        return tmp;
    }
    public Matrix getAStar(int h, int v)
    {
        Matrix tmp = new Matrix(column-1, row-1);
        for (int i = 0; i < row-1; i++)
        {
            if (i < h - 1)
            {
                for (int j = 0; j < column-1; j++)
                {
                    if (j < v - 1)
                    {
                        tmp.matrix[i][j] = matrix[i][j];
                    }
                    else
                    {
                        tmp.matrix[i][j] = matrix[i][j + 1];
                    }
                }
            }
            else
            {
                for (int j = 0; j < column-1; j++)
                {
                    if (j < v - 1)
                    {
                        tmp.matrix[i][j] = matrix[i + 1][j];
                    }
                    else
                    {
                        tmp.matrix[i][j] = matrix[i + 1][j + 1];
                    }
                }
            }
        }
        return tmp;
    }
    public double getMartrixResult()
    {
        //二维矩阵计算
        if (row == 2)
        {
            return matrix[0][0] * matrix[1][1] - matrix[0][1] * matrix[1][0];
        }
        double result = 0;
        int num = row;
        double[] nums = new double[num];
        for (int i = 0; i < row; i++)
        {
            if (i % 2 == 0)
            {
                nums[i] = matrix[0][i] * getAStar(1, i + 1).getMartrixResult();
            }
            else
            {
                nums[i] = -matrix[0][i] * getAStar(1, i + 1).getMartrixResult();
            }
        }
        for (int i = 0; i < row; i++)
        {
            result += nums[i];
        }
//      System.out.println(result);
        return result;
    }
    public Matrix getReverseMartrix()
    {
        Matrix tmp = new Matrix(row, column);
        double A = getMartrixResult();
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < column; j++)
            {
                if ((i + j) % 2 == 0)
                {
                    tmp.matrix[i][j] = getAStar(i + 1, j + 1).getMartrixResult() / A;
                }
                else
                {
                    tmp.matrix[i][j] = -getAStar(i + 1, j + 1).getMartrixResult()/ A;
                }
            }
        }
        tmp = tmp.T();
        return tmp;
    }
    public void print()
    {
        for (int i = 0; i < row; i++)
        {
            for (int j = 0; j < column; j++)
            {
                if (j == 0)
                {
                    System.out.print(matrix[i][j]);
                }
                else
                    System.out.print(" " + matrix[i][j]);
            }
            System.out.println("");
        }
    }
}
public class Main
{
    public static void main(String args[]) throws FileNotFoundException
    {
        // Scanner cin = new Scanner(System.in);
        InputStream in = new FileInputStream(new File("/Users/sky/Documents/SVN/1.in"));
        Scanner cin = new Scanner(in);
        int num_of_param, num_of_test, num_of_untest;
        double x[][];
        double y[][];
        double testdata[][];
        while (cin.hasNext())
        {
            num_of_param = cin.nextInt()+1;
            num_of_test = cin.nextInt();
            x = new double[num_of_test][num_of_param];
            y = new double[num_of_test][1];
            for (int i = 0; i < num_of_test; i++)
            {
                for (int j = 1; j < num_of_param; j++)
                    x[i][j] = cin.nextDouble();
                y[i][0]=cin.nextDouble();
            }
            for (int j=0;j<num_of_test;j++)
                x[j][0]=1;
            Matrix X = new Matrix(x);
            Matrix Y = new Matrix(y);
            Matrix B = new Matrix(num_of_test,1);
            //B=(XTX)−1XTY
            B=X.T().MUL(X).getReverseMartrix().MUL(X.T()).MUL(Y);
            num_of_untest=cin.nextInt();
            testdata=new double[num_of_untest][num_of_param];
            for (int i = 0; i < num_of_untest; i++)
            {
                for (int j = 1; j < num_of_param; j++)
                    testdata[i][j] = cin.nextDouble();
            }
            for (int j=0;j<num_of_untest;j++)
                testdata[j][0]=1;
            Matrix result=new Matrix(testdata).MUL(B);
            result.print();
        }
    }
}

码字很辛苦,转载请注明来自徐霁的博客《矩阵类计算B=(XTX)−1XTY多元回归分析》

评论

你需要 登录 才可以回复.