Skip to content

Commit

Permalink
Add default_batch_size to IterationData. (#5588)
Browse files Browse the repository at this point in the history
Add IterationData::default_batch_size.
Simplify batch size handling in ExecutorImpl.

Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient authored Jul 30, 2024
1 parent 552ffe2 commit c9591d1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
33 changes: 10 additions & 23 deletions dali/pipeline/executor/executor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,14 @@ void ClearOutputs(Workspace &ws, const OpSpec &spec) {
* size between all inputs and outputs. The notable exception
* of split and merge operators cannot rely on this value.
*/
inline int InferBatchSizeFromInput(const Workspace &ws, int stage_batch_size) {
inline int InferBatchSizeFromInput(const Workspace &ws) {
if (ws.NumInput() > 0) {
return ws.GetInputBatchSize(0);
}
if (ws.NumArgumentInput() > 0) {
return ws.ArgumentInput(0).num_samples();
}
return stage_batch_size;
return ws.GetIterationData()->default_batch_size;
}

template <typename WorkspacePolicy, typename QueuePolicy>
Expand All @@ -124,9 +124,7 @@ void Executor<WorkspacePolicy, QueuePolicy>::HandleError(const std::string &cont
template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::PreRun() {
auto batch_size = InferBatchSize(batch_size_providers_);
batch_sizes_cpu_.push(batch_size);
batch_sizes_mixed_.push(batch_size);
batch_sizes_gpu_.push(batch_size);
upcoming_batch_sizes_.push(batch_size);
}

template <typename WorkspacePolicy, typename QueuePolicy>
Expand Down Expand Up @@ -157,7 +155,8 @@ void Executor<WorkspacePolicy, QueuePolicy>::SyncDevice() {

template <typename WorkspacePolicy, typename QueuePolicy>
void Executor<WorkspacePolicy, QueuePolicy>::RunCPUImpl(size_t iteration_id) {
GetCurrentIterationData(iteration_id)->iteration_index = iteration_id;
auto iter_data = GetCurrentIterationData(iteration_id);
iter_data->iteration_index = iteration_id;
PreRun();
const char placement_error[] =
"Cannot run a pipeline with Mixed/GPU ops in CPU-only mode. Please provide "
Expand Down Expand Up @@ -189,18 +188,16 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunCPUImpl(size_t iteration_id) {
return;
}

int stage_batch_size = batch_sizes_cpu_.front();
batch_sizes_cpu_.pop();
int upcoming_batch_size = upcoming_batch_sizes_.front();
upcoming_batch_sizes_.pop();
iter_data->default_batch_size = upcoming_batch_size;

// Run the cpu-ops in the thread
// Process each CPU Op in batch
for (int cpu_op_id = 0; cpu_op_id < graph_->NumOp(OpType::CPU) && !exec_error_; ++cpu_op_id) {
OpNode &op_node = graph_->Node(OpType::CPU, cpu_op_id);
decltype(auto) ws = ws_policy_.template GetWorkspace<OpType::CPU>(cpu_idxs, *graph_, cpu_op_id);

int batch_size = InferBatchSizeFromInput(ws, stage_batch_size);
ws.SetBatchSizes(batch_size);

DomainTimeRange tr("[DALI][CPU op] " + op_node.instance_name, DomainTimeRange::kBlue1);

try {
Expand Down Expand Up @@ -234,17 +231,11 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunMixedImpl(size_t iteration_id) {
if (device_id_ != CPU_ONLY_DEVICE_ID)
CUDA_CALL(cudaEventSynchronize(mixed_stage_event_));

int stage_batch_size = batch_sizes_mixed_.front();
batch_sizes_mixed_.pop();

for (int i = 0; i < graph_->NumOp(OpType::MIXED) && !exec_error_; ++i) {
OpNode &op_node = graph_->Node(OpType::MIXED, i);
try {
decltype(auto) ws = ws_policy_.template GetWorkspace<OpType::MIXED>(mixed_idxs, *graph_, i);

int batch_size = InferBatchSizeFromInput(ws, stage_batch_size);
ws.SetBatchSizes(batch_size);

DomainTimeRange tr("[DALI][Mixed op] " + op_node.instance_name, DomainTimeRange::kOrange);
RunHelper(op_node, ws, iteration_id);
FillStats(mixed_memory_stats_, ws, "MIXED_" + op_node.instance_name,
Expand Down Expand Up @@ -298,17 +289,11 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunGPUImpl(size_t iteration_id) {
// iterations of a stage of the pipeline.
CUDA_CALL(cudaEventSynchronize(gpu_stage_event_));

int stage_batch_size = batch_sizes_gpu_.front();
batch_sizes_gpu_.pop();

for (int i = 0; i < graph_->NumOp(OpType::GPU) && !exec_error_; ++i) {
OpNode &op_node = graph_->Node(OpType::GPU, i);
try {
decltype(auto) ws = ws_policy_.template GetWorkspace<OpType::GPU>(gpu_idxs, *graph_, i);

int batch_size = InferBatchSizeFromInput(ws, stage_batch_size);
ws.SetBatchSizes(batch_size);

auto parent_events = ws.ParentEvents();

for (auto &event : parent_events) {
Expand Down Expand Up @@ -419,6 +404,8 @@ void Executor<WorkspacePolicy, QueuePolicy>::RunHelper(OpNode &op_node, Workspac
SmallVector<int, 16> empty_layout_in_idxs;

ws.InjectIterationData(GetCurrentIterationData(iteration_id));
int batch_size = InferBatchSizeFromInput(ws);
ws.SetBatchSizes(batch_size);
ws.ClearOperatorTraces();

auto ws_order = ws.has_stream() ? AccessOrder(ws.stream()) : AccessOrder::host();
Expand Down
2 changes: 1 addition & 1 deletion dali/pipeline/executor/executor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class DLL_PUBLIC Executor : public ExecutorBase, public QueuePolicy {

StageQueues stage_queue_depths_;

std::queue<int> batch_sizes_cpu_, batch_sizes_mixed_, batch_sizes_gpu_;
std::queue<int> upcoming_batch_sizes_;

OpGraph *graph_ = nullptr;
EventPool event_pool_;
Expand Down
10 changes: 10 additions & 0 deletions dali/pipeline/workspace/iteration_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,17 @@ class OperatorTraces {
* a single iteration.
*/
struct IterationData {
/** The index of the current iteration. */
int64_t iteration_index = 0;

/** Default batch size for the current iteration.
*
* Presently this is the batch size set by external sources or the maximum batch size,
* if no external source is present.
* Actual batch size may change, e.g. due to conditional execution.
*/
int default_batch_size = 0;

OperatorTraces operator_traces;
std::shared_ptr<Checkpoint> checkpoint;
};
Expand Down

0 comments on commit c9591d1

Please sign in to comment.