|
| 1 | +/* |
| 2 | + * ------------------------------------------------------------------------ |
| 3 | + * |
| 4 | + * Copyright by KNIME GmbH, Konstanz, Germany |
| 5 | + * Website: http://www.knime.org; Email: [email protected] |
| 6 | + * |
| 7 | + * This program is free software; you can redistribute it and/or modify |
| 8 | + * it under the terms of the GNU General Public License, Version 3, as |
| 9 | + * published by the Free Software Foundation. |
| 10 | + * |
| 11 | + * This program is distributed in the hope that it will be useful, but |
| 12 | + * WITHOUT ANY WARRANTY; without even the implied warranty of |
| 13 | + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 14 | + * GNU General Public License for more details. |
| 15 | + * |
| 16 | + * You should have received a copy of the GNU General Public License |
| 17 | + * along with this program; if not, see <http://www.gnu.org/licenses>. |
| 18 | + * |
| 19 | + * Additional permission under GNU GPL version 3 section 7: |
| 20 | + * |
| 21 | + * KNIME interoperates with ECLIPSE solely via ECLIPSE's plug-in APIs. |
| 22 | + * Hence, KNIME and ECLIPSE are both independent programs and are not |
| 23 | + * derived from each other. Should, however, the interpretation of the |
| 24 | + * GNU GPL Version 3 ("License") under any applicable laws result in |
| 25 | + * KNIME and ECLIPSE being a combined program, KNIME GMBH herewith grants |
| 26 | + * you the additional permission to use and propagate KNIME together with |
| 27 | + * ECLIPSE with only the license terms in place for ECLIPSE applying to |
| 28 | + * ECLIPSE and the GNU GPL Version 3 applying for KNIME, provided the |
| 29 | + * license terms of ECLIPSE themselves allow for the respective use and |
| 30 | + * propagation of ECLIPSE together with KNIME. |
| 31 | + * |
| 32 | + * Additional permission relating to nodes for KNIME that extend the Node |
| 33 | + * Extension (and in particular that are based on subclasses of NodeModel, |
| 34 | + * NodeDialog, and NodeView) and that only interoperate with KNIME through |
| 35 | + * standard APIs ("Nodes"): |
| 36 | + * Nodes are deemed to be separate and independent programs and to not be |
| 37 | + * covered works. Notwithstanding anything to the contrary in the |
| 38 | + * License, the License does not apply to Nodes, you are not required to |
| 39 | + * license Nodes under the License, and you are granted a license to |
| 40 | + * prepare and propagate Nodes, in each case even if such Nodes are |
| 41 | + * propagated with or for interoperation with KNIME. The owner of a Node |
| 42 | + * may freely choose the license terms applicable to such Node, including |
| 43 | + * when such Node is propagated with or for interoperation with KNIME. |
| 44 | + * --------------------------------------------------------------------- |
| 45 | + * |
| 46 | + * History |
| 47 | + * May 23, 2017 (marcel): created |
| 48 | + */ |
| 49 | +package org.knime.dl.keras.testing; |
| 50 | + |
| 51 | +import java.io.IOException; |
| 52 | +import java.net.MalformedURLException; |
| 53 | +import java.net.URL; |
| 54 | +import java.nio.file.InvalidPathException; |
| 55 | +import java.util.Collections; |
| 56 | +import java.util.HashMap; |
| 57 | +import java.util.Map; |
| 58 | +import java.util.Set; |
| 59 | +import java.util.function.Consumer; |
| 60 | + |
| 61 | +import org.junit.Test; |
| 62 | +import org.knime.core.util.FileUtil; |
| 63 | +import org.knime.dl.core.DLLayerData; |
| 64 | +import org.knime.dl.core.DLLayerDataSpec; |
| 65 | +import org.knime.dl.core.data.writables.DLWritableFloatBuffer; |
| 66 | +import org.knime.dl.core.execution.DLLayerDataInput; |
| 67 | +import org.knime.dl.core.execution.DLLayerDataOutput; |
| 68 | +import org.knime.dl.keras.core.DLKerasDefaultBackend; |
| 69 | +import org.knime.dl.keras.core.DLKerasExecutableNetwork; |
| 70 | +import org.knime.dl.keras.core.DLKerasExecutableNetwork.DLKerasExecutableNetworkSpec; |
| 71 | +import org.knime.dl.keras.core.DLKerasNetwork; |
| 72 | +import org.knime.dl.keras.core.io.DLKerasNetworkReader; |
| 73 | +import org.knime.dl.util.DLUtils; |
| 74 | + |
| 75 | +/** |
| 76 | + * |
| 77 | + * @author Marcel Wiedenmann, KNIME, Konstanz, Germany |
| 78 | + * @author Christian Dietz, KNIME, Konstanz, Germany |
| 79 | + */ |
| 80 | +public class DLKerasNetworkExecutor1To1Test { |
| 81 | + |
| 82 | + private static final String BUNDLE_ID = "org.knime.dl.keras.testing"; |
| 83 | + |
| 84 | + @Test |
| 85 | + public void test() throws IOException, InvalidPathException, MalformedURLException { |
| 86 | + final URL source = |
| 87 | + FileUtil.toURL(DLUtils.Files.getFileFromBundle(BUNDLE_ID, "data/my_2d_input_model.h5").getAbsolutePath()); |
| 88 | + |
| 89 | + final DLKerasDefaultBackend backend = new DLKerasDefaultBackend(); |
| 90 | + final DLKerasNetworkReader reader = backend.createReader(); |
| 91 | + DLKerasNetwork network; |
| 92 | + try { |
| 93 | + network = reader.readNetwork(source); |
| 94 | + } catch (IllegalArgumentException | IOException e) { |
| 95 | + throw new RuntimeException(e); |
| 96 | + } |
| 97 | + final DLKerasExecutableNetwork execNetwork = backend.toExecutableNetwork(network); |
| 98 | + final DLKerasExecutableNetworkSpec execSpec = execNetwork.getSpec(); |
| 99 | + |
| 100 | + final HashMap<DLLayerDataSpec, ?> inputs = new HashMap<>(execSpec.getInputSpecs().length); |
| 101 | + for (final DLLayerDataSpec inputSpec : execSpec.getInputSpecs()) { |
| 102 | + final DLLayerDataInput<?> input = execNetwork.getInputForSpec(inputSpec, 1); |
| 103 | + populate(input.getBatch()[0]); |
| 104 | + } |
| 105 | + final Set<DLLayerDataSpec> selectedOutputs = Collections.singleton(execSpec.getOutputSpecs()[0]); |
| 106 | + |
| 107 | + execNetwork.execute(selectedOutputs, new Consumer<Map<DLLayerDataSpec, DLLayerDataOutput<?>>>() { |
| 108 | + |
| 109 | + @Override |
| 110 | + public void accept(final Map<DLLayerDataSpec, DLLayerDataOutput<?>> t) { |
| 111 | + // TODO: test against known results - this is sth. that should rather be tested via a test workflow |
| 112 | + } |
| 113 | + }); |
| 114 | + } |
| 115 | + |
| 116 | + private static void populate(final DLLayerData<?> data) { |
| 117 | + if (data.getBuffer() instanceof DLWritableFloatBuffer) { |
| 118 | + final DLWritableFloatBuffer buffer = (DLWritableFloatBuffer)data.getBuffer(); |
| 119 | + buffer.resetWrite(); |
| 120 | + for (int i = 0; i < buffer.getCapacity(); i++) { |
| 121 | + buffer.put(5f); |
| 122 | + } |
| 123 | + } else { |
| 124 | + throw new IllegalStateException("Unexpected input buffer type."); |
| 125 | + } |
| 126 | + } |
| 127 | +} |
0 commit comments