@@ -51,7 +51,7 @@ def approx(
51
51
nps_tf .dtype64 = tf .float64
52
52
53
53
54
- @pytest .fixture (params = [nps_torch , nps_tf ], scope = "module" )
54
+ @pytest .fixture (params = [nps_tf , nps_torch ], scope = "module" )
55
55
def nps (request ):
56
56
return request .param
57
57
@@ -64,14 +64,17 @@ def generate_data(
64
64
n_context = 5 ,
65
65
n_target = 7 ,
66
66
binary = False ,
67
+ dtype = None ,
67
68
):
68
- xc = B .randn (nps .dtype , batch_size , dim_x , n_context )
69
- yc = B .randn (nps .dtype , batch_size , dim_y , n_context )
70
- xt = B .randn (nps .dtype , batch_size , dim_x , n_target )
71
- yt = B .randn (nps .dtype , batch_size , dim_y , n_target )
69
+ if dtype is None :
70
+ dtype = nps .dtype
71
+ xc = B .randn (dtype , batch_size , dim_x , n_context )
72
+ yc = B .randn (dtype , batch_size , dim_y , n_context )
73
+ xt = B .randn (dtype , batch_size , dim_x , n_target )
74
+ yt = B .randn (dtype , batch_size , dim_y , n_target )
72
75
if binary :
73
- yc = B .cast (nps . dtype , yc >= 0 )
74
- yt = B .cast (nps . dtype , yt >= 0 )
76
+ yc = B .cast (dtype , yc >= 0 )
77
+ yt = B .cast (dtype , yt >= 0 )
75
78
return xc , yc , xt , yt
76
79
77
80
0 commit comments