Commit b78b35d7 authored by Chinmay Garde's avatar Chinmay Garde

Implement addition of constraints to the solver

parent af67d087
...@@ -11,6 +11,9 @@ class Expression extends EquationMember { ...@@ -11,6 +11,9 @@ class Expression extends EquationMember {
double get value => terms.fold(constant, (value, term) => value + term.value); double get value => terms.fold(constant, (value, term) => value + term.value);
Expression(this.terms, this.constant); Expression(this.terms, this.constant);
Expression.fromExpression(Expression expr)
: this.terms = new List<Term>.from(expr.terms),
this.constant = expr.constant;
Expression asExpression() => this; Expression asExpression() => this;
......
...@@ -5,41 +5,43 @@ ...@@ -5,41 +5,43 @@
part of cassowary; part of cassowary;
class Row { class Row {
final Map<Symbol, double> _cells = new Map<Symbol, double>(); final Map<Symbol, double> cells;
double _constant = 0.0; double constant = 0.0;
double get constant => _constant; Row(this.constant) : this.cells = new Map<Symbol, double>();
Map<Symbol, double> get cells => _cells; Row.fromRow(Row row)
: this.cells = new Map<Symbol, double>.from(row.cells),
this.constant = row.constant;
double add(double value) => _constant += value; double add(double value) => constant += value;
void insertSymbol(Symbol symbol, [double coefficient = 1.0]) { void insertSymbol(Symbol symbol, [double coefficient = 1.0]) {
double val = _elvis(_cells[symbol], 0.0) + coefficient; double val = _elvis(cells[symbol], 0.0) + coefficient;
if (_nearZero(val)) { if (_nearZero(val)) {
_cells.remove(symbol); cells.remove(symbol);
} else { } else {
_cells[symbol] = val + coefficient; cells[symbol] = val + coefficient;
} }
} }
void insertRow(Row other, [double coefficient = 1.0]) { void insertRow(Row other, [double coefficient = 1.0]) {
_constant += other.constant * coefficient; constant += other.constant * coefficient;
other.cells.forEach((s, v) => insertSymbol(s, v * coefficient)); other.cells.forEach((s, v) => insertSymbol(s, v * coefficient));
} }
void removeSymbol(Symbol symbol) { void removeSymbol(Symbol symbol) {
_cells.remove(symbol); cells.remove(symbol);
} }
void reverseSign() => _cells.forEach((s, v) => _cells[s] = -v); void reverseSign() => cells.forEach((s, v) => cells[s] = -v);
void solveForSymbol(Symbol symbol) { void solveForSymbol(Symbol symbol) {
assert(_cells.containsKey(symbol)); assert(cells.containsKey(symbol));
double coefficient = -1.0 / _cells[symbol]; double coefficient = -1.0 / cells[symbol];
_cells.remove(symbol); cells.remove(symbol);
_constant *= coefficient; constant *= coefficient;
_cells.forEach((s, v) => _cells[s] = v * coefficient); cells.forEach((s, v) => cells[s] = v * coefficient);
} }
void solveForSymbols(Symbol lhs, Symbol rhs) { void solveForSymbols(Symbol lhs, Symbol rhs) {
...@@ -47,16 +49,16 @@ class Row { ...@@ -47,16 +49,16 @@ class Row {
solveForSymbol(rhs); solveForSymbol(rhs);
} }
double coefficientForSymbol(Symbol symbol) => _elvis(_cells[symbol], 0.0); double coefficientForSymbol(Symbol symbol) => _elvis(cells[symbol], 0.0);
void substitute(Symbol symbol, Row row) { void substitute(Symbol symbol, Row row) {
double coefficient = _cells[symbol]; double coefficient = cells[symbol];
if (coefficient == null) { if (coefficient == null) {
return; return;
} }
_cells.remove(symbol); cells.remove(symbol);
insertRow(row, coefficient); insertRow(row, coefficient);
} }
} }
...@@ -10,11 +10,45 @@ class Solver { ...@@ -10,11 +10,45 @@ class Solver {
final Map<Variable, Symbol> _vars = new Map<Variable, Symbol>(); final Map<Variable, Symbol> _vars = new Map<Variable, Symbol>();
final Map<Variable, EditInfo> _edits = new Map<Variable, EditInfo>(); final Map<Variable, EditInfo> _edits = new Map<Variable, EditInfo>();
final List<Symbol> _infeasibleRows = new List<Symbol>(); final List<Symbol> _infeasibleRows = new List<Symbol>();
final Row _objective = new Row(); final Row _objective = new Row(0.0);
final Row _artificial = new Row(); Row _artificial = new Row(0.0);
int tick = 0;
Result addConstraint(Constraint c) { Result addConstraint(Constraint constraint) {
return Result.unimplemented; if (_constraints.containsKey(constraint)) {
return Result.duplicateConstraint;
}
Tag tag = new Tag(
new Symbol(SymbolType.invalid, 0), new Symbol(SymbolType.invalid, 0));
Row row = _createRow(constraint, tag);
Symbol subject = _chooseSubjectForRow(row, tag);
if (subject.type == SymbolType.invalid && _allDummiesInRow(row)) {
if (!_nearZero(row.constant)) {
return Result.unsatisfiableConstraint;
} else {
subject = tag.marker;
}
}
if (subject.type == SymbolType.invalid) {
if (!_addWithArtificialVariableOnRow(row)) {
return Result.unsatisfiableConstraint;
}
} else {
row.solveForSymbol(subject);
_substitute(subject, row);
_rows[subject] = row;
}
_constraints[constraint] = tag;
_optimizeObjectiveRow(_objective);
return Result.success;
} }
Result removeContraint(Constraint c) { Result removeContraint(Constraint c) {
...@@ -44,6 +78,238 @@ class Solver { ...@@ -44,6 +78,238 @@ class Solver {
void updateVariable() {} void updateVariable() {}
Solver operator <<(Constraint c) => this..addConstraint(c); Solver operator <<(Constraint c) => this..addConstraint(c);
Symbol _getSymbolForVariable(Variable variable) {
Symbol symbol = _vars[variable];
if (symbol != null) {
return symbol;
}
symbol = new Symbol(SymbolType.external, tick++);
_vars[variable] = symbol;
return symbol;
}
Row _createRow(Constraint constraint, Tag tag) {
Expression expr = new Expression.fromExpression(constraint.expression);
Row row = new Row(expr.constant);
expr.terms.forEach((term) {
if (!_nearZero(term.coefficient)) {
Symbol symbol = _getSymbolForVariable(term.variable);
Row foundRow = _rows[symbol];
if (foundRow != null) {
row.insertRow(foundRow, term.coefficient);
} else {
row.insertSymbol(symbol, term.coefficient);
}
}
});
switch (constraint.relation) {
case Relation.lessThanOrEqualTo:
case Relation.greaterThanOrEqualTo:
{
double coefficient =
constraint.relation == Relation.lessThanOrEqualTo ? 1.0 : -1.0;
Symbol slack = new Symbol(SymbolType.slack, tick++);
tag.marker = slack;
row.insertSymbol(slack, coefficient);
if (!constraint.required) {
Symbol error = new Symbol(SymbolType.error, tick++);
tag.other = error;
row.insertSymbol(error, -coefficient);
_objective.insertSymbol(error, constraint.priority);
}
}
break;
case Relation.equalTo:
if (!constraint.required) {
Symbol errPlus = new Symbol(SymbolType.error, tick++);
Symbol errMinus = new Symbol(SymbolType.error, tick++);
tag.marker = errPlus;
tag.other = errMinus;
row.insertSymbol(errPlus, -1.0);
row.insertSymbol(errMinus, 1.0);
_objective.insertSymbol(errPlus, constraint.priority);
_objective.insertSymbol(errMinus, constraint.priority);
} else {
Symbol dummy = new Symbol(SymbolType.dummy, tick++);
tag.marker = dummy;
row.insertSymbol(dummy);
}
break;
}
if (row.constant < 0.0) {
row.reverseSign();
}
return row;
}
Symbol _chooseSubjectForRow(Row row, Tag tag) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type == SymbolType.external) {
return symbol;
}
}
if (tag.marker.type == SymbolType.slack ||
tag.marker.type == SymbolType.error) {
if (row.coefficientForSymbol(tag.marker) < 0.0) {
return tag.marker;
}
}
if (tag.other.type == SymbolType.slack ||
tag.other.type == SymbolType.error) {
if (row.coefficientForSymbol(tag.other) < 0.0) {
return tag.other;
}
}
return new Symbol(SymbolType.invalid, 0);
}
bool _allDummiesInRow(Row row) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type != SymbolType.dummy) {
return false;
}
}
return true;
}
bool _addWithArtificialVariableOnRow(Row row) {
Symbol artificial = new Symbol(SymbolType.slack, tick++);
_rows[artificial] = new Row.fromRow(row);
_artificial = new Row.fromRow(row);
Result result = _optimizeObjectiveRow(_artificial);
if (result.error) {
// FIXME(csg): Propagate this up!
return false;
}
bool success = _nearZero(_artificial.constant);
_artificial = new Row(0.0);
Row foundRow = _rows[artificial];
if (foundRow != null) {
_rows.remove(artificial);
if (foundRow.cells.isEmpty) {
return success;
}
Symbol entering = _anyPivotableSymbol(foundRow);
if (entering.type == SymbolType.invalid) {
return false;
}
foundRow.solveForSymbols(artificial, entering);
_substitute(entering, foundRow);
_rows[entering] = foundRow;
}
for (Row row in _rows.values) {
row.removeSymbol(artificial);
}
_objective.removeSymbol(artificial);
return success;
}
Result _optimizeObjectiveRow(Row objective) {
while (true) {
Symbol entering = _getEnteringSymbolForObjectiveRow(objective);
if (entering.type == SymbolType.invalid) {
return Result.success;
}
_Pair<Symbol, Row> leavingPair =
_getLeavingRowForEnteringSymbol(entering);
if (leavingPair == null) {
return Result.internalSolverError;
}
Symbol leaving = leavingPair.first;
Row row = leavingPair.second;
_rows.remove(leavingPair.first);
row.solveForSymbols(leaving, entering);
_substitute(entering, row);
_rows[entering] = row;
}
}
Symbol _getEnteringSymbolForObjectiveRow(Row objective) {
Map<Symbol, double> cells = objective.cells;
for (Symbol symbol in cells.keys) {
if (symbol.type != SymbolType.dummy && cells[symbol] < 0.0) {
return symbol;
}
}
return new Symbol(SymbolType.invalid, 0);
}
_Pair<Symbol, Row> _getLeavingRowForEnteringSymbol(Symbol entering) {
double ratio = double.MAX_FINITE;
_Pair<Symbol, Row> result = new _Pair(null, null);
_rows.forEach((symbol, row) {
if (symbol.type != SymbolType.external) {
double temp = row.coefficientForSymbol(entering);
if (temp < 0.0) {
double temp_ratio = -row.constant / temp;
if (temp_ratio < ratio) {
ratio = temp_ratio;
result.first = symbol;
result.second = row;
}
}
}
});
if (result.first == null || result.second == null) {
return null;
}
return result;
}
void _substitute(Symbol symbol, Row row) {
_rows.forEach((first, second) {
second.substitute(symbol, row);
if (first.type != SymbolType.external && second.constant < 0.0) {
_infeasibleRows.add(first);
}
});
_objective.substitute(symbol, row);
if (_artificial != null) {
_artificial.substitute(symbol, row);
}
}
Symbol _anyPivotableSymbol(Row row) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type == SymbolType.slack || symbol.type == SymbolType.error) {
return symbol;
}
}
return new Symbol(SymbolType.invalid, 0);
}
} }
class Tag { class Tag {
......
...@@ -7,7 +7,8 @@ part of cassowary; ...@@ -7,7 +7,8 @@ part of cassowary;
enum SymbolType { invalid, external, slack, error, dummy, } enum SymbolType { invalid, external, slack, error, dummy, }
class Symbol { class Symbol {
SymbolType type; final SymbolType type;
int tick;
Symbol(this.type); Symbol(this.type, this.tick);
} }
...@@ -13,3 +13,9 @@ bool _nearZero(double value) { ...@@ -13,3 +13,9 @@ bool _nearZero(double value) {
// instead. Sadly, due the lack of generic types on functions, we have to use // instead. Sadly, due the lack of generic types on functions, we have to use
// dynamic instead. // dynamic instead.
_elvis(a, b) => a != null ? a : b; _elvis(a, b) => a != null ? a : b;
class _Pair<X, Y> {
X first;
Y second;
_Pair(this.first, this.second);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment