lsq_solver.dart 6.25 KB
Newer Older
Ian Hickson's avatar
Ian Hickson committed
1
// Copyright 2014 The Flutter Authors. All rights reserved.
2 3 4
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

5

6
import 'dart:math' as math;
7

8 9
import 'package:flutter/foundation.dart';

10
// TODO(abarth): Consider using vector_math.
11
class _Vector {
12 13 14 15
  _Vector(int size)
    : _offset = 0,
      _length = size,
      _elements = Float64List(size);
16

17
  _Vector.fromVOL(List<double> values, int offset, int length)
18 19 20
    : _offset = offset,
      _length = length,
      _elements = values;
21

Ian Hickson's avatar
Ian Hickson committed
22 23 24 25 26
  final int _offset;

  final int _length;

  final List<double> _elements;
27

Ian Hickson's avatar
Ian Hickson committed
28 29 30 31
  double operator [](int i) => _elements[i + _offset];
  void operator []=(int i, double value) {
    _elements[i + _offset] = value;
  }
32

Ian Hickson's avatar
Ian Hickson committed
33
  double operator *(_Vector a) {
34
    double result = 0.0;
35
    for (int i = 0; i < _length; i += 1) {
36
      result += this[i] * a[i];
37
    }
38 39 40 41 42 43
    return result;
  }

  double norm() => math.sqrt(this * this);
}

44
// TODO(abarth): Consider using vector_math.
45 46
class _Matrix {
  _Matrix(int rows, int cols)
47 48
    : _columns = cols,
      _elements = Float64List(rows * cols);
49

Ian Hickson's avatar
Ian Hickson committed
50 51 52
  final int _columns;
  final List<double> _elements;

53
  double get(int row, int col) => _elements[row * _columns + col];
54
  void set(int row, int col, double value) {
55
    _elements[row * _columns + col] = value;
56 57
  }

58
  _Vector getRow(int row) => _Vector.fromVOL(
59 60
    _elements,
    row * _columns,
61
    _columns,
62
  );
63 64
}

65
/// An nth degree polynomial fit to a dataset.
66
class PolynomialFit {
67 68 69
  /// Creates a polynomial fit of the given degree.
  ///
  /// There are n + 1 coefficients in a fit of degree n.
70
  PolynomialFit(int degree) : coefficients = Float64List(degree + 1);
71

72
  /// The polynomial coefficients of the fit.
73 74 75
  ///
  /// For each `i`, the element `coefficients[i]` is the coefficient of
  /// the `i`-th power of the variable.
76
  final List<double> coefficients;
77 78 79

  /// An indicator of the quality of the fit.
  ///
80 81 82 83 84
  /// Larger values indicate greater quality.  The value ranges from 0.0 to 1.0.
  ///
  /// The confidence is defined as the fraction of the dataset's variance
  /// that is captured by variance in the fit polynomial.  In statistics
  /// textbooks this is often called "r-squared".
85
  late double confidence;
86 87 88 89 90 91 92

  @override
  String toString() {
    final String coefficientString =
        coefficients.map((double c) => c.toStringAsPrecision(3)).toList().toString();
    return '${objectRuntimeType(this, 'PolynomialFit')}($coefficientString, confidence: ${confidence.toStringAsFixed(3)})';
  }
93 94
}

95
/// Uses the least-squares algorithm to fit a polynomial to a set of data.
96
class LeastSquaresSolver {
97 98
  /// Creates a least-squares solver.
  ///
99
  /// The [x], [y], and [w] arguments must not be null.
100 101 102
  LeastSquaresSolver(this.x, this.y, this.w)
    : assert(x.length == y.length),
      assert(y.length == w.length);
103

104
  /// The x-coordinates of each data point.
105
  final List<double> x;
106 107

  /// The y-coordinates of each data point.
108
  final List<double> y;
109 110

  /// The weight to use for each data point.
111 112
  final List<double> w;

113
  /// Fits a polynomial of the given degree to the data points.
114 115 116
  ///
  /// When there is not enough data to fit a curve null is returned.
  PolynomialFit? solve(int degree) {
117 118
    if (degree > x.length) {
      // Not enough data to fit a curve.
119
      return null;
120
    }
121

122
    final PolynomialFit result = PolynomialFit(degree);
123

Florian Loitsch's avatar
Florian Loitsch committed
124
    // Shorthands for the purpose of notation equivalence to original C++ code.
125 126 127 128
    final int m = x.length;
    final int n = degree + 1;

    // Expand the X vector to a matrix A, pre-multiplied by the weights.
129
    final _Matrix a = _Matrix(n, m);
Ian Hickson's avatar
Ian Hickson committed
130
    for (int h = 0; h < m; h += 1) {
131
      a.set(0, h, w[h]);
132
      for (int i = 1; i < n; i += 1) {
133
        a.set(i, h, a.get(i - 1, h) * x[h]);
134
      }
135 136 137 138 139
    }

    // Apply the Gram-Schmidt process to A to obtain its QR decomposition.

    // Orthonormal basis, column-major ordVectorer.
140
    final _Matrix q = _Matrix(n, m);
141
    // Upper triangular matrix, row-major order.
142
    final _Matrix r = _Matrix(n, n);
Ian Hickson's avatar
Ian Hickson committed
143
    for (int j = 0; j < n; j += 1) {
144
      for (int h = 0; h < m; h += 1) {
145
        q.set(j, h, a.get(j, h));
146
      }
Ian Hickson's avatar
Ian Hickson committed
147
      for (int i = 0; i < j; i += 1) {
148
        final double dot = q.getRow(j) * q.getRow(i);
149
        for (int h = 0; h < m; h += 1) {
150
          q.set(j, h, q.get(j, h) - dot * q.get(i, h));
151
        }
152 153
      }

154
      final double norm = q.getRow(j).norm();
155
      if (norm < precisionErrorTolerance) {
Florian Loitsch's avatar
Florian Loitsch committed
156
        // Vectors are linearly dependent or zero so no solution.
157 158 159
        return null;
      }

160
      final double inverseNorm = 1.0 / norm;
161
      for (int h = 0; h < m; h += 1) {
162
        q.set(j, h, q.get(j, h) * inverseNorm);
163 164
      }
      for (int i = 0; i < n; i += 1) {
165
        r.set(j, i, i < j ? 0.0 : q.getRow(j) * a.getRow(i));
166
      }
167 168
    }

169
    // Solve R B = Qt W Y to find B. This is easy because R is upper triangular.
170
    // We just work from bottom-right to top-left calculating B's coefficients.
171
    final _Vector wy = _Vector(m);
172
    for (int h = 0; h < m; h += 1) {
173
      wy[h] = y[h] * w[h];
174
    }
Ian Hickson's avatar
Ian Hickson committed
175
    for (int i = n - 1; i >= 0; i -= 1) {
176
      result.coefficients[i] = q.getRow(i) * wy;
177
      for (int j = n - 1; j > i; j -= 1) {
178
        result.coefficients[i] -= r.get(i, j) * result.coefficients[j];
179
      }
180
      result.coefficients[i] /= r.get(i, i);
181 182
    }

183
    // Calculate the coefficient of determination (confidence) as:
Ian Hickson's avatar
Ian Hickson committed
184 185
    //   1 - (sumSquaredError / sumSquaredTotal)
    // ...where sumSquaredError is the residual sum of squares (variance of the
186 187 188
    // error), and sumSquaredTotal is the total sum of squares (variance of the
    // data) where each has been weighted.
    double yMean = 0.0;
189
    for (int h = 0; h < m; h += 1) {
190
      yMean += y[h];
191
    }
192
    yMean /= m;
193

194 195
    double sumSquaredError = 0.0;
    double sumSquaredTotal = 0.0;
Ian Hickson's avatar
Ian Hickson committed
196
    for (int h = 0; h < m; h += 1) {
197
      double term = 1.0;
Ian Hickson's avatar
Ian Hickson committed
198 199
      double err = y[h] - result.coefficients[0];
      for (int i = 1; i < n; i += 1) {
200
        term *= x[h];
201
        err -= term * result.coefficients[i];
202
      }
203
      sumSquaredError += w[h] * w[h] * err * err;
Ian Hickson's avatar
Ian Hickson committed
204
      final double v = y[h] - yMean;
205
      sumSquaredTotal += w[h] * w[h] * v * v;
206 207
    }

208
    result.confidence = sumSquaredTotal <= precisionErrorTolerance ? 1.0 :
Ian Hickson's avatar
Ian Hickson committed
209
                          1.0 - (sumSquaredError / sumSquaredTotal);
210 211 212 213 214

    return result;
  }

}