Skip to content

Commit e696c09

Browse files
committed
Don't use groups for transposed convs for Keras
1 parent a862c99 commit e696c09

File tree

1 file changed

+11
-1
lines changed
  • neuralprocesses/tensorflow

1 file changed

+11
-1
lines changed

neuralprocesses/tensorflow/nn.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import partial
23
from typing import Optional, Union
34

@@ -119,14 +120,23 @@ def ConvNd(
119120
else:
120121
suffix = ""
121122

123+
if groups > 1:
124+
if transposed:
125+
warnings.warn(
126+
"Keras does not depthwise separable transposed convolutions! "
127+
"Using non-separable convolutions for the transposed convolutions. "
128+
"This could be a LOT more expensive."
129+
)
130+
else:
131+
additional_args["groups"] = groups
132+
122133
conv_layer = getattr(tf.keras.layers, f"Conv{dim}D{suffix}")(
123134
input_shape=(in_channels,) + (None,) * dim,
124135
filters=out_channels,
125136
kernel_size=kernel,
126137
strides=stride,
127138
padding="same",
128139
dilation_rate=dilation,
129-
groups=groups,
130140
use_bias=bias,
131141
data_format=data_format,
132142
dtype=dtype,

0 commit comments

Comments
 (0)