Tensorflow 2 – Changing the Datatype of a Tensor

Last updated on August 30, 2021 A Goodman Loading... Post a comment

In Tensorflow 2, you can cast the datatype of a tensor to a new datatype by using the tf.cast function. The examples below will show you more clearly.

Example 1

The code:

import tensorflow as tf
x1 = tf.constant([
    [1.1, 2.2],
    [3.3, 4.4]
], dtype=tf.float16
)

y1 = tf.cast(x1, dtype=tf.float32)
print(y1.dtype)

Output:

<dtype: 'float32'>

Example 2

The code:

import tensorflow as tf
import numpy as np

array = np.random.rand(3, 5)
array = np.multiply(array, 100)
tensor = tf.constant(array) # dtype = float64
tensor = tf.cast(tensor, tf.int32)
print(tensor)

Output:

tf.Tensor(
[[42 94 25 57 19]
 [41 44 90 71 45]
 [30 32 43 87 19]], shape=(3, 5), dtype=int32)

You can find more details about the tf.cast function in Tensorflow’s official API docs. Happy coding.

Subscribe
Notify of
guest
0 Comments
Inline Feedbacks
View all comments

You May Also Like