Skip to content

Instantly share code, notes, and snippets.

@thomasjpfan
Created February 27, 2023 21:18
Show Gist options
  • Select an option

  • Save thomasjpfan/4e02ec23f7b63b0797cc145ae18d76db to your computer and use it in GitHub Desktop.

Select an option

Save thomasjpfan/4e02ec23f7b63b0797cc145ae18d76db to your computer and use it in GitHub Desktop.
diff --git a/sklearn/utils/_weight_vector.pxd.tp b/sklearn/utils/_weight_vector.pxd.tp
index 9d1779373c..f5e3e4af5a 100644
--- a/sklearn/utils/_weight_vector.pxd.tp
+++ b/sklearn/utils/_weight_vector.pxd.tp
@@ -27,11 +27,12 @@ cdef class WeightVector{{name_suffix}}(object):
cdef readonly {{c_type}}[::1] aw
cdef {{c_type}} *w_data_ptr
cdef {{c_type}} *aw_data_ptr
- cdef {{c_type}} wscale
- cdef {{c_type}} average_a
- cdef {{c_type}} average_b
+
+ cdef double wscale
+ cdef double average_a
+ cdef double average_b
cdef int n_features
- cdef {{c_type}} sq_norm
+ cdef double sq_norm
cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr,
int xnnz, {{c_type}} c) noexcept nogil
diff --git a/sklearn/utils/_weight_vector.pyx.tp b/sklearn/utils/_weight_vector.pyx.tp
index e2d374813a..caa992d4ad 100644
--- a/sklearn/utils/_weight_vector.pyx.tp
+++ b/sklearn/utils/_weight_vector.pyx.tp
@@ -99,8 +99,8 @@ cdef class WeightVector{{name_suffix}}(object):
cdef int j
cdef int idx
cdef {{c_type}} val
- cdef {{c_type}} innerprod = 0.0
- cdef {{c_type}} xsqnorm = 0.0
+ cdef double innerprod = 0.0
+ cdef double xsqnorm = 0.0
# the next two lines save a factor of 2!
cdef {{c_type}} wscale = self.wscale
@@ -139,8 +139,8 @@ cdef class WeightVector{{name_suffix}}(object):
cdef int idx
cdef {{c_type}} val
cdef {{c_type}} mu = 1.0 / num_iter
- cdef {{c_type}} average_a = self.average_a
- cdef {{c_type}} wscale = self.wscale
+ cdef double average_a = self.average_a
+ cdef double wscale = self.wscale
cdef {{c_type}}* aw_data_ptr = self.aw_data_ptr
for j in range(xnnz):
@@ -174,7 +174,7 @@ cdef class WeightVector{{name_suffix}}(object):
"""
cdef int j
cdef int idx
- cdef {{c_type}} innerprod = 0.0
+ cdef double innerprod = 0.0
cdef {{c_type}}* w_data_ptr = self.w_data_ptr
for j in range(xnnz):
idx = x_ind_ptr[j]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment