Skip to content

Commit

Permalink
Exclude weight related types/shapes from bert loss. (microsoft#4548)
Browse files Browse the repository at this point in the history
  • Loading branch information
codemzs authored Jul 18, 2020
1 parent b42fe49 commit 6c950a1
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions orttraining/orttraining/models/bert/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -700,14 +700,12 @@ static Status RunPerformanceTest(const BertParameters& params, const Environment
{batch_size, params.max_sequence_length},
{batch_size, params.max_predictions_per_sequence},
{batch_size, params.max_predictions_per_sequence},
{batch_size, params.max_predictions_per_sequence},
{batch_size}};
std::vector<onnx::TensorProto_DataType> tensor_types = {onnx::TensorProto_DataType_INT64,
onnx::TensorProto_DataType_INT64,
onnx::TensorProto_DataType_INT64,
onnx::TensorProto_DataType_INT64,
onnx::TensorProto_DataType_INT64,
onnx::TensorProto_DataType_FLOAT,
onnx::TensorProto_DataType_INT64};
const size_t num_of_perf_samples = params.num_train_steps * params.batch_size;
auto random_perf_data = std::make_shared<RandomDataSet>(num_of_perf_samples, tensor_names, tensor_shapes, tensor_types);
Expand Down

0 comments on commit 6c950a1

Please sign in to comment.