Skip to content

Commit

Permalink
some memory optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent Perron committed Sep 3, 2019
1 parent a6dd04c commit 3d7a5c0
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 67 deletions.
51 changes: 32 additions & 19 deletions ortools/sat/cp_model_expand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,68 +419,76 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
const int size = element.vars_size();
if (!context->IntersectDomainWith(index_ref, Domain(0, size - 1))) {
VLOG(1) << "Empty domain for the index variable in ExpandElement()";
CHECK(!context->NotifyThatModelIsUnsat());
return;
}

bool all_constants = true;
std::set<int64> reached_values;
absl::flat_hash_set<int64> constant_var_values;
std::vector<int64> invalid_indices;
const Domain initial_index_domain = context->DomainOf(index_ref);
const Domain initial_target_domain = context->DomainOf(target_ref);
for (const ClosedInterval& interval : initial_index_domain) {
Domain index_domain = context->DomainOf(index_ref);
Domain target_domain = context->DomainOf(target_ref);
for (const ClosedInterval& interval : index_domain) {
for (int64 v = interval.start; v <= interval.end; ++v) {
const int var = element.vars(v);
const Domain var_domain = context->DomainOf(var);
if (var_domain.IntersectionWith(initial_target_domain).IsEmpty()) {
if (var_domain.IntersectionWith(target_domain).IsEmpty()) {
invalid_indices.push_back(v);
continue;
}
if (var_domain.Min() != var_domain.Max()) {
all_constants = false;
break;
}
reached_values.insert(var_domain.Min());
constant_var_values.insert(var_domain.Min());
}
}

if (!invalid_indices.empty()) {
if (!context->IntersectDomainWith(
index_ref, Domain::FromValues(invalid_indices).Complement())) {
VLOG(1) << "No compatible variable domains in ExpandElement()";
CHECK(!context->NotifyThatModelIsUnsat());
return;
}
}

const Domain index_domain = context->DomainOf(index_ref);
// Re-read the domain.
index_domain = context->DomainOf(index_ref);
}

std::map<int64, BoolArgumentProto*> supports;
// This BoolOrs implements the deduction that if all index literals pointing
// to the same values in the constant array are false, then this value is no
// no longer valid for the target variable.
// Order is not important.
absl::flat_hash_map<int64, BoolArgumentProto*> supports;
if (all_constants && target_ref != index_ref) {
if (!context->IntersectDomainWith(
target_ref, Domain::FromValues(
{reached_values.begin(), reached_values.end()}))) {
target_ref, Domain::FromValues({constant_var_values.begin(),
constant_var_values.end()}))) {
VLOG(1) << "Empty domain for the target variable in ExpandElement()";
return;
}

const Domain domain = context->DomainOf(target_ref);
if (domain.Size() == 1) {
context->UpdateRuleStats("element: array is constant");
target_domain = context->DomainOf(target_ref);
if (target_domain.Size() == 1) {
context->UpdateRuleStats("element: one value array");
ct->Clear();
return;
}

for (const ClosedInterval& interval : context->DomainOf(target_ref)) {
// TODO(user): only create 1 literal if the value has only one support.

for (const ClosedInterval& interval : target_domain) {
for (int64 v = interval.start; v <= interval.end; ++v) {
const int lit = context->GetOrCreateVarValueEncoding(target_ref, v);
CHECK(gtl::ContainsKey(reached_values, v));
CHECK(constant_var_values.contains(v));
supports[v] =
context->working_model->add_constraints()->mutable_bool_or();
supports[v]->add_literals(NegatedRef(lit));
}
}
}

const Domain target_domain = context->DomainOf(target_ref);

// While this is not stricly needed since all value in the index will be
// covered, it allows to easily detect this fact in the presolve.
auto* bool_or = context->working_model->add_constraints()->mutable_bool_or();
Expand All @@ -502,7 +510,9 @@ void ExpandElement(ConstraintProto* ct, PresolveContext* context) {
} else if (var_domain.Size() == 1) {
context->AddImplyInDomain(index_lit, target_ref, var_domain);
if (all_constants) {
supports[var_domain.Min()]->add_literals(index_lit);
BoolArgumentProto* const support =
gtl::FindOrDie(supports, var_domain.Min());
support->add_literals(index_lit);
}
} else {
ConstraintProto* const ct = context->working_model->add_constraints();
Expand Down Expand Up @@ -556,6 +566,9 @@ void ExpandCpModel(CpModelProto* working_model, PresolveOptions options) {
default:
break;
}

// Early exit if the model is unsat.
if (context.ModelIsUnsat()) return;
}

// Update any changed domain from the context.
Expand Down
6 changes: 1 addition & 5 deletions ortools/sat/cp_model_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,7 @@ void FillDomainInProto(const Domain& domain, ProtoWithDomain* proto) {
// Reads a Domain from the domain field of a proto.
template <typename ProtoWithDomain>
Domain ReadDomainFromProto(const ProtoWithDomain& proto) {
std::vector<ClosedInterval> intervals;
for (int i = 0; i < proto.domain_size(); i += 2) {
intervals.push_back({proto.domain(i), proto.domain(i + 1)});
}
return Domain::FromIntervals(intervals);
return Domain::FromFlatSpanOfIntervals(proto.domain());
}

// Returns the list of values in a given domain.
Expand Down
90 changes: 51 additions & 39 deletions ortools/sat/integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,47 +50,46 @@ void IntegerEncoder::FullyEncodeVariable(IntegerVariable var) {
}

// Mark var and Negation(var) as fully encoded.
CHECK_LT(var.value(), is_fully_encoded_.size());
CHECK_LT(NegationOf(var).value(), is_fully_encoded_.size());
CHECK(!equality_by_var_[var].empty());
CHECK(!equality_by_var_[NegationOf(var)].empty());
is_fully_encoded_[var] = true;
is_fully_encoded_[NegationOf(var)] = true;
CHECK_LT(GetPositiveOnlyIndex(var), is_fully_encoded_.size());
CHECK(!equality_by_var_[GetPositiveOnlyIndex(var)].empty());
is_fully_encoded_[GetPositiveOnlyIndex(var)] = true;
}

bool IntegerEncoder::VariableIsFullyEncoded(IntegerVariable var) const {
if (var >= is_fully_encoded_.size()) return false;
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
if (index >= is_fully_encoded_.size()) return false;

// Once fully encoded, the status never changes.
if (is_fully_encoded_[var]) return true;
if (is_fully_encoded_[index]) return true;
if (!VariableIsPositive(var)) var = PositiveVariable(var);

// TODO(user): Cache result as long as equality_by_var_[var] is unchanged?
// TODO(user): Cache result as long as equality_by_var_[index] is unchanged?
// It might not be needed since if the variable is not fully encoded, then
// PartialDomainEncoding() will filter unreachable values, and so the size
// check will be false until further value have been encoded.
const int64 initial_domain_size = (*domains_)[var].Size();
if (equality_by_var_[var].size() < initial_domain_size) return false;
if (equality_by_var_[index].size() < initial_domain_size) return false;

// This cleans equality_by_var_[var] as a side effect and in particular, sorts
// it by values.
// This cleans equality_by_var_[index] as a side effect and in particular,
// sorts it by values.
PartialDomainEncoding(var);

// TODO(user): Comparing the size might be enough, but we want to be always
// valid even if either (*domains_[var]) or PartialDomainEncoding(var) are
// not properly synced because the propagation is not finished.
const auto& ref = equality_by_var_[var];
int index = 0;
const auto& ref = equality_by_var_[index];
int i = 0;
for (const ClosedInterval interval : (*domains_)[var]) {
for (int64 v = interval.start; v <= interval.end; ++v) {
if (index < ref.size() && v == ref[index].value) {
index++;
if (i < ref.size() && v == ref[i].value) {
i++;
}
}
}
if (index == ref.size()) {
is_fully_encoded_[var] = true;
if (i == ref.size()) {
is_fully_encoded_[index] = true;
}
return is_fully_encoded_[var];
return is_fully_encoded_[index];
}

std::vector<IntegerEncoder::ValueLiteralPair>
Expand All @@ -102,23 +101,31 @@ IntegerEncoder::FullDomainEncoding(IntegerVariable var) const {
std::vector<IntegerEncoder::ValueLiteralPair>
IntegerEncoder::PartialDomainEncoding(IntegerVariable var) const {
CHECK_EQ(sat_solver_->CurrentDecisionLevel(), 0);
if (var >= equality_by_var_.size()) return {};
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
if (index >= equality_by_var_.size()) return {};

int new_size = 0;
std::vector<ValueLiteralPair>& ref = equality_by_var_[var];
std::vector<ValueLiteralPair>& ref = equality_by_var_[index];
for (int i = 0; i < ref.size(); ++i) {
const ValueLiteralPair pair = ref[i];
if (sat_solver_->Assignment().LiteralIsFalse(pair.literal)) continue;
if (sat_solver_->Assignment().LiteralIsTrue(pair.literal)) {
ref.clear();
ref.push_back(pair);
return ref;
new_size = 1;
break;
}
ref[new_size++] = pair;
}
ref.resize(new_size);
std::sort(ref.begin(), ref.end());
return ref;

std::vector<IntegerEncoder::ValueLiteralPair> result = ref;
if (!VariableIsPositive(var)) {
std::reverse(result.begin(), result.end());
for (ValueLiteralPair& ref : result) ref.value = -ref.value;
}
return result;
}

// Note that by not inserting the literal in "order" we can in the worst case
Expand Down Expand Up @@ -219,10 +226,18 @@ Literal IntegerEncoder::GetOrCreateAssociatedLiteral(IntegerLiteral i_lit) {
return literal;
}

namespace {
std::pair<PositiveOnlyIndex, IntegerValue> PositiveVarKey(IntegerVariable var,
IntegerValue value) {
return std::make_pair(GetPositiveOnlyIndex(var),
VariableIsPositive(var) ? value : -value);
}
} // namespace

LiteralIndex IntegerEncoder::GetAssociatedEqualityLiteral(
IntegerVariable var, IntegerValue value) const {
const std::pair<IntegerVariable, IntegerValue> key{var, value};
const auto it = equality_to_associated_literal_.find(key);
const auto it =
equality_to_associated_literal_.find(PositiveVarKey(var, value));
if (it != equality_to_associated_literal_.end()) {
return it->second.Index();
}
Expand All @@ -232,8 +247,8 @@ LiteralIndex IntegerEncoder::GetAssociatedEqualityLiteral(
Literal IntegerEncoder::GetOrCreateLiteralAssociatedToEquality(
IntegerVariable var, IntegerValue value) {
{
const std::pair<IntegerVariable, IntegerValue> key{var, value};
const auto it = equality_to_associated_literal_.find(key);
const auto it =
equality_to_associated_literal_.find(PositiveVarKey(var, value));
if (it != equality_to_associated_literal_.end()) {
return it->second;
}
Expand Down Expand Up @@ -315,8 +330,8 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,

// We use the "do not insert if present" behavior of .insert() to do just one
// lookup.
const auto insert_result =
equality_to_associated_literal_.insert({{var, value}, literal});
const auto insert_result = equality_to_associated_literal_.insert(
{PositiveVarKey(var, value), literal});
if (!insert_result.second) {
// If this key is already associated, make the two literals equal.
const Literal representative = insert_result.first->second;
Expand All @@ -327,8 +342,6 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
}
return;
}
gtl::InsertOrDieNoPrint(&equality_to_associated_literal_,
{{NegationOf(var), -value}, literal});

// Fix literal for value outside the domain.
if (!domain.Contains(value.value())) {
Expand All @@ -339,14 +352,13 @@ void IntegerEncoder::AssociateToIntegerEqualValue(Literal literal,
// Update equality_by_var. Note that due to the
// equality_to_associated_literal_ hash table, there should never be any
// duplicate values for a given variable.
const int needed_size = std::max(var.value(), NegationOf(var).value()) + 1;
if (needed_size > equality_by_var_.size()) {
equality_by_var_.resize(needed_size);
is_fully_encoded_.resize(needed_size);
}
equality_by_var_[var].push_back(ValueLiteralPair(value, literal));
equality_by_var_[NegationOf(var)].push_back(
ValueLiteralPair(-value, literal));
const PositiveOnlyIndex index = GetPositiveOnlyIndex(var);
if (index >= equality_by_var_.size()) {
equality_by_var_.resize(index.value() + 1);
is_fully_encoded_.resize(index.value() + 1);
}
equality_by_var_[index].push_back(
ValueLiteralPair(VariableIsPositive(var) ? value : -value, literal));

// Fix literal for constant domain.
if (value == domain.Min() && value == domain.Max()) {
Expand Down
16 changes: 12 additions & 4 deletions ortools/sat/integer.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ inline IntegerVariable PositiveVariable(IntegerVariable i) {
return IntegerVariable(i.value() & (~1));
}

// Special type for storing only one thing for var and NegationOf(var).
DEFINE_INT_TYPE(PositiveOnlyIndex, int32);
inline PositiveOnlyIndex GetPositiveOnlyIndex(IntegerVariable var) {
return PositiveOnlyIndex(var.value() / 2);
}

// Returns the vector of the negated variables.
std::vector<IntegerVariable> NegationOf(
const std::vector<IntegerVariable>& vars);
Expand Down Expand Up @@ -450,16 +456,18 @@ class IntegerEncoder {
// Mapping (variable == value) -> associated literal. Note that even if
// there is more than one literal associated to the same fact, we just keep
// the first one that was added.
absl::flat_hash_map<std::pair<IntegerVariable, IntegerValue>, Literal>
//
// Note that we only keep positive IntegerVariable here to reduce memory
// usage.
absl::flat_hash_map<std::pair<PositiveOnlyIndex, IntegerValue>, Literal>
equality_to_associated_literal_;

// Mutable because this is lazily cleaned-up by PartialDomainEncoding().
const std::vector<ValueLiteralPair> empty_value_literal_vector_;
mutable gtl::ITIVector<IntegerVariable, std::vector<ValueLiteralPair>>
mutable gtl::ITIVector<PositiveOnlyIndex, std::vector<ValueLiteralPair>>
equality_by_var_;

// Variables that are fully encoded.
mutable gtl::ITIVector<IntegerVariable, bool> is_fully_encoded_;
mutable gtl::ITIVector<PositiveOnlyIndex, bool> is_fully_encoded_;

// A literal that is always true, convenient to encode trivial domains.
// This will be lazily created when needed.
Expand Down

0 comments on commit 3d7a5c0

Please sign in to comment.