From fd008687d97e91c508fc055734e6c3e3f48c244f Mon Sep 17 00:00:00 2001 From: dskkato Date: Sat, 10 Jul 2021 16:59:56 +0900 Subject: [PATCH 1/2] update addition example's python code for tf2.5 --- examples/addition.rs | 2 +- examples/addition/addition.py | 24 +++++++++++++++--------- examples/addition/model.pb | Bin 136 -> 214 bytes 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/addition.rs b/examples/addition.rs index f832c4b346..0e9639b840 100644 --- a/examples/addition.rs +++ b/examples/addition.rs @@ -49,7 +49,7 @@ fn main() -> Result<(), Box> { let mut args = SessionRunArgs::new(); args.add_feed(&graph.operation_by_name_required("x")?, 0, &x); args.add_feed(&graph.operation_by_name_required("y")?, 0, &y); - let z = args.request_fetch(&graph.operation_by_name_required("z")?, 0); + let z = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0); session.run(&mut args)?; // Check our results. diff --git a/examples/addition/addition.py b/examples/addition/addition.py index bc5f869e20..4a220e6b95 100644 --- a/examples/addition/addition.py +++ b/examples/addition/addition.py @@ -1,14 +1,20 @@ -# TODO: Stop using v1 compatibility -import tensorflow.compat.v1 as tf +import tensorflow as tf +# check tensorflow version is 2.x +tf_major_version = tf.__version__.split('.')[0] +assert tf_major_version == '2' -tf.disable_eager_execution() -x = tf.placeholder(tf.int32, name = 'x') -y = tf.placeholder(tf.int32, name = 'y') -z = tf.add(x, y, name = 'z') +@tf.function +def add(x, y): + return tf.add(x, y) -tf.variables_initializer(tf.global_variables(), name = 'init') +x = tf.TensorSpec((), dtype=tf.dtypes.int32, name='x') +y = tf.TensorSpec((), dtype=tf.dtypes.int32, name='y') -definition = tf.Session().graph_def +concrete_function = add.get_concrete_function(x, y) directory = 'examples/addition' -tf.train.write_graph(definition, directory, 'model.pb', as_text=False) +tf.io.write_graph(concrete_function.graph, directory, 'model.pb', as_text=False) + +# check inputs/outputs node names to refer from Rust later on +print(f'input nodes : {concrete_function.inputs}') +print(f'output nodes : {concrete_function.outputs}') \ No newline at end of file diff --git a/examples/addition/model.pb b/examples/addition/model.pb index 19330df59951193e08b1182fc8f34e8f39744f49..85135b240e89010f17910f4a9e5baa1436eb3450 100644 GIT binary patch literal 214 zcmd;b=VGi7;tt43OisH94iJcph1;H8&rwMAW1G}$CMNy5Fy1_A;nm!#m>bT01;tt43Oiseti-(IPGcU75h{Z48zd(tZ;|wbRQ9m7o From 4850f7d3f40d527cacc9dbaac7d66212bd1fb2c7 Mon Sep 17 00:00:00 2001 From: dskkato Date: Mon, 12 Jul 2021 14:37:04 +0900 Subject: [PATCH 2/2] rename add operation to z --- examples/addition.rs | 2 +- examples/addition/addition.py | 6 +----- examples/addition/model.pb | Bin 214 -> 176 bytes 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/addition.rs b/examples/addition.rs index 0e9639b840..f832c4b346 100644 --- a/examples/addition.rs +++ b/examples/addition.rs @@ -49,7 +49,7 @@ fn main() -> Result<(), Box> { let mut args = SessionRunArgs::new(); args.add_feed(&graph.operation_by_name_required("x")?, 0, &x); args.add_feed(&graph.operation_by_name_required("y")?, 0, &y); - let z = args.request_fetch(&graph.operation_by_name_required("Identity")?, 0); + let z = args.request_fetch(&graph.operation_by_name_required("z")?, 0); session.run(&mut args)?; // Check our results. diff --git a/examples/addition/addition.py b/examples/addition/addition.py index 4a220e6b95..859c91fae6 100644 --- a/examples/addition/addition.py +++ b/examples/addition/addition.py @@ -6,7 +6,7 @@ @tf.function def add(x, y): - return tf.add(x, y) + tf.add(x, y, name='z') x = tf.TensorSpec((), dtype=tf.dtypes.int32, name='x') y = tf.TensorSpec((), dtype=tf.dtypes.int32, name='y') @@ -14,7 +14,3 @@ def add(x, y): concrete_function = add.get_concrete_function(x, y) directory = 'examples/addition' tf.io.write_graph(concrete_function.graph, directory, 'model.pb', as_text=False) - -# check inputs/outputs node names to refer from Rust later on -print(f'input nodes : {concrete_function.inputs}') -print(f'output nodes : {concrete_function.outputs}') \ No newline at end of file diff --git a/examples/addition/model.pb b/examples/addition/model.pb index 85135b240e89010f17910f4a9e5baa1436eb3450..1a45a6aeae49d15a8af9a39cf9e6c35d67941163 100644 GIT binary patch delta 36 rcmcb{xPfuPWLa@8#wsCZ$CMN)#tJFMN-cIS#t