-
Notifications
You must be signed in to change notification settings - Fork 113
Working with datasets: Env, FrameEnv, SeqEnv, UserDataset
Although this project is primarily built for the ML20M dataset, it comes with native support for any similar dataset that more or less looks like the latter. I named the name class environment (similar to any of the RL library), it takes two string locations as main arguments. In most of the snippets, it is defined as:
# fixed-length frame
env = recnn.env.FrameEnv('../../data/embeddings/ml20_pca128.pkl',
'../../data/ml-20m/ratings.csv', frame_size, batch_size)
# OR
# dynamical length sequential
state_encoder = nn.LSTM(129, 256, batch_first=True).to(cuda)
env = recnn.env.SeqEnv('../../data/embeddings/ml20_pca128.pkl',
'../../data/ml-20m/ratings.csv', batch_size, state_encoder, cuda)
-
Label encoding items (movies). In the original ml20m dataset, the index of movies is not linear, and cannot be fit into int32 by default. I label encode the movies for efficient indexing and better memory usage.
-
Sorting and clustering. Items by default are sorted by timestamp, alternatively, you can provide different argument to sort it.
A quick reminder of how the ml20 looks like:
user movie rating timestamp
0 55 19 2 1568310847
1 39 5 5 1568318423
2 88 8 2 1568323421
3 79 16 7 1568334534
4 98 19 5 1568343643
You can make it work with your data by providing data_cols
argument. This dataset is a little different
data_cols = {user_id='user', rating='rating', timestamp='timestamp', item='movie', sort_users=False}
env = recnn.env.FrameEnv('your embeddings',
'your dataset', frame_size, batch_size, data_cols=data_cols)
As in the example above, sequential state representation is defined as:
# dynamical length sequential
state_encoder = nn.LSTM(129, 256, batch_first=True).to(cuda)
env = recnn.env.SeqEnv('../../data/embeddings/ml20_pca128.pkl',
'../../data/ml-20m/ratings.csv', batch_size, state_encoder, cuda)
You can use different state encoders, varying from basic LSTM/RNN/GRU to more efficient models such as Chaos Free Networks and Temporal Convolution.
The important thing is to include the state rep model in the optimizer with other networks!
pm1 = list(policy_net.parameters()) + list(state_encoder.parameters())
pm2 = list(value_net.parameters()) + list(state_encoder.parameters())
value_optimizer = recnn.optim.RAdam(pm1, ...)
policy_optimizer = recnn.optim.RAdam(pm2, ...)