Skip to content

Commit

Permalink
Moved csv to dev dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 6, 2020
1 parent 6e8f79f commit 5a833c7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ serde = { version = "1.0.114", features = ["derive"] }
dirs = "3.0.1"
itertools = "0.9.0"
ordered-float = "2.0.0"
csv = "1.1.3"
reqwest = "0.10.7"
lazy_static = "1.4.0"
uuid = { version = "0.8.1", features = ["v4"] }
Expand All @@ -46,3 +45,4 @@ thiserror = "1.0.20"

[dev-dependencies]
anyhow = "1.0.32"
csv = "1.1.3"
25 changes: 23 additions & 2 deletions examples/sst2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,31 @@
extern crate anyhow;
extern crate dirs;

use rust_bert::pipelines::sentiment::{ss2_processor, SentimentModel};
use std::env;
use rust_bert::pipelines::sentiment::SentimentModel;
use serde::Deserialize;
use std::error::Error;
use std::path::PathBuf;
use std::{env, fs};

#[derive(Debug, Deserialize)]
struct Record {
sentence: String,
label: i8,
}

fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
let file = fs::File::open(file_path).expect("unable to open file");
let mut csv = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(b'\t')
.from_reader(file);
let mut records = Vec::new();
for result in csv.deserialize() {
let record: Record = result?;
records.push(record.sentence);
}
Ok(records)
}
fn main() -> anyhow::Result<()> {
// Set-up classifier
let sentiment_classifier = SentimentModel::new(Default::default())?;
Expand Down
24 changes: 0 additions & 24 deletions src/pipelines/sentiment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ use crate::common::error::RustBertError;
use crate::pipelines::sequence_classification::{
SequenceClassificationConfig, SequenceClassificationModel,
};
use serde::Deserialize;
use std::error::Error;
use std::fs;
use std::path::PathBuf;

#[derive(Debug, PartialEq)]
/// Enum with the possible sentiment polarities. Note that the pre-trained SST2 model does not include neutral sentiment.
Expand Down Expand Up @@ -154,23 +150,3 @@ impl SentimentModel {
sentiments
}
}

#[derive(Debug, Deserialize)]
struct Record {
sentence: String,
label: i8,
}

pub fn ss2_processor(file_path: PathBuf) -> Result<Vec<String>, Box<dyn Error>> {
let file = fs::File::open(file_path).expect("unable to open file");
let mut csv = csv::ReaderBuilder::new()
.has_headers(true)
.delimiter(b'\t')
.from_reader(file);
let mut records = Vec::new();
for result in csv.deserialize() {
let record: Record = result?;
records.push(record.sentence);
}
Ok(records)
}
16 changes: 13 additions & 3 deletions tests/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use rust_bert::bart::{
BartVocabResources,
};
use rust_bert::pipelines::summarization::{SummarizationConfig, SummarizationModel};
use rust_bert::pipelines::zero_shot_classification::ZeroShotClassificationModel;
use rust_bert::pipelines::zero_shot_classification::{
ZeroShotClassificationConfig, ZeroShotClassificationModel,
};
use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::{RobertaTokenizer, Tokenizer, TruncationStrategy};
Expand Down Expand Up @@ -162,7 +164,11 @@ about exoplanets like K2-18b."];
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification() -> anyhow::Result<()> {
// Set-up model model
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()
};
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;

let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package.";
Expand Down Expand Up @@ -191,7 +197,11 @@ fn bart_zero_shot_classification() -> anyhow::Result<()> {
#[cfg_attr(not(feature = "all-tests"), ignore)]
fn bart_zero_shot_classification_multilabel() -> anyhow::Result<()> {
// Set-up model model
let sequence_classification_model = ZeroShotClassificationModel::new(Default::default())?;
let zero_shot_config = ZeroShotClassificationConfig {
device: Device::Cpu,
..Default::default()
};
let sequence_classification_model = ZeroShotClassificationModel::new(zero_shot_config)?;

let input_sentence = "Who are you voting for in 2020?";
let input_sequence_2 = "The prime minister has announced a stimulus package which was widely criticized by the opposition.";
Expand Down

0 comments on commit 5a833c7

Please sign in to comment.