Skip to content

Commit 3de9678

Browse files
committed
fix seed and add results for C3D/R21D
1 parent fe37aee commit 3de9678

File tree

3 files changed

+69
-16
lines changed

3 files changed

+69
-16
lines changed

README.md

+44-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework
2-
Official code for paper, Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework [ACMMM'20].
2+
Official code for paper, Self-supervised Video Representation Learning Using Inter-intra Contrastive Framework [ACMMM'20].
33

44
[Arxiv paper](https://arxiv.org/abs/2008.02531) [Project page](https://bestjuly.github.io/Inter-intra-video-contrastive-learning/)
55

@@ -10,7 +10,7 @@ Official code for paper, Self-supervised Video Representation Learning Using Int
1010
- python 3.7.4
1111
- accimage
1212

13-
## Inter-intra contrastive framework
13+
## Inter-intra contrastive (IIC) framework
1414
For samples, we have
1515
- [ ] Inter-positives: samples with **same labels**, not used for self-supervised learning;
1616
- [x] Inter-negatives: **different samples**, or samples with different indexes;
@@ -33,8 +33,13 @@ The **inter-intra learning framework** can be extended to
3333
- Different intra-negative generation methods: frame repeating, frame shuffling ...
3434
- Different backbones: C3D, R3D, R(2+1)D, I3D ...
3535

36+
## Updates
37+
Oct. 1, 2020 - Results using C3D and R(2+1)D are added; fix random seed more tightly.
38+
Aug. 26, 2020 - Add pretrained weights for R3D.
3639

3740
## Usage of this repo
41+
> Notification: we have added codes to fix random seed more tightly for better reproducibility. However, results in our paper used previous random seed settings. Therefore, there should be tiny differences for the performance from that reported in our paper. To reproduce retrieval results same as our paper, please use the provided model weights.
42+
3843
### Data preparation
3944
You can download UCF101/HMDB51 dataset from official website: [UCF101](http://crcv.ucf.edu/data/UCF101.php) and [HMDB51](http://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/). Then decoded videos to frames.
4045
I highly recommend the pre-computed optical flow images and resized RGB frames in this [repo](https://github.com/feichtenhofer/twostreamfusion).
@@ -141,7 +146,7 @@ The key code for this part is
141146
shift_x = torch.roll(x,1,2)
142147
x = ((shift_x -x) + 1)/2
143148
```
144-
Which is slightly different from that in papers.
149+
which is slightly different from that in papers.
145150

146151
We also reimplement VCP in this [repo](https://github.com/BestJuly/VCP). By simply using residual clips, significant improvements can be obtained for both video retrieval and video recognition.
147152

@@ -157,13 +162,46 @@ Pertrained weights from self-supervised training step: R3D[(google drive)](https
157162
Finetuned weights for action recognition: R3D[(google drive)](https://drive.google.com/file/d/12uzHArg5hMGLuEUz36H4fJgGaeN4QyhZ/view?usp=sharing).
158163

159164
> With this model, for video recognition, you should achieve
160-
> 72.7% @top1 with `python ft_classify.py --model=r3d --modality=res --mode=test -ckpt=./path/to/model`
165+
> 72.7% @top1 with `python ft_classify.py --model=r3d --modality=res --mode=test -ckpt=./path/to/model --dataset=ucf101 --split=1`.
161166
> This result is better than that reported in paper. Results may be further improved with strong data augmentations.
162167
163-
We may add more pretrained weights to support different network backbones in the future.
164-
165168
For any questions, please contact Li TAO ([email protected]).
166169

170+
### Results for other network architectures
171+
172+
Results are averaged on 3 splits without using optical flow. R3D and R21D are the same as VCOP / VCP / PRP.
173+
174+
UCF101 | top1 | top5 | top10 | top20 | top50 | Recong
175+
--- |--- |--- |--- |--- |--- |---
176+
C3D (VCOP) | 12.5 | 29.0 | 39.0 | 50.6 | 66.9 | 65.6
177+
C3D (VCP) | 17.3 | 31.5 | 42.0 | 52.6 | 67.7 | 68.5
178+
C3D (PRP) | 23.2 | 38.1 | 46.0 | 55.7 | 68.4 | 69.1
179+
C3D (ours, repeat) | **31.9** | **48.2** | **57.3**| **67.1** | **79.1** | **70.0**
180+
C3D (ours, shuffle) | 28.9 | 45.4 | 55.5 | 66.2 | 78.8 | 69.7
181+
R21D (VCOP) | 10.7 | 25.9 | 35.4 | 47.3 | 63.9 | 72.4
182+
R21D (VCP) | 19.9 | 33.7 | 42.0 | 50.5 | 64.4 | 66.3
183+
R21D (PRP) | 20.3 | 34.0 | 41.9 | 51.7 | 64.2 | 72.1
184+
R21D (ours, repeat) | **34.7** | **51.7** | **60.9** | **69.4** | **81.9** | 72.4
185+
R21D (ours, shuffle) | 30.2 | 45.6 | 55.0 | 64.4 | 77.6 | **73.3**
186+
Res18-3D (ours, repeat) | 36.8 | 54.1 | 63.1 | 72.0 | 83.3 | -
187+
Res18-3D (ours, shuffle) | 33.0 | 49.2 | 59.1 | 69.1 | 80.6 | -
188+
189+
190+
HMDB51 | top1 | top5 | top10 | top20 | top50 | Recong
191+
--- |--- |--- |--- |--- |--- |---
192+
C3D (VCOP) | 7.4 | 22.6 | 34.4 | 48.5 | 70.1 | 28.4
193+
C3D (VCP) | 7.8 | 23.8 | 35.3 | 49.3 | 71.6 | 32.5
194+
C3D (PRP) | 10.5 | 27.2 | 40.4 | 56.2 | 75.9 | **34.5**
195+
C3D (ours, repeat) | 9.9 | 29.6 | 42.0 | 57.3 | 78.4 | 30.8
196+
C3D (ours, shuffle) | **11.5** | **31.3** | **43.9** | **60.1** | **80.3** | 29.7
197+
R21D (VCOP) | 5.7 | 19.5 | 30.7 | 45.6 | 67.0 | 30.9
198+
R21D (VCP) | 6.7 | 21.3 | 32.7 | 49.2 | 73.3 | 32.2
199+
R21D (PRP) | 8.2 | 25.3 | 36.2 | 51.0 | 73.0 | **35.0**
200+
R21D (ours, repeat)| **12.7** | **33.3** | **45.8** | **61.6** | **81.3** | 34.0
201+
R21D (ours, shuffle)| 12.6 | 31.9 | 44.2 | 59.9 | 80.7 | 31.2
202+
Res18-3D (ours, repeat) | 15.5 | 34.4 | 48.9 | 63.8 | 83.8 | -
203+
Res18-3D (ours, shuffle) | 12.4 | 33.6 | 46.9 | 63.2 | 83.5 | -
204+
167205
## Citation
168206
If you find our work helpful for your research, please consider citing the paper
169207
```

ft_classify.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def parse_args():
199199
parser.add_argument('--start-epoch', type=int, default=1, help='manual epoch number (useful on restarts)')
200200
parser.add_argument('--bs', type=int, default=16, help='mini-batch size')
201201
parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
202-
parser.add_argument('--seed', type=int, help='seed for initializing training.')
202+
parser.add_argument('--seed', type=int, default=632, help='seed for initializing training.')
203203
parser.add_argument('--modality', default='res', type=str, help='modality from [rgb, res, u, v]')
204204
args = parser.parse_args()
205205
return args
@@ -209,15 +209,17 @@ def parse_args():
209209
args = parse_args()
210210
print(vars(args))
211211

212-
torch.backends.cudnn.benchmark = True
213-
214-
if args.seed:
215-
print('Set random seed to', args.seed)
216-
random.seed(args.seed)
217-
np.random.seed(args.seed)
218-
torch.manual_seed(args.seed)
219-
if args.gpu:
220-
torch.cuda.manual_seed_all(args.seed)
212+
# Uncomment to fix all parameters for reproducibility
213+
seed = args.seed
214+
torch.backends.cudnn.deterministic = True
215+
torch.backends.cudnn.benchmark = False
216+
random.seed(seed)
217+
np.random.seed(seed)
218+
os.environ['PYTHONHASHSEED'] = str(seed)
219+
torch.manual_seed(seed)
220+
torch.cuda.manual_seed(seed)
221+
torch.cuda.manual_seed_all(seed)
222+
#'''
221223

222224
########### model ##############
223225
if args.dataset == 'ucf101':

train_ssl.py

+13
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,23 @@ def main():
236236
best_acc = 0 # best test accuracy
237237
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
238238

239+
''' Old version
239240
random.seed(args.seed)
240241
np.random.seed(args.seed)
241242
torch.manual_seed(args.seed)
242243
torch.cuda.manual_seed_all(args.seed)
244+
'''
245+
# Fix all parameters for reproducibility
246+
seed = args.seed
247+
torch.backends.cudnn.deterministic = True
248+
torch.backends.cudnn.benchmark = False
249+
random.seed(seed)
250+
np.random.seed(seed)
251+
os.environ['PYTHONHASHSEED'] = str(seed)
252+
torch.manual_seed(seed)
253+
torch.cuda.manual_seed(seed)
254+
torch.cuda.manual_seed_all(seed)
255+
#'''
243256

244257
print('[Warning] The training modalities are RGB and [{}]'.format(args.modality))
245258

0 commit comments

Comments
 (0)