Skip to content

Instantly share code, notes, and snippets.

@jackalcooper
Created February 20, 2022 02:44
Show Gist options
  • Save jackalcooper/1dfefbe6dbc8c8a2138d92082d20f448 to your computer and use it in GitHub Desktop.
Save jackalcooper/1dfefbe6dbc8c8a2138d92082d20f448 to your computer and use it in GitHub Desktop.
from jax import random, pmap
import jax.numpy as jnp
from jax import grad, jit, vmap
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
x = jnp.ones((2, 2))
@pmap
def f(x):
y = jnp.sin(x)
@pmap
def g(z):
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
print(jit(f).lower(x).compiler_ir(dialect='mhlo'))
@jackalcooper
Copy link
Author

module @jit_f.1 {
  func public @main(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
    %0 = mhlo.constant dense<0> : tensor<ui32>
    %1 = mhlo.constant dense<2> : tensor<ui32>
    %2 = mhlo.constant dense<2> : tensor<ui32>
    %3 = "mhlo.replica_id"() : () -> tensor<ui32>
    %4 = mhlo.divide %3, %1 : tensor<ui32>
    %5 = mhlo.remainder %4, %2 : tensor<ui32>
    %6 = "mhlo.dynamic-slice"(%arg0, %5, %0) {slice_sizes = dense<[1, 2]> : tensor<2xi64>} : (tensor<2x2xf32>, tensor<ui32>, tensor<ui32>) -> tensor<1x2xf32>
    %7 = "mhlo.reshape"(%6) : (tensor<1x2xf32>) -> tensor<2xf32>
    %8 = "mhlo.sine"(%7) : (tensor<2xf32>) -> tensor<2xf32>
    %9 = mhlo.constant dense<0> : tensor<ui32>
    %10 = mhlo.constant dense<1> : tensor<ui32>
    %11 = mhlo.constant dense<2> : tensor<ui32>
    %12 = "mhlo.replica_id"() : () -> tensor<ui32>
    %13 = mhlo.divide %12, %10 : tensor<ui32>
    %14 = mhlo.remainder %13, %11 : tensor<ui32>
    %15 = "mhlo.dynamic-slice"(%7, %14) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>, tensor<ui32>) -> tensor<1xf32>
    %16 = "mhlo.reshape"(%15) : (tensor<1xf32>) -> tensor<f32>
    %17 = "mhlo.cosine"(%16) : (tensor<f32>) -> tensor<f32>
    %18 = "mhlo.sine"(%16) : (tensor<f32>) -> tensor<f32>
    %19 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %20 = mhlo.reduce(%8 init: %19) across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %133 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%133) : (tensor<f32>) -> ()
    }
    %21 = call @tan(%20) : (tensor<f32>) -> tensor<f32>
    %22 = mhlo.multiply %17, %21 : tensor<f32>
    %23 = "mhlo.tanh"(%7) : (tensor<2xf32>) -> tensor<2xf32>
    %24 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %25 = mhlo.reduce(%23 init: %24) across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %133 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%133) : (tensor<f32>) -> ()
    }
    %26 = mhlo.multiply %22, %25 : tensor<f32>
    %27 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %28 = "mhlo.broadcast"(%27) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32>
    %29 = mhlo.constant dense<0> : tensor<ui32>
    %30 = mhlo.constant dense<1> : tensor<ui32>
    %31 = mhlo.constant dense<2> : tensor<ui32>
    %32 = "mhlo.replica_id"() : () -> tensor<ui32>
    %33 = mhlo.divide %32, %30 : tensor<ui32>
    %34 = mhlo.remainder %33, %31 : tensor<ui32>
    %35 = "mhlo.broadcast"(%26) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<f32>) -> tensor<1xf32>
    %36 = "mhlo.dynamic-update-slice"(%28, %35, %34) : (tensor<2xf32>, tensor<1xf32>, tensor<ui32>) -> tensor<2xf32>
    %37 = "mhlo.cross-replica-sum"(%36) {replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>} : (tensor<2xf32>) -> tensor<2xf32>
    %38 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %39 = "mhlo.broadcast"(%38) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32>
    %40 = mhlo.constant dense<0> : tensor<ui32>
    %41 = mhlo.constant dense<1> : tensor<ui32>
    %42 = mhlo.constant dense<2> : tensor<ui32>
    %43 = "mhlo.replica_id"() : () -> tensor<ui32>
    %44 = mhlo.divide %43, %41 : tensor<ui32>
    %45 = mhlo.remainder %44, %42 : tensor<ui32>
    %46 = "mhlo.broadcast"(%25) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<f32>) -> tensor<1xf32>
    %47 = "mhlo.dynamic-update-slice"(%39, %46, %45) : (tensor<2xf32>, tensor<1xf32>, tensor<ui32>) -> tensor<2xf32>
    %48 = "mhlo.cross-replica-sum"(%47) {replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>} : (tensor<2xf32>) -> tensor<2xf32>
    %49 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %50 = "mhlo.broadcast"(%49) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32>
    %51 = mhlo.constant dense<0> : tensor<ui32>
    %52 = mhlo.constant dense<1> : tensor<ui32>
    %53 = mhlo.constant dense<2> : tensor<ui32>
    %54 = "mhlo.replica_id"() : () -> tensor<ui32>
    %55 = mhlo.divide %54, %52 : tensor<ui32>
    %56 = mhlo.remainder %55, %53 : tensor<ui32>
    %57 = "mhlo.broadcast"(%21) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<f32>) -> tensor<1xf32>
    %58 = "mhlo.dynamic-update-slice"(%50, %57, %56) : (tensor<2xf32>, tensor<1xf32>, tensor<ui32>) -> tensor<2xf32>
    %59 = "mhlo.cross-replica-sum"(%58) {replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>} : (tensor<2xf32>) -> tensor<2xf32>
    %60 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %61 = "mhlo.broadcast"(%60) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32>
    %62 = mhlo.constant dense<0> : tensor<ui32>
    %63 = mhlo.constant dense<1> : tensor<ui32>
    %64 = mhlo.constant dense<2> : tensor<ui32>
    %65 = "mhlo.replica_id"() : () -> tensor<ui32>
    %66 = mhlo.divide %65, %63 : tensor<ui32>
    %67 = mhlo.remainder %66, %64 : tensor<ui32>
    %68 = "mhlo.broadcast"(%18) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<f32>) -> tensor<1xf32>
    %69 = "mhlo.dynamic-update-slice"(%61, %68, %67) : (tensor<2xf32>, tensor<1xf32>, tensor<ui32>) -> tensor<2xf32>
    %70 = "mhlo.cross-replica-sum"(%69) {replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>} : (tensor<2xf32>) -> tensor<2xf32>
    %71 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %72 = mhlo.reduce(%37 init: %71) across dimensions = [0] : (tensor<2xf32>, tensor<f32>) -> tensor<f32>
     reducer(%arg1: tensor<f32>, %arg2: tensor<f32>)  {
      %133 = mhlo.add %arg1, %arg2 : tensor<f32>
      "mhlo.return"(%133) : (tensor<f32>) -> ()
    }
    %73 = mhlo.constant dense<1.000000e+00> : tensor<f32>
    %74 = "mhlo.broadcast_in_dim"(%73) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<2xf32>
    %75 = mhlo.constant dense<0> : tensor<ui32>
    %76 = mhlo.constant dense<1> : tensor<ui32>
    %77 = mhlo.constant dense<2> : tensor<ui32>
    %78 = "mhlo.replica_id"() : () -> tensor<ui32>
    %79 = mhlo.divide %78, %76 : tensor<ui32>
    %80 = mhlo.remainder %79, %77 : tensor<ui32>
    %81 = "mhlo.dynamic-slice"(%48, %80) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>, tensor<ui32>) -> tensor<1xf32>
    %82 = "mhlo.reshape"(%81) : (tensor<1xf32>) -> tensor<f32>
    %83 = mhlo.constant dense<0> : tensor<ui32>
    %84 = mhlo.constant dense<1> : tensor<ui32>
    %85 = mhlo.constant dense<2> : tensor<ui32>
    %86 = "mhlo.replica_id"() : () -> tensor<ui32>
    %87 = mhlo.divide %86, %84 : tensor<ui32>
    %88 = mhlo.remainder %87, %85 : tensor<ui32>
    %89 = "mhlo.dynamic-slice"(%59, %88) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>, tensor<ui32>) -> tensor<1xf32>
    %90 = "mhlo.reshape"(%89) : (tensor<1xf32>) -> tensor<f32>
    %91 = mhlo.constant dense<0> : tensor<ui32>
    %92 = mhlo.constant dense<1> : tensor<ui32>
    %93 = mhlo.constant dense<2> : tensor<ui32>
    %94 = "mhlo.replica_id"() : () -> tensor<ui32>
    %95 = mhlo.divide %94, %92 : tensor<ui32>
    %96 = mhlo.remainder %95, %93 : tensor<ui32>
    %97 = "mhlo.dynamic-slice"(%70, %96) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>, tensor<ui32>) -> tensor<1xf32>
    %98 = "mhlo.reshape"(%97) : (tensor<1xf32>) -> tensor<f32>
    %99 = mhlo.constant dense<0> : tensor<ui32>
    %100 = mhlo.constant dense<1> : tensor<ui32>
    %101 = mhlo.constant dense<2> : tensor<ui32>
    %102 = "mhlo.replica_id"() : () -> tensor<ui32>
    %103 = mhlo.divide %102, %100 : tensor<ui32>
    %104 = mhlo.remainder %103, %101 : tensor<ui32>
    %105 = "mhlo.dynamic-slice"(%74, %104) {slice_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>, tensor<ui32>) -> tensor<1xf32>
    %106 = "mhlo.reshape"(%105) : (tensor<1xf32>) -> tensor<f32>
    %107 = mhlo.multiply %106, %82 : tensor<f32>
    %108 = mhlo.multiply %107, %90 : tensor<f32>
    %109 = "mhlo.negate"(%108) : (tensor<f32>) -> tensor<f32>
    %110 = mhlo.multiply %109, %98 : tensor<f32>
    %111 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %112 = "mhlo.broadcast"(%111) {broadcast_sizes = dense<2> : tensor<1xi64>} : (tensor<f32>) -> tensor<2xf32>
    %113 = mhlo.constant dense<0> : tensor<ui32>
    %114 = mhlo.constant dense<1> : tensor<ui32>
    %115 = mhlo.constant dense<2> : tensor<ui32>
    %116 = "mhlo.replica_id"() : () -> tensor<ui32>
    %117 = mhlo.divide %116, %114 : tensor<ui32>
    %118 = mhlo.remainder %117, %115 : tensor<ui32>
    %119 = "mhlo.broadcast"(%110) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<f32>) -> tensor<1xf32>
    %120 = "mhlo.dynamic-update-slice"(%112, %119, %118) : (tensor<2xf32>, tensor<1xf32>, tensor<ui32>) -> tensor<2xf32>
    %121 = "mhlo.cross-replica-sum"(%120) {replica_groups = dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>} : (tensor<2xf32>) -> tensor<2xf32>
    %122 = mhlo.constant dense<0.000000e+00> : tensor<f32>
    %123 = "mhlo.broadcast"(%122) {broadcast_sizes = dense<2> : tensor<2xi64>} : (tensor<f32>) -> tensor<2x2xf32>
    %124 = mhlo.constant dense<0> : tensor<ui32>
    %125 = mhlo.constant dense<2> : tensor<ui32>
    %126 = mhlo.constant dense<2> : tensor<ui32>
    %127 = "mhlo.replica_id"() : () -> tensor<ui32>
    %128 = mhlo.divide %127, %125 : tensor<ui32>
    %129 = mhlo.remainder %128, %126 : tensor<ui32>
    %130 = "mhlo.broadcast"(%121) {broadcast_sizes = dense<1> : tensor<1xi64>} : (tensor<2xf32>) -> tensor<1x2xf32>
    %131 = "mhlo.dynamic-update-slice"(%123, %130, %129, %124) : (tensor<2x2xf32>, tensor<1x2xf32>, tensor<ui32>, tensor<ui32>) -> tensor<2x2xf32>
    %132 = "mhlo.cross-replica-sum"(%131) {replica_groups = dense<[[0, 2], [1, 3]]> : tensor<2x2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
    return %132 : tensor<2x2xf32>
  }
  func private @tan(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = call @xla_fallback_tan(%arg0) : (tensor<f32>) -> tensor<f32>
    return %0 : tensor<f32>
  }
  func private @xla_fallback_tan(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = mhlo.constant dense<false> : tensor<i1>
    %1 = mhlo.constant dense<false> : tensor<i1>
    %2 = "mhlo.sine"(%arg0) : (tensor<f32>) -> tensor<f32>
    %3 = "mhlo.cosine"(%arg0) : (tensor<f32>) -> tensor<f32>
    %4 = mhlo.divide %2, %3 : tensor<f32>
    return %4 : tensor<f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment