Skip to content

Commit f07b84e

Browse files
Remove dependence of dorado_lib on dorado_models_lib
1 parent cd6d2bf commit f07b84e

File tree

4 files changed

+43
-14
lines changed

4 files changed

+43
-14
lines changed

dorado/cli/basecaller.cpp

+19-7
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ void setup(std::vector<std::string> args,
6363
const std::string& dump_stats_filter,
6464
const std::string& resume_from_file,
6565
argparse::ArgumentParser& resume_parser) {
66+
auto model_config = dorado::load_crf_model_config(model_path);
67+
std::string model_name = std::filesystem::canonical(model_path).filename().string();
68+
6669
torch::set_num_threads(1);
6770

6871
if (!DataLoader::is_read_data_present(data_path, recursive_file_loading)) {
@@ -72,7 +75,11 @@ void setup(std::vector<std::string> args,
7275

7376
// Check sample rate of model vs data.
7477
auto data_sample_rate = DataLoader::get_sample_rate(data_path, recursive_file_loading);
75-
auto model_sample_rate = get_model_sample_rate(model_path);
78+
auto model_sample_rate = model_config.sample_rate;
79+
if (model_sample_rate < 0) {
80+
// If unsuccessful, find sample rate by model name.
81+
model_sample_rate = utils::get_sample_rate_by_model_name(model_name);
82+
}
7683
if (!skip_model_compatibility_check &&
7784
!sample_rates_compatible(data_sample_rate, model_sample_rate)) {
7885
std::stringstream err;
@@ -93,11 +100,9 @@ void setup(std::vector<std::string> args,
93100
throw std::runtime_error("Modified base models cannot be used with FASTQ output");
94101
}
95102

96-
auto model_config = dorado::load_crf_model_config(model_path);
97103
auto [runners, num_devices] =
98104
create_basecall_runners(model_config, device, num_runners, 0, batch_size, chunk_size);
99105

100-
std::string model_name = std::filesystem::canonical(model_path).filename().string();
101106
auto read_groups = DataLoader::load_read_groups(data_path, model_name, recursive_file_loading);
102107
auto read_list = utils::load_read_list(read_list_file_path);
103108

@@ -135,10 +140,17 @@ void setup(std::vector<std::string> args,
135140
{read_converter}, min_qscore, default_parameters.min_sequence_length,
136141
std::unordered_set<std::string>{}, thread_allocations.read_filter_threads);
137142

138-
pipelines::create_simplex_pipeline(pipeline_desc, std::move(runners), std::move(remora_runners),
139-
overlap, thread_allocations.scaler_node_threads,
140-
thread_allocations.remora_threads * num_devices,
141-
read_filter_node);
143+
auto mean_qscore_start_pos = model_config.mean_qscore_start_pos;
144+
if (mean_qscore_start_pos < 0) {
145+
mean_qscore_start_pos = utils::get_mean_qscore_start_pos_by_model_name(model_name);
146+
if (mean_qscore_start_pos < 0) {
147+
throw std::runtime_error("Mean q-score start position cannot be < 0");
148+
}
149+
}
150+
pipelines::create_simplex_pipeline(
151+
pipeline_desc, std::move(runners), std::move(remora_runners), overlap,
152+
mean_qscore_start_pos, thread_allocations.scaler_node_threads,
153+
thread_allocations.remora_threads * num_devices, read_filter_node);
142154

143155
// Create the Pipeline from our description.
144156
std::vector<dorado::stats::StatsReporter> stats_reporters;

dorado/cli/duplex.cpp

+16-2
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,12 @@ int duplex(int argc, char* argv[]) {
243243

244244
// Check sample rate of model vs data.
245245
auto data_sample_rate = DataLoader::get_sample_rate(reads, recursive_file_loading);
246-
auto model_sample_rate = get_model_sample_rate(model_path);
246+
auto model_sample_rate = model_config.sample_rate;
247+
if (model_sample_rate < 0) {
248+
// If unsuccessful, find sample rate by model name.
249+
model_sample_rate = utils::get_sample_rate_by_model_name(
250+
model_config.model_path.filename().string());
251+
}
247252
auto skip_model_compatibility_check =
248253
internal_parser.get<bool>("--skip-model-compatibility-check");
249254
if (!skip_model_compatibility_check &&
@@ -312,9 +317,18 @@ int duplex(int argc, char* argv[]) {
312317
pairing_parameters = std::move(template_complement_map);
313318
}
314319

320+
auto mean_qscore_start_pos = model_config.mean_qscore_start_pos;
321+
if (mean_qscore_start_pos < 0) {
322+
mean_qscore_start_pos =
323+
utils::get_mean_qscore_start_pos_by_model_name(stereo_model_name);
324+
if (mean_qscore_start_pos < 0) {
325+
throw std::runtime_error("Mean q-score start position cannot be < 0");
326+
}
327+
}
315328
pipelines::create_stereo_duplex_pipeline(
316329
pipeline_desc, std::move(runners), std::move(stereo_runners), overlap,
317-
num_devices * 2, num_devices, std::move(pairing_parameters), read_filter_node);
330+
mean_qscore_start_pos, num_devices * 2, num_devices,
331+
std::move(pairing_parameters), read_filter_node);
318332

319333
std::vector<dorado::stats::StatsReporter> stats_reporters;
320334
pipeline = Pipeline::create(std::move(pipeline_desc), &stats_reporters);

dorado/read_pipeline/Pipelines.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
1818
std::vector<dorado::Runner>&& runners,
1919
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
2020
size_t overlap,
21+
uint32_t mean_qscore_start_pos,
2122
int scaler_node_threads,
2223
int modbase_node_threads,
2324
NodeHandle sink_node_handle,
@@ -43,7 +44,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
4344

4445
auto basecaller_node = pipeline_desc.add_node<BasecallerNode>(
4546
{}, std::move(runners), overlap, kBatchTimeoutMS, model_name, 1000, "BasecallerNode",
46-
false, get_model_mean_qscore_start_pos(model_config));
47+
false, mean_qscore_start_pos);
4748

4849
NodeHandle last_node_handle = PipelineDescriptor::InvalidNodeHandle;
4950
if (mod_base_caller_node != PipelineDescriptor::InvalidNodeHandle) {
@@ -71,6 +72,7 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
7172
std::vector<dorado::Runner>&& runners,
7273
std::vector<dorado::Runner>&& stereo_runners,
7374
size_t overlap,
75+
uint32_t mean_qscore_start_pos,
7476
int scaler_node_threads,
7577
int splitter_node_threads,
7678
PairingParameters pairing_parameters,
@@ -89,8 +91,7 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
8991

9092
auto stereo_basecaller_node = pipeline_desc.add_node<BasecallerNode>(
9193
{}, std::move(stereo_runners), adjusted_stereo_overlap, kStereoBatchTimeoutMS,
92-
duplex_rg_name, 1000, "StereoBasecallerNode", true,
93-
get_model_mean_qscore_start_pos(stereo_model_config));
94+
duplex_rg_name, 1000, "StereoBasecallerNode", true, mean_qscore_start_pos);
9495

9596
auto simplex_model_stride = runners.front()->model_stride();
9697
auto stereo_node = pipeline_desc.add_node<StereoDuplexEncoderNode>({stereo_basecaller_node},
@@ -118,8 +119,7 @@ void create_stereo_duplex_pipeline(PipelineDescriptor& pipeline_desc,
118119
const int kSimplexBatchTimeoutMS = 100;
119120
auto basecaller_node = pipeline_desc.add_node<BasecallerNode>(
120121
{splitter_node}, std::move(runners), adjusted_simplex_overlap, kSimplexBatchTimeoutMS,
121-
model_name, 1000, "BasecallerNode", true,
122-
get_model_mean_qscore_start_pos(model_config));
122+
model_name, 1000, "BasecallerNode", true, mean_qscore_start_pos);
123123

124124
auto scaler_node = pipeline_desc.add_node<ScalerNode>(
125125
{basecaller_node}, model_config.signal_norm_params, scaler_node_threads);

dorado/read_pipeline/Pipelines.h

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "ReadPipeline.h"
44

5+
#include <cstdint>
56
#include <map>
67
#include <memory>
78
#include <string>
@@ -25,6 +26,7 @@ void create_simplex_pipeline(PipelineDescriptor& pipeline_desc,
2526
std::vector<dorado::Runner>&& runners,
2627
std::vector<std::unique_ptr<dorado::ModBaseRunner>>&& modbase_runners,
2728
size_t overlap,
29+
uint32_t mean_qscore_start_pos,
2830
int scaler_node_threads,
2931
int modbase_threads,
3032
NodeHandle sink_node_handle = PipelineDescriptor::InvalidNodeHandle,
@@ -38,6 +40,7 @@ void create_stereo_duplex_pipeline(
3840
std::vector<dorado::Runner>&& runners,
3941
std::vector<dorado::Runner>&& stereo_runners,
4042
size_t overlap,
43+
uint32_t mean_qscore_start_pos,
4144
int scaler_node_threads,
4245
int splitter_node_threads,
4346
PairingParameters pairing_parameters,

0 commit comments

Comments
 (0)