Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,19 +457,6 @@ def render_object(val: tvm.Object) -> str:
return str(val)


@tvm.register_global_func("relax.run.shape_to_tensor")
def relax_shape_to_tensor(shape_tuple: tvm_ffi.Shape) -> tvm.runtime.Tensor:
"""
Takes a Shape and convert it to Tensor.

Parameters
----------
shape_tuple: tvm_ffi.Shape
Shape tuple that we want to convert to Tensor at runtime
"""
return tvm.runtime.tensor([int(v) for v in shape_tuple])


@tvm.register_global_func("relax.run.print")
def relax_print(format_str: str, *format_args: tvm.Object) -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions src/relax/backend/vm/lower_runtime_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == shape_to_tensor_op_) {
return ShapeToTensor(call);
} else if (call->op == tensor_to_shape_op_) {
return TensorToShape(call);
} else if (call->op == call_py_func_op_) {
Expand Down Expand Up @@ -141,6 +143,13 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
return Call(builtin_shape_of_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeToTensor(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 1);
TVM_FFI_ICHECK(call_node->struct_info_.defined());

return Call(builtin_shape_to_tensor_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr TensorToShape(const Call& call_node) {
TVM_FFI_ICHECK(call_node->args.size() == 1);
TVM_FFI_ICHECK(call_node->struct_info_.defined());
Expand Down Expand Up @@ -223,6 +232,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& shape_to_tensor_op_ = Op::Get("relax.shape_to_tensor");
const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape");
const Op& call_py_func_op_ = Op::Get("relax.call_py_func");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
Expand All @@ -242,6 +252,7 @@ class LowerRuntimeBuiltinMutator : public ExprMutator {
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_shape_to_tensor_{"vm.builtin.shape_to_tensor"};
const ExternFunc builtin_tensor_to_shape_{"vm.builtin.tensor_to_shape"};
const ExternFunc builtin_call_py_func_{"vm.builtin.call_py_func"};
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
Expand Down
2 changes: 1 addition & 1 deletion src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1200,7 +1200,7 @@ TVM_REGISTER_OP("relax.shape_to_tensor")
.set_num_inputs(1)
.add_argument("input", "Expr", "The input expression")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnShapeToTensorStructInfo)
.set_attr<FCallPacked>("FCallPacked", "relax.run.shape_to_tensor")
.set_attr<FCallPacked>("FCallPacked", "vm.builtin.shape_to_tensor")
.set_attr<bool>("FPurity", true);

Expr MakeShapeToTensor(Expr expr) {
Expand Down
5 changes: 3 additions & 2 deletions src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,9 @@ class ConstantFolder : public ExprMutator {
is_known &= (val.dtype() == DataType::Int(64));
}
if (is_known) {
const auto func = tvm::ffi::Function::GetGlobalRequired("relax.run.shape_to_tensor");
runtime::Tensor vals = func(arr).cast<runtime::Tensor>();
ffi::Shape shape_obj(arr);
const auto func = tvm::ffi::Function::GetGlobalRequired("vm.builtin.shape_to_tensor");
runtime::Tensor vals = func(shape_obj).cast<runtime::Tensor>();
return Constant(vals);
}
}
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,16 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
*rv = arr;
})
.def("vm.builtin.shape_to_tensor",
[](ffi::Shape shape) -> Tensor {
int64_t size = static_cast<int64_t>(shape.size());
Tensor out_tensor = Tensor::Empty({size}, DataType::Int(64), {kDLCPU, 0});
int64_t* ptr = static_cast<int64_t*>(out_tensor->data);
for (int64_t i = 0; i < size; ++i) {
ptr[i] = shape[i];
}
return out_tensor;
})
.def("vm.builtin.tensor_to_shape",
[](Tensor data) {
Tensor arr = data;
Expand Down
28 changes: 28 additions & 0 deletions tests/python/relax/test_vm_builtin_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,33 @@ def main(A: R.Tensor([16], "float32"), shape_tensor: R.Tensor([2], "int64")):
tvm.ir.assert_structural_equal(Expected, After)


def test_vm_shape_to_tensor():
"""R.shape_to_tensor lowers to vm.builtin.shape_to_tensor"""

@I.ir_module
class Before:
@R.function
def main(s: R.Shape([4, 4, 4])):
R.func_attr({"relax.force_pure": True})
t = R.shape_to_tensor(s)
return t

@I.ir_module
class Expected:
@R.function
def main(s: R.Shape([4, 4, 4])):
R.func_attr({"relax.force_pure": True})
t = R.call_packed(
"vm.builtin.shape_to_tensor",
s,
sinfo_args=R.Tensor([3], dtype="int64"),
)
return t

After = relax.transform.VMBuiltinLower()(Before)

tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()
Loading