Verified Commit 6d8a9417 authored by Julien Jerphanion's avatar Julien Jerphanion
Browse files

Use tf.float32 as a dtype for tf.Tensors

parent 6da5b067
Pipeline #143235 failed with stage
in 11 minutes and 40 seconds
......@@ -415,7 +415,7 @@ class SplineSphere(SplineSurface):
self._south_pole = points[-1]
self._coefs[2 * self.Mt:-2 * self.Mt] = points[1:-1]
if tf.is_tensor(points):
self._coefs = tf.convert_to_tensor(self._coefs)
self._coefs = tf.convert_to_tensor(self._coefs, dtype=tf.float32)
self._is_initialised = False
......@@ -551,7 +551,7 @@ class SplineSphere(SplineSurface):
# Support for GPU computations
if tf.is_tensor(self._coefs):
phi = tf.convert_to_tensor(phi)
phi = tf.convert_to_tensor(phi, dtype=tf.float32)
SplineSphere._phi[key] = phi
return phi
......@@ -658,7 +658,7 @@ class SplineSphere(SplineSurface):
# GPU support
if tf.is_tensor(self._coefs):
self._coefs = tf.convert_to_tensor(reshaped_coefs)
self._coefs = tf.convert_to_tensor(reshaped_coefs, dtype=tf.float32)
else:
self._coefs = reshaped_coefs
......@@ -743,7 +743,7 @@ class SplineSphere(SplineSurface):
# GPU support
if tf.is_tensor(sample_points):
self._coefs = tf.convert_to_tensor(self._coefs)
self._coefs = tf.convert_to_tensor(self._coefs, dtype=tf.float32)
vs0 = self._basis_s.value(0.0)
vs1 = self._basis_s.value(1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment