diff --git a/examples/brutus.jl b/examples/brutus.jl new file mode 100644 index 00000000..bba5835b --- /dev/null +++ b/examples/brutus.jl @@ -0,0 +1,281 @@ +module Brutus + +import LLVM +using MLIR.IR +using MLIR.Dialects: arith, func, cf, std +using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode + +const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} + +function cmpi_pred(predicate) + function(ops; loc=Location()) + arith.cmpi(predicate, ops; loc) + end +end + +function single_op_wrapper(fop) + (block::Block, args::Vector{Value}; loc=Location()) -> push!(block, fop(args; loc)) +end + +const intrinsics_to_mlir = Dict([ + Base.add_int => single_op_wrapper(arith.addi), + Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)), + Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)), + Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), + Base.mul_int => single_op_wrapper(arith.muli), + Base.mul_float => single_op_wrapper(arith.mulf), + Base.not_int => function(block, args; loc=Location()) + arg = only(args) + ones = push!(block, arith.constant(-1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(Value[arg, ones]; loc)) + end, +]) + +"Generates a block argument for each phi node present in the block." +function prepare_block(ir, bb) + b = Block() + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + type = stmt[:type] + IR.push_argument!(b, MLIRType(type), Location()) + end + + return b +end + +"Values to populate the Phi Node when jumping from `from` to `to`." +function collect_value_arguments(ir, from, to) + to = ir.cfg.blocks[to] + values = [] + for s in to.stmts + stmt = ir.stmts[s] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + edge = findfirst(==(from), inst.edges) + if isnothing(edge) # use dummy scalar val instead + val = zero(stmt[:type]) + push!(values, val) + else + push!(values, inst.values[edge]) + end + end + values +end + +""" + code_mlir(f, types::Type{Tuple}) -> IR.Operation + +Returns a `func.func` operation corresponding to the ircode of the provided method. +This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. + +!!! note + The Julia SSAIR to MLIR conversion implemented is very primitive and only supports a + handful of primitives. A better to perform this conversion would to create a dialect + representing Julia IR and progressively lower it to base MLIR dialects. +""" +function code_mlir(f, types) + ctx = context() + ir, ret = Core.Compiler.code_ircode(f, types) |> only + @assert first(ir.argtypes) isa Core.Const + + values = Vector{Value}(undef, length(ir.stmts)) + + for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",)) + IR.get_or_load_dialect!(dialect) + end + + blocks = [ + prepare_block(ir, bb) + for bb in ir.cfg.blocks + ] + + current_block = entry_block = blocks[begin] + + for argtype in types.parameters + IR.push_argument!(entry_block, MLIRType(argtype), Location()) + end + + function get_value(x)::Value + if x isa Core.SSAValue + @assert isassigned(values, x.id) "value $x was not assigned" + values[x.id] + elseif x isa Core.Argument + IR.get_argument(entry_block, x.n - 1) + elseif x isa BrutusScalar + IR.get_result(push!(current_block, arith.constant(x))) + else + error("could not use value $x inside MLIR") + end + end + + for (block_id, (b, bb)) in enumerate(zip(blocks, ir.cfg.blocks)) + current_block = b + n_phi_nodes = 0 + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + line = ir.linetable[stmt[:line]] + + if Meta.isexpr(inst, :call) + val_type = stmt[:type] + if !(val_type <: BrutusScalar) + error("type $val_type is not supported") + end + out_type = MLIRType(val_type) + + called_func = first(inst.args) + if called_func isa GlobalRef # TODO: should probably use something else here + called_func = getproperty(called_func.mod, called_func.name) + end + + fop! = intrinsics_to_mlir[called_func] + args = get_value.(@view inst.args[begin+1:end]) + + loc = Location(string(line.file), line.line, 0) + res = IR.get_result(fop!(current_block, args; loc)) + + values[sidx] = res + elseif inst isa PhiNode + values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1) + elseif inst isa PiNode + values[sidx] = get_value(inst.val) + elseif inst isa GotoNode + args = get_value.(collect_value_arguments(ir, block_id, inst.label)) + dest = blocks[inst.label] + loc = Location(string(line.file), line.line, 0) + brop = LLVM.version() >= v"15" ? cf.br : std.br + push!(current_block, brop(dest, args; loc)) + elseif inst isa GotoIfNot + false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) + cond = get_value(inst.cond) + @assert length(bb.succs) == 2 # NOTE: We assume that length(bb.succs) == 2, this might be wrong + other_dest = setdiff(bb.succs, inst.dest) |> only + true_args = get_value.(collect_value_arguments(ir, block_id, other_dest)) + other_dest = blocks[other_dest] + dest = blocks[inst.dest] + + loc = Location(string(line.file), line.line, 0) + cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br + cond_br = cond_brop(cond, other_dest, dest, true_args, false_args; loc) + push!(current_block, cond_br) + elseif inst isa ReturnNode + line = ir.linetable[stmt[:line]] + retop = LLVM.version() >= v"15" ? func.return_ : std.return_ + loc = Location(string(line.file), line.line, 0) + push!(current_block, retop([get_value(inst.val)]; loc)) + elseif Meta.isexpr(inst, :code_coverage_effect) + # Skip + else + error("unhandled ir $(inst)") + end + end + end + + func_name = nameof(f) + + region = Region() + for b in blocks + push!(region, b) + end + + LLVM15 = LLVM.version() >= v"15" + + input_types = MLIRType[ + IR.get_type(IR.get_argument(entry_block, i)) + for i in 1:IR.num_arguments(entry_block) + ] + result_types = [MLIRType(ret)] + + ftype = MLIRType(input_types => result_types) + op = IR.create_operation( + LLVM15 ? "func.func" : "builtin.func", + Location(); + attributes = [ + NamedAttribute("sym_name", IR.Attribute(string(func_name))), + NamedAttribute(LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), + ], + owned_regions = Region[region], + result_inference=false, + ) + + IR.verifyall(op) + + op +end + +""" + @code_mlir f(args...) +""" +macro code_mlir(call) + @assert Meta.isexpr(call, :call) "only calls are supported" + + f = first(call.args) |> esc + args = Expr(:curly, + Tuple, + map(arg -> :($(Core.Typeof)($arg)), + call.args[begin+1:end])..., + ) |> esc + + quote + code_mlir($f, $args) + end +end + +end # module Brutus + +# --- + +function pow(x::F, n) where {F} + p = one(F) + for _ in 1:n + p *= x + end + p +end + +function f(x) + if x == 1 + 2 + else + 3 + end +end + +# --- + +using Test +using MLIR.IR, MLIR + +ctx = Context() +# IR.enable_multithreading!(ctx, false) + +op = Brutus.code_mlir(pow, Tuple{Int, Int}) + +mod = MModule(Location()) +body = IR.get_body(mod) +push!(body, op) + +pm = IR.PassManager() +opm = IR.OpPassManager(pm) + +# IR.enable_ir_printing!(pm) +IR.enable_verifier!(pm, true) + +MLIR.API.mlirRegisterAllPasses() +MLIR.API.mlirRegisterAllLLVMTranslations(ctx) +IR.add_pipeline!(opm, Brutus.LLVM.version() >= v"15" ? "convert-arith-to-llvm,convert-func-to-llvm" : "convert-std-to-llvm") + +IR.run!(pm, mod) + +jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) +fptr = MLIR.API.mlirExecutionEngineLookup(jit, "pow") + +x, y = 3, 4 + +@test ccall(fptr, Int, (Int, Int), x, y) == pow(x, y) diff --git a/src/Dialects.jl b/src/Dialects.jl new file mode 100644 index 00000000..4b1211db --- /dev/null +++ b/src/Dialects.jl @@ -0,0 +1,186 @@ +module Dialects + +module Arith + +using ...IR +using ...Builder: blockbuilder, _has_blockbuilder + +for (f, t) in Iterators.product( + (:add, :sub, :mul), + (:i, :f), +) + fname = Symbol(f, t) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) + end +end + +for fname in (:xori, :andi, :ori) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) + end +end + +for (f, t) in Iterators.product( + (:div, :max, :min), + (:si, :ui, :f), +) + fname = Symbol(f, t) + @eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location()) + op = IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop +for f in (:index_cast, :index_castui) + @eval function $f(operand; loc=Location()) + op = IR.create_operation( + $(string("arith.", f)), + loc; + operands=[operand], + results=[IR.IndexType()], + ) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop +function extf(operand, type; loc=Location()) + op = IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) + push!(blockbuilder().block, op) + return IR.get_result(op , 1) +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop +function constant(value, type=MLIRType(typeof(value)); loc=Location()) + op = IR.create_operation( + "arith.constant", + loc; + results=[type], + attributes=[ + IR.NamedAttribute("value", + Attribute(value, type)), + ], + ) + push!(blockbuilder().block, op) + return IR.get_result(op, 1) +end + +module Predicates + const eq = 0 + const ne = 1 + const slt = 2 + const sle = 3 + const sgt = 4 + const sge = 5 + const ult = 6 + const ule = 7 + const ugt = 8 + const uge = 9 +end + +function cmpi(predicate, operands; loc=Location()) + op = IR.create_operation( + "arith.cmpi", + loc; + operands, + results=[MLIRType(Bool)], + attributes=[ + IR.NamedAttribute("predicate", + Attribute(predicate)) + ], + ) + push!(blockbuilder().block, op) + return get_result(op, 1) +end + +end # module arith + +module STD +# for llvm 14 + +using ...IR + +function return_(operands; loc=Location()) + IR.create_operation("std.return", loc; operands, result_inference=false) +end + +function br(dest, operands; loc=Location()) + IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) +end + +function cond_br( + cond, + true_dest, false_dest, + true_dest_operands, + false_dest_operands; + loc=Location(), +) + IR.create_operation( + "std.cond_br", + loc; + successors=[true_dest, false_dest], + operands=[cond, true_dest_operands..., false_dest_operands...], + attributes=[ + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) +end + +end # module std + +module Func +# https://mlir.llvm.org/docs/Dialects/Func/ + +using ...IR + +function return_(operands; loc=Location()) + IR.create_operation("func.return", loc; operands, result_inference=false) +end + +end # module func + +module CF + +using ...IR +using ...Builder + +function br(dest, operands=[]; loc=Location()) + op = IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) + push!(Builder.blockbuilder().block, op) + return op # no value so returning operation itself (?) +end + +function cond_br( + cond, + true_dest, false_dest, + true_dest_operands=[], + false_dest_operands=[]; + loc=Location(), +) + op = IR.create_operation( + "cf.cond_br", loc; + operands=[cond, true_dest_operands..., false_dest_operands...], + successors=[true_dest, false_dest], + attributes=[ + IR.NamedAttribute("operand_segment_sizes", + IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) + push!(blockbuilder().block, op) + return op +end + +end # module cf + +end # module Dialects diff --git a/src/IR/IR.jl b/src/IR/IR.jl new file mode 100644 index 00000000..72701e7e --- /dev/null +++ b/src/IR/IR.jl @@ -0,0 +1,839 @@ +export + Operation, + OperationState, + Location, + Context, + MModule, + Value, + MLIRType, + Region, + Block, + Attribute, + NamedAttribute + +import Base: ==, String +using .API: + MlirDialectRegistry, + MlirDialectHandle, + MlirAttribute, + MlirNamedAttribute, + MlirDialect, + MlirStringRef, + MlirOperation, + MlirOperationState, + MlirLocation, + MlirBlock, + MlirRegion, + MlirModule, + MlirContext, + MlirType, + MlirValue, + MlirIdentifier, + MlirPassManager, + MlirOpPassManager + +function print_callback(str::MlirStringRef, userdata) + data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) + write(userdata isa Base.RefValue ? userdata[] : userdata, data) + return Cvoid() +end + +### Dialect + +struct Dialect + dialect::MlirDialect + + Dialect(dialect) = begin + @assert !mlirIsNull(dialect) "cannot create Dialect from null MlirDialect" + new(dialect) + end +end + +Base.convert(::Type{MlirDialect}, dialect::Dialect) = dialect.dialect +function Base.show(io::IO, dialect::Dialect) + print(io, "Dialect(\"", String(API.mlirDialectGetNamespace(dialect)), "\")") +end + +### DialectHandle + +struct DialectHandle + handle::API.MlirDialectHandle +end + +function DialectHandle(s::Symbol) + s = Symbol("mlirGetDialectHandle__", s, "__") + DialectHandle(getproperty(API, s)()) +end + +Base.convert(::Type{MlirDialectHandle}, handle::DialectHandle) = handle.handle + +### Dialect Registry + +mutable struct DialectRegistry + registry::MlirDialectRegistry +end +function DialectRegistry() + registry = API.mlirDialectRegistryCreate() + @assert !mlirIsNull(registry) "cannot create DialectRegistry with null MlirDialectRegistry" + finalizer(DialectRegistry(registry)) do registry + API.mlirDialectRegistryDestroy(registry.registry) + end +end + +function Base.insert!(registry::DialectRegistry, handle::DialectHandle) + API.mlirDialectHandleInsertDialect(registry, handle) +end + +### Context + +struct Context + context::MlirContext +end + +function Context() + context = API.mlirContextCreate() + @assert !mlirIsNull(context) "cannot create Context with null MlirContext" + context = Context(context) + activate(context) + context +end + +function dispose(ctx::Context) + deactivate(ctx) + API.mlirContextDestroy(context.context) +end + +function Context(f::Core.Function) + ctx = Context() + try + f(ctx) + finally + dispose(ctx) + end +end + +Base.convert(::Type{MlirContext}, c::Context) = c.context + +num_loaded_dialects() = API.mlirContextGetNumLoadedDialects(context()) +function get_or_load_dialect!(handle::DialectHandle) + mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context()) + if mlirIsNull(mlir_dialect) + error("could not load dialect from handle $handle") + else + Dialect(mlir_dialect) + end +end +function get_or_load_dialect!(dialect::String) + get_or_load_dialect!(DialectHandle(Symbol(dialect))) +end + +function enable_multithreading!(enable=true) + API.mlirContextEnableMultithreading(context(), enable) + context() +end + +is_registered_operation(opname) = API.mlirContextIsRegisteredOperation(context(), opname) + +### Location + +struct Location + location::MlirLocation + + Location(location) = begin + @assert !mlirIsNull(location) "cannot create Location with null MlirLocation" + new(location) + end +end + +Location() = Location(API.mlirLocationUnknownGet(context())) +Location(filename, line, column) = + Location(API.mlirLocationFileLineColGet(context(), filename, line, column)) + +Base.convert(::Type{MlirLocation}, location::Location) = location.location + +function Base.show(io::IO, location::Location) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + print(io, "Location(#= ") + API.mlirLocationPrint(location, c_print_callback, ref) + print(io, " =#)") +end + +### Type + +struct MLIRType + type::MlirType + + MLIRType(type) = begin + @assert !mlirIsNull(type) + new(type) + end +end + +MLIRType(t::MLIRType) = t +MLIRType(T::Type{<:Signed}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(T::Type{<:Unsigned}) = + MLIRType(API.mlirIntegerTypeGet(context(), sizeof(T) * 8)) +MLIRType(::Type{Bool}) = + MLIRType(API.mlirIntegerTypeGet(context(), 1)) +MLIRType(::Type{Float32}) = + MLIRType(API.mlirF32TypeGet(context())) +MLIRType(::Type{Float64}) = + MLIRType(API.mlirF64TypeGet(context())) +MLIRType(ft::Pair) = + MLIRType(API.mlirFunctionTypeGet(context(), + length(ft.first), [MLIRType(t) for t in ft.first], + length(ft.second), [MLIRType(t) for t in ft.second])) +MLIRType(a::AbstractArray{T}) where {T} = MLIRType(MLIRType(T), size(a)) +MLIRType(::Type{<:AbstractArray{T,N}}, dims) where {T,N} = + MLIRType(API.mlirRankedTensorTypeGetChecked( + Location(), + N, collect(dims), + MLIRType(T), + Attribute(), + )) +MLIRType(element_type::MLIRType, dims) = + MLIRType(API.mlirRankedTensorTypeGetChecked( + Location(), + length(dims), collect(dims), + element_type, + Attribute(), + )) +MLIRType(::T) where {T<:Real} = MLIRType(T) +MLIRType(_, type::MLIRType) = type + +IndexType() = MLIRType(API.mlirIndexTypeGet(context())) + +Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type +Base.parse(::Type{MLIRType}, context, s) = + MLIRType(API.mlirTypeParseGet(context, s)) + +function Base.eltype(type::MLIRType) + if API.mlirTypeIsAShaped(type) + MLIRType(API.mlirShapedTypeGetElementType(type)) + else + type + end +end + +function show_inner(io::IO, type::MLIRType) + if API.mlirTypeIsAInteger(type) + is_signless = API.mlirIntegerTypeIsSignless(type) + is_signed = API.mlirIntegerTypeIsSigned(type) + + width = API.mlirIntegerTypeGetWidth(type) + t = if is_signed + "si" + elseif is_signless + "i" + else + "u" + end + print(io, t, width) + elseif API.mlirTypeIsAF64(type) + print(io, "f64") + elseif API.mlirTypeIsAF32(type) + print(io, "f32") + elseif API.mlirTypeIsARankedTensor(type) + print(io, "tensor<") + s = size(type) + print(io, join(s, "x"), "x") + show_inner(io, eltype(type)) + print(io, ">") + elseif API.mlirTypeIsAIndex(type) + print(io, "index") + else + print(io, "unknown") + end +end + +function Base.show(io::IO, type::MLIRType) + print(io, "MLIRType(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirTypePrint(type, c_print_callback, ref) + print(io, " =#)") +end + +function inttype(size, issigned) + size == 1 && issigned && return Bool + ints = (Int8, Int16, Int32, Int64, Int128) + IT = ints[Int(log2(size))-2] + issigned ? IT : unsigned(IT) +end + +function julia_type(type::MLIRType) + if API.mlirTypeIsAInteger(type) + is_signed = API.mlirIntegerTypeIsSigned(type) || + API.mlirIntegerTypeIsSignless(type) + width = API.mlirIntegerTypeGetWidth(type) + + try + inttype(width, is_signed) + catch + t = is_signed ? "i" : "u" + throw("could not convert type $(t)$(width) to julia") + end + elseif API.mlirTypeIsAF32(type) + Float32 + elseif API.mlirTypeIsAF64(type) + Float64 + else + throw("could not convert type $type to julia") + end +end + +Base.ndims(type::MLIRType) = + if API.mlirTypeIsAShaped(type) && API.mlirShapedTypeHasRank(type) + API.mlirShapedTypeGetRank(type) + else + 0 + end + +Base.size(type::MLIRType, i::Int) = API.mlirShapedTypeGetDimSize(type, i - 1) +Base.size(type::MLIRType) = Tuple(size(type, i) for i in 1:ndims(type)) + +function is_tensor(type::MLIRType) + API.mlirTypeIsAShaped(type) +end + +function is_integer(type::MLIRType) + API.mlirTypeIsAInteger(type) +end + +is_function_type(mtype) = API.mlirTypeIsAFunction(mtype) + +function num_inputs(ftype::MLIRType) + @assert is_function_type(ftype) "cannot get the number of inputs on type $(ftype), expected a function type" + API.mlirFunctionTypeGetNumInputs(ftype) +end +function num_results(ftype::MLIRType) + @assert is_function_type(ftype) "cannot get the number of results on type $(ftype), expected a function type" + API.mlirFunctionTypeGetNumResults(ftype) +end + +function get_input(ftype::MLIRType, pos) + @assert is_function_type(ftype) "cannot get input on type $(ftype), expected a function type" + MLIRType(API.mlirFunctionTypeGetInput(ftype, pos - 1)) +end +function get_result(ftype::MLIRType, pos=1) + @assert is_function_type(ftype) "cannot get result on type $(ftype), expected a function type" + MLIRType(API.mlirFunctionTypeGetResult(ftype, pos - 1)) +end + +### Attribute + +struct Attribute + attribute::MlirAttribute +end + +Attribute() = Attribute(API.mlirAttributeGetNull()) +Attribute(s::AbstractString) = Attribute(API.mlirStringAttrGet(context(), s)) +Attribute(type::MLIRType) = Attribute(API.mlirTypeAttrGet(type)) +Attribute(f::F, type=MLIRType(F)) where {F<:AbstractFloat} = Attribute( + API.mlirFloatAttrDoubleGet(context(), type, Float64(f)) +) +Attribute(i::T) where {T<:Integer} = Attribute( + API.mlirIntegerAttrGet(MLIRType(T), Int64(i)) +) +function Attribute(values::T) where {T<:AbstractArray{Int32}} + type = MLIRType(T, size(values)) + Attribute( + API.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(values::T) where {T<:AbstractArray{Int64}} + type = MLIRType(T, size(values)) + Attribute( + API.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(values::T) where {T<:AbstractArray{Float64}} + type = MLIRType(T, size(values)) + Attribute( + API.mlirDenseElementsAttrDoubleGet(type, length(values), values) + ) +end +function Attribute(values::T) where {T<:AbstractArray{Float32}} + type = MLIRType(T, size(values)) + Attribute( + API.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function Attribute(values::AbstractArray{Int32}, type) + Attribute( + API.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(values::AbstractArray{Int}, type) + Attribute( + API.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(values::AbstractArray{Float32}, type) + Attribute( + API.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function ArrayAttribute(values::AbstractVector{Int}) + elements = Attribute.((context(),), values) + Attribute( + API.mlirArrayAttrGet(context(), length(elements), elements) + ) +end +function ArrayAttribute(attributes::Vector{Attribute}) + Attribute( + API.mlirArrayAttrGet(context(), length(attributes), attributes), + ) +end +function DenseArrayAttribute(values::AbstractVector{Int}) + Attribute( + API.mlirDenseI64ArrayGet(context(), length(values), collect(values)) + ) +end +function Attribute(value::Int, type::MLIRType) + Attribute( + API.mlirIntegerAttrGet(type, value) + ) +end +function Attribute(value::Bool, ::MLIRType=nothing) + Attribute( + API.mlirBoolAttrGet(context(), value) + ) +end + +Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute +Base.parse(::Type{Attribute}, s) = + Attribute(API.mlirAttributeParseGet(context(), s)) + +function get_type(attribute::Attribute) + MLIRType(API.mlirAttributeGetType(attribute)) +end +function type_value(attribute) + @assert API.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" + MLIRType(API.mlirTypeAttrGetValue(attribute)) +end +function bool_value(attribute) + @assert API.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" + API.mlirBoolAttrGetValue(attribute) +end +function string_value(attribute) + @assert API.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" + String(API.mlirStringAttrGetValue(attribute)) +end + +function Base.show(io::IO, attribute::Attribute) + print(io, "Attribute(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirAttributePrint(attribute, c_print_callback, ref) + print(io, " =#)") +end + +### Named Attribute + +struct NamedAttribute + named_attribute::MlirNamedAttribute +end + +function NamedAttribute(name, attribute) + @assert !mlirIsNull(attribute.attribute) + NamedAttribute(API.mlirNamedAttributeGet( + API.mlirIdentifierGet(context(), name), + attribute + )) +end + +Base.convert(::Type{MlirAttribute}, named_attribute::NamedAttribute) = + named_attribute.named_attribute + +### Value + +struct Value + value::MlirValue + + Value(value) = begin + @assert !mlirIsNull(value) "cannot create Value with null MlirValue" + new(value) + end +end + +get_type(value) = MLIRType(API.mlirValueGetType(value)) + +Base.convert(::Type{MlirValue}, value::Value) = value.value +Base.size(value::Value) = Base.size(get_type(value)) +Base.ndims(value::Value) = Base.ndims(get_type(value)) + +function Base.show(io::IO, value::Value) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirValuePrint(value, c_print_callback, ref) +end + +is_a_op_result(value) = API.mlirValueIsAOpResult(value) +is_a_block_argument(value) = API.mlirValueIsABlockArgument(value) + +function set_type!(value, type) + @assert is_a_block_argument(value) "could not set type, value is not a block argument" + API.mlirBlockArgumentSetType(value, type) + value +end + +function get_owner(value::Value) + if is_a_block_argument(value) + raw_block = API.mlirBlockArgumentGetOwner(value) + if mlirIsNull(raw_block) + return nothing + end + + return Block(raw_block, false) + end + + raw_op = API.mlirOpResultGetOwner(value) + if mlirIsNull(raw_op) + return nothing + end + + return Operation(raw_op, false) +end + +### Operation + +mutable struct Operation + operation::MlirOperation + @atomic owned::Bool + + Operation(operation, owned=true) = begin + @assert !mlirIsNull(operation) "cannot create Operation with null MlirOperation" + finalizer(new(operation, owned)) do op + if op.owned + API.mlirOperationDestroy(op.operation) + end + end + end +end + +function create_operation( + name, loc; + results=nothing, + operands=nothing, + owned_regions=nothing, + successors=nothing, + attributes=nothing, + result_inference=isnothing(results) +) + GC.@preserve name loc begin + state = Ref(API.mlirOperationStateGet(name, loc)) + if !isnothing(results) + if result_inference + error("Result inference and provided results conflict") + end + API.mlirOperationStateAddResults(state, length(results), results) + end + if !isnothing(operands) + API.mlirOperationStateAddOperands(state, length(operands), operands) + end + if !isnothing(owned_regions) + lose_ownership!.(owned_regions) + GC.@preserve owned_regions begin + mlir_regions = Base.unsafe_convert.(MlirRegion, owned_regions) + API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) + end + end + if !isnothing(successors) + GC.@preserve successors begin + mlir_blocks = Base.unsafe_convert.(MlirBlock, successors) + API.mlirOperationStateAddSuccessors( + state, + length(mlir_blocks), + mlir_blocks, + ) + end + end + if !isnothing(attributes) + API.mlirOperationStateAddAttributes(state, length(attributes), attributes) + end + if result_inference + API.mlirOperationStateEnableResultTypeInference(state) + end + op = API.mlirOperationCreate(state) + if mlirIsNull(op) + error("Create Operation failed") + end + Operation(op, true) + end +end + +Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) + +num_regions(operation) = API.mlirOperationGetNumRegions(operation) +function get_region(operation, i) + i ∉ 1:num_regions(operation) && throw(BoundsError(operation, i)) + Region(API.mlirOperationGetRegion(operation, i - 1), false) +end +num_results(operation::Operation) = API.mlirOperationGetNumResults(operation) +get_results(operation) = [ + get_result(operation, i) + for i in 1:num_results(operation) +] +function get_result(operation::Operation, i=1) + i ∉ 1:num_results(operation) && throw(BoundsError(operation, i)) + Value(API.mlirOperationGetResult(operation, i - 1)) +end +num_operands(operation) = API.mlirOperationGetNumOperands(operation) +function get_operand(operation, i=1) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + Value(API.mlirOperationGetOperand(operation, i - 1)) +end +function set_operand!(operation, i, value) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + API.mlirOperationSetOperand(operation, i - 1, value) + value +end + +function get_attribute_by_name(operation, name) + raw_attr = API.mlirOperationGetAttributeByName(operation, name) + if mlirIsNull(raw_attr) + return nothing + end + Attribute(raw_attr) +end +function set_attribute_by_name!(operation, name, attribute) + API.mlirOperationSetAttributeByName(operation, name, attribute) + operation +end + +location(operation) = Location(API.mlirOperationGetLocation(operation)) +name(operation) = String(API.mlirOperationGetName(operation)) +block(operation) = Block(API.mlirOperationGetBlock(operation), false) +parent_operation(operation) = Operation(API.mlirOperationGetParentOperation(operation), false) +dialect(operation) = first(split(get_name(operation), '.')) |> Symbol + +function get_first_region(op::Operation) + reg = iterate(RegionIterator(op)) + isnothing(reg) && return nothing + first(reg) +end +function get_first_block(op::Operation) + reg = get_first_region(op) + isnothing(reg) && return nothing + block = iterate(BlockIterator(reg)) + isnothing(block) && return nothing + first(block) +end +function get_first_child_op(op::Operation) + block = get_first_block(op) + isnothing(block) && return nothing + cop = iterate(OperationIterator(block)) + first(cop) +end + +op::Operation == other::Operation = API.mlirOperationEqual(op, other) + +Base.cconvert(::Type{MlirOperation}, operation::Operation) = operation +Base.unsafe_convert(::Type{MlirOperation}, operation::Operation) = operation.operation + +function lose_ownership!(operation::Operation) + @assert operation.owned + @atomic operation.owned = false + operation +end + +function Base.show(io::IO, operation::Operation) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + flags = API.mlirOpPrintingFlagsCreate() + get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) + API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + println(io) +end + +verify(operation::Operation) = API.mlirOperationVerify(operation) + +### Block + +mutable struct Block + block::MlirBlock + @atomic owned::Bool + + Block(block::MlirBlock, owned::Bool=true) = begin + @assert !mlirIsNull(block) "cannot create Block with null MlirBlock" + finalizer(new(block, owned)) do block + if block.owned + API.mlirBlockDestroy(block.block) + end + end + end +end + +Block() = Block(MLIRType[], Location[]) +function Block(args::Vector{MLIRType}, locs::Vector{Location}) + @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" + Block(API.mlirBlockCreate(length(args), args, locs)) +end + +function Base.push!(block::Block, op::Operation) + API.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) + op +end +function Base.insert!(block::Block, pos, op::Operation) + API.mlirBlockInsertOwnedOperation(block, pos - 1, lose_ownership!(op)) + op +end +function Base.pushfirst!(block::Block, op::Operation) + insert!(block, 1, op) + op +end +function insert_after!(block::Block, reference::Operation, op::Operation) + API.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) + op +end +function insert_before!(block::Block, reference::Operation, op::Operation) + API.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) + op +end + +num_arguments(block::Block) = + API.mlirBlockGetNumArguments(block) +function get_argument(block::Block, i) + i ∉ 1:num_arguments(block) && throw(BoundsError(block, i)) + Value(API.mlirBlockGetArgument(block, i - 1)) +end +push_argument!(block::Block, type, loc) = + Value(API.mlirBlockAddArgument(block, type, loc)) + +Base.cconvert(::Type{MlirBlock}, block::Block) = block +Base.unsafe_convert(::Type{MlirBlock}, block::Block) = block.block + +function lose_ownership!(block::Block) + @assert block.owned + @atomic block.owned = false + block +end + +function Base.show(io::IO, block::Block) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirBlockPrint(block, c_print_callback, ref) +end + +### Region + +mutable struct Region + region::MlirRegion + @atomic owned::Bool + + Region(region, owned=true) = begin + @assert !mlirIsNull(region) + finalizer(new(region, owned)) do region + if region.owned + API.mlirRegionDestroy(region.region) + end + end + end +end + +Region() = Region(API.mlirRegionCreate()) + +function Base.push!(region::Region, block::Block) + API.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) + block +end +function Base.insert!(region::Region, pos, block::Block) + API.mlirRegionInsertOwnedBlock(region, pos - 1, lose_ownership!(block)) + block +end +function Base.pushfirst!(region::Region, block) + insert!(region, 1, block) + block +end +insert_after!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) +insert_before!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) + +function get_first_block(region::Region) + block = iterate(BlockIterator(region)) + isnothing(block) && return nothing + first(block) +end + +function lose_ownership!(region::Region) + @assert region.owned + @atomic region.owned = false + region +end + +Base.cconvert(::Type{MlirRegion}, region::Region) = region +Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region + +### Module + +mutable struct MModule + module_::MlirModule + + MModule(module_) = begin + @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" + finalizer(API.mlirModuleDestroy, new(module_)) + end +end + +MModule(loc::Location=Location()) = + MModule(API.mlirModuleCreateEmpty(loc)) +get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) +get_body(module_) = Block(API.mlirModuleGetBody(module_), false) +get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) + +Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ +Base.parse(::Type{MModule}, module_) = MModule(API.mlirModuleCreateParse(context(), module_), context()) + +macro mlir_str(code) + quote + ctx = Context() + parse(MModule, ctx, code) + end +end + +function Base.show(io::IO, module_::MModule) + println(io, "MModule:") + show(io, get_operation(module_)) +end + +### TypeID + +struct TypeID + typeid::API.MlirTypeID +end + +Base.hash(typeid::TypeID) = API.mlirTypeIDHashValue(typeid.typeid) +Base.convert(::Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid + +@static if isdefined(API, :MlirTypeIDAllocator) + + ### TypeIDAllocator + + mutable struct TypeIDAllocator + allocator::API.MlirTypeIDAllocator + + function TypeIDAllocator() + ptr = API.mlirTypeIDAllocatorCreate() + @assert ptr != C_NULL "cannot create TypeIDAllocator" + finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + end + end + + Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator + Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator + + TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) + +else + + struct TypeIDAllocator end + +end + +include("./Support.jl") +include("./Pass.jl") + diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl new file mode 100644 index 00000000..f4718dbe --- /dev/null +++ b/src/IR/Pass.jl @@ -0,0 +1,174 @@ +### Pass Manager + +abstract type AbstractPass end + +mutable struct ExternalPassHandle + ctx::Union{Nothing,Context} + pass::AbstractPass +end + +mutable struct PassManager + pass::MlirPassManager + allocator::TypeIDAllocator + passes::Dict{TypeID,ExternalPassHandle} + + PassManager(pm::MlirPassManager) = begin + @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" + finalizer(new(pm, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + API.mlirPassManagerDestroy(pm.pass) + end + end +end + +function enable_ir_printing!(pm) + API.mlirPassManagerEnableIRPrinting(pm) + pm +end +function enable_verifier!(pm, enable=true) + API.mlirPassManagerEnableVerifier(pm, enable) + pm +end + +PassManager() = + PassManager(API.mlirPassManagerCreate(context())) + +function run!(pm::PassManager, module_) + status = API.mlirPassManagerRun(pm, module_) + if mlirLogicalResultIsFailure(status) + throw("failed to run pass manager on module") + end + module_ +end + +Base.convert(::Type{MlirPassManager}, pass::PassManager) = pass.pass + +### Op Pass Manager + +struct OpPassManager + op_pass::MlirOpPassManager + pass::PassManager + + OpPassManager(op_pass, pass) = begin + @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" + new(op_pass, pass) + end +end + +OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) + +Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass + +function Base.show(io::IO, op_pass::OpPassManager) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + println(io, "OpPassManager(\"\"\"") + API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + println(io) + print(io, "\"\"\")") +end + +struct AddPipelineException <: Exception + message::String +end + +function Base.showerror(io::IO, err::AddPipelineException) + print(io, "failed to add pipeline:", err.message) + nothing +end + +function add_pipeline!(op_pass::OpPassManager, pipeline) + @static if isdefined(API, :mlirOpPassManagerAddPipeline) + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) + if mlirLogicalResultIsFailure(result) + exc = AddPipelineException(String(take!(io))) + throw(exc) + end + else + result = API.mlirParsePassPipeline(op_pass, pipeline) + if mlirLogicalResultIsFailure(result) + throw(AddPipelineException(" " * pipeline)) + end + end + op_pass +end + +function add_owned_pass!(pm::PassManager, pass) + API.mlirPassManagerAddOwnedPass(pm, pass) + pm +end + +function add_owned_pass!(opm::OpPassManager, pass) + API.mlirOpPassManagerAddOwnedPass(opm, pass) + opm +end + + +@static if isdefined(API, :mlirCreateExternalPass) + + ### Pass + + # AbstractPass interface: + opname(::AbstractPass) = "" + function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") + end + + function _pass_construct(ptr::ExternalPassHandle) + nothing + end + + function _pass_destruct(ptr::ExternalPassHandle) + nothing + end + + function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + mlirLogicalResultSuccess() + catch + mlirLogicalResultFailure() + end + end + + function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) + end + + function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception = (ex, catch_backtrace()) + API.mlirExternalPassSignalFailure(external_pass) + end + nothing + end + + function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) + end + function create_external_pass!(manager, pass, name, argument, + description, opname=opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = API.MlirExternalPassCallbacks( + @cfunction(_pass_construct, Cvoid, (Any,)), + @cfunction(_pass_destruct, Cvoid, (Any,)), + @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), + @cfunction(_pass_clone, Any, (Any,)), + @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass + end + +end \ No newline at end of file diff --git a/src/IR/Support.jl b/src/IR/Support.jl new file mode 100644 index 00000000..f84689e3 --- /dev/null +++ b/src/IR/Support.jl @@ -0,0 +1,133 @@ +function mlirIsNull(val) + val.ptr == C_NULL +end + +### Identifier + +String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) + +### Logical Result + +mlirLogicalResultSuccess() = API.MlirLogicalResult(1) +mlirLogicalResultFailure() = API.MlirLogicalResult(0) + +mlirLogicalResultIsSuccess(result) = result.value != 0 +mlirLogicalResultIsFailure(result) = result.value == 0 + +### Iterators + +""" + BlockIterator(region::Region) + +Iterates over all blocks in the given region. +""" +struct BlockIterator + region::Region +end + +function Base.iterate(it::BlockIterator) + reg = it.region + raw_block = API.mlirRegionGetFirstBlock(reg) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +function Base.iterate(it::BlockIterator, block) + raw_block = API.mlirBlockGetNextInRegion(block) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +""" + OperationIterator(block::Block) + +Iterates over all operations for the given block. +""" +struct OperationIterator + block::Block +end + +function Base.iterate(it::OperationIterator) + raw_op = API.mlirBlockGetFirstOperation(it.block) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +function Base.iterate(it::OperationIterator, op) + raw_op = API.mlirOperationGetNextInBlock(op) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +""" + RegionIterator(::Operation) + +Iterates over all sub-regions for the given operation. +""" +struct RegionIterator + op::Operation +end + +function Base.iterate(it::RegionIterator) + raw_region = API.mlirOperationGetFirstRegion(it.op) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +function Base.iterate(it::RegionIterator, region) + raw_region = API.mlirRegionGetNextInOperation(region) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +### Utils + +function visit(f, op) + for region in RegionIterator(op) + for block in BlockIterator(region) + for op in OperationIterator(block) + f(op) + end + end + end +end + +""" + verifyall(operation; debug=false) + +Prints the operations which could not be verified. +""" +function verifyall(operation::Operation; debug=false) + io = IOContext(stdout, :debug => debug) + visit(operation) do op + if !verify(op) + show(io, op) + end + end +end +verifyall(module_::MModule) = get_operation(module_) |> verifyall + diff --git a/src/IR/state.jl b/src/IR/state.jl new file mode 100644 index 00000000..072f65ee --- /dev/null +++ b/src/IR/state.jl @@ -0,0 +1,43 @@ +# Global state + +# to simplify the API, we maintain a stack of contexts in task local storage +# and pass them implicitly to MLIR API's that require them. + +export context, activate, deactivate, context! + +using ..IR + +_has_context() = haskey(task_local_storage(), :MLIRContext) && + !isempty(task_local_storage(:MLIRContext)) + +function context(; throw_error::Core.Bool=true) + if !_has_context() + throw_error && error("No MLIR context is active") + return nothing + end + last(task_local_storage(:MLIRContext)) +end + +function activate(ctx::Context) + stack = get!(task_local_storage(), :MLIRContext) do + Context[] + end + push!(stack, ctx) + return +end + +function deactivate(ctx::Context) + context() == ctx || error("Deactivating wrong context") + pop!(task_local_storage(:MLIRContext)) +end + +function context!(f, ctx::Context) + activate(ctx) + try + f() + finally + deactivate(ctx) + end +end + + diff --git a/src/MLIR.jl b/src/MLIR.jl index c60e67e4..3fc155ee 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,4 +35,15 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -end # module +module IR + import ..API: API + + include("./IR/IR.jl") + include("./IR/state.jl") +end # module IR + +include("./highlevel.jl") +include("./Dialects.jl") + + +end # module MLIR diff --git a/src/highlevel.jl b/src/highlevel.jl new file mode 100644 index 00000000..3c26fb43 --- /dev/null +++ b/src/highlevel.jl @@ -0,0 +1,144 @@ +module Builder + +export @Block, @Region + +using ...IR + +ctx = IR.Context() +loc = IR.Location() + +struct BlockBuilder + block::IR.Block + expr::Expr +end +_has_blockbuilder() = haskey(task_local_storage(), :BlockBuilder) && + !isempty(task_local_storage(:BlockBuilder)) + +function blockbuilder() + if !_has_blockbuilder() + error("No BlockBuilder is active") + return nothing + end + last(task_local_storage(:BlockBuilder)) +end +function activate(b::BlockBuilder) + stack = get!(task_local_storage(), :BlockBuilder) do + BlockBuilder[] + end + push!(stack, b) +end +function deactivate(b::BlockBuilder) + blockbuilder() == b || error("Deactivating wrong RegionBuilder") + pop!(task_local_storage(:BlockBuilder)) +end + +struct RegionBuilder + region::IR.Region + blockbuilders::Vector{BlockBuilder} +end +_has_regionbuilder() = haskey(task_local_storage(), :RegionBuilder) && + !isempty(task_local_storage(:RegionBuilder)) +function regionbuilder() + if !_has_regionbuilder() + error("No RegionBuilder is active") + return nothing + end + last(task_local_storage(:RegionBuilder)) +end +function activate(r::RegionBuilder) + stack = get!(task_local_storage(), :RegionBuilder) do + RegionBuilder[] + end + push!(stack, r) +end +function deactivate(r::RegionBuilder) + regionbuilder() == r || error("Deactivating wrong RegionBuilder") + pop!(task_local_storage(:RegionBuilder)) +end + +function Region(expr) + exprs = Expr[] + + #= Create region =# + region = IR.Region() + #= Push region on the stack =# + regionbuilder = RegionBuilder(region, BlockBuilder[]) + activate(regionbuilder) + #= + `expr` calls to @block. + These calls will create the block variables that + are referenced in control flow operations. + Blocks are added to the region at the top of the + stack and a queue of blocks is kept. The + expressions to generate the operations in each + block can't be executed yet since they can't + reference the blocks before their creation. + =# + push!(exprs, expr) + #= + Once the blocks are created, the operation + code can be run. This happens in order. All the + operations are pushed to the block at the front + of the queue + =# + push!(exprs, quote + for blockbuilder in $regionbuilder.blockbuilders + $activate(blockbuilder) + eval(blockbuilder.expr) + $deactivate(blockbuilder) + end + end) + + push!(exprs, quote + $deactivate($regionbuilder) + $region + end) + + return Expr(:block, exprs...) +end +macro Region(expr) + quote + $(esc(Region(expr))) + end +end + +function Block(expr) + block = IR.Block() + blockbuilder = BlockBuilder(block, expr) + + if (_has_regionbuilder()) + #= Add block to current region =# + push!(regionbuilder().region, block) + #= + Add blockbuilder to the queue to come back later to + generate its operations. + =# + push!(regionbuilder().blockbuilders, blockbuilder) + + #= + Only return the block, don't create the + operations yet. + =# + return quote + $block + end + else + #= + If there's no regionbuilder, the operations + defined in `expr` can immediately get executed + =# + return quote + $activate($blockbuilder) + $expr + $deactivate($blockbuilder) + $block + end + end +end +macro Block(expr) + quote + $(esc(Block(expr))) + end +end + +end # Builder \ No newline at end of file