Fix inconsistent training results with RGBA/PNG images #1193
+15
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Issue summary
The training relies on PIL to resize the input images and extracts the resized alpha to mask the rendered image during training. Since PIL pre-multiplies the resized RGB with the resized alpha, the training produces different Gaussian points depending on whether the input get resized or not. Moreover, the extracted alpha channel from PIL is not perfectly binarized, causing floaters around the edges. The issue has been going around in #1039, #1121, and #1114 since they trained with either PNG images or a dataset containing masks in the 4th channel (preprocessed DTU, NeRF Synthetic).
The fix is self-contained in the
PILtoTorch
function. It checks if the input is of type RGBA and manually masks the RGB channels. This alpha channel is then discarded and the process continues as if the input was RGB, making the alpha multiplication step in the train script a no-op.Details
In the current commit, here's how a ground truth RGBA is treated during the training:
PIL.Image.Image.resize
in thePILtoTorch
function.gt_image
intrain.py
(viaCamera.original_image
).alpha_mask
. This mask is then multiplied with the renderedimage
intrain.py
and the loss is called on thegt_image
and the maskedimage
.If the input RGBA is actually resized in
PILtoTorch
(theresolution
param is different from the image's resolution), PIL automatically multiplies the resized RGB with the resized alpha:resize
resize
This creates two different scenarios:
-r
flag), the RGB ground truth is the original image without masking, and the savedalpha_mask
is perfectly binarized.alpha_mask
is distorted along edges.Scenario 1: no resize
The Gaussian points undergoes tension during training since they get masked before getting fed into the loss but the ground truth is the original image:
Scenario 2: RGBA is resized
The resized
alpha_mask
is not perfectly binarized along the edge due to interpolation. This imperfect mask is multiplied with the rendered image, causing floaters:-r 2
(Iter 7000)The fix
To minimize the modification, when
PILtoTorch
encounters RGBA, we manually extract and mask the RGB channels and let the input become this new masked RGB. The remaining logic is as-is and the alpha multiplication step intrain.py
becomes no-op.-r 2
(Iter 7000)Test environment
3.9
2.4.0
12.4
19.43
Notes
The
render.py
might need fixed to export the masked GT (rather than the original RGB) when running on trained model with original resolution settings (no- r
).