Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Run tests on old pending allocator change #22008

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace Dml
bool enableCpuSyncSpinning,
bool disableMemoryArena);

ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr);
ID3D12Resource* GetD3D12ResourceFromAllocation(void* ptr);
void FlushContext(onnxruntime::IExecutionProvider* provider);
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,6 @@ namespace Dml
// The allocation info is already destructing at this point
}


const AllocationInfo* BucketizedBufferAllocator::DecodeDataHandle(const void* opaqueHandle)
{
if (opaqueHandle == nullptr)
{
// There is no memory allocated which needs to be decoded.
ORT_THROW_HR(E_INVALIDARG);
}
const auto* allocInfo = static_cast<const AllocationInfo*>(opaqueHandle);
return allocInfo;
}

void BucketizedBufferAllocator::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode)
{
m_defaultRoundingMode = roundingMode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ namespace Dml
D3D12_RESOURCE_STATES initialState,
std::unique_ptr<DmlSubAllocator>&& subAllocator);

// Returns the information associated with an opaque allocation handle returned by IAllocator::Alloc.
const AllocationInfo* DecodeDataHandle(const void* opaqueHandle);

void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode);

public: // onnxruntime::IAllocator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@ DmlCommandRecorder::DmlCommandRecorder(
ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder)));
}

void DmlCommandRecorder::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
{
m_bufferAllocator = allocator;
}

void DmlCommandRecorder::InitializeOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
Expand Down Expand Up @@ -57,8 +53,6 @@ void DmlCommandRecorder::InitializeOperator(
UINT64 temporaryResourceSize = initBindingProps.TemporaryResourceSize;
if (temporaryResourceSize > 0)
{
auto allocator = m_bufferAllocator.lock();

// Allocate and immediately free a temporary buffer. The buffer resource will still be
// alive (managed by the pool); freeing allows the resource to be shared with other operators.
void* tempResourceHandle = allocator->Alloc(static_cast<size_t>(temporaryResourceSize));
Expand All @@ -67,7 +61,7 @@ void DmlCommandRecorder::InitializeOperator(
ORT_THROW_HR(E_OUTOFMEMORY);
}

ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource();
ID3D12Resource* buffer = Dml::GetD3D12ResourceFromAllocation(tempResourceHandle);
allocator->Free(tempResourceHandle);

// Bind the temporary resource.
Expand Down Expand Up @@ -107,6 +101,7 @@ void DmlCommandRecorder::InitializeOperator(
}

void DmlCommandRecorder::ExecuteOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
Expand All @@ -133,8 +128,6 @@ void DmlCommandRecorder::ExecuteOperator(
UINT64 temporaryResourceSize = execBindingProps.TemporaryResourceSize;
if (temporaryResourceSize > 0)
{
auto allocator = m_bufferAllocator.lock();

// Allocate and immediately free a temporary buffer. The buffer resource will still be
// alive (managed by the pool); freeing allows the resource to be shared with other operators.
void* tempResourceHandle = allocator->Alloc(static_cast<size_t>(temporaryResourceSize));
Expand All @@ -143,7 +136,7 @@ void DmlCommandRecorder::ExecuteOperator(
ORT_THROW_HR(E_OUTOFMEMORY);
}

ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource();
ID3D12Resource* buffer = Dml::GetD3D12ResourceFromAllocation(tempResourceHandle);
allocator->Free(tempResourceHandle);

// Bind the temporary resource.
Expand Down Expand Up @@ -338,7 +331,7 @@ void DmlCommandRecorder::CloseAndExecute()
}

void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* commandList)
{
{
ORT_THROW_IF_FAILED(m_currentCommandList->Close());

ID3D12GraphicsCommandList* commandListsToExecute[2] = {};
Expand All @@ -359,7 +352,7 @@ void DmlCommandRecorder::CloseAndExecute(_In_opt_ ID3D12GraphicsCommandList* com
m_queue->ExecuteCommandLists(
gsl::span<ID3D12CommandList*>(reinterpret_cast<ID3D12CommandList**>(commandListsToExecute), commandListsToExecuteCount));
}

m_cachedCommandList = m_currentCommandList;
m_currentCommandList = nullptr;
m_operationsRecordedInCurrentCommandList = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ namespace Dml
std::shared_ptr<CommandQueue> commandQueue);

void InitializeOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding);

void ExecuteOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
Expand Down Expand Up @@ -56,8 +58,6 @@ namespace Dml
void Open() final;
void CloseAndExecute() final;

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

bool HasUnsubmittedWork() override
{
return m_operationsRecordedInCurrentCommandList;
Expand All @@ -83,9 +83,6 @@ namespace Dml
DescriptorPool m_descriptorPool;
ID3D12DescriptorHeap* m_currentDescriptorHeap = nullptr;

// The weak pointer avoids a circular reference from context->recorder->allocator->context
std::weak_ptr<BucketizedBufferAllocator> m_bufferAllocator;

CommandAllocatorRing<2> m_commandAllocatorRing;

// The command list currently being recorded into, and whether any command have been recorded yet.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ namespace Dml
}

ORT_THROW_IF_FAILED(m_provider->InitializeOperator(
kernelInfo.GetAllocator(OrtMemType::OrtMemTypeDefault),
m_compiledExecutionPlanOperator.Get(),
m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr,
gsl::make_span(initInputBindings)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ namespace Dml
ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf())));
}

void ExecutionContext::SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator)
{
m_dmlRecorder.SetAllocator(allocator);
}

void ExecutionContext::CopyBufferRegion(
ID3D12Resource* dstBuffer,
uint64_t dstOffset,
Expand Down Expand Up @@ -91,17 +86,19 @@ namespace Dml
}

void ExecutionContext::InitializeOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding)
{
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.InitializeOperator(op, persistentResourceBinding, inputArrayBinding);
m_dmlRecorder.InitializeOperator(allocator, op, persistentResourceBinding, inputArrayBinding);
}

void ExecutionContext::ExecuteOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
Expand All @@ -110,7 +107,7 @@ namespace Dml
assert(!m_closed);
SetCommandRecorder(&m_dmlRecorder);

m_dmlRecorder.ExecuteOperator(op, persistentResourceBinding, inputBindings, outputBindings);
m_dmlRecorder.ExecuteOperator(allocator, op, persistentResourceBinding, inputBindings, outputBindings);
}

void ExecutionContext::AddUAVBarrier()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ namespace Dml
bool cpuSyncSpinningEnabled,
bool keepOpen);

void SetAllocator(std::weak_ptr<BucketizedBufferAllocator> allocator);

// Waits for flushed work, discards unflushed work, and discards associated references to
// prevent circular references. Must be the last call on the object before destruction.
void Close();
Expand All @@ -49,11 +47,13 @@ namespace Dml
gsl::span<const std::byte> pattern /* Data type agnostic value, treated as raw bits */);

void InitializeOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
const DML_BINDING_DESC& inputArrayBinding);

void ExecuteOperator(
onnxruntime::AllocatorPtr& allocator,
IDMLCompiledOperator* op,
const DML_BINDING_DESC& persistentResourceBinding,
gsl::span<const DML_BINDING_DESC> inputBindings,
Expand Down
Loading
Loading