solver.dart 17 KB
Newer Older
1 2 3 4 5 6 7
// Copyright (c) 2015 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

part of cassowary;

class Solver {
8 9 10 11 12 13 14
  final Map<Constraint, _Tag> _constraints = new Map<Constraint, _Tag>();
  final Map<_Symbol, _Row> _rows = new Map<_Symbol, _Row>();
  final Map<Variable, _Symbol> _vars = new Map<Variable, _Symbol>();
  final Map<Variable, _EditInfo> _edits = new Map<Variable, _EditInfo>();
  final List<_Symbol> _infeasibleRows = new List<_Symbol>();
  final _Row _objective = new _Row(0.0);
  _Row _artificial = new _Row(0.0);
15
  int tick = 1;
16

17 18 19 20
  /// Attempts to add the constraints in the list to the solver. If it cannot
  /// add any for some reason, a cleanup is attempted so that either all
  /// constraints will be added or none.
  Result addConstraints(List<Constraint> constraints) {
21 22
    _SolverBulkUpdate applier = (Constraint c) => addConstraint(c);
    _SolverBulkUpdate undoer = (Constraint c) => removeConstraint(c);
23

24
    return _bulkEdit(constraints, applier, undoer);
25 26
  }

27 28 29 30 31
  Result addConstraint(Constraint constraint) {
    if (_constraints.containsKey(constraint)) {
      return Result.duplicateConstraint;
    }

32 33
    _Tag tag = new _Tag(new _Symbol(_SymbolType.invalid, 0),
        new _Symbol(_SymbolType.invalid, 0));
34

35
    _Row row = _createRow(constraint, tag);
36

37
    _Symbol subject = _chooseSubjectForRow(row, tag);
38

39
    if (subject.type == _SymbolType.invalid && _allDummiesInRow(row)) {
40 41 42 43 44 45 46
      if (!_nearZero(row.constant)) {
        return Result.unsatisfiableConstraint;
      } else {
        subject = tag.marker;
      }
    }

47
    if (subject.type == _SymbolType.invalid) {
48 49 50 51 52 53 54 55 56 57 58
      if (!_addWithArtificialVariableOnRow(row)) {
        return Result.unsatisfiableConstraint;
      }
    } else {
      row.solveForSymbol(subject);
      _substitute(subject, row);
      _rows[subject] = row;
    }

    _constraints[constraint] = tag;

59
    return _optimizeObjectiveRow(_objective);
60 61
  }

62 63 64 65 66 67 68
  Result removeConstraints(List<Constraint> constraints) {
    _SolverBulkUpdate applier = (Constraint c) => removeConstraint(c);
    _SolverBulkUpdate undoer = (Constraint c) => addConstraint(c);

    return _bulkEdit(constraints, applier, undoer);
  }

69
  Result removeConstraint(Constraint constraint) {
70
    _Tag tag = _constraints[constraint];
71 72 73 74
    if (tag == null) {
      return Result.unknownConstraint;
    }

75
    tag = new _Tag.fromTag(tag);
76 77 78 79
    _constraints.remove(constraint);

    _removeConstraintEffects(constraint, tag);

80
    _Row row = _rows[tag.marker];
81 82 83
    if (row != null) {
      _rows.remove(tag.marker);
    } else {
84
      _Pair<_Symbol, _Row> rowPair = _leavingRowPairForMarkerSymbol(tag.marker);
85 86 87 88 89

      if (rowPair == null) {
        return Result.internalSolverError;
      }

90
      _Symbol leaving = rowPair.first;
91 92 93 94 95 96 97 98
      row = rowPair.second;
      var removed = _rows.remove(rowPair.first);
      assert(removed != null);
      row.solveForSymbols(leaving, tag.marker);
      _substitute(tag.marker, row);
    }

    return _optimizeObjectiveRow(_objective);
99 100
  }

101 102
  bool hasConstraint(Constraint constraint) {
    return _constraints.containsKey(constraint);
103 104
  }

105 106 107 108 109 110 111
  Result addEditVariables(List<Variable> variables, double priority) {
    _SolverBulkUpdate applier = (Variable v) => addEditVariable(v, priority);
    _SolverBulkUpdate undoer = (Variable v) => removeEditVariable(v);

    return _bulkEdit(variables, applier, undoer);
  }

112 113 114 115 116 117 118 119 120 121 122
  Result addEditVariable(Variable variable, double priority) {
    if (_edits.containsKey(variable)) {
      return Result.duplicateEditVariable;
    }

    if (!_isValidNonRequiredPriority(priority)) {
      return Result.badRequiredStrength;
    }

    Constraint constraint = new Constraint(
        new Expression([new Term(variable, 1.0)], 0.0), Relation.equalTo);
123
    constraint.priority = priority;
124 125 126 127 128

    if (addConstraint(constraint) != Result.success) {
      return Result.internalSolverError;
    }

129
    _EditInfo info = new _EditInfo();
130 131 132 133 134 135 136
    info.tag = _constraints[constraint];
    info.constraint = constraint;
    info.constant = 0.0;

    _edits[variable] = info;

    return Result.success;
137 138
  }

139 140 141 142 143 144 145 146
  Result removeEditVariables(List<Variable> variables) {
    _SolverBulkUpdate applier = (Variable v) => removeEditVariable(v);
    _SolverBulkUpdate undoer = (Variable v) =>
        addEditVariable(v, _edits[v].constraint.priority);

    return _bulkEdit(variables, applier, undoer);
  }

147
  Result removeEditVariable(Variable variable) {
148
    _EditInfo info = _edits[variable];
149
    if (info == null)
150 151
      return Result.unknownEditVariable;

152
    if (removeConstraint(info.constraint) != Result.success)
153 154 155 156
      return Result.internalSolverError;

    _edits.remove(variable);
    return Result.success;
157 158
  }

159 160
  bool hasEditVariable(Variable variable) {
    return _edits.containsKey(variable);
161 162
  }

163 164 165 166 167 168 169 170
  Result suggestValueForVariable(Variable variable, double value) {
    if (!_edits.containsKey(variable)) {
      return Result.unknownEditVariable;
    }

    _suggestValueForEditInfoWithoutDualOptimization(_edits[variable], value);

    return _dualOptimize();
171 172
  }

173
  Set flushUpdates() {
174
    Set updates = new HashSet<dynamic>();
175

176
    for (Variable variable in _vars.keys) {
177 178
      _Symbol symbol = _vars[variable];
      _Row row = _rows[symbol];
179 180 181 182

      double updatedValue = row == null ? 0.0 : row.constant;

      if (variable._applyUpdate(updatedValue) && variable._owner != null) {
183
        dynamic context = variable._owner.context;
184
        if (context != null)
185
          updates.add(context);
186 187
      }
    }
188

189
    return updates;
190
  }
191

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
  Result _bulkEdit(Iterable items,
                   _SolverBulkUpdate applier,
                   _SolverBulkUpdate undoer) {
    List applied = new List();
    bool needsCleanup = false;

    Result result = Result.success;

    for (dynamic item in items) {
      result = applier(item);
      if (result == Result.success) {
        applied.add(item);
      } else {
        needsCleanup = true;
        break;
      }
    }

    if (needsCleanup) {
211
      for (dynamic item in applied.reversed)
212 213 214 215 216 217
        undoer(item);
    }

    return result;
  }

218
  _Symbol _symbolForVariable(Variable variable) {
219
    _Symbol symbol = _vars[variable];
220

221
    if (symbol != null)
222 223
      return symbol;

224
    symbol = new _Symbol(_SymbolType.external, tick++);
225 226 227 228 229
    _vars[variable] = symbol;

    return symbol;
  }

230
  _Row _createRow(Constraint constraint, _Tag tag) {
231
    Expression expr = new Expression.fromExpression(constraint.expression);
232
    _Row row = new _Row(expr.constant);
233 234 235

    expr.terms.forEach((term) {
      if (!_nearZero(term.coefficient)) {
236
        _Symbol symbol = _symbolForVariable(term.variable);
237

238
        _Row foundRow = _rows[symbol];
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254

        if (foundRow != null) {
          row.insertRow(foundRow, term.coefficient);
        } else {
          row.insertSymbol(symbol, term.coefficient);
        }
      }
    });

    switch (constraint.relation) {
      case Relation.lessThanOrEqualTo:
      case Relation.greaterThanOrEqualTo:
        {
          double coefficient =
              constraint.relation == Relation.lessThanOrEqualTo ? 1.0 : -1.0;

255
          _Symbol slack = new _Symbol(_SymbolType.slack, tick++);
256 257 258
          tag.marker = slack;
          row.insertSymbol(slack, coefficient);

259
          if (constraint.priority < Priority.required) {
260
            _Symbol error = new _Symbol(_SymbolType.error, tick++);
261 262 263 264 265 266 267
            tag.other = error;
            row.insertSymbol(error, -coefficient);
            _objective.insertSymbol(error, constraint.priority);
          }
        }
        break;
      case Relation.equalTo:
268
        if (constraint.priority < Priority.required) {
269 270
          _Symbol errPlus = new _Symbol(_SymbolType.error, tick++);
          _Symbol errMinus = new _Symbol(_SymbolType.error, tick++);
271 272 273 274 275 276 277
          tag.marker = errPlus;
          tag.other = errMinus;
          row.insertSymbol(errPlus, -1.0);
          row.insertSymbol(errMinus, 1.0);
          _objective.insertSymbol(errPlus, constraint.priority);
          _objective.insertSymbol(errMinus, constraint.priority);
        } else {
278
          _Symbol dummy = new _Symbol(_SymbolType.dummy, tick++);
279 280 281 282 283 284 285 286 287 288 289 290 291
          tag.marker = dummy;
          row.insertSymbol(dummy);
        }
        break;
    }

    if (row.constant < 0.0) {
      row.reverseSign();
    }

    return row;
  }

292 293
  _Symbol _chooseSubjectForRow(_Row row, _Tag tag) {
    for (_Symbol symbol in row.cells.keys) {
294
      if (symbol.type == _SymbolType.external) {
295 296 297 298
        return symbol;
      }
    }

299 300
    if (tag.marker.type == _SymbolType.slack ||
        tag.marker.type == _SymbolType.error) {
301 302 303 304 305
      if (row.coefficientForSymbol(tag.marker) < 0.0) {
        return tag.marker;
      }
    }

306 307
    if (tag.other.type == _SymbolType.slack ||
        tag.other.type == _SymbolType.error) {
308 309 310 311 312
      if (row.coefficientForSymbol(tag.other) < 0.0) {
        return tag.other;
      }
    }

313
    return new _Symbol(_SymbolType.invalid, 0);
314 315
  }

316 317
  bool _allDummiesInRow(_Row row) {
    for (_Symbol symbol in row.cells.keys) {
318
      if (symbol.type != _SymbolType.dummy) {
319 320 321 322 323 324
        return false;
      }
    }
    return true;
  }

325
  bool _addWithArtificialVariableOnRow(_Row row) {
326
    _Symbol artificial = new _Symbol(_SymbolType.slack, tick++);
327 328
    _rows[artificial] = new _Row.fromRow(row);
    _artificial = new _Row.fromRow(row);
329 330 331 332 333 334 335 336 337

    Result result = _optimizeObjectiveRow(_artificial);

    if (result.error) {
      // FIXME(csg): Propagate this up!
      return false;
    }

    bool success = _nearZero(_artificial.constant);
338
    _artificial = new _Row(0.0);
339

340
    _Row foundRow = _rows[artificial];
341 342 343 344 345 346
    if (foundRow != null) {
      _rows.remove(artificial);
      if (foundRow.cells.isEmpty) {
        return success;
      }

347
      _Symbol entering = _anyPivotableSymbol(foundRow);
348
      if (entering.type == _SymbolType.invalid) {
349 350 351 352 353 354 355 356
        return false;
      }

      foundRow.solveForSymbols(artificial, entering);
      _substitute(entering, foundRow);
      _rows[entering] = foundRow;
    }

357
    for (_Row row in _rows.values) {
358 359 360 361 362 363
      row.removeSymbol(artificial);
    }
    _objective.removeSymbol(artificial);
    return success;
  }

364
  Result _optimizeObjectiveRow(_Row objective) {
365
    while (true) {
366
      _Symbol entering = _enteringSymbolForObjectiveRow(objective);
367
      if (entering.type == _SymbolType.invalid) {
368 369 370
        return Result.success;
      }

371
      _Pair<_Symbol, _Row> leavingPair = _leavingRowForEnteringSymbol(entering);
372 373 374 375 376

      if (leavingPair == null) {
        return Result.internalSolverError;
      }

377 378
      _Symbol leaving = leavingPair.first;
      _Row row = leavingPair.second;
379 380 381 382 383 384 385
      _rows.remove(leavingPair.first);
      row.solveForSymbols(leaving, entering);
      _substitute(entering, row);
      _rows[entering] = row;
    }
  }

386
  _Symbol _enteringSymbolForObjectiveRow(_Row objective) {
387
    Map<_Symbol, double> cells = objective.cells;
388

389
    for (_Symbol symbol in cells.keys) {
390
      if (symbol.type != _SymbolType.dummy && cells[symbol] < 0.0) {
391 392 393 394
        return symbol;
      }
    }

395
    return new _Symbol(_SymbolType.invalid, 0);
396 397
  }

398
  _Pair<_Symbol, _Row> _leavingRowForEnteringSymbol(_Symbol entering) {
399
    double ratio = double.MAX_FINITE;
400
    _Pair<_Symbol, _Row> result = new _Pair(null, null);
401 402

    _rows.forEach((symbol, row) {
403
      if (symbol.type != _SymbolType.external) {
404 405 406
        double temp = row.coefficientForSymbol(entering);

        if (temp < 0.0) {
Hixie's avatar
Hixie committed
407
          double tempRatio = -row.constant / temp;
408

Hixie's avatar
Hixie committed
409 410
          if (tempRatio < ratio) {
            ratio = tempRatio;
411 412 413 414 415 416 417 418 419 420 421 422 423 424
            result.first = symbol;
            result.second = row;
          }
        }
      }
    });

    if (result.first == null || result.second == null) {
      return null;
    }

    return result;
  }

425
  void _substitute(_Symbol symbol, _Row row) {
426 427
    _rows.forEach((first, second) {
      second.substitute(symbol, row);
428
      if (first.type != _SymbolType.external && second.constant < 0.0) {
429 430 431 432 433 434 435 436 437 438
        _infeasibleRows.add(first);
      }
    });

    _objective.substitute(symbol, row);
    if (_artificial != null) {
      _artificial.substitute(symbol, row);
    }
  }

439 440
  _Symbol _anyPivotableSymbol(_Row row) {
    for (_Symbol symbol in row.cells.keys) {
441 442
      if (symbol.type == _SymbolType.slack ||
          symbol.type == _SymbolType.error) {
443 444 445
        return symbol;
      }
    }
446
    return new _Symbol(_SymbolType.invalid, 0);
447
  }
448

449
  void _removeConstraintEffects(Constraint cn, _Tag tag) {
450
    if (tag.marker.type == _SymbolType.error) {
451 452
      _removeMarkerEffects(tag.marker, cn.priority);
    }
453
    if (tag.other.type == _SymbolType.error) {
454 455 456 457
      _removeMarkerEffects(tag.other, cn.priority);
    }
  }

458 459
  void _removeMarkerEffects(_Symbol marker, double strength) {
    _Row row = _rows[marker];
460 461 462 463 464 465 466
    if (row != null) {
      _objective.insertRow(row, -strength);
    } else {
      _objective.insertSymbol(marker, -strength);
    }
  }

467
  _Pair<_Symbol, _Row> _leavingRowPairForMarkerSymbol(_Symbol marker) {
468 469 470
    double r1 = double.MAX_FINITE;
    double r2 = double.MAX_FINITE;

471
    _Pair<_Symbol, _Row> first, second, third;
472 473 474 475 476 477 478 479

    _rows.forEach((symbol, row) {
      double c = row.coefficientForSymbol(marker);

      if (c == 0.0) {
        return;
      }

480
      if (symbol.type == _SymbolType.external) {
481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504
        third = new _Pair(symbol, row);
      } else if (c < 0.0) {
        double r = -row.constant / c;
        if (r < r1) {
          r1 = r;
          first = new _Pair(symbol, row);
        }
      } else {
        double r = row.constant / c;
        if (r < r2) {
          r2 = r;
          second = new _Pair(symbol, row);
        }
      }
    });

    if (first != null) {
      return first;
    }
    if (second != null) {
      return second;
    }
    return third;
  }
505 506

  void _suggestValueForEditInfoWithoutDualOptimization(
507
      _EditInfo info, double value) {
508 509 510 511
    double delta = value - info.constant;
    info.constant = value;

    {
512 513
      _Symbol symbol = info.tag.marker;
      _Row row = _rows[info.tag.marker];
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532

      if (row != null) {
        if (row.add(-delta) < 0.0) {
          _infeasibleRows.add(symbol);
        }
        return;
      }

      symbol = info.tag.other;
      row = _rows[info.tag.other];

      if (row != null) {
        if (row.add(delta) < 0.0) {
          _infeasibleRows.add(symbol);
        }
        return;
      }
    }

533 534
    for (_Symbol symbol in _rows.keys) {
      _Row row = _rows[symbol];
535 536 537
      double coeff = row.coefficientForSymbol(info.tag.marker);
      if (coeff != 0.0 &&
          row.add(delta * coeff) < 0.0 &&
538
          symbol.type != _SymbolType.external) {
539 540 541 542 543 544 545
        _infeasibleRows.add(symbol);
      }
    }
  }

  Result _dualOptimize() {
    while (_infeasibleRows.length != 0) {
546 547
      _Symbol leaving = _infeasibleRows.removeLast();
      _Row row = _rows[leaving];
548 549

      if (row != null && row.constant < 0.0) {
550
        _Symbol entering = _dualEnteringSymbolForRow(row);
551

552
        if (entering.type == _SymbolType.invalid) {
553 554 555 556 557 558 559 560 561 562 563 564 565
          return Result.internalSolverError;
        }

        _rows.remove(leaving);

        row.solveForSymbols(leaving, entering);
        _substitute(entering, row);
        _rows[entering] = row;
      }
    }
    return Result.success;
  }

566
  _Symbol _dualEnteringSymbolForRow(_Row row) {
567
    _Symbol entering;
568 569 570

    double ratio = double.MAX_FINITE;

571
    Map<_Symbol, double> rowCells = row.cells;
572

573
    for (_Symbol symbol in rowCells.keys) {
574 575
      double value = rowCells[symbol];

576
      if (value > 0.0 && symbol.type != _SymbolType.dummy) {
577 578 579 580 581 582 583 584 585
        double coeff = _objective.coefficientForSymbol(symbol);
        double r = coeff / value;
        if (r < ratio) {
          ratio = r;
          entering = symbol;
        }
      }
    }

586
    return _elvis(entering, new _Symbol(_SymbolType.invalid, 0));
587
  }
588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623

  String toString() {
    StringBuffer buffer = new StringBuffer();
    String separator = "\n~~~~~~~~~";

    // Objective
    buffer.writeln(separator + " Objective");
    buffer.writeln(_objective.toString());

    // Tableau
    buffer.writeln(separator + " Tableau");
    _rows.forEach((symbol, row) {
      buffer.write(symbol.toString());
      buffer.write(" | ");
      buffer.writeln(row.toString());
    });

    // Infeasible
    buffer.writeln(separator + " Infeasible");
    _infeasibleRows.forEach((symbol) => buffer.writeln(symbol.toString()));

    // Variables
    buffer.writeln(separator + " Variables");
    _vars.forEach((variable, symbol) =>
        buffer.writeln("${variable.toString()} = ${symbol.toString()}"));

    // Edit Variables
    buffer.writeln(separator + " Edit Variables");
    _edits.forEach((variable, editinfo) => buffer.writeln(variable));

    // Constraints
    buffer.writeln(separator + " Constraints");
    _constraints.forEach((constraint, _) => buffer.writeln(constraint));

    return buffer.toString();
  }
624
}
625

626 627 628
class _Tag {
  _Symbol marker;
  _Symbol other;
629

630 631
  _Tag(this.marker, this.other);
  _Tag.fromTag(_Tag tag)
632 633
      : this.marker = tag.marker,
        this.other = tag.other;
634 635
}

636 637
class _EditInfo {
  _Tag tag;
638 639 640
  Constraint constraint;
  double constant;
}
641 642

bool _isValidNonRequiredPriority(double priority) {
643
  return (priority >= 0.0 && priority < Priority.required);
644
}
645 646

typedef Result _SolverBulkUpdate(dynamic item);