diff --git a/predict.py b/predict.py index 2f096a9..96d0087 100644 --- a/predict.py +++ b/predict.py @@ -93,7 +93,7 @@ def predict(self, texts): is_multi = config.task_info.label_type == ClassificationType.MULTI_LABEL for line in codecs.open(sys.argv[2], "r", predictor.dataset.CHARSET): input_texts.append(line.strip("\n")) - epoches = math.ceil(int(len(input_texts)/batch_size)) + epoches = math.ceil(len(input_texts)/batch_size) for i in range(epoches): batch_texts = input_texts[i*batch_size:(i+1)*batch_size] predict_prob = predictor.predict(batch_texts)