Skip to content

Commit 18eb6d6

Browse files
PerhapsS44Matei Barbu
authored and
Matei Barbu
committed
added support for setting floating point range
Users may want to reduce their memory consumption by using fp16. However, in my tests, such attempts will result in lower quality renders. Some data type conversions did not have any impact, so I removed them completely.
1 parent 472689c commit 18eb6d6

File tree

9 files changed

+53
-29
lines changed

9 files changed

+53
-29
lines changed

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
194194
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
195195
#### --percent_dense
196196
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
197+
#### --data_dtype
198+
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
197199

198200
</details>
199201
<br>

arguments/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, parser, sentinel=False):
5353
self._resolution = -1
5454
self._white_background = False
5555
self.data_device = "cuda"
56+
self.data_dtype = "float32"
5657
self.eval = False
5758
super().__init__(parser, "Loading Parameters", sentinel)
5859

gaussian_renderer/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
9292
rotations = rotations,
9393
cov3D_precomp = cov3D_precomp)
9494

95+
# after rasterization, we convert the resulting image to the target dtype
96+
# The rasterizer expects parameters as float32, so the result is also float32.
97+
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype)
98+
9599
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
96100
# They will be excluded from value updates used in the splitting criteria.
97101
return {"render": rendered_image,

render.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
3434
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
3535
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
3636

37-
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
37+
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32):
3838
with torch.no_grad():
39-
gaussians = GaussianModel(dataset.sh_degree)
39+
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype)
4040
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
4141

4242
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
43-
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
43+
background = torch.tensor(bg_color, dtype=dtype, device="cuda")
4444

4545
if not skip_train:
4646
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
@@ -62,5 +62,7 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam
6262

6363
# Initialize system state (RNG)
6464
safe_state(args.quiet)
65+
66+
dtype = torch.float32
6567

66-
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
68+
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)

scene/cameras.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Camera(nn.Module):
1818
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
1919
image_name, uid,
20-
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
20+
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32
2121
):
2222
super(Camera, self).__init__()
2323

@@ -28,6 +28,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
2828
self.FoVx = FoVx
2929
self.FoVy = FoVy
3030
self.image_name = image_name
31+
self.data_dtype = data_dtype
3132

3233
try:
3334
self.data_device = torch.device(data_device)
@@ -36,12 +37,12 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3637
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
3738
self.data_device = torch.device("cuda")
3839

39-
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
40+
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device)
4041
self.image_width = self.original_image.shape[2]
4142
self.image_height = self.original_image.shape[1]
4243

4344
if gt_alpha_mask is not None:
44-
self.original_image *= gt_alpha_mask.to(self.data_device)
45+
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device)
4546
else:
4647
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
4748

scene/gaussian_model.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323

2424
class GaussianModel:
2525

26-
def setup_functions(self):
26+
def setup_functions(self, dtype):
2727
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
28-
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
28+
L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype)
2929
actual_covariance = L @ L.transpose(1, 2)
30-
symm = strip_symmetric(actual_covariance)
30+
symm = strip_symmetric(actual_covariance, dtype)
3131
return symm
3232

3333
self.scaling_activation = torch.exp
@@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
4141
self.rotation_activation = torch.nn.functional.normalize
4242

4343

44-
def __init__(self, sh_degree : int):
44+
def __init__(self, sh_degree : int, dtype=torch.float32):
4545
self.active_sh_degree = 0
4646
self.max_sh_degree = sh_degree
4747
self._xyz = torch.empty(0)
@@ -56,7 +56,8 @@ def __init__(self, sh_degree : int):
5656
self.optimizer = None
5757
self.percent_dense = 0
5858
self.spatial_lr_scale = 0
59-
self.setup_functions()
59+
self.dtype = dtype
60+
self.setup_functions(dtype)
6061

6162
def capture(self):
6263
return (
@@ -136,15 +137,15 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
136137
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
137138
rots[:, 0] = 1
138139

139-
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
140+
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda"))
140141

141142
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
142143
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
143144
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
144145
self._scaling = nn.Parameter(scales.requires_grad_(True))
145146
self._rotation = nn.Parameter(rots.requires_grad_(True))
146147
self._opacity = nn.Parameter(opacities.requires_grad_(True))
147-
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
148+
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype)
148149

149150
def training_setup(self, training_args):
150151
self.percent_dense = training_args.percent_dense
@@ -246,12 +247,12 @@ def load_ply(self, path):
246247
for idx, attr_name in enumerate(rot_names):
247248
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
248249

249-
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
250-
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
251-
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
252-
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
253-
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
254-
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
250+
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True))
251+
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
252+
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
253+
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True))
254+
self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True))
255+
self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True))
255256

256257
self.active_sh_degree = self.max_sh_degree
257258

train.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from gaussian_renderer import render, network_gui
1717
import sys
1818
from scene import Scene, GaussianModel
19-
from utils.general_utils import safe_state
19+
from utils.general_utils import get_data_dtype, safe_state
2020
import uuid
2121
from tqdm import tqdm
2222
from utils.image_utils import psnr
@@ -216,6 +216,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
216216
# Start GUI server, configure and run training
217217
network_gui.init(args.ip, args.port)
218218
torch.autograd.set_detect_anomaly(args.detect_anomaly)
219+
219220
training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)
220221

221222
# All done

utils/camera_utils.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from scene.cameras import Camera
1313
import numpy as np
14-
from utils.general_utils import PILtoTorch
14+
from utils.general_utils import PILtoTorch, get_data_dtype
1515
from utils.graphics_utils import fov2focal
1616

1717
WARNED = False
@@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale):
3939
resolution = (int(orig_w / scale), int(orig_h / scale))
4040

4141
resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42+
43+
# resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype))
4244

4345
gt_image = resized_image_rgb[:3, ...]
4446
loaded_mask = None
@@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale):
4951
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
5052
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
5153
image=gt_image, gt_alpha_mask=loaded_mask,
52-
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
54+
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
55+
data_dtype=get_data_dtype(args.data_dtype))
5356

5457
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
5558
camera_list = []

utils/general_utils.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def helper(step):
6161

6262
return helper
6363

64-
def strip_lowerdiag(L):
65-
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
64+
def strip_lowerdiag(L, dtype=torch.float32):
65+
uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda")
6666

6767
uncertainty[:, 0] = L[:, 0, 0]
6868
uncertainty[:, 1] = L[:, 0, 1]
@@ -72,8 +72,8 @@ def strip_lowerdiag(L):
7272
uncertainty[:, 5] = L[:, 2, 2]
7373
return uncertainty
7474

75-
def strip_symmetric(sym):
76-
return strip_lowerdiag(sym)
75+
def strip_symmetric(sym, dtype=torch.float32):
76+
return strip_lowerdiag(sym, dtype=dtype)
7777

7878
def build_rotation(r):
7979
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
@@ -98,8 +98,8 @@ def build_rotation(r):
9898
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
9999
return R
100100

101-
def build_scaling_rotation(s, r):
102-
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
101+
def build_scaling_rotation(s, r, dtype=torch.float32):
102+
L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda")
103103
R = build_rotation(r)
104104

105105
L[:,0,0] = s[:,0]
@@ -131,3 +131,12 @@ def flush(self):
131131
np.random.seed(0)
132132
torch.manual_seed(0)
133133
torch.cuda.set_device(torch.device("cuda:0"))
134+
135+
def get_data_dtype(dtype):
136+
if dtype == "float32":
137+
return torch.float32
138+
elif dtype == "float64":
139+
return torch.float64
140+
elif dtype == "float16":
141+
return torch.float16
142+
return torch.float32

0 commit comments

Comments
 (0)