Skip to content

Commit 6ce28c3

Browse files
committed
data processing code and instructions
1 parent 44f0e44 commit 6ce28c3

10 files changed

+728
-2
lines changed

README.md

+11-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,17 @@ We start training from the official SD1.4 model (with the first layer modified t
3535

3636
### Data Processing
3737
The data processing code can be found under the `data_processing` folder. You can simply put all the videos in a directory, and pass the directory as the folder name in `data_processing/moments_processing.py`. If your videos are long (~ex more than 5 seconds and contain cut scenes), then you would want to use pyscenedetect to detect cut scenes and split the videos accordingly.
38-
For data processing, you also need to download the checkpoint for SegmentAnything, and install soft-splatting
38+
For data processing, you also need to download the checkpoint for SegmentAnything, and install soft-splatting. You can setup softmax-splatting and SAM, by following
39+
```
40+
cd data_processing
41+
git clone https://github.com/sniklaus/softmax-splatting.git
42+
pip install segment_anything
43+
cd sam_model
44+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
45+
```
46+
For softmax-splatting to run, you need to install `pip install cupy` (or you might need to use `pip install cupy-cuda11x` or `pip install cupy-cuda12x` depending on your cuda version, and load the appropriate cuda module)
3947

48+
Then run `python moments_processing.py` to start processing frames from the provided examples video (included under `data_processing/example_videos`). For the version provided, we used the [Moments in Time Dataset](http://moments.csail.mit.edu)
4049

4150
### Running the training script
4251
Make sure that you have downloaded the pretrained SD1.4 model above. Once you download the training dataset and pretrained model, you can simply start training the model with
@@ -45,7 +54,7 @@ Make sure that you have downloaded the pretrained SD1.4 model above. Once you d
4554
```
4655
The training code is in `main.py`, and relies mainly on pytorch_lightning in training.
4756

48-
<TODO add details on how you should modify the config>
57+
Note that you need to modify the train and val paths in the chosen config file to the location where you have the processed data.
4958

5059
Note: we use Deepspeed to lower the memory requirements, so the saved model weights will be sharded. The script to reconstruct the model weights will be created in the checkpoint directory with name `zero_to_fp32.py`. One bug in the file is that it wouldn't recognize files with deepspeed1 (which is the one we use), so simply find and replace the string `== 2` with the string `<= 2` and it will work.
5160

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

data_processing/moments_dataset.py

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2024 Adobe. All rights reserved.
2+
3+
#%%
4+
import glob
5+
import torch
6+
import torchvision
7+
import matplotlib.pyplot as plt
8+
from torch.utils.data import Dataset
9+
import numpy as np
10+
11+
12+
# %%
13+
class MomentsDataset(Dataset):
14+
def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None:
15+
super().__init__()
16+
17+
self.videos_paths = glob.glob(f'{videos_folder}/*mp4')
18+
self.resize = torchvision.transforms.Resize(size=frame_size)
19+
self.center_crop = torchvision.transforms.CenterCrop(size=frame_size)
20+
self.num_samples_per_video = samples_per_video
21+
self.num_frames = num_frames
22+
23+
def __len__(self):
24+
return len(self.videos_paths) * self.num_samples_per_video
25+
26+
def __getitem__(self, idx):
27+
video_idx = idx // self.num_samples_per_video
28+
video_path = self.videos_paths[video_idx]
29+
30+
try:
31+
start_idx = np.random.randint(0, 20)
32+
33+
unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW")
34+
sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int))
35+
sampled_frames = unsampled_video_frames[sampled_indices]
36+
processed_frames = []
37+
38+
for frame in sampled_frames:
39+
resized_cropped_frame = self.center_crop(self.resize(frame))
40+
processed_frames.append(resized_cropped_frame)
41+
frames = torch.stack(processed_frames, dim=0)
42+
frames = frames.float() / 255.0
43+
except Exception as e:
44+
print('oops', e)
45+
rand_idx = np.random.randint(0, len(self))
46+
return self.__getitem__(rand_idx)
47+
48+
out_dict = {'frames': frames,
49+
'caption': 'none',
50+
'keywords': 'none'}
51+
52+
return out_dict
53+
54+

0 commit comments

Comments
 (0)