Created
January 24, 2019 17:40
-
-
Save aeftimia/a5249168c84bc541ace2fc4e1d22a13e to your computer and use it in GitHub Desktop.
Keras style orthogonality constraint
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy | |
| import tensorflow | |
| from keras.constraints import Constraint | |
| from tensorflow.linalg import expm | |
| class Orthogonal(Constraint): | |
| """Orthogonal weight constraint. | |
| Constrains the weights incident to each hidden unit | |
| to be orthogonal when there are more inputs than hidden units. | |
| When there are more hidden units than there are inputs, | |
| the rows of the layer's weight matrix are constrainted | |
| to be orthogonal. | |
| # Arguments | |
| axis: Axis or axes along which to calculate weight norms. | |
| `None` to use all but the last (output) axis. | |
| For instance, in a `Dense` layer the weight matrix | |
| has shape `(input_dim, output_dim)`, | |
| set `axis` to `0` to constrain each weight vector | |
| of length `(input_dim,)`. | |
| In a `Conv2D` layer with `data_format="channels_last"`, | |
| the weight tensor has shape | |
| `(rows, cols, input_depth, output_depth)`, | |
| set `axis` to `[0, 1, 2]` | |
| to constrain the weights of each filter tensor of size | |
| `(rows, cols, input_depth)`. | |
| orthonormal: If `True`, the weight matrix is further | |
| constrained to be orthonormal along the appropriate axis. | |
| """ | |
| def __init__(self, axis=None, orthonormal=False): | |
| self.axis = axis | |
| self.orthonormal = orthonormal | |
| def __call__(self, w): | |
| if self.axis is None: | |
| self.axis = list(range(len(w.shape) - 1)) | |
| elif type(self.axis) == int: | |
| self.axis = [self.axis] | |
| else: | |
| self.axis = numpy.asarray(self.axis, dtype='uint8') | |
| self.axis = list(self.axis) | |
| axis_shape = [w.shape[a] for a in self.axis] | |
| perm = [i for i in range(len(w.shape) - 1) if i not in self.axis] | |
| perm.extend(self.axis) | |
| perm.append(len(w.shape) - 1) | |
| w = tensorflow.transpose(w, perm=perm) | |
| shape = w.shape | |
| w = tensorflow.reshape(w, [-1] + axis_shape + [shape[-1]]) | |
| w = tensorflow.map_fn(self.orthogonalize, w) | |
| w = tensorflow.reshape(w, shape) | |
| w = tensorflow.transpose(w, perm=numpy.argsort(perm)) | |
| return w | |
| def orthogonalize(self, w): | |
| shape = w.shape | |
| output_shape = int(shape[-1]) | |
| input_shape = int(numpy.prod(shape[:-1])) | |
| final_shape = int(max(input_shape, output_shape)) | |
| w_matrix = tensorflow.reshape(w, (output_shape, input_shape)) | |
| w_matrix = tensorflow.pad(w_matrix, | |
| tensorflow.constant([ | |
| [0, final_shape - output_shape], | |
| [0, final_shape - input_shape] | |
| ])) | |
| upper_triangular = tensorflow.matrix_band_part(w_matrix, 1, -1) | |
| antisymmetric = upper_triangular - tensorflow.transpose(upper_triangular) | |
| rotation = expm(antisymmetric) | |
| w_final = tensorflow.slice(rotation, [0,] * 2, [output_shape, input_shape]) | |
| if not self.orthonormal: | |
| if input_shape >= output_shape: | |
| w_final = tensorflow.matmul(w_final, | |
| tensorflow.matrix_band_part( | |
| tensorflow.slice(w_matrix, | |
| [0, 0], | |
| [input_shape, input_shape]), | |
| 0, 0)) | |
| else: | |
| w_final = tensorflow.matmul(tensorflow.matrix_band_part( | |
| tensorflow.slice(w_matrix, | |
| [0, 0], | |
| [output_shape, output_shape]), | |
| 0, 0), w_final) | |
| return tensorflow.reshape(w_final, w.shape) | |
| def get_config(self): | |
| return {'axis': self.axis, | |
| 'orthonormal': self.orthonormal} |
Author
Please feel free to use this code however you want.
What's the licence ?
Could you please give me an example of using this code. Many thanks!
Thanks for this code!
A question: in most papers (amongst the ones I read, at least), I usually see orthogonal constraint being expressed like this:

So why do I need expm? Shouldn't the following code be sufficient? Any explanation/references are greatly appreciated. :)
C = 1e-3
def orthogonal_reg(w): # 1703.01827
units = w.shape[-1]
w = tf.reshape(w, (-1, units))
w = tf.transpose(w) @ w
return (C/2)*tf.linalg.norm(w - tf.eye(units))Please feel free to use this code however you want.
Also I have question with regard to the paper to ref the method, as @lucasdavid also mentioned
Hi,
Thank you it's so useful. Does it work with complex numbers?
Author
Yes, the technique should generalize
…On Sat, Apr 22, 2023 at 8:58 AM payami20 ***@***.***> wrote:
***@***.**** commented on this gist.
------------------------------
Hi,
Thank you it's so useful. Does it work with complex numbers?
—
Reply to this email directly, view it on GitHub
<https://gist.github.com/aeftimia/a5249168c84bc541ace2fc4e1d22a13e#gistcomment-4545231>
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDLXXVWWNVX5WDKFT3TVTLXCP5ZTBFKMF2HI4TJMJ2XIZLTSKBKK5TBNR2WLJDHNFZXJJDOMFWWLK3UNBZGKYLEL52HS4DFQKSXMYLMOVS2I5DSOVS2I3TBNVS3W5DIOJSWCZC7OBQXE5DJMNUXAYLOORPWCY3UNF3GS5DZVRZXKYTKMVRXIX3UPFYGLK2HNFZXIQ3PNVWWK3TUUZ2G64DJMNZZDAVEOR4XAZNEM5UXG5FFOZQWY5LFVA4TIMRWGAYDKOFHORZGSZ3HMVZKMY3SMVQXIZI>
.
You are receiving this email because you authored the thread.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
What's the licence ?