|
1 |
| -# RG-Flow: A hierarchical and explainable flow models based on renormalization group and sparse prior |
| 1 | +# RG-Flow |
2 | 2 |
|
3 |
| -The code requires Python >= 3.7 and PyTorch >= 1.6, with optional CUDA support. |
| 3 | +This repository contains the code for the paper "RG-Flow: A hierarchical and explainable flow model based on renormalization group and sparse prior". |
| 4 | + |
| 5 | +TODO: add the link and the bibtex entry to the paper |
| 6 | + |
| 7 | +# Dependencies |
| 8 | + |
| 9 | +The code requires `Python >= 3.7` and `PyTorch >= 1.6`, with optional CUDA support. Other dependencies can be installed via |
| 10 | + |
| 11 | +`pip install -r requirements.txt` |
| 12 | + |
| 13 | +# Running experiments |
| 14 | + |
| 15 | +`main.py` is the code for training the network. All adjustable arguments are stored in `args.py`, together with their default values when we were training on the CelebA datset. They can be displayed via `python main.py --help`: |
| 16 | + |
| 17 | +``` |
| 18 | +usage: main.py [-h] [--data {celeba32,celeba64,mnist32,cifar10,chair600}] [--data_path DATA_PATH] |
| 19 | + [--nchannels NCHANNELS] [--L L] [--prior {gaussian,laplace}] [--subnet {rnvp,ar}] |
| 20 | + [--kernel_size KERNEL_SIZE] [--nlayers NLAYERS] [--nresblocks NRESBLOCKS] |
| 21 | + [--nmlp NMLP] [--nhidden NHIDDEN] [--dtype {float32,float64}] |
| 22 | + [--batch_size BATCH_SIZE] [--lr LR] [--weight_decay WEIGHT_DECAY] [--epoch EPOCH] |
| 23 | + [--clip_grad CLIP_GRAD] [--no_stdout] [--print_step PRINT_STEP] |
| 24 | + [--save_epoch SAVE_EPOCH] [--keep_epoch KEEP_EPOCH] [--plot_epoch PLOT_EPOCH] |
| 25 | + [--cuda CUDA] [--out_infix OUT_INFIX] [-o OUT_DIR] |
| 26 | +
|
| 27 | +optional arguments: |
| 28 | + -h, --help show this help message and exit |
| 29 | +
|
| 30 | +dataset parameters: |
| 31 | + --data {celeba32,celeba64,mnist32,cifar10,chair600} |
| 32 | + dataset name |
| 33 | + --data_path DATA_PATH |
| 34 | + dataset path |
| 35 | + --nchannels NCHANNELS |
| 36 | + number of channels |
| 37 | + --L L edge length of images |
| 38 | +
|
| 39 | +network parameters: |
| 40 | + --prior {gaussian,laplace} |
| 41 | + prior of latent variables |
| 42 | + --subnet {rnvp,ar} type of subnet in an RG block |
| 43 | + --kernel_size KERNEL_SIZE |
| 44 | + edge length of an RG block |
| 45 | + --nlayers NLAYERS number of subnet layers in an RG block |
| 46 | + --nresblocks NRESBLOCKS |
| 47 | + number of residual blocks in a subnet layer |
| 48 | + --nmlp NMLP number of MLP hidden layers in an residual block |
| 49 | + --nhidden NHIDDEN width of MLP hidden layers |
| 50 | + --dtype {float32,float64} |
| 51 | + dtype |
| 52 | +
|
| 53 | +optimizer parameters: |
| 54 | + --batch_size BATCH_SIZE |
| 55 | + batch size |
| 56 | + --lr LR learning rate |
| 57 | + --weight_decay WEIGHT_DECAY |
| 58 | + weight decay |
| 59 | + --epoch EPOCH number of epoches |
| 60 | + --clip_grad CLIP_GRAD |
| 61 | + global norm to clip gradients, 0 for disabled |
| 62 | +
|
| 63 | +system parameters: |
| 64 | + --no_stdout do not print log to stdout, for better performance |
| 65 | + --print_step PRINT_STEP |
| 66 | + number of batches to print log, 0 for disabled |
| 67 | + --save_epoch SAVE_EPOCH |
| 68 | + number of epochs to save network weights, 0 for disabled |
| 69 | + --keep_epoch KEEP_EPOCH |
| 70 | + number of epochs to keep saved network weights, 0 for disabled |
| 71 | + --plot_epoch PLOT_EPOCH |
| 72 | + number of epochs to plot samples, 0 for disabled |
| 73 | + --cuda CUDA IDs of GPUs to use, empty for disabled |
| 74 | + --out_infix OUT_INFIX |
| 75 | + infix in output filename to distinguish repeated runs |
| 76 | + -o OUT_DIR, --out_dir OUT_DIR |
| 77 | + directory for output, empty for disabled |
| 78 | +``` |
| 79 | + |
| 80 | +During training, the log file and the network weights will be saved in `out_dir`. |
| 81 | + |
| 82 | +After the network is trained, `plot_mix_temperature.py` can be used to plot samples using mixed effective temperatures, described in Appendix B of the paper. |
| 83 | + |
| 84 | +# Gallery |
| 85 | + |
| 86 | +## RG-Flow structure |
| 87 | + |
| 88 | + |
| 89 | + |
| 90 | +## Random walk in high-level latent representations |
| 91 | + |
| 92 | + |
| 93 | + |
| 94 | +## Random walk in mid-level latent representations |
| 95 | + |
| 96 | + |
| 97 | + |
| 98 | +## Learned receptive fields |
| 99 | + |
| 100 | + |
| 101 | + |
| 102 | +## Learned factors |
| 103 | + |
| 104 | + |
| 105 | + |
| 106 | +### High-level factor: emotion |
| 107 | + |
| 108 | + |
| 109 | + |
| 110 | +### High-level factor: gender |
| 111 | + |
| 112 | + |
| 113 | + |
| 114 | +### Mid-level factor: light direction |
| 115 | + |
| 116 | + |
| 117 | + |
| 118 | +### Mid-level factor: rotation |
| 119 | + |
| 120 | + |
| 121 | + |
| 122 | +## Face mixing in the scaling direction |
| 123 | + |
| 124 | + |
0 commit comments