85
85
# ------------
86
86
#
87
87
# To start, Download the data ZIP file
88
- # `here <https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html>`__
88
+ # `here <https://zissou.infosci.cornell.edu/convokit/datasets/movie-corpus/movie-corpus.zip>`__
89
+
89
90
# and put in a ``data/`` directory under the current directory.
90
91
#
91
92
# After that, let’s import some necessities.
110
111
from io import open
111
112
import itertools
112
113
import math
114
+ import json
113
115
114
116
115
117
USE_CUDA = torch .cuda .is_available ()
140
142
# original format.
141
143
#
142
144
143
- corpus_name = "cornell movie-dialogs corpus"
145
+ corpus_name = "movie-corpus"
144
146
corpus = os .path .join ("data" , corpus_name )
145
147
146
148
def printLines (file , n = 10 ):
@@ -149,7 +151,7 @@ def printLines(file, n=10):
149
151
for line in lines [:n ]:
150
152
print (line )
151
153
152
- printLines (os .path .join (corpus , "movie_lines.txt " ))
154
+ printLines (os .path .join (corpus , "utterances.jsonl " ))
153
155
154
156
155
157
######################################################################
@@ -160,55 +162,47 @@ def printLines(file, n=10):
160
162
# contains a tab-separated *query sentence* and a *response sentence* pair.
161
163
#
162
164
# The following functions facilitate the parsing of the raw
163
- # *movie_lines.txt * data file.
165
+ # *utterances.jsonl * data file.
164
166
#
165
- # - ``loadLines`` splits each line of the file into a dictionary of
166
- # fields (lineID, characterID, movieID, character, text)
167
- # - ``loadConversations`` groups fields of lines from ``loadLines`` into
168
- # conversations based on *movie_conversations.txt*
167
+ # - ``loadLinesAndConversations`` splits each line of the file into a dictionary of
168
+ # lines with fields: lineID, characterID, and text and then groups them
169
+ # into conversations with fields: conversationID, movieID, and lines.
169
170
# - ``extractSentencePairs`` extracts pairs of sentences from
170
171
# conversations
171
172
#
172
173
173
- # Splits each line of the file into a dictionary of fields
174
- def loadLines (fileName , fields ):
174
+ # Splits each line of the file to create lines and conversations
175
+ def loadLinesAndConversations (fileName ):
175
176
lines = {}
177
+ conversations = {}
176
178
with open (fileName , 'r' , encoding = 'iso-8859-1' ) as f :
177
179
for line in f :
178
- values = line . split ( " +++$+++ " )
179
- # Extract fields
180
+ lineJson = json . loads ( line )
181
+ # Extract fields for line object
180
182
lineObj = {}
181
- for i , field in enumerate (fields ):
182
- lineObj [field ] = values [i ]
183
+ lineObj ["lineID" ] = lineJson ["id" ]
184
+ lineObj ["characterID" ] = lineJson ["speaker" ]
185
+ lineObj ["text" ] = lineJson ["text" ]
183
186
lines [lineObj ['lineID' ]] = lineObj
184
- return lines
185
187
188
+ # Extract fields for conversation object
189
+ if lineJson ["conversation_id" ] not in conversations :
190
+ convObj = {}
191
+ convObj ["conversationID" ] = lineJson ["conversation_id" ]
192
+ convObj ["movieID" ] = lineJson ["meta" ]["movie_id" ]
193
+ convObj ["lines" ] = [lineObj ]
194
+ else :
195
+ convObj = conversations [lineJson ["conversation_id" ]]
196
+ convObj ["lines" ].insert (0 , lineObj )
197
+ conversations [convObj ["conversationID" ]] = convObj
186
198
187
- # Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
188
- def loadConversations (fileName , lines , fields ):
189
- conversations = []
190
- with open (fileName , 'r' , encoding = 'iso-8859-1' ) as f :
191
- for line in f :
192
- values = line .split (" +++$+++ " )
193
- # Extract fields
194
- convObj = {}
195
- for i , field in enumerate (fields ):
196
- convObj [field ] = values [i ]
197
- # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
198
- utterance_id_pattern = re .compile ('L[0-9]+' )
199
- lineIds = utterance_id_pattern .findall (convObj ["utteranceIDs" ])
200
- # Reassemble lines
201
- convObj ["lines" ] = []
202
- for lineId in lineIds :
203
- convObj ["lines" ].append (lines [lineId ])
204
- conversations .append (convObj )
205
- return conversations
199
+ return lines , conversations
206
200
207
201
208
202
# Extracts pairs of sentences from conversations
209
203
def extractSentencePairs (conversations ):
210
204
qa_pairs = []
211
- for conversation in conversations :
205
+ for conversation in conversations . values () :
212
206
# Iterate over all the lines of the conversation
213
207
for i in range (len (conversation ["lines" ]) - 1 ): # We ignore the last line (no answer for it)
214
208
inputLine = conversation ["lines" ][i ]["text" ].strip ()
@@ -231,18 +225,12 @@ def extractSentencePairs(conversations):
231
225
# Unescape the delimiter
232
226
delimiter = str (codecs .decode (delimiter , "unicode_escape" ))
233
227
234
- # Initialize lines dict, conversations list, and field ids
228
+ # Initialize lines dict and conversations dict
235
229
lines = {}
236
- conversations = []
237
- MOVIE_LINES_FIELDS = ["lineID" , "characterID" , "movieID" , "character" , "text" ]
238
- MOVIE_CONVERSATIONS_FIELDS = ["character1ID" , "character2ID" , "movieID" , "utteranceIDs" ]
239
-
240
- # Load lines and process conversations
241
- print ("\n Processing corpus..." )
242
- lines = loadLines (os .path .join (corpus , "movie_lines.txt" ), MOVIE_LINES_FIELDS )
243
- print ("\n Loading conversations..." )
244
- conversations = loadConversations (os .path .join (corpus , "movie_conversations.txt" ),
245
- lines , MOVIE_CONVERSATIONS_FIELDS )
230
+ conversations = {}
231
+ # Load lines and conversations
232
+ print ("\n Processing corpus into lines and conversations..." )
233
+ lines , conversations = loadLinesAndConversations (os .path .join (corpus , "utterances.jsonl" ))
246
234
247
235
# Write new csv file
248
236
print ("\n Writing newly formatted file..." )
@@ -1341,7 +1329,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
1341
1329
for k , v in state .items ():
1342
1330
if isinstance (v , torch .Tensor ):
1343
1331
state [k ] = v .cuda ()
1344
-
1332
+
1345
1333
# Run training iterations
1346
1334
print ("Starting Training!" )
1347
1335
trainIters (model_name , voc , pairs , encoder , decoder , encoder_optimizer , decoder_optimizer ,
0 commit comments