Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created December 17, 2022 20:36
Show Gist options
  • Save AmosLewis/03e551db516fb3d789e0b322bed71336 to your computer and use it in GitHub Desktop.
Save AmosLewis/03e551db516fb3d789e0b322bed71336 to your computer and use it in GitHub Desktop.
func.func @torch.aten.gather(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?],f32> {
%int-1 = torch.constant.int -1
%false = torch.constant.bool false
%0 = torch.aten.gather %arg0, %int-1, %arg1, %false : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.vtensor<[?,?,?,?],si64>, !torch.bool -> !torch.vtensor<[?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?],f32>
}
@AmosLewis
Copy link
Author

if (inputShape[axis_val] < 0)
      op->emitOpError("Failed convertReduceMean: support for dynamic input "
                      "shape not implemented");

@AmosLewis
Copy link
Author

AmosLewis commented Dec 17, 2022

    auto lhsTy = lhs.getType().cast<RankedTensorType>();
    auto rhsTy = rhs.getType().cast<RankedTensorType>();

    auto lhsRank = lhsTy.getRank();
    auto rhsRank = rhsTy.getRank();

    auto lhsShape = makeShapeTorchCompatible(lhsTy.getShape());
    auto rhsShape = makeShapeTorchCompatible(rhsTy.getShape());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment