Commit b78b35d7 authored by Chinmay Garde's avatar Chinmay Garde

Implement addition of constraints to the solver

parent af67d087
......@@ -11,6 +11,9 @@ class Expression extends EquationMember {
double get value => terms.fold(constant, (value, term) => value + term.value);
Expression(this.terms, this.constant);
Expression.fromExpression(Expression expr)
: this.terms = new List<Term>.from(expr.terms),
this.constant = expr.constant;
Expression asExpression() => this;
......
......@@ -5,41 +5,43 @@
part of cassowary;
class Row {
final Map<Symbol, double> _cells = new Map<Symbol, double>();
double _constant = 0.0;
final Map<Symbol, double> cells;
double constant = 0.0;
double get constant => _constant;
Map<Symbol, double> get cells => _cells;
Row(this.constant) : this.cells = new Map<Symbol, double>();
Row.fromRow(Row row)
: this.cells = new Map<Symbol, double>.from(row.cells),
this.constant = row.constant;
double add(double value) => _constant += value;
double add(double value) => constant += value;
void insertSymbol(Symbol symbol, [double coefficient = 1.0]) {
double val = _elvis(_cells[symbol], 0.0) + coefficient;
double val = _elvis(cells[symbol], 0.0) + coefficient;
if (_nearZero(val)) {
_cells.remove(symbol);
cells.remove(symbol);
} else {
_cells[symbol] = val + coefficient;
cells[symbol] = val + coefficient;
}
}
void insertRow(Row other, [double coefficient = 1.0]) {
_constant += other.constant * coefficient;
constant += other.constant * coefficient;
other.cells.forEach((s, v) => insertSymbol(s, v * coefficient));
}
void removeSymbol(Symbol symbol) {
_cells.remove(symbol);
cells.remove(symbol);
}
void reverseSign() => _cells.forEach((s, v) => _cells[s] = -v);
void reverseSign() => cells.forEach((s, v) => cells[s] = -v);
void solveForSymbol(Symbol symbol) {
assert(_cells.containsKey(symbol));
double coefficient = -1.0 / _cells[symbol];
_cells.remove(symbol);
_constant *= coefficient;
_cells.forEach((s, v) => _cells[s] = v * coefficient);
assert(cells.containsKey(symbol));
double coefficient = -1.0 / cells[symbol];
cells.remove(symbol);
constant *= coefficient;
cells.forEach((s, v) => cells[s] = v * coefficient);
}
void solveForSymbols(Symbol lhs, Symbol rhs) {
......@@ -47,16 +49,16 @@ class Row {
solveForSymbol(rhs);
}
double coefficientForSymbol(Symbol symbol) => _elvis(_cells[symbol], 0.0);
double coefficientForSymbol(Symbol symbol) => _elvis(cells[symbol], 0.0);
void substitute(Symbol symbol, Row row) {
double coefficient = _cells[symbol];
double coefficient = cells[symbol];
if (coefficient == null) {
return;
}
_cells.remove(symbol);
cells.remove(symbol);
insertRow(row, coefficient);
}
}
......@@ -10,11 +10,45 @@ class Solver {
final Map<Variable, Symbol> _vars = new Map<Variable, Symbol>();
final Map<Variable, EditInfo> _edits = new Map<Variable, EditInfo>();
final List<Symbol> _infeasibleRows = new List<Symbol>();
final Row _objective = new Row();
final Row _artificial = new Row();
final Row _objective = new Row(0.0);
Row _artificial = new Row(0.0);
int tick = 0;
Result addConstraint(Constraint c) {
return Result.unimplemented;
Result addConstraint(Constraint constraint) {
if (_constraints.containsKey(constraint)) {
return Result.duplicateConstraint;
}
Tag tag = new Tag(
new Symbol(SymbolType.invalid, 0), new Symbol(SymbolType.invalid, 0));
Row row = _createRow(constraint, tag);
Symbol subject = _chooseSubjectForRow(row, tag);
if (subject.type == SymbolType.invalid && _allDummiesInRow(row)) {
if (!_nearZero(row.constant)) {
return Result.unsatisfiableConstraint;
} else {
subject = tag.marker;
}
}
if (subject.type == SymbolType.invalid) {
if (!_addWithArtificialVariableOnRow(row)) {
return Result.unsatisfiableConstraint;
}
} else {
row.solveForSymbol(subject);
_substitute(subject, row);
_rows[subject] = row;
}
_constraints[constraint] = tag;
_optimizeObjectiveRow(_objective);
return Result.success;
}
Result removeContraint(Constraint c) {
......@@ -44,6 +78,238 @@ class Solver {
void updateVariable() {}
Solver operator <<(Constraint c) => this..addConstraint(c);
Symbol _getSymbolForVariable(Variable variable) {
Symbol symbol = _vars[variable];
if (symbol != null) {
return symbol;
}
symbol = new Symbol(SymbolType.external, tick++);
_vars[variable] = symbol;
return symbol;
}
Row _createRow(Constraint constraint, Tag tag) {
Expression expr = new Expression.fromExpression(constraint.expression);
Row row = new Row(expr.constant);
expr.terms.forEach((term) {
if (!_nearZero(term.coefficient)) {
Symbol symbol = _getSymbolForVariable(term.variable);
Row foundRow = _rows[symbol];
if (foundRow != null) {
row.insertRow(foundRow, term.coefficient);
} else {
row.insertSymbol(symbol, term.coefficient);
}
}
});
switch (constraint.relation) {
case Relation.lessThanOrEqualTo:
case Relation.greaterThanOrEqualTo:
{
double coefficient =
constraint.relation == Relation.lessThanOrEqualTo ? 1.0 : -1.0;
Symbol slack = new Symbol(SymbolType.slack, tick++);
tag.marker = slack;
row.insertSymbol(slack, coefficient);
if (!constraint.required) {
Symbol error = new Symbol(SymbolType.error, tick++);
tag.other = error;
row.insertSymbol(error, -coefficient);
_objective.insertSymbol(error, constraint.priority);
}
}
break;
case Relation.equalTo:
if (!constraint.required) {
Symbol errPlus = new Symbol(SymbolType.error, tick++);
Symbol errMinus = new Symbol(SymbolType.error, tick++);
tag.marker = errPlus;
tag.other = errMinus;
row.insertSymbol(errPlus, -1.0);
row.insertSymbol(errMinus, 1.0);
_objective.insertSymbol(errPlus, constraint.priority);
_objective.insertSymbol(errMinus, constraint.priority);
} else {
Symbol dummy = new Symbol(SymbolType.dummy, tick++);
tag.marker = dummy;
row.insertSymbol(dummy);
}
break;
}
if (row.constant < 0.0) {
row.reverseSign();
}
return row;
}
Symbol _chooseSubjectForRow(Row row, Tag tag) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type == SymbolType.external) {
return symbol;
}
}
if (tag.marker.type == SymbolType.slack ||
tag.marker.type == SymbolType.error) {
if (row.coefficientForSymbol(tag.marker) < 0.0) {
return tag.marker;
}
}
if (tag.other.type == SymbolType.slack ||
tag.other.type == SymbolType.error) {
if (row.coefficientForSymbol(tag.other) < 0.0) {
return tag.other;
}
}
return new Symbol(SymbolType.invalid, 0);
}
bool _allDummiesInRow(Row row) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type != SymbolType.dummy) {
return false;
}
}
return true;
}
bool _addWithArtificialVariableOnRow(Row row) {
Symbol artificial = new Symbol(SymbolType.slack, tick++);
_rows[artificial] = new Row.fromRow(row);
_artificial = new Row.fromRow(row);
Result result = _optimizeObjectiveRow(_artificial);
if (result.error) {
// FIXME(csg): Propagate this up!
return false;
}
bool success = _nearZero(_artificial.constant);
_artificial = new Row(0.0);
Row foundRow = _rows[artificial];
if (foundRow != null) {
_rows.remove(artificial);
if (foundRow.cells.isEmpty) {
return success;
}
Symbol entering = _anyPivotableSymbol(foundRow);
if (entering.type == SymbolType.invalid) {
return false;
}
foundRow.solveForSymbols(artificial, entering);
_substitute(entering, foundRow);
_rows[entering] = foundRow;
}
for (Row row in _rows.values) {
row.removeSymbol(artificial);
}
_objective.removeSymbol(artificial);
return success;
}
Result _optimizeObjectiveRow(Row objective) {
while (true) {
Symbol entering = _getEnteringSymbolForObjectiveRow(objective);
if (entering.type == SymbolType.invalid) {
return Result.success;
}
_Pair<Symbol, Row> leavingPair =
_getLeavingRowForEnteringSymbol(entering);
if (leavingPair == null) {
return Result.internalSolverError;
}
Symbol leaving = leavingPair.first;
Row row = leavingPair.second;
_rows.remove(leavingPair.first);
row.solveForSymbols(leaving, entering);
_substitute(entering, row);
_rows[entering] = row;
}
}
Symbol _getEnteringSymbolForObjectiveRow(Row objective) {
Map<Symbol, double> cells = objective.cells;
for (Symbol symbol in cells.keys) {
if (symbol.type != SymbolType.dummy && cells[symbol] < 0.0) {
return symbol;
}
}
return new Symbol(SymbolType.invalid, 0);
}
_Pair<Symbol, Row> _getLeavingRowForEnteringSymbol(Symbol entering) {
double ratio = double.MAX_FINITE;
_Pair<Symbol, Row> result = new _Pair(null, null);
_rows.forEach((symbol, row) {
if (symbol.type != SymbolType.external) {
double temp = row.coefficientForSymbol(entering);
if (temp < 0.0) {
double temp_ratio = -row.constant / temp;
if (temp_ratio < ratio) {
ratio = temp_ratio;
result.first = symbol;
result.second = row;
}
}
}
});
if (result.first == null || result.second == null) {
return null;
}
return result;
}
void _substitute(Symbol symbol, Row row) {
_rows.forEach((first, second) {
second.substitute(symbol, row);
if (first.type != SymbolType.external && second.constant < 0.0) {
_infeasibleRows.add(first);
}
});
_objective.substitute(symbol, row);
if (_artificial != null) {
_artificial.substitute(symbol, row);
}
}
Symbol _anyPivotableSymbol(Row row) {
for (Symbol symbol in row.cells.keys) {
if (symbol.type == SymbolType.slack || symbol.type == SymbolType.error) {
return symbol;
}
}
return new Symbol(SymbolType.invalid, 0);
}
}
class Tag {
......
......@@ -7,7 +7,8 @@ part of cassowary;
enum SymbolType { invalid, external, slack, error, dummy, }
class Symbol {
SymbolType type;
final SymbolType type;
int tick;
Symbol(this.type);
Symbol(this.type, this.tick);
}
......@@ -13,3 +13,9 @@ bool _nearZero(double value) {
// instead. Sadly, due the lack of generic types on functions, we have to use
// dynamic instead.
_elvis(a, b) => a != null ? a : b;
class _Pair<X, Y> {
X first;
Y second;
_Pair(this.first, this.second);
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment