/**
 * OLS (Ordinary least squares) regression calculation.
 *
 * @unstable
 */

import { Matrix, SVD, pseudoInverse } from 'ml-matrix';

import * as jstat from 'jstat';

export class OLSRegression {
  static calculateOLS(x, y) {
    const intercept = true;
    const origY = y;
    // const origX = x;
    x = new Matrix(x);
    y = new Matrix(y);
    if (intercept) {
      x.addColumn(new Array(x.rows).fill(1));
    }
    const xt = x.transpose();
    const xx = xt.mmul(x);
    const xy = xt.mmul(y);
    const invxx = new SVD(xx).inverse();
    const beta = xy.transpose().mmul(invxx).transpose();

    const coefficients = beta.to2DArray();
    const fittedValues = x.mmul(beta);
    const residuals = y.clone().addM(fittedValues.neg());
    const variance =
      residuals
        .to2DArray()
        .map((ri) => Math.pow(ri[0], 2))
        .reduce((a, b) => a + b, 0) /
      (y.rows - x.columns);

    const df = y.rows - x.columns;
    const stdError = Math.sqrt(variance);
    const stdErrorMatrix = pseudoInverse(xx).mul(variance);
    const stdErrors = stdErrorMatrix.diagonal().map((d) => Math.sqrt(d));
    const tStats = coefficients.map((d, i) => (stdErrors[i] === 0 ? 0 : d[0] / stdErrors[i]));
    const pValues = tStats.map((t) => jstat.ttest(t, df, 2));

    const SSR = residuals
      .to2DArray()
      .map((ri) => Math.pow(ri[0], 2))
      .reduce((a, b) => a + b, 0);
    const origYAvg = origY.reduce((a, b) => a + b[0], 0) / origY.length;
    const SST = origY.reduce((a, b) => a + Math.pow(b[0] - origYAvg, 2), 0);
    const rSquared = 1 - SSR / SST;

    return {
      coefficients,
      stdErrors,
      tStats,
      pValues,

      stdError,
      rSquared,
      df,
    };
  }

  static calculateVIF(x, y) {
    const rSquared = this.calculateOLS(x, y)?.rSquared;

    return 1 / (1 - rSquared);
  }

  static calculate(x, y, vif: boolean = true) {
    const result = this.calculateOLS(x, y);
    const vifs = {};

    if (vif) {
      vifs['vifs'] = [];

      const xMatrix = new Matrix(x);
      for (let c = 0, lenc = xMatrix.columns; c < lenc; c++) {
        vifs['vifs'].push(
          this.calculateVIF(
            xMatrix.to2DArray().map(function (d) {
              const newArr = d.slice();
              newArr.splice(c, 1);
              return newArr;
            }),
            xMatrix.to2DArray().map(function (d) {
              return [d[c]];
            }),
          ),
        );
      }
    }

    return {
      result,
      ...(vif && vifs),
    };
  }
}
