Skip to content

Commit 601bca3

Browse files
committed
Add plot_mix_temperature.py; update README.md
1 parent c256035 commit 601bca3

14 files changed

+184
-13
lines changed

README.md

+123-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,124 @@
1-
# RG-Flow: A hierarchical and explainable flow models based on renormalization group and sparse prior
1+
# RG-Flow
22

3-
The code requires Python >= 3.7 and PyTorch >= 1.6, with optional CUDA support.
3+
This repository contains the code for the paper "RG-Flow: A hierarchical and explainable flow model based on renormalization group and sparse prior".
4+
5+
TODO: add the link and the bibtex entry to the paper
6+
7+
# Dependencies
8+
9+
The code requires `Python >= 3.7` and `PyTorch >= 1.6`, with optional CUDA support. Other dependencies can be installed via
10+
11+
`pip install -r requirements.txt`
12+
13+
# Running experiments
14+
15+
`main.py` is the code for training the network. All adjustable arguments are stored in `args.py`, together with their default values when we were training on the CelebA datset. They can be displayed via `python main.py --help`:
16+
17+
```
18+
usage: main.py [-h] [--data {celeba32,celeba64,mnist32,cifar10,chair600}] [--data_path DATA_PATH]
19+
[--nchannels NCHANNELS] [--L L] [--prior {gaussian,laplace}] [--subnet {rnvp,ar}]
20+
[--kernel_size KERNEL_SIZE] [--nlayers NLAYERS] [--nresblocks NRESBLOCKS]
21+
[--nmlp NMLP] [--nhidden NHIDDEN] [--dtype {float32,float64}]
22+
[--batch_size BATCH_SIZE] [--lr LR] [--weight_decay WEIGHT_DECAY] [--epoch EPOCH]
23+
[--clip_grad CLIP_GRAD] [--no_stdout] [--print_step PRINT_STEP]
24+
[--save_epoch SAVE_EPOCH] [--keep_epoch KEEP_EPOCH] [--plot_epoch PLOT_EPOCH]
25+
[--cuda CUDA] [--out_infix OUT_INFIX] [-o OUT_DIR]
26+
27+
optional arguments:
28+
-h, --help show this help message and exit
29+
30+
dataset parameters:
31+
--data {celeba32,celeba64,mnist32,cifar10,chair600}
32+
dataset name
33+
--data_path DATA_PATH
34+
dataset path
35+
--nchannels NCHANNELS
36+
number of channels
37+
--L L edge length of images
38+
39+
network parameters:
40+
--prior {gaussian,laplace}
41+
prior of latent variables
42+
--subnet {rnvp,ar} type of subnet in an RG block
43+
--kernel_size KERNEL_SIZE
44+
edge length of an RG block
45+
--nlayers NLAYERS number of subnet layers in an RG block
46+
--nresblocks NRESBLOCKS
47+
number of residual blocks in a subnet layer
48+
--nmlp NMLP number of MLP hidden layers in an residual block
49+
--nhidden NHIDDEN width of MLP hidden layers
50+
--dtype {float32,float64}
51+
dtype
52+
53+
optimizer parameters:
54+
--batch_size BATCH_SIZE
55+
batch size
56+
--lr LR learning rate
57+
--weight_decay WEIGHT_DECAY
58+
weight decay
59+
--epoch EPOCH number of epoches
60+
--clip_grad CLIP_GRAD
61+
global norm to clip gradients, 0 for disabled
62+
63+
system parameters:
64+
--no_stdout do not print log to stdout, for better performance
65+
--print_step PRINT_STEP
66+
number of batches to print log, 0 for disabled
67+
--save_epoch SAVE_EPOCH
68+
number of epochs to save network weights, 0 for disabled
69+
--keep_epoch KEEP_EPOCH
70+
number of epochs to keep saved network weights, 0 for disabled
71+
--plot_epoch PLOT_EPOCH
72+
number of epochs to plot samples, 0 for disabled
73+
--cuda CUDA IDs of GPUs to use, empty for disabled
74+
--out_infix OUT_INFIX
75+
infix in output filename to distinguish repeated runs
76+
-o OUT_DIR, --out_dir OUT_DIR
77+
directory for output, empty for disabled
78+
```
79+
80+
During training, the log file and the network weights will be saved in `out_dir`.
81+
82+
After the network is trained, `plot_mix_temperature.py` can be used to plot samples using mixed effective temperatures, described in Appendix B of the paper.
83+
84+
# Gallery
85+
86+
## RG-Flow structure
87+
88+
![RG-Flow structure](docs/structure.png)
89+
90+
## Random walk in high-level latent representations
91+
92+
![Random walk in high-level latent representations](docs/high_level_walk.gif)
93+
94+
## Random walk in mid-level latent representations
95+
96+
![Random walk in mid-level latent representations](docs/mid_level_walk.gif)
97+
98+
## Learned receptive fields
99+
100+
![Learned receptive fields](docs/RF.png)
101+
102+
## Learned factors
103+
104+
![Learned factors](docs/factors.png)
105+
106+
### High-level factor: emotion
107+
108+
![High-level factor: emotion](docs/smile_video.gif)
109+
110+
### High-level factor: gender
111+
112+
![High-level factor: gender](docs/gender_video.gif)
113+
114+
### Mid-level factor: light direction
115+
116+
![Mid-level factor: light direction](docs/projection_video.gif)
117+
118+
### Mid-level factor: rotation
119+
120+
![Mid-level factor: rotation](docs/rotation_video.gif)
121+
122+
## Face mixing in the scaling direction
123+
124+
![Face mixing in the scaling direction](docs/mix.png)

args.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
default='celeba32',
1414
choices=['celeba32', 'celeba64', 'mnist32', 'cifar10', 'chair600'],
1515
help='dataset name')
16-
group.add_argument(
17-
'--data_path',
18-
type=str,
19-
default='./data', # You should put your dataset directory here
20-
help='dataset path')
16+
group.add_argument('--data_path',
17+
type=str,
18+
default='./data',
19+
help='dataset path')
2120
group.add_argument('--nchannels',
2221
type=int,
2322
default=3,
@@ -109,7 +108,7 @@
109108
'--out_dir',
110109
type=str,
111110
default='./saved_model',
112-
help='directory prefix for output, empty for disabled')
111+
help='directory for output, empty for disabled')
113112

114113
args = parser.parse_args()
115114

docs/RF.png

985 KB
Loading

docs/factors.png

2.97 MB
Loading

docs/gender_video.gif

221 KB
Loading

docs/high_level_walk.gif

17.1 MB
Loading

docs/mid_level_walk.gif

17.1 MB
Loading

docs/mix.png

1.23 MB
Loading

docs/projection_video.gif

215 KB
Loading

docs/rotation_video.gif

217 KB
Loading

docs/smile_video.gif

226 KB
Loading

docs/structure.png

289 KB
Loading

plot_mix_temperature.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
3+
from math import sqrt
4+
5+
import numpy as np
6+
import torch
7+
from matplotlib import pyplot as plt
8+
from torch.distributions.laplace import Laplace
9+
10+
import utils
11+
from args import args
12+
from main import build_mera
13+
14+
T_low = 0.2 # Effective temperature for low-level latent variables
15+
T_high = 0.8 # Effective temperature for high-level latent variables
16+
level_cutoff = 1 # Cutoff level (\lambda in the paper)
17+
18+
19+
def main():
20+
flow = build_mera()
21+
last_epoch = utils.get_last_checkpoint_step()
22+
utils.load_checkpoint(last_epoch, flow)
23+
flow.train(False)
24+
25+
shape = (16, args.nchannels, args.L, args.L)
26+
prior_low = Laplace(torch.tensor(0.), torch.tensor(T_low / sqrt(2)))
27+
z = prior_low.sample(shape)
28+
prior_high = Laplace(torch.tensor(0.), torch.tensor(T_high / sqrt(2)))
29+
z_high = prior_high.sample(shape)
30+
k = 2**level_cutoff
31+
z[:, :, ::k, ::k] = z_high[:, :, ::k, ::k]
32+
z = z.to(args.device)
33+
34+
with torch.no_grad():
35+
x, _ = flow.inverse(z)
36+
37+
samples = x.permute(0, 2, 3, 1).detach().cpu().numpy()
38+
samples = 1 / (1 + np.exp(-samples))
39+
40+
fig, axes = plt.subplots(4, 4, figsize=(4, 4), sharex=True, sharey=True)
41+
for i in range(4):
42+
for j in range(4):
43+
ax = axes[i, j]
44+
ax.imshow(samples[j * 4 + i])
45+
ax.axis('off')
46+
plt.tight_layout()
47+
plt.savefig('./mix_T.pdf', bbox_inches='tight')
48+
49+
50+
if __name__ == '__main__':
51+
main()

requirements.txt

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
matplotlib
2-
numpy
3-
scikit-image
4-
torch
5-
torchvision
1+
matplotlib>=3.3
2+
numpy>=1.19
3+
scikit-image>=0.17
4+
torch>=1.6
5+
torchvision>=0.7

0 commit comments

Comments
 (0)