Skip to content

Commit

Permalink
Accept arrays of messages in MutateImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalybuka committed Feb 4, 2020
1 parent 74afb61 commit f62086c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 37 deletions.
89 changes: 53 additions & 36 deletions src/mutator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ enum class Mutation : uint8_t {

using MutationBitset = std::bitset<static_cast<size_t>(Mutation::Last)>;

using Messages = std::vector<Message*>;
using ConstMessages = std::vector<const Message*>;

// Return random integer from [0, count)
size_t GetRandomIndex(RandomEngine* random, size_t count) {
assert(count > 0);
Expand Down Expand Up @@ -173,25 +176,27 @@ class CanCopyAndDifferentField
class MutationSampler {
public:
MutationSampler(bool keep_initialized, MutationBitset allowed_mutations,
RandomEngine* random, Message* message)
RandomEngine* random)
: keep_initialized_(keep_initialized),
allowed_mutations_(allowed_mutations),
random_(random),
sampler_(random) {
Sample(message);
assert(mutation() != Mutation::None ||
!allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
message->GetDescriptor()->field_count() == 0);
}
sampler_(random) {}

// Returns selected field.
const FieldInstance& field() const { return sampler_.selected().field; }

// Returns selected mutation.
Mutation mutation() const { return sampler_.selected().mutation; }

private:
void Sample(Message* message) {
SampleImpl(message);
assert(mutation() != Mutation::None ||
!allowed_mutations_[static_cast<size_t>(Mutation::Mutate)] ||
message->GetDescriptor()->field_count() == 0);
}

private:
void SampleImpl(Message* message) {
const Descriptor* descriptor = message->GetDescriptor();
const Reflection* reflection = message->GetReflection();

Expand Down Expand Up @@ -256,9 +261,9 @@ class MutationSampler {
if (field->is_repeated()) {
const int field_size = reflection->FieldSize(*message, field);
for (int j = 0; j < field_size; ++j)
Sample(reflection->MutableRepeatedMessage(message, field, j));
SampleImpl(reflection->MutableRepeatedMessage(message, field, j));
} else if (reflection->HasField(*message, field)) {
Sample(reflection->MutableMessage(message, field));
SampleImpl(reflection->MutableMessage(message, field));
}
}
}
Expand Down Expand Up @@ -289,13 +294,13 @@ class MutationSampler {
class DataSourceSampler {
public:
DataSourceSampler(const ConstFieldInstance& match, RandomEngine* random,
int size_increase_hint, const Message& message)
int size_increase_hint)
: match_(match),
random_(random),
size_increase_hint_(size_increase_hint),
sampler_(random) {
Sample(message);
}
sampler_(random) {}

void Sample(const Message& message) { SampleImpl(message); }

// Returns selected field.
const ConstFieldInstance& field() const {
Expand All @@ -306,7 +311,7 @@ class DataSourceSampler {
bool IsEmpty() const { return sampler_.IsEmpty(); }

private:
void Sample(const Message& message) {
void SampleImpl(const Message& message) {
const Descriptor* descriptor = message.GetDescriptor();
const Reflection* reflection = message.GetReflection();

Expand All @@ -317,10 +322,10 @@ class DataSourceSampler {
if (field->is_repeated()) {
const int field_size = reflection->FieldSize(message, field);
for (int j = 0; j < field_size; ++j) {
Sample(reflection->GetRepeatedMessage(message, field, j));
SampleImpl(reflection->GetRepeatedMessage(message, field, j));
}
} else if (reflection->HasField(message, field)) {
Sample(reflection->GetMessage(message, field));
SampleImpl(reflection->GetMessage(message, field));
}
}

Expand Down Expand Up @@ -357,8 +362,8 @@ class DataSourceSampler {

class PostProcessing {
public:
using PostProcessors = std::unordered_multimap<const protobuf::Descriptor*,
Mutator::PostProcess>;
using PostProcessors =
std::unordered_multimap<const Descriptor*, Mutator::PostProcess>;

PostProcessing(bool keep_initialized, const PostProcessors& post_processors,
RandomEngine* random)
Expand Down Expand Up @@ -421,12 +426,12 @@ class PostProcessing {
class FieldMutator {
public:
FieldMutator(int size_increase_hint, bool enforce_changes,
bool enforce_utf8_strings, const Message& source,
bool enforce_utf8_strings, const ConstMessages& sources,
Mutator* mutator)
: size_increase_hint_(size_increase_hint),
enforce_changes_(enforce_changes),
enforce_utf8_strings_(enforce_utf8_strings),
source_(source),
sources_(sources),
mutator_(mutator) {}

void Mutate(int32_t* value) const {
Expand Down Expand Up @@ -478,7 +483,8 @@ class FieldMutator {
assert(*message);
if (GetRandomBool(mutator_->random(), mutator_->random_to_default_ratio_))
return;
mutator_->MutateImpl(source_, message->get(), false, size_increase_hint_);
mutator_->MutateImpl(sources_, {message->get()}, false,
size_increase_hint_);
}

private:
Expand All @@ -498,7 +504,7 @@ class FieldMutator {
int size_increase_hint_;
size_t enforce_changes_;
bool enforce_utf8_strings_;
const Message& source_;
const ConstMessages& sources_;
Mutator* mutator_;
};

Expand All @@ -507,10 +513,11 @@ namespace {
struct MutateField : public FieldFunction<MutateField> {
template <class T>
void ForType(const FieldInstance& field, int size_increase_hint,
const Message& source, Mutator* mutator) const {
const ConstMessages& sources, Mutator* mutator) const {
T value;
field.Load(&value);
FieldMutator(size_increase_hint, true, field.EnforceUtf8(), source, mutator)
FieldMutator(size_increase_hint, true, field.EnforceUtf8(), sources,
mutator)
.Mutate(&value);
field.Store(value);
}
Expand All @@ -520,12 +527,12 @@ struct CreateField : public FieldFunction<CreateField> {
public:
template <class T>
void ForType(const FieldInstance& field, int size_increase_hint,
const Message& source, Mutator* mutator) const {
const ConstMessages& sources, Mutator* mutator) const {
T value;
field.GetDefault(&value);
FieldMutator field_mutator(size_increase_hint,
false /* defaults could be useful */,
field.EnforceUtf8(), source, mutator);
field.EnforceUtf8(), sources, mutator);
field_mutator.Mutate(&value);
field.Create(value);
}
Expand All @@ -536,7 +543,10 @@ struct CreateField : public FieldFunction<CreateField> {
void Mutator::Seed(uint32_t value) { random_.seed(value); }

void Mutator::Mutate(Message* message, size_t max_size_hint) {
MutateImpl(*message, message, false,
Messages messages;
messages.push_back(message);
ConstMessages sources(messages.begin(), messages.end());
MutateImpl(sources, messages, false,
static_cast<int>(max_size_hint) -
static_cast<int>(message->ByteSizeLong()));

Expand All @@ -547,10 +557,13 @@ void Mutator::Mutate(Message* message, size_t max_size_hint) {

void Mutator::CrossOver(const Message& message1, Message* message2,
size_t max_size_hint) {
Messages messages;
messages.push_back(message2);
ConstMessages sources;
sources.push_back(&message1);
int size_increase_hint = static_cast<int>(max_size_hint) -
static_cast<int>(message2->ByteSizeLong());
MutateImpl(message1, message2, true, size_increase_hint) ||
MutateImpl(*message2, message2, true, size_increase_hint);
MutateImpl(sources, messages, true, size_increase_hint);

PostProcessing(keep_initialized_, post_processors_, &random_)
.Run(message2, kMaxInitializeDepth);
Expand All @@ -562,7 +575,7 @@ void Mutator::RegisterPostProcessor(const Descriptor* desc,
post_processors_.emplace(desc, callback);
}

bool Mutator::MutateImpl(const Message& source, Message* message,
bool Mutator::MutateImpl(const ConstMessages& sources, const Messages& messages,
bool copy_clone_only, int size_increase_hint) {
if (size_increase_hint > 0) size_increase_hint /= 2;
MutationBitset mutations;
Expand All @@ -575,32 +588,36 @@ bool Mutator::MutateImpl(const Message& source, Message* message,
mutations.set();
}
while (mutations.any()) {
MutationSampler mutation(keep_initialized_, mutations, &random_, message);
MutationSampler mutation(keep_initialized_, mutations, &random_);
for (Message* message : messages) mutation.Sample(message);

// Don't try same mutation next time.
mutations[static_cast<size_t>(mutation.mutation())] = false;
switch (mutation.mutation()) {
case Mutation::None:
return true;
case Mutation::Add:
CreateField()(mutation.field(), size_increase_hint, source, this);
CreateField()(mutation.field(), size_increase_hint, sources, this);
return true;
case Mutation::Mutate:
MutateField()(mutation.field(), size_increase_hint, source, this);
MutateField()(mutation.field(), size_increase_hint, sources, this);
return true;
case Mutation::Delete:
DeleteField()(mutation.field());
return true;
case Mutation::Clone: {
CreateDefaultField()(mutation.field());
DataSourceSampler source_sampler(mutation.field(), &random_,
size_increase_hint, source);
size_increase_hint);
for (const Message* source : sources) source_sampler.Sample(*source);
if (source_sampler.IsEmpty()) return true; // CreateField is enough.
CopyField()(source_sampler.field(), mutation.field());
return true;
}
case Mutation::Copy: {
DataSourceSampler source_sampler(mutation.field(), &random_,
size_increase_hint, source);
size_increase_hint);
for (const Message* source : sources) source_sampler.Sample(*source);
if (source_sampler.IsEmpty()) break;
CopyField()(source_sampler.field(), mutation.field());
return true;
Expand Down
3 changes: 2 additions & 1 deletion src/mutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class Mutator {
private:
friend class FieldMutator;
friend class TestMutator;
bool MutateImpl(const protobuf::Message& source, protobuf::Message* message,
bool MutateImpl(const std::vector<const protobuf::Message*>& sources,
const std::vector<protobuf::Message*>& messages,
bool copy_clone_only, int size_increase_hint);
std::string MutateUtf8String(const std::string& value,
int size_increase_hint);
Expand Down

0 comments on commit f62086c

Please sign in to comment.