diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 04f12d087f65..964df133cf66 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -457,19 +457,6 @@ def render_object(val: tvm.Object) -> str: return str(val) -@tvm.register_global_func("relax.run.shape_to_tensor") -def relax_shape_to_tensor(shape_tuple: tvm_ffi.Shape) -> tvm.runtime.Tensor: - """ - Takes a Shape and convert it to Tensor. - - Parameters - ---------- - shape_tuple: tvm_ffi.Shape - Shape tuple that we want to convert to Tensor at runtime - """ - return tvm.runtime.tensor([int(v) for v in shape_tuple]) - - @tvm.register_global_func("relax.run.print") def relax_print(format_str: str, *format_args: tvm.Object) -> None: """ diff --git a/src/relax/backend/vm/lower_runtime_builtin.cc b/src/relax/backend/vm/lower_runtime_builtin.cc index 093ddc3c9916..a1f316bfac09 100644 --- a/src/relax/backend/vm/lower_runtime_builtin.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -52,6 +52,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); + } else if (call->op == shape_to_tensor_op_) { + return ShapeToTensor(call); } else if (call->op == tensor_to_shape_op_) { return TensorToShape(call); } else if (call->op == call_py_func_op_) { @@ -141,6 +143,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)}); } + Expr ShapeToTensor(const Call& call_node) { + TVM_FFI_ICHECK(call_node->args.size() == 1); + TVM_FFI_ICHECK(call_node->struct_info_.defined()); + + return Call(builtin_shape_to_tensor_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + Expr TensorToShape(const Call& call_node) { TVM_FFI_ICHECK(call_node->args.size() == 1); TVM_FFI_ICHECK(call_node->struct_info_.defined()); @@ -223,6 +232,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); + const Op& shape_to_tensor_op_ = Op::Get("relax.shape_to_tensor"); const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); const Op& call_py_func_op_ = Op::Get("relax.call_py_func"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); @@ -242,6 +252,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator { const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"}; + const ExternFunc builtin_shape_to_tensor_{"vm.builtin.shape_to_tensor"}; const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"}; const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"}; const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 71df381bb72a..6d27cce75795 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1200,7 +1200,7 @@ TVM_REGISTER_OP("relax.shape_to_tensor") .set_num_inputs(1) .add_argument("input", "Expr", "The input expression") .set_attr("FInferStructInfo", ReturnShapeToTensorStructInfo) - .set_attr("FCallPacked", "relax.run.shape_to_tensor") + .set_attr("FCallPacked", "vm.builtin.shape_to_tensor") .set_attr("FPurity", true); Expr MakeShapeToTensor(Expr expr) { diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 75b1b09bd48b..68afee992bda 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -394,8 +394,9 @@ class ConstantFolder : public ExprMutator { is_known &= (val.dtype() == DataType::Int(64)); } if (is_known) { - const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor"); - runtime::Tensor vals = func(arr).cast(); + ffi::Shape shape_obj(arr); + const auto func = tvm::ffi::Function::GetGlobalRequired("vm.builtin.shape_to_tensor"); + runtime::Tensor vals = func(shape_obj).cast(); return Constant(vals); } } diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc index 322a0a137c17..f5467b1e4daf 100644 --- a/src/runtime/vm/builtin.cc +++ b/src/runtime/vm/builtin.cc @@ -675,6 +675,16 @@ TVM_FFI_STATIC_INIT_BLOCK() { } *rv = arr; }) + .def("vm.builtin.shape_to_tensor", + [](ffi::Shape shape) -> Tensor { + int64_t size = static_cast(shape.size()); + Tensor out_tensor = Tensor::Empty({size}, DataType::Int(64), {kDLCPU, 0}); + int64_t* ptr = static_cast(out_tensor->data); + for (int64_t i = 0; i < size; ++i) { + ptr[i] = shape[i]; + } + return out_tensor; + }) .def("vm.builtin.tensor_to_shape", [](Tensor data) { Tensor arr = data; diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index 59ac5c3f12d0..05a2bf040729 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -148,5 +148,33 @@ def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")): tvm.ir.assert_structural_equal(Expected, After) +def test_vm_shape_to_tensor(): + """R.shape_to_tensor lowers to vm.builtin.shape_to_tensor""" + + @I.ir_module + class Before: + @R.function + def main(s: R.Shape([4, 4, 4])): + R.func_attr({"relax.force_pure": True}) + t = R.shape_to_tensor(s) + return t + + @I.ir_module + class Expected: + @R.function + def main(s: R.Shape([4, 4, 4])): + R.func_attr({"relax.force_pure": True}) + t = R.call_packed( + "vm.builtin.shape_to_tensor", + s, + sinfo_args=R.Tensor([3], dtype="int64"), + ) + return t + + After = relax.transform.VMBuiltinLower()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()