Skip to content

Commit

Permalink
Bug in driver argument parser. Fix Levenshtein_distance. (ROCm#2020)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakhinderwalia authored Aug 4, 2023
1 parent e46a6a5 commit aeb9f78
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/driver/argument_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,19 @@ struct argument_parser
if(params.empty())
throw std::runtime_error("No argument passed.");
if(not fs::exists(params.back()))
throw std::runtime_error("Path does not exists: " + params.back());
throw std::runtime_error("Path does not exist: " + params.back());
});
}

MIGRAPHX_DRIVER_STATIC auto matches(const std::unordered_set<std::string>& names)
{
return validate([=](auto&, auto&, auto& params) {
for(const auto& p : params)
{
if(names.count(p) == 0)
throw std::runtime_error("Invalid argument: " + p + ". Valid arguments are {" +
to_string_range(names) + "}");
}
});
}

Expand Down Expand Up @@ -570,8 +582,7 @@ struct argument_parser
continue;
if(flag[0] != '-')
continue;
auto d =
levenshtein_distance(flag.begin(), flag.end(), input.begin(), input.end());
std::ptrdiff_t d = levenshtein_distance(flag, input);
if(d < result.distance)
result = result_t{&arg, flag, input, d};
}
Expand Down
3 changes: 2 additions & 1 deletion src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ struct loader
{"--model"},
ap.help("Load model"),
ap.type("resnet50|inceptionv3|alexnet"),
ap.matches({"resnet50", "inceptionv3", "alexnet"}),
ap.group("input"));
ap(file_type, {"--onnx"}, ap.help("Load as onnx"), ap.set_value("onnx"));
ap(file_type, {"--tf"}, ap.help("Load as tensorflow"), ap.set_value("tf"));
Expand Down Expand Up @@ -769,7 +770,7 @@ struct main_command
{
std::cout << "'" << color::fg_yellow << wrong_commands.front() << color::reset
<< "' is not a valid command." << std::endl;
std::cout << get_command_help("Available commands:") << std::endl;
std::cout << get_command_help("Available commands:");
}
else
{
Expand Down
37 changes: 37 additions & 0 deletions src/include/migraphx/algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,43 @@ levenshtein_distance(Iterator1 first1, Iterator1 last1, Iterator2 first2, Iterat
return std::ptrdiff_t{1} + std::min({x1, x2, x3});
}

inline size_t levenshtein_distance(const std::string& s1, const std::string& s2)
{
const size_t l1 = s1.length();
const size_t l2 = s2.length();

if(l1 < l2)
levenshtein_distance(s2, s1);

std::vector<size_t> d(l2 + 1);

for(size_t j = 1; j <= l2; j++)
d[j] = j;

for(size_t i = 1; i <= l1; i++)
{
size_t prev_cost = d[0];
d[0] = i;

for(size_t j = 1; j <= l2; j++)
{
if(s1[i - 1] == s2[j - 1])
{
d[j] = prev_cost;
}
else
{
size_t cost_insert_or_delete = std::min(d[j - 1], d[j]);
size_t cost_substitute = prev_cost;
prev_cost = d[j];
d[j] = std::min(cost_substitute, cost_insert_or_delete) + 1;
}
}
}

return d[l2];
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand Down

0 comments on commit aeb9f78

Please sign in to comment.