23
23
24
24
class GaussianModel :
25
25
26
- def setup_functions (self ):
26
+ def setup_functions (self , dtype ):
27
27
def build_covariance_from_scaling_rotation (scaling , scaling_modifier , rotation ):
28
- L = build_scaling_rotation (scaling_modifier * scaling , rotation )
28
+ L = build_scaling_rotation (scaling_modifier * scaling , rotation , dtype )
29
29
actual_covariance = L @ L .transpose (1 , 2 )
30
- symm = strip_symmetric (actual_covariance )
30
+ symm = strip_symmetric (actual_covariance , dtype )
31
31
return symm
32
32
33
33
self .scaling_activation = torch .exp
@@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
41
41
self .rotation_activation = torch .nn .functional .normalize
42
42
43
43
44
- def __init__ (self , sh_degree : int ):
44
+ def __init__ (self , sh_degree : int , dtype = torch . float32 ):
45
45
self .active_sh_degree = 0
46
46
self .max_sh_degree = sh_degree
47
47
self ._xyz = torch .empty (0 )
@@ -56,7 +56,8 @@ def __init__(self, sh_degree : int):
56
56
self .optimizer = None
57
57
self .percent_dense = 0
58
58
self .spatial_lr_scale = 0
59
- self .setup_functions ()
59
+ self .dtype = dtype
60
+ self .setup_functions (dtype )
60
61
61
62
def capture (self ):
62
63
return (
@@ -136,15 +137,15 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
136
137
rots = torch .zeros ((fused_point_cloud .shape [0 ], 4 ), device = "cuda" )
137
138
rots [:, 0 ] = 1
138
139
139
- opacities = inverse_sigmoid (0.1 * torch .ones ((fused_point_cloud .shape [0 ], 1 ), dtype = torch . float , device = "cuda" ))
140
+ opacities = inverse_sigmoid (0.1 * torch .ones ((fused_point_cloud .shape [0 ], 1 ), dtype = self . dtype , device = "cuda" ))
140
141
141
142
self ._xyz = nn .Parameter (fused_point_cloud .requires_grad_ (True ))
142
143
self ._features_dc = nn .Parameter (features [:,:,0 :1 ].transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
143
144
self ._features_rest = nn .Parameter (features [:,:,1 :].transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
144
145
self ._scaling = nn .Parameter (scales .requires_grad_ (True ))
145
146
self ._rotation = nn .Parameter (rots .requires_grad_ (True ))
146
147
self ._opacity = nn .Parameter (opacities .requires_grad_ (True ))
147
- self .max_radii2D = torch .zeros ((self .get_xyz .shape [0 ]), device = "cuda" )
148
+ self .max_radii2D = torch .zeros ((self .get_xyz .shape [0 ]), device = "cuda" , dtype = self . dtype )
148
149
149
150
def training_setup (self , training_args ):
150
151
self .percent_dense = training_args .percent_dense
@@ -246,12 +247,12 @@ def load_ply(self, path):
246
247
for idx , attr_name in enumerate (rot_names ):
247
248
rots [:, idx ] = np .asarray (plydata .elements [0 ][attr_name ])
248
249
249
- self ._xyz = nn .Parameter (torch .tensor (xyz , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
250
- self ._features_dc = nn .Parameter (torch .tensor (features_dc , dtype = torch . float , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
251
- self ._features_rest = nn .Parameter (torch .tensor (features_extra , dtype = torch . float , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
252
- self ._opacity = nn .Parameter (torch .tensor (opacities , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
253
- self ._scaling = nn .Parameter (torch .tensor (scales , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
254
- self ._rotation = nn .Parameter (torch .tensor (rots , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
250
+ self ._xyz = nn .Parameter (torch .tensor (xyz , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
251
+ self ._features_dc = nn .Parameter (torch .tensor (features_dc , dtype = self . dtype , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
252
+ self ._features_rest = nn .Parameter (torch .tensor (features_extra , dtype = self . dtype , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
253
+ self ._opacity = nn .Parameter (torch .tensor (opacities , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
254
+ self ._scaling = nn .Parameter (torch .tensor (scales , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
255
+ self ._rotation = nn .Parameter (torch .tensor (rots , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
255
256
256
257
self .active_sh_degree = self .max_sh_degree
257
258
0 commit comments