Skip to content

Instantly share code, notes, and snippets.

@AmosLewis
Created January 16, 2023 07:26
Show Gist options
  • Save AmosLewis/fca6b0d16ee325fcf7ee400459f4fd40 to your computer and use it in GitHub Desktop.
Save AmosLewis/fca6b0d16ee325fcf7ee400459f4fd40 to your computer and use it in GitHub Desktop.
func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],f64> {
%float8.000000e00 = torch.constant.float 8.000000e+00
%1 = "torch.prim.NumToTensor.Scalar"(%float8.000000e00) : (!torch.float) -> !torch.vtensor<[],f64>
return %1 : !torch.vtensor<[],f64>
}
@AmosLewis
Copy link
Author

AmosLewis commented Jan 16, 2023

With tosa::cast f32 to f64:

  auto outElemTy = resultType.getElementType();
  if (outElemTy.isa<mlir::IntegerType>() || outElemTy.isF32()) {
    DenseElementsAttr constAttr = isInt ? DenseElementsAttr::get(resultType, {intValue}) : DenseElementsAttr::get(resultType, {floatValue});
    rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, resultType, constAttr);
  } else if (outElemTy.isF64()) {
    auto resultF32 =
        tosa::getConstTensor<float>(rewriter, op, floatValue, {}).value();
    rewriter.replaceOpWithNewOp<tosa::CastOp>(op, resultType, resultF32);
  }

Bug Output:

// *** IR Dump After Pattern Application ***
mlir-asm-printer: Verifying operation: func.func
ImplicitTypeIDRegistry::lookupOrInsert(mlir::InferShapedTypeOpInterface::Trait<Empty>)
'tosa.cast' op result #0 must be tensor of number values, but got 'tensor<f64>'
mlir-asm-printer: 'func.func' failed to verify and will be printed in generic form
"func.func"() ({
  %0 = "torch.constant.float"() {value = 8.000000e+00 : f64} : () -> !torch.float
  %1 = "builtin.unrealized_conversion_cast"(%0) : (!torch.float) -> f64
  %2 = "tosa.const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
  %3 = "tosa.cast"(%2) : (tensor<f32>) -> tensor<f64>
  %4 = "torch.prim.NumToTensor.Scalar"(%0) : (!torch.float) -> !torch.vtensor<[],f64>
  "func.return"(%4) : (!torch.vtensor<[],f64>) -> ()
}) {function_type = () -> !torch.vtensor<[],f64>, sym_name = "torch.prim.NumToTensor.Scalar"} : () -> ()


} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func.return'(0x55a53fab0c30) {
  "func.return"(%4) : (!torch.vtensor<[],f64>) -> ()

  * Fold {
  } -> FAILURE : unable to fold
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//
** Insert  : 'torch_c.from_builtin_tensor'(0x55a53fb17c60)
ImplicitTypeIDRegistry::lookupOrInsert(mlir::detail::PreservedAnalyses::AllAnalysesType)
/tmp/NumToTensor.mlir:3:8: error: 'tosa.cast' op result #0 must be tensor of number values, but got 'tensor<f64>'
  %1 = "torch.prim.NumToTensor.Scalar"(%float8.000000e00) : (!torch.float) -> !torch.vtensor<[],f64>
       ^
/tmp/NumToTensor.mlir:3:8: note: see current operation: %2 = "tosa.cast"(%1) : (tensor<f32>) -> tensor<f64>

@AmosLewis
Copy link
Author

For comparison with F64: F32 successfuly

func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],f32> {
  %float8.000000e00 = torch.constant.float 8.000000e+00
  %1 = "torch.prim.NumToTensor.Scalar"(%float8.000000e00) : (!torch.float) -> !torch.vtensor<[],f32>
  return %1 : !torch.vtensor<[],f32>
}

--->

module {
  func.func @torch.prim.NumToTensor.Scalar() -> !torch.vtensor<[],f32> {
    %float8.000000e00 = torch.constant.float 8.000000e+00
    %0 = "tosa.const"() {value = dense<8.000000e+00> : tensor<f32>} : () -> tensor<f32>
    %1 = torch_c.from_builtin_tensor %0 : tensor<f32> -> !torch.vtensor<[],f32>
    return %1 : !torch.vtensor<[],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment