Compare commits

...

68 commits

Author SHA1 Message Date
Xu
c42ce9adeb feat: 添加性能测试模式 2025-04-21 21:14:04 +08:00
Xu
d894579bd7 fix: TensorRT 的硬件要求已提升到 SM7.5 2025-04-18 16:54:30 +08:00
Xu
6f084e9fb1 chore: 优化引入 onnxruntime 的方式 2025-04-18 16:45:36 +08:00
Xu
11062b11b4 chore: ARM64 不支持 TensorRT 2025-04-12 22:48:20 +08:00
Xu
c91cd6f587 UI: 禁用更新 UI 2025-04-12 21:18:18 +08:00
Xu
83d312aacd perf: 优化 d3d11 和 cuda 的同步性能 2025-04-12 15:38:52 +08:00
Xu
b18d2aae54 feat: 不再支持 CUDA 2025-04-11 22:01:15 +08:00
Xu
4b0bc51504 feat: 所有捕获方式都能工作 2025-04-11 19:25:43 +08:00
Xu
523a24b2b1 fix: 修复独立显卡无法使用 DML 的问题 2025-04-09 21:24:40 +08:00
Xu
2bcbc3e399 fix: 修复 DML 缓冲区对齐 2025-04-08 23:48:53 +08:00
Xu
91859d2d1c fix: 不再使用 wil::CreateDirectoryDeepNoThrow,因为它不支持相对路径 2025-04-08 22:38:53 +08:00
Xu
d1e60d7e41 chore: 更新依赖
除了 ImGui 停留在 v1.90,不值得适配
2025-04-08 21:09:06 +08:00
Xu
8e37bc17d0 chore: 更新依赖,改为使用 HybridCRT
tensorrt 的 dll 终于不依赖 CRT 了
2025-04-08 20:41:35 +08:00
Xu
768f13d31a Merge branch 'main' into onnx 2025-04-06 14:37:21 +08:00
刘旭
7de2f4bf32 Merge branch 'dev' into onnx 2024-06-12 10:45:15 +08:00
刘旭
1f2693ff4e chore: 修复编译 2024-05-13 13:44:19 +08:00
刘旭
b56809634b Merge branch 'dev' into onnx 2024-05-13 09:55:54 +08:00
Xu
eace2e87b0 Merge branch 'dev' into onnx 2024-04-04 20:57:26 +08:00
Xu
6b4a92cc6d Merge branch 'render-system' into onnx 2024-03-30 19:56:33 +08:00
Xu
08b07e155c Merge branch 'render-system' into onnx 2024-03-24 12:19:32 +08:00
Xu
6ede0212cd Merge branch 'render-system' into onnx 2024-03-23 20:24:50 +08:00
Xu
f8dc1ff04d Merge branch 'render-system' into onnx 2024-03-20 00:37:29 +08:00
刘旭
ca99356fbf Merge branch 'render-system' into onnx 2024-03-18 15:26:32 +08:00
Xu
a5726c7506 feat: 支持任意缩放倍率 2024-03-10 16:22:37 +08:00
Xu
69416aff3d chore: 动态链接 CRT 2024-03-10 15:15:18 +08:00
Xu
06ca4e0be6 chore: 依赖 dll 移到 third_party 文件夹 2024-03-10 14:52:22 +08:00
Xu
a6cc7fa67a feat: 从 json 中读取模型 2024-03-10 13:23:15 +08:00
Xu
cc176f72f2 fix: 错误处理 2024-03-10 12:19:53 +08:00
Xu
a1019bba34 feat: 添加 CUDA 后端 2024-03-10 00:55:32 +08:00
Xu
4bcd77be76 refactor 2024-03-09 23:59:13 +08:00
Xu
544ab2a0bf feat: 检查模型是否支持 2024-03-09 22:35:06 +08:00
Xu
8270f3a24c refactor 2024-03-09 21:46:02 +08:00
Xu
6f9b8b358f refactor 2024-03-09 13:01:58 +08:00
Xu
048a3494a3 feat: 初步支持 DML 2024-03-08 22:43:14 +08:00
刘旭
d87c47c790 wip 2024-03-08 17:29:13 +08:00
Xu
d36bf89b65 feat: tensorrt 缓存和 ONNX Runtime 版本、TensorRT 版本绑定 2024-03-07 21:34:36 +08:00
Xu
a8325eccfa feat: tensorrt 支持缓存 2024-03-07 20:59:05 +08:00
Xu
005498b029 feat: tensorrt 支持 fp16 2024-03-06 19:58:54 +08:00
Xu
a4af854ed3 perf 2024-03-06 18:55:18 +08:00
Xu
9d26a4c795 perf 2024-03-06 00:25:48 +08:00
刘旭
0700810727 wip 2024-03-05 16:38:05 +08:00
Xu
d0dc556239 test 2024-03-05 00:15:51 +08:00
刘旭
587fdd5cc6 fix: 检查 Compute Capability 2024-03-04 15:46:50 +08:00
Xu
786ff2fc22 测试 ORT 2024-03-04 00:42:07 +08:00
Xu
98f8649a27 chore 2024-02-29 19:24:48 +08:00
Xu
08e20fe1a9 fix: 优化 cuda api 调用 2024-02-28 22:15:04 +08:00
Xu
feb52a2ca9 perf: 删除不需要的同步 2024-02-27 23:54:58 +08:00
Xu
a5f9e4ecb6 feat: 初步集成 tensorrt 2024-02-27 23:46:31 +08:00
Xu
6685a1df01 wip 2024-02-26 23:17:50 +08:00
Xu
972d0b057a wip 2024-02-26 20:58:45 +08:00
刘旭
fb7c840ca1 wip 2024-02-26 16:59:39 +08:00
Xu
b1036cd9f2 测试 tensorrt 2024-02-25 23:34:20 +08:00
Xu
657209dd39 perf: 性能优化 2024-02-24 23:01:59 +08:00
Xu
427c7a6973 fix: 错误处理 2024-02-24 20:48:40 +08:00
Xu
45a3178a10 fix: 添加错误处理 2024-02-24 18:09:30 +08:00
Xu
6a94c860fd chore 2024-02-24 17:56:37 +08:00
Xu
e8cad13732 perf: 性能优化 2024-02-24 17:16:36 +08:00
Xu
95e04aed1d Merge branch 'effect-profiler' into onnx 2024-02-24 14:10:57 +08:00
Xu
aee12b750b Merge branch 'render-system' into effect-profiler 2024-02-22 20:57:31 +08:00
Xu
34c2123b36 Merge branch 'render-system' into effect-profiler 2024-02-21 23:46:59 +08:00
Xu
809ab1aac1 test 2024-02-21 23:46:47 +08:00
刘旭
4c0bd3131f feat: test 2024-02-21 17:03:01 +08:00
Xu
b658d536e6 Merge branch 'render-system' into effect-profiler 2024-02-20 19:57:56 +08:00
Xu
30a43bc919 Merge branch 'render-system' into effect-profiler 2024-01-22 20:59:42 +08:00
Xu
1bfabb45e3 chore: 添加注释 2024-01-22 19:48:35 +08:00
Xu
4512d3e399 feat: 降低渲染时间更新频率 2024-01-22 00:33:02 +08:00
Xu
94d94f9508 feat: 支持在叠加层中显示渲染时间 2024-01-20 23:03:59 +08:00
Xu
4246f9841a feat: 支持查询每个效果的渲染时间 2024-01-20 20:26:05 +08:00
59 changed files with 2069 additions and 128 deletions

View file

@ -29,7 +29,7 @@
<PrecompiledHeader>Use</PrecompiledHeader>
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
<PrecompiledHeaderOutputFile>$(IntDir)pch.pch</PrecompiledHeaderOutputFile>
<PreprocessorDefinitions>_WINDOWS;WIN32_LEAN_AND_MEAN;WINRT_LEAN_AND_MEAN;WINRT_NO_MODULE_LOCK;WIL_SUPPRESS_EXCEPTIONS;NOGDICAPMASKS;NOICONS;NOATOM;NOCLIPBOARD;NODRAWTEXT;NOMEMMGR;NOMETAFILE;NOMINMAX;NOOPENFILE;NOSCROLL;NOSERVICE;NOSOUND;NOTEXTMETRIC;NOCOMM;NOKANJI;NOHELP;NOPROFILER;NODEFERWINDOWPOS;NOMCX;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions>_WINDOWS;WIN32_LEAN_AND_MEAN;WINRT_LEAN_AND_MEAN;WINRT_NO_MODULE_LOCK;WIL_SUPPRESS_EXCEPTIONS;WIL_USE_STL=1;NOGDICAPMASKS;NOICONS;NOATOM;NOCLIPBOARD;NODRAWTEXT;NOMEMMGR;NOMETAFILE;NOMINMAX;NOOPENFILE;NOSCROLL;NOSERVICE;NOSOUND;NOTEXTMETRIC;NOCOMM;NOKANJI;NOHELP;NOPROFILER;NODEFERWINDOWPOS;NOMCX;%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(CommitId)'!=''">MAGPIE_COMMIT_ID=$(CommitId);%(PreprocessorDefinitions)</PreprocessorDefinitions>
<PreprocessorDefinitions Condition="'$(MajorVersion)'!='' And '$(MinorVersion)'!='' And '$(PatchVersion)'!='' And '$(VersionTag)'!=''">MAGPIE_VERSION_MAJOR=$(MajorVersion);MAGPIE_VERSION_MINOR=$(MinorVersion);MAGPIE_VERSION_PATCH=$(PatchVersion);MAGPIE_VERSION_TAG=$(VersionTag);%(PreprocessorDefinitions)</PreprocessorDefinitions>
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
@ -38,7 +38,7 @@
<GenerateDebugInformation>true</GenerateDebugInformation>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(Configuration)'=='Debug'">
<ClCompile>
<PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
@ -59,7 +59,37 @@
<OptimizeReferences>true</OptimizeReferences>
</Link>
</ItemDefinitionGroup>
<!-- HybridCRT -->
<Import Project="$(MSBuildThisFileDirectory)HybridCRT.props" />
<!-- onnxruntime -->
<ItemDefinitionGroup>
<ClCompile>
<AdditionalIncludeDirectories>$(SolutionDir)obj\onnxruntime\include;$(SolutionDir)obj\onnxruntime\include\onnxruntime\core\session;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<Link>
<AdditionalLibraryDirectories>$(SolutionDir)obj\onnxruntime\lib\$(Platform);%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.Debug.dll"
Condition="'$(Configuration)'=='Debug' And Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.Debug.dll')">
<Link>third_party\DirectML.Debug.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.dll"
Condition="Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.dll')">
<Link>third_party\DirectML.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\onnxruntime.dll"
Condition="Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\onnxruntime.dll')">
<Link>third_party\onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
</ItemGroup>
</Project>

View file

@ -117,19 +117,19 @@
</local:SimpleStackPanel>
</local:SimpleStackPanel>
<local:SettingsGroup x:Uid="About_Version_UpdateSettings">
<local:SettingsCard x:Uid="About_Version_UpdateSettings_AutoCheckForUpdates">
<local:SettingsCard x:Uid="About_Version_UpdateSettings_AutoCheckForUpdates"
IsEnabled="False">
<local:SettingsCard.HeaderIcon>
<FontIcon Glyph="&#xECC5;" />
</local:SettingsCard.HeaderIcon>
<ToggleSwitch x:Uid="ToggleSwitch"
IsOn="{x:Bind ViewModel.IsAutoCheckForUpdates, Mode=TwoWay}" />
<ToggleSwitch x:Uid="ToggleSwitch" />
</local:SettingsCard>
<local:SettingsCard x:Uid="About_Version_UpdateSettings_CheckForPreviewUpdates">
<local:SettingsCard x:Uid="About_Version_UpdateSettings_CheckForPreviewUpdates"
IsEnabled="False">
<local:SettingsCard.HeaderIcon>
<FontIcon Glyph="&#xED56;" />
</local:SettingsCard.HeaderIcon>
<ToggleSwitch x:Uid="ToggleSwitch"
IsOn="{x:Bind ViewModel.IsCheckForPreviewUpdates, Mode=TwoWay}" />
<ToggleSwitch x:Uid="ToggleSwitch" />
</local:SettingsCard>
</local:SettingsGroup>
<local:SettingsGroup x:Uid="About_Feedback">

View file

@ -63,7 +63,7 @@ hstring AboutViewModel::Version() const noexcept {
L" ",
WIDEN(STRING(MAGPIE_VERSION_TAG)) + 1,
#else
L" dev",
L" onnx-preview2",
#endif
#ifdef MAGPIE_COMMIT_ID
L" | ",

View file

@ -394,6 +394,7 @@ void AppSettings::IsDeveloperMode(bool value) noexcept {
if (!value) {
// 关闭开发者模式则禁用所有开发者选项
_isDebugMode = false;
_isBenchmarkMode = false;
_isEffectCacheDisabled = false;
_isFontCacheDisabled = false;
_isSaveEffectSources = false;
@ -458,9 +459,8 @@ void AppSettings::_UpdateWindowPlacement() noexcept {
}
bool AppSettings::_Save(const _AppSettingsData& data) noexcept {
HRESULT hr = wil::CreateDirectoryDeepNoThrow(data._configDir.c_str());
if (FAILED(hr)) {
Logger::Get().ComError("创建配置文件夹失败", hr);
if (!Win32Utils::CreateDir(data._configDir, true)) {
Logger::Get().Win32Error("创建配置文件夹失败");
return false;
}
@ -509,6 +509,8 @@ bool AppSettings::_Save(const _AppSettingsData& data) noexcept {
writer.Bool(data._isDeveloperMode);
writer.Key("debugMode");
writer.Bool(data._isDebugMode);
writer.Key("benchmarkMode");
writer.Bool(data._isBenchmarkMode);
writer.Key("disableEffectCache");
writer.Bool(data._isEffectCacheDisabled);
writer.Key("disableFontCache");
@ -666,6 +668,7 @@ void AppSettings::_LoadSettings(const rapidjson::GenericObject<true, rapidjson::
}
JsonHelper::ReadBool(root, "developerMode", _isDeveloperMode);
JsonHelper::ReadBool(root, "debugMode", _isDebugMode);
JsonHelper::ReadBool(root, "benchmarkMode", _isBenchmarkMode);
JsonHelper::ReadBool(root, "disableEffectCache", _isEffectCacheDisabled);
JsonHelper::ReadBool(root, "disableFontCache", _isFontCacheDisabled);
JsonHelper::ReadBool(root, "saveEffectSources", _isSaveEffectSources);
@ -1039,9 +1042,8 @@ bool AppSettings::_UpdateConfigPath(std::wstring* existingConfigPath) noexcept {
}
// 确保配置文件夹存在
HRESULT hr = wil::CreateDirectoryDeepNoThrow(_configDir.c_str());
if (FAILED(hr)) {
Logger::Get().ComError("创建配置文件夹失败", hr);
if (!Win32Utils::CreateDir(_configDir, true)) {
Logger::Get().Win32Error("创建配置文件夹失败");
return false;
}

View file

@ -55,6 +55,7 @@ struct _AppSettingsData {
bool _isAlwaysRunAsAdmin = false;
bool _isDeveloperMode = false;
bool _isDebugMode = false;
bool _isBenchmarkMode = false;
bool _isEffectCacheDisabled = false;
bool _isFontCacheDisabled = false;
bool _isSaveEffectSources = false;
@ -151,6 +152,15 @@ public:
SaveAsync();
}
bool IsBenchmarkMode() const noexcept {
return _isBenchmarkMode;
}
void IsBenchmarkMode(bool value) noexcept {
_isBenchmarkMode = value;
SaveAsync();
}
bool IsEffectCacheDisabled() const noexcept {
return _isEffectCacheDisabled;
}

View file

@ -194,6 +194,10 @@
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_DebugMode"
IsChecked="{x:Bind ViewModel.IsDebugMode, Mode=TwoWay}" />
</local:SettingsCard>
<local:SettingsCard ContentAlignment="Left">
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_BenchmarkMode"
IsChecked="{x:Bind ViewModel.IsBenchmarkMode, Mode=TwoWay}" />
</local:SettingsCard>
<local:SettingsCard ContentAlignment="Left">
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_DisableEffectCache"
IsChecked="{x:Bind ViewModel.IsEffectCacheDisabled, Mode=TwoWay}" />

View file

@ -298,6 +298,21 @@ void HomeViewModel::IsDebugMode(bool value) {
RaisePropertyChanged(L"IsDebugMode");
}
bool HomeViewModel::IsBenchmarkMode() const noexcept {
return AppSettings::Get().IsBenchmarkMode();
}
void HomeViewModel::IsBenchmarkMode(bool value) {
AppSettings& settings = AppSettings::Get();
if (settings.IsBenchmarkMode() == value) {
return;
}
settings.IsBenchmarkMode(value);
RaisePropertyChanged(L"IsBenchmarkMode");
}
bool HomeViewModel::IsEffectCacheDisabled() const noexcept {
return AppSettings::Get().IsEffectCacheDisabled();
}

View file

@ -72,6 +72,9 @@ struct HomeViewModel : HomeViewModelT<HomeViewModel>, wil::notify_property_chang
bool IsDebugMode() const noexcept;
void IsDebugMode(bool value);
bool IsBenchmarkMode() const noexcept;
void IsBenchmarkMode(bool value);
bool IsEffectCacheDisabled() const noexcept;
void IsEffectCacheDisabled(bool value);

View file

@ -35,6 +35,7 @@ namespace Magpie.App {
Boolean IsDeveloperMode;
Boolean IsDebugMode;
Boolean IsBenchmarkMode;
Boolean IsEffectCacheDisabled;
Boolean IsFontCacheDisabled;
Boolean IsSaveEffectSources;

View file

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" />
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" />
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" />
<PropertyGroup Label="Globals">
<CppWinRTGenerateWindowsMetadata>true</CppWinRTGenerateWindowsMetadata>
<MinimalCoreWin>true</MinimalCoreWin>
@ -59,9 +59,11 @@
<Link>
<GenerateWindowsMetadata>false</GenerateWindowsMetadata>
<SubSystem>Console</SubSystem>
<AdditionalDependencies>kernel32.lib;ole32.lib;oleaut32.lib;user32.lib;gdi32.lib;$(OutDir).\Magpie.Core.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies>kernel32.lib;ole32.lib;oleaut32.lib;user32.lib;gdi32.lib;onnxruntime.lib;directml.lib;$(OutDir).\Magpie.Core.lib;%(AdditionalDependencies)</AdditionalDependencies>
<AdditionalDependencies Condition="'$(Platform)'=='x64'">cudart.lib;%(AdditionalDependencies)</AdditionalDependencies>
<ModuleDefinitionFile>Magpie.App.def</ModuleDefinitionFile>
<DelayLoadDLLs>d3dcompiler_47.dll;Magnification.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
<DelayLoadDLLs>d3d12.dll;DirectML.dll;d3dcompiler_47.dll;Magnification.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
<DelayLoadDLLs Condition="'$(Platform)'=='x64'">cudart64_12.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
</Link>
</ItemDefinitionGroup>
<ItemGroup>
@ -719,20 +721,20 @@ File.Delete("priconfig.xml");
</ItemGroup>
</Target>
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" />
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" />
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
</Target>
</Project>

View file

@ -838,4 +838,7 @@
<data name="Home_Advanced_SimulateExclusiveFullscreen_InfoBar.Title" xml:space="preserve">
<value>This option is not compatible with some older games. Please use it with caution.</value>
</data>
<data name="Home_Advanced_DeveloperOptions_BenchmarkMode.Content" xml:space="preserve">
<value>Benchmark mode</value>
</data>
</root>

View file

@ -838,4 +838,7 @@
<data name="Home_Advanced_SimulateExclusiveFullscreen_InfoBar.Title" xml:space="preserve">
<value>此选项和一些旧游戏不兼容,请谨慎使用。</value>
</data>
<data name="Home_Advanced_DeveloperOptions_BenchmarkMode.Content" xml:space="preserve">
<value>性能测试模式</value>
</data>
</root>

View file

@ -304,6 +304,7 @@ bool ScalingService::_StartScale(HWND hWnd, const Profile& profile) {
}
options.IsDebugMode(settings.IsDebugMode());
options.IsBenchmarkMode(settings.IsBenchmarkMode());
options.IsEffectCacheDisabled(settings.IsEffectCacheDisabled());
options.IsFontCacheDisabled(settings.IsFontCacheDisabled());
options.IsSaveEffectSources(settings.IsSaveEffectSources());

View file

@ -1,12 +1,12 @@
[requires]
fmt/10.2.1
spdlog/1.14.1
parallel-hashmap/1.37
fmt/11.1.3
spdlog/1.15.1
parallel-hashmap/2.0.0
rapidjson/cci.20230929
kuba-zip/0.3.2
muparser/2.3.4
muparser/2.3.5
yas/7.1.0
imgui/1.90.8
imgui/1.90.9
[generators]
MSBuildDeps

View file

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.UI.Xaml" version="2.8.6" targetFramework="native" />
<package id="Microsoft.Web.WebView2" version="1.0.2535.41" targetFramework="native" />
<package id="Microsoft.UI.Xaml" version="2.8.7" targetFramework="native" />
<package id="Microsoft.Web.WebView2" version="1.0.3179.45" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
</packages>

View file

@ -95,7 +95,9 @@ bool DesktopDuplicationFrameSource::_Initialize() noexcept {
DXGI_FORMAT_B8G8R8A8_UNORM,
_srcRect.right - _srcRect.left,
_srcRect.bottom - _srcRect.top,
D3D11_BIND_SHADER_RESOURCE
D3D11_BIND_SHADER_RESOURCE,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
);
if (!_output) {
Logger::Get().Error("CreateTexture2D 失败");

View file

@ -76,6 +76,7 @@ bool DeviceResources::_ObtainAdapterAndDevice(int adapterIdx) noexcept {
if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) {
Logger::Get().Warn("用户指定的显示卡为 WARP已忽略");
} else if (_TryCreateD3DDevice(adapter)) {
_adapterIdx = adapterIdx;
return true;
} else {
Logger::Get().Warn("用户指定的显示卡不支持 FL 11");
@ -105,21 +106,31 @@ bool DeviceResources::_ObtainAdapterAndDevice(int adapterIdx) noexcept {
}
if (_TryCreateD3DDevice(adapter)) {
_adapterIdx = adapterIndex;
return true;
}
}
// 作为最后手段,回落到 Basic Render Driver AdapterWARP
// https://docs.microsoft.com/en-us/windows/win32/direct3darticles/directx-warp
HRESULT hr = _dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter));
if (FAILED(hr)) {
Logger::Get().ComError("EnumWarpAdapter 失败", hr);
return false;
}
for (UINT adapterIndex = 0;
SUCCEEDED(_dxgiFactory->EnumAdapters1(adapterIndex, adapter.put()));
++adapterIndex
) {
DXGI_ADAPTER_DESC1 desc;
HRESULT hr = adapter->GetDesc1(&desc);
if (FAILED(hr)) {
continue;
}
if (!_TryCreateD3DDevice(adapter)) {
Logger::Get().ComError("创建 WARP 设备失败", hr);
return false;
if ((desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) == 0) {
continue;
}
if (_TryCreateD3DDevice(adapter)) {
_adapterIdx = adapterIndex;
return true;
}
}
return true;

View file

@ -15,6 +15,7 @@ public:
ID3D11Device5* GetD3DDevice() const noexcept { return _d3dDevice.get(); }
ID3D11DeviceContext4* GetD3DDC() const noexcept { return _d3dDC.get(); }
IDXGIAdapter4* GetGraphicsAdapter() const noexcept { return _graphicsAdapter.get(); }
uint32_t GetAdapterIndex() const noexcept { return _adapterIdx; }
bool IsSupportTearing() const noexcept {
return _isSupportTearing;
@ -28,6 +29,7 @@ private:
winrt::com_ptr<IDXGIFactory7> _dxgiFactory;
winrt::com_ptr<IDXGIAdapter4> _graphicsAdapter;
uint32_t _adapterIdx = 0;
winrt::com_ptr<ID3D11Device5> _d3dDevice;
winrt::com_ptr<ID3D11DeviceContext4> _d3dDC;

View file

@ -0,0 +1,597 @@
#include "pch.h"
#include "DirectMLInferenceBackend.h"
#include "DeviceResources.h"
#include "DirectXHelper.h"
#include "shaders/TensorToTextureCS.h"
#include "shaders/TextureToTensorCS.h"
#include "Logger.h"
#include <onnxruntime/core/providers/dml/dml_provider_factory.h>
#include "Win32Utils.h"
namespace Magpie::Core {
static winrt::com_ptr<ID3D12Device> CreateD3D12Device(IDXGIAdapter4* adapter) noexcept {
#ifdef _DEBUG
// 启用 D3D12 调试层
{
winrt::com_ptr<ID3D12Debug> debugController;
HRESULT hr = D3D12GetDebugInterface(IID_PPV_ARGS(&debugController));
if (SUCCEEDED(hr)) {
debugController->EnableDebugLayer();
}
}
#endif
winrt::com_ptr<ID3D12Device> d3d12Device;
HRESULT hr = D3D12CreateDevice(
adapter,
D3D_FEATURE_LEVEL_11_0,
IID_PPV_ARGS(&d3d12Device)
);
if (FAILED(hr)) {
Logger::Get().ComError("D3D12CreateDevice 失败", hr);
return d3d12Device;
}
return d3d12Device;
}
static winrt::com_ptr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12Device) noexcept {
winrt::com_ptr<IDMLDevice> dmlDevice;
HRESULT hr = DMLCreateDevice1(
d3d12Device,
#ifdef _DEBUG
DML_CREATE_DEVICE_FLAG_DEBUG,
#else
DML_CREATE_DEVICE_FLAG_NONE,
#endif
// https://github.com/microsoft/onnxruntime/blob/554fb4ad1fcf808304d4758d73d93a8ecc362bf6/onnxruntime/core/providers/dml/dml_provider_factory.cc#L519
DML_FEATURE_LEVEL_5_0,
IID_PPV_ARGS(&dmlDevice)
);
if (FAILED(hr)) {
Logger::Get().ComError("DMLCreateDevice1 失败", hr);
return dmlDevice;
}
return dmlDevice;
}
static winrt::com_ptr<ID3D12Resource> ShareTextureWithD3D12(ID3D11Texture2D* texture, ID3D12Device* d3d12Device, DWORD access) noexcept {
winrt::com_ptr<ID3D12Resource> result;
winrt::com_ptr<IDXGIResource1> dxgiResource;
HRESULT hr = texture->QueryInterface<IDXGIResource1>(dxgiResource.put());
if (FAILED(hr)) {
Logger::Get().ComError("获取 IDXGIResource1 失败", hr);
return result;
}
wil::unique_handle sharedHandle;
hr = dxgiResource->CreateSharedHandle(nullptr, access, nullptr, sharedHandle.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateSharedHandle 失败", hr);
return result;
}
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&result));
if (FAILED(hr)) {
Logger::Get().ComError("OpenSharedHandle 失败", hr);
return result;
}
return result;
}
static winrt::com_ptr<IUnknown> AllocateD3D12Resource(const OrtDmlApi* ortDmlApi, ID3D12Resource* buffer) {
void* dmlResource;
Ort::ThrowOnError(ortDmlApi->CreateGPUAllocationFromD3DResource(buffer, &dmlResource));
winrt::com_ptr<IUnknown> allocatedBuffer;
allocatedBuffer.copy_from((IUnknown*)dmlResource);
Ort::ThrowOnError(ortDmlApi->FreeGPUAllocation(dmlResource));
return allocatedBuffer;
}
bool DirectMLInferenceBackend::Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& /*descriptorStore*/,
ID3D11Texture2D* input,
ID3D11Texture2D** output
) noexcept {
ID3D11Device5* d3d11Device = deviceResources.GetD3DDevice();
_d3d11DC = deviceResources.GetD3DDC();
const SIZE inputSize = DirectXHelper::GetTextureSize(input);
const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
// 创建输出纹理
_outputTex = DirectXHelper::CreateTexture2D(
d3d11Device,
DXGI_FORMAT_R8G8B8A8_UNORM,
outputSize.cx,
outputSize.cy,
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
);
if (!_outputTex) {
Logger::Get().Error("创建输出纹理失败");
return false;
}
*output = _outputTex.get();
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);
winrt::com_ptr<ID3D12Device> d3d12Device = CreateD3D12Device(deviceResources.GetGraphicsAdapter());
if (!d3d12Device) {
Logger::Get().Error("CreateD3D12Device 失败");
return false;
}
{
D3D12_COMMAND_QUEUE_DESC desc{
.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE,
.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT
};
HRESULT hr = d3d12Device->CreateCommandQueue(&desc, IID_PPV_ARGS(&_commandQueue));
if (FAILED(hr)) {
return false;
}
}
bool isFP16Data = false;
try {
const OrtApi& ortApi = Ort::GetApi();
_env = Ort::Env(ORT_LOGGING_LEVEL_INFO, "", _OrtLog, nullptr);
const OrtDmlApi* ortDmlApi = nullptr;
ortApi.GetExecutionProviderApi("DML", ORT_API_VERSION, (const void**)&ortDmlApi);
Ort::SessionOptions sessionOptions;
sessionOptions.SetIntraOpNumThreads(1);
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
sessionOptions.DisableMemPattern();
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
winrt::com_ptr<IDMLDevice> dmlDevice = CreateDMLDevice(d3d12Device.get());
if (!dmlDevice) {
Logger::Get().Error("CreateDMLDevice 失败");
return false;
}
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(
sessionOptions, dmlDevice.get(), _commandQueue.get()));
_session = Ort::Session(_env, modelPath, sessionOptions);
if (!_IsModelValid(_session, isFP16Data)) {
Logger::Get().Error("不支持此模型");
return false;
}
// 创建张量缓冲区
{
D3D12_HEAP_PROPERTIES heapDesc{
.Type = D3D12_HEAP_TYPE_DEFAULT
};
D3D12_RESOURCE_DESC resDesc{
.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER,
.Width = (inputElemCount * (isFP16Data ? 2 : 4) + 3) & ~3,
.Height = 1,
.DepthOrArraySize = 1,
.MipLevels = 1,
.SampleDesc{
.Count = 1
},
.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
};
HRESULT hr = d3d12Device->CreateCommittedResource(
&heapDesc,
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
&resDesc,
D3D12_RESOURCE_STATE_COMMON,
nullptr,
IID_PPV_ARGS(&_inputBuffer)
);
if (FAILED(hr)) {
return false;
}
resDesc.Width = UINT64((outputElemCount * (isFP16Data ? 2 : 4) + 3) & ~3);
hr = d3d12Device->CreateCommittedResource(
&heapDesc,
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
&resDesc,
D3D12_RESOURCE_STATE_COMMON,
nullptr,
IID_PPV_ARGS(&_outputBuffer)
);
if (FAILED(hr)) {
return false;
}
}
// 创建 IOBinding
_ioBinding = Ort::IoBinding(_session);
// DmlExecutionProvider 的 device_id 始终为 0传其他值会出错。
// 见 https://github.com/microsoft/onnxruntime/blob/89f8206ba4f1c22c39e0297fb55272e8ce8cd7d0/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp#L77
// WinML 也始终使用 0: https://github.com/microsoft/onnxruntime/blob/89f8206ba4f1c22c39e0297fb55272e8ce8cd7d0/winml/lib/Api.Ort/OnnxruntimeEngine.cpp#L654
Ort::MemoryInfo memoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
0,
OrtMemType::OrtMemTypeDefault
);
const ONNXTensorElementDataType dataType =
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
const int64_t inputShape[]{ 1,3,inputSize.cy,inputSize.cx };
_allocatedInput = AllocateD3D12Resource(ortDmlApi, _inputBuffer.get());
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
memoryInfo,
_allocatedInput.get(),
size_t(inputElemCount * (isFP16Data ? 2 : 4)),
inputShape,
std::size(inputShape),
dataType
));
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
_allocatedOutput = AllocateD3D12Resource(ortDmlApi, _outputBuffer.get());
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
memoryInfo,
_allocatedOutput.get(),
size_t(outputElemCount * (isFP16Data ? 2 : 4)),
outputShape,
std::size(outputShape),
dataType
));
} catch (const Ort::Exception& e) {
Logger::Get().Error(e.what());
return false;
}
if (!_CreateFence(d3d11Device, d3d12Device.get())) {
Logger::Get().Error("_CreateFence 失败");
return false;
}
_d3d12InputTex = ShareTextureWithD3D12(input, d3d12Device.get(), DXGI_SHARED_RESOURCE_READ);
_d3d12OutputTex = ShareTextureWithD3D12(_outputTex.get(), d3d12Device.get(),
DXGI_SHARED_RESOURCE_READ | DXGI_SHARED_RESOURCE_WRITE);
if (!_d3d12InputTex || !_d3d12OutputTex) {
Logger::Get().Error("ShareTextureWithD3D12 失败");
return false;
}
UINT descriptorSize;
if (!_CreateCBVHeap(d3d12Device.get(), inputElemCount, outputElemCount, isFP16Data, descriptorSize)) {
Logger::Get().Error("_CreateCBVHeap 失败");
return false;
}
if (!_CreatePipelineStates(d3d12Device.get())) {
Logger::Get().Error("_CreatePipelineStates 失败");
return false;
}
if (!_CalcCommandLists(d3d12Device.get(), inputSize, outputSize, descriptorSize)) {
Logger::Get().Error("_CalcCommandLists 失败");
return false;
}
return true;
}
void DirectMLInferenceBackend::Evaluate() noexcept {
HRESULT hr = _d3d11DC->Signal(_d3d11Fence.get(), ++_fenceValue);
if (FAILED(hr)) {
Logger::Get().ComError("Signal 失败", hr);
return;
}
_d3d11DC->Flush();
hr = _commandQueue->Wait(_d3d12Fence.get(), _fenceValue);
if (FAILED(hr)) {
Logger::Get().ComError("Wait 失败", hr);
return;
}
// 输入纹理 -> 输入张量
{
ID3D12CommandList* t = _tex2TensorCommandList.get();
_commandQueue->ExecuteCommandLists(1, &t);
}
try {
_session.Run(Ort::RunOptions{ nullptr }, _ioBinding);
} catch (const Ort::Exception& e) {
Logger::Get().Error(e.what());
return;
}
// 输出张量 -> 输出纹理
{
ID3D12CommandList* t = _tensor2TexCommandList.get();
_commandQueue->ExecuteCommandLists(1, &t);
}
hr = _commandQueue->Signal(_d3d12Fence.get(), ++_fenceValue);
if (FAILED(hr)) {
Logger::Get().ComError("Signal 失败", hr);
return;
}
hr = _d3d11DC->Wait(_d3d11Fence.get(), _fenceValue);
if (FAILED(hr)) {
Logger::Get().ComError("Wait 失败", hr);
return;
}
}
bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12Device* d3d12Device) noexcept {
HRESULT hr = d3d11Device->CreateFence(
_fenceValue, D3D11_FENCE_FLAG_SHARED, IID_PPV_ARGS(&_d3d11Fence));
if (FAILED(hr)) {
Logger::Get().ComError("CreateFence 失败", hr);
return false;
}
wil::unique_handle sharedHandle;
hr = _d3d11Fence->CreateSharedHandle(nullptr, GENERIC_ALL, nullptr, sharedHandle.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateSharedHandle 失败", hr);
return false;
}
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&_d3d12Fence));
if (FAILED(hr)) {
Logger::Get().ComError("OpenSharedHandle 失败", hr);
return false;
}
return true;
}
bool DirectMLInferenceBackend::_CreateCBVHeap(
ID3D12Device* d3d12Device,
uint32_t inputElemCount,
uint32_t outputElemCount,
bool isFP16Data,
UINT& descriptorSize
) noexcept {
{
D3D12_DESCRIPTOR_HEAP_DESC desc{
.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
.NumDescriptors = 4,
.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE
};
HRESULT hr = d3d12Device->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&_cbvHeap));
if (FAILED(hr)) {
Logger::Get().ComError("CreateDescriptorHeap 失败", hr);
return false;
}
}
descriptorSize = d3d12Device->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
D3D12_CPU_DESCRIPTOR_HANDLE cbvHandle = _cbvHeap->GetCPUDescriptorHandleForHeapStart();
d3d12Device->CreateShaderResourceView(_d3d12InputTex.get(), nullptr, cbvHandle);
cbvHandle.ptr += descriptorSize;
{
D3D12_UNORDERED_ACCESS_VIEW_DESC desc{
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
.Buffer{
.NumElements = inputElemCount
}
};
d3d12Device->CreateUnorderedAccessView(_inputBuffer.get(), nullptr, &desc, cbvHandle);
}
cbvHandle.ptr += descriptorSize;
{
D3D12_SHADER_RESOURCE_VIEW_DESC desc{
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
.Buffer{
.NumElements = outputElemCount
}
};
d3d12Device->CreateShaderResourceView(_outputBuffer.get(), &desc, cbvHandle);
}
cbvHandle.ptr += descriptorSize;
d3d12Device->CreateUnorderedAccessView(_d3d12OutputTex.get(), nullptr, nullptr, cbvHandle);
return true;
}
bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device) noexcept {
{
D3D12_DESCRIPTOR_RANGE ranges[]{
D3D12_DESCRIPTOR_RANGE{
.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV,
.NumDescriptors = 1,
.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND
},
D3D12_DESCRIPTOR_RANGE{
.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV,
.NumDescriptors = 1,
.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND
},
};
D3D12_ROOT_PARAMETER rootParam{
D3D12_ROOT_PARAMETER{
.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE,
.DescriptorTable{
.NumDescriptorRanges = (UINT)std::size(ranges),
.pDescriptorRanges = ranges
}
}
};
D3D12_STATIC_SAMPLER_DESC samDesc{
.Filter = D3D12_FILTER_MIN_MAG_MIP_POINT,
.AddressU = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
.AddressV = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
.AddressW = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
.ComparisonFunc = D3D12_COMPARISON_FUNC_NEVER
};
D3D12_VERSIONED_ROOT_SIGNATURE_DESC desc{
.Version = D3D_ROOT_SIGNATURE_VERSION_1_0,
.Desc_1_0{
.NumParameters = 1,
.pParameters = &rootParam,
.NumStaticSamplers = 1,
.pStaticSamplers = &samDesc
}
};
winrt::com_ptr<ID3DBlob> blob;
HRESULT hr = D3D12SerializeVersionedRootSignature(&desc, blob.put(), nullptr);
if (FAILED(hr)) {
Logger::Get().ComError("D3D12SerializeVersionedRootSignature 失败", hr);
return false;
}
hr = d3d12Device->CreateRootSignature(
0,
blob->GetBufferPointer(),
blob->GetBufferSize(),
IID_PPV_ARGS(&_rootSignature)
);
if (FAILED(hr)) {
Logger::Get().ComError("CreateRootSignature 失败", hr);
return false;
}
}
D3D12_COMPUTE_PIPELINE_STATE_DESC desc{
.pRootSignature = _rootSignature.get(),
.CS{
.pShaderBytecode = TextureToTensorCS,
.BytecodeLength = std::size(TextureToTensorCS)
}
};
HRESULT hr = d3d12Device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&_tex2TensorPipelineState));
if (FAILED(hr)) {
Logger::Get().ComError("CreateComputePipelineState 失败", hr);
return false;
}
desc.CS.pShaderBytecode = TensorToTextureCS;
desc.CS.BytecodeLength = std::size(TensorToTextureCS);
hr = d3d12Device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&_tensor2TexPipelineState));
if (FAILED(hr)) {
Logger::Get().ComError("CreateComputePipelineState 失败", hr);
return false;
}
return true;
}
bool DirectMLInferenceBackend::_CalcCommandLists(
ID3D12Device* d3d12Device,
SIZE inputSize,
SIZE outputSize,
UINT descriptorSize
) noexcept {
winrt::com_ptr<ID3D12CommandAllocator> d3d12CommandAllocator;
HRESULT hr = d3d12Device->CreateCommandAllocator(
D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&d3d12CommandAllocator));
if (FAILED(hr)) {
Logger::Get().ComError("CreateCommandAllocator 失败", hr);
return false;
}
// 输入纹理 -> 输入张量
hr = d3d12Device->CreateCommandList(
0,
D3D12_COMMAND_LIST_TYPE_COMPUTE,
d3d12CommandAllocator.get(),
_tex2TensorPipelineState.get(),
IID_PPV_ARGS(&_tex2TensorCommandList)
);
if (FAILED(hr)) {
Logger::Get().ComError("CreateCommandList 失败", hr);
return false;
}
_tex2TensorCommandList->SetComputeRootSignature(_rootSignature.get());
{
ID3D12DescriptorHeap* t = _cbvHeap.get();
_tex2TensorCommandList->SetDescriptorHeaps(1, &t);
}
_tex2TensorCommandList->SetComputeRootDescriptorTable(0, _cbvHeap->GetGPUDescriptorHandleForHeapStart());
static constexpr std::pair<uint32_t, uint32_t> TEX_TO_TENSOR_BLOCK_SIZE{ 16, 16 };
_tex2TensorCommandList->Dispatch(
(inputSize.cx + TEX_TO_TENSOR_BLOCK_SIZE.first - 1) / TEX_TO_TENSOR_BLOCK_SIZE.first,
(inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second,
1
);
hr = _tex2TensorCommandList->Close();
if (FAILED(hr)) {
Logger::Get().ComError("Close 失败", hr);
return false;
}
// 输出张量 -> 输出纹理
hr = d3d12Device->CreateCommandList(
0,
D3D12_COMMAND_LIST_TYPE_COMPUTE,
d3d12CommandAllocator.get(),
_tensor2TexPipelineState.get(),
IID_PPV_ARGS(&_tensor2TexCommandList)
);
if (FAILED(hr)) {
Logger::Get().ComError("CreateCommandList 失败", hr);
return false;
}
_tensor2TexCommandList->SetComputeRootSignature(_rootSignature.get());
{
ID3D12DescriptorHeap* t = _cbvHeap.get();
_tensor2TexCommandList->SetDescriptorHeaps(1, &t);
}
{
D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = _cbvHeap->GetGPUDescriptorHandleForHeapStart();
gpuHandle.ptr += 2 * static_cast<UINT64>(descriptorSize);
_tensor2TexCommandList->SetComputeRootDescriptorTable(0, gpuHandle);
}
static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
_tensor2TexCommandList->Dispatch(
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
1
);
hr = _tensor2TexCommandList->Close();
if (FAILED(hr)) {
Logger::Get().ComError("Close 失败", hr);
return false;
}
return true;
}
}

View file

@ -0,0 +1,77 @@
#pragma once
#include "InferenceBackendBase.h"
#include <d3d12.h>
struct OrtDmlApi;
namespace Magpie::Core {
class DirectMLInferenceBackend : public InferenceBackendBase {
public:
DirectMLInferenceBackend() = default;
DirectMLInferenceBackend(const DirectMLInferenceBackend&) = delete;
DirectMLInferenceBackend(DirectMLInferenceBackend&&) = default;
bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
ID3D11Texture2D** output
) noexcept override;
void Evaluate() noexcept override;
private:
bool _CreateFence(ID3D11Device5* d3d11Device, ID3D12Device* d3d12Device) noexcept;
bool _CreateCBVHeap(
ID3D12Device* d3d12Device,
uint32_t inputElemCount,
uint32_t outputElemCount,
bool isFP16Data,
UINT& descriptorSize
) noexcept;
bool _CreatePipelineStates(ID3D12Device* d3d12Device) noexcept;
bool _CalcCommandLists(
ID3D12Device* d3d12Device,
SIZE inputSize,
SIZE outputSize,
UINT descriptorSize
) noexcept;
ID3D11DeviceContext4* _d3d11DC = nullptr;
winrt::com_ptr<ID3D11Texture2D> _outputTex;
winrt::com_ptr<ID3D11Fence> _d3d11Fence;
winrt::com_ptr<ID3D12Fence> _d3d12Fence;
UINT64 _fenceValue = 0;
winrt::com_ptr<ID3D12Resource> _d3d12InputTex;
winrt::com_ptr<ID3D12Resource> _d3d12OutputTex;
winrt::com_ptr<ID3D12Resource> _inputBuffer;
winrt::com_ptr<ID3D12Resource> _outputBuffer;
winrt::com_ptr<ID3D12DescriptorHeap> _cbvHeap;
winrt::com_ptr<ID3D12RootSignature> _rootSignature;
winrt::com_ptr<ID3D12PipelineState> _tex2TensorPipelineState;
winrt::com_ptr<ID3D12PipelineState> _tensor2TexPipelineState;
winrt::com_ptr<ID3D12CommandQueue> _commandQueue;
winrt::com_ptr<ID3D12GraphicsCommandList> _tex2TensorCommandList;
winrt::com_ptr<ID3D12GraphicsCommandList> _tensor2TexCommandList;
Ort::Env _env{ nullptr };
Ort::Session _session{ nullptr };
winrt::com_ptr<IUnknown> _allocatedInput;
winrt::com_ptr<IUnknown> _allocatedOutput;
Ort::IoBinding _ioBinding{ nullptr };
};
}

View file

@ -107,4 +107,10 @@ winrt::com_ptr<ID3D11Texture2D> DirectXHelper::CreateTexture2D(
return result;
}
SIZE DirectXHelper::GetTextureSize(ID3D11Texture2D* texture) noexcept {
D3D11_TEXTURE2D_DESC desc;
texture->GetDesc(&desc);
return SIZE{ (LONG)desc.Width, (LONG)desc.Height };
}
}

View file

@ -25,6 +25,8 @@ struct DirectXHelper {
UINT miscFlags = 0,
const D3D11_SUBRESOURCE_DATA* pInitialData = nullptr
) noexcept;
static SIZE GetTextureSize(ID3D11Texture2D* texture) noexcept;
};
}

View file

@ -87,7 +87,9 @@ bool DwmSharedSurfaceFrameSource::_Initialize() noexcept {
DXGI_FORMAT_B8G8R8A8_UNORM,
frameRect.right - frameRect.left,
frameRect.bottom - frameRect.top,
D3D11_BIND_SHADER_RESOURCE
D3D11_BIND_SHADER_RESOURCE,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
);
if (!_output) {
Logger::Get().Error("CreateTexture2D 失败");

View file

@ -7,6 +7,7 @@
#include <d3dcompiler.h>
#include "Utils.h"
#include "YasHelper.h"
#include "HashHelper.h"
namespace yas::detail {
@ -235,27 +236,6 @@ void EffectCacheManager::Save(std::wstring_view effectName, std::wstring_view ha
Logger::Get().Info(StrUtils::Concat("已保存缓存 ", StrUtils::UTF16ToUTF8(cacheFileName)));
}
static std::wstring HexHash(std::span<const BYTE> data) {
uint64_t hashBytes = Utils::HashData(data);
static wchar_t oct2Hex[16] = {
L'0',L'1',L'2',L'3',L'4',L'5',L'6',L'7',
L'8',L'9',L'a',L'b',L'c',L'd',L'e',L'f'
};
std::wstring result(16, 0);
wchar_t* pResult = &result[0];
BYTE* b = (BYTE*)&hashBytes;
for (int i = 0; i < 8; ++i) {
*pResult++ = oct2Hex[(*b >> 4) & 0xf];
*pResult++ = oct2Hex[*b & 0xf];
++b;
}
return result;
}
std::wstring EffectCacheManager::GetHash(
std::string_view source,
const phmap::flat_hash_map<std::wstring, float>* inlineParams
@ -271,7 +251,7 @@ std::wstring EffectCacheManager::GetHash(
}
}
return HexHash(std::span((const BYTE*)source.data(), source.size()));
return HashHelper::HexHash(std::span((const BYTE*)str.data(), str.size()));
}
std::wstring EffectCacheManager::GetHash(std::string& source, const phmap::flat_hash_map<std::wstring, float>* inlineParams) {
@ -286,7 +266,7 @@ std::wstring EffectCacheManager::GetHash(std::string& source, const phmap::flat_
}
}
std::wstring result = HexHash(std::span((const BYTE*)source.data(), source.size()));
std::wstring result = HashHelper::HexHash(std::span((const BYTE*)source.data(), source.size()));
source.resize(originSize);
return result;
}

View file

@ -1376,9 +1376,8 @@ static uint32_t CompilePasses(
std::wstring sourcesPath = sourcesPathName.substr(0, sourcesPathName.find_last_of(L'\\'));
if ((flags & EffectCompilerFlags::SaveSources) && !Win32Utils::DirExists(sourcesPath.c_str())) {
HRESULT hr = wil::CreateDirectoryDeepNoThrow(sourcesPath.c_str());
if (FAILED(hr)) {
Logger::Get().ComError("创建 sources 文件夹失败", hr);
if (!Win32Utils::CreateDir(sourcesPath, true)) {
Logger::Get().Win32Error("创建 sources 文件夹失败");
}
}

View file

@ -92,12 +92,7 @@ bool EffectDrawer::Initialize(
) noexcept {
_d3dDC = deviceResources.GetD3DDC();
SIZE inputSize{};
{
D3D11_TEXTURE2D_DESC inputDesc;
(*inOutTexture)->GetDesc(&inputDesc);
inputSize = { (LONG)inputDesc.Width, (LONG)inputDesc.Height };
}
const SIZE inputSize = DirectXHelper::GetTextureSize(*inOutTexture);
static mu::Parser exprParser;
exprParser.DefineConst("INPUT_WIDTH", inputSize.cx);
@ -165,7 +160,7 @@ bool EffectDrawer::Initialize(
if (texDesc.format != EffectIntermediateTextureFormat::UNKNOWN) {
// 检查纹理格式是否匹配
D3D11_TEXTURE2D_DESC srcDesc{};
D3D11_TEXTURE2D_DESC srcDesc;
_textures[i]->GetDesc(&srcDesc);
if (srcDesc.Format != EffectHelper::FORMAT_DESCS[(uint32_t)texDesc.format].dxgiFormat) {
Logger::Get().Error("SOURCE 纹理格式不匹配");
@ -235,11 +230,10 @@ bool EffectDrawer::Initialize(
}
}
D3D11_TEXTURE2D_DESC outputDesc;
_textures[passDesc.outputs[0]]->GetDesc(&outputDesc);
SIZE passOutputSize = DirectXHelper::GetTextureSize(_textures[passDesc.outputs[0]].get());
_dispatches.emplace_back(
(outputDesc.Width + passDesc.blockSize.first - 1) / passDesc.blockSize.first,
(outputDesc.Height + passDesc.blockSize.second - 1) / passDesc.blockSize.second
(passOutputSize.cx + passDesc.blockSize.first - 1) / passDesc.blockSize.first,
(passOutputSize.cy + passDesc.blockSize.second - 1) / passDesc.blockSize.second
);
}
@ -293,7 +287,7 @@ bool EffectDrawer::_InitializeConstants(
psStylePassParams += 4;
}
}
_constants.resize((builtinConstantCount + psStylePassParams + (isInlineParams ? 0 : desc.params.size()) + 3) / 4 * 4);
_constants.resize((builtinConstantCount + psStylePassParams + (isInlineParams ? 0 : desc.params.size()) + 3) & ~3);
// cbuffer __CB1 : register(b0) {
// uint2 __inputSize;
// uint2 __outputSize;
@ -318,15 +312,14 @@ bool EffectDrawer::_InitializeConstants(
if (psStylePassParams > 0) {
for (UINT i = 0, end = (UINT)desc.passes.size() - 1; i < end; ++i) {
if (desc.passes[i].isPSStyle) {
D3D11_TEXTURE2D_DESC outputDesc;
_textures[desc.passes[i].outputs[0]]->GetDesc(&outputDesc);
pCurParam->uintVal = outputDesc.Width;
SIZE passOutputSize = DirectXHelper::GetTextureSize(_textures[desc.passes[i].outputs[0]].get());
pCurParam->uintVal = passOutputSize.cx;
++pCurParam;
pCurParam->uintVal = outputDesc.Height;
pCurParam->uintVal = passOutputSize.cy;
++pCurParam;
pCurParam->floatVal = 1.0f / outputDesc.Width;
pCurParam->floatVal = 1.0f / passOutputSize.cx;
++pCurParam;
pCurParam->floatVal = 1.0f / outputDesc.Height;
pCurParam->floatVal = 1.0f / passOutputSize.cy;
++pCurParam;
}
}

View file

@ -58,7 +58,7 @@ bool GDIFrameSource::_Initialize() noexcept {
_frameRect.bottom - _frameRect.top,
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_GDI_COMPATIBLE
D3D11_RESOURCE_MISC_GDI_COMPATIBLE | D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
);
if (!_output) {
Logger::Get().Error("创建纹理失败");

View file

@ -75,7 +75,9 @@ bool GraphicsCaptureFrameSource::_Initialize() noexcept {
DXGI_FORMAT_B8G8R8A8_UNORM,
_frameBox.right - _frameBox.left,
_frameBox.bottom - _frameBox.top,
D3D11_BIND_SHADER_RESOURCE
D3D11_BIND_SHADER_RESOURCE,
D3D11_USAGE_DEFAULT,
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
);
if (!_output) {
Logger::Get().Error("创建纹理失败");

View file

@ -0,0 +1,28 @@
#include "pch.h"
#include "HashHelper.h"
#include "Utils.h"
namespace Magpie::Core {
std::wstring HashHelper::HexHash(std::span<const uint8_t> data) noexcept {
uint64_t hashBytes = Utils::HashData(data);
static wchar_t oct2Hex[16] = {
L'0',L'1',L'2',L'3',L'4',L'5',L'6',L'7',
L'8',L'9',L'a',L'b',L'c',L'd',L'e',L'f'
};
std::wstring result(16, 0);
wchar_t* pResult = &result[0];
BYTE* b = (BYTE*)&hashBytes;
for (int i = 0; i < 8; ++i) {
*pResult++ = oct2Hex[(*b >> 4) & 0xf];
*pResult++ = oct2Hex[*b & 0xf];
++b;
}
return result;
}
}

View file

@ -0,0 +1,9 @@
#pragma once
namespace Magpie::Core {
struct HashHelper {
static std::wstring HexHash(std::span<const uint8_t> data) noexcept;
};
}

View file

@ -0,0 +1,80 @@
#include "pch.h"
#include "InferenceBackendBase.h"
#include "StrUtils.h"
#include "Logger.h"
namespace Magpie::Core {
void ORT_API_CALL InferenceBackendBase::_OrtLog(
void* /*param*/,
OrtLoggingLevel severity,
const char* /*category*/,
const char* /*logid*/,
const char* /*code_location*/,
const char* message
) {
const char* SEVERITIES[] = {
"verbose",
"info",
"warning",
"error",
"fatal"
};
std::string log = StrUtils::Concat("[", SEVERITIES[severity], "] ", message);
if (severity == ORT_LOGGING_LEVEL_INFO) {
Logger::Get().Info(log);
OutputDebugStringA((log + "\n").c_str());
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
Logger::Get().Warn(log);
} else {
Logger::Get().Error(log);
}
}
static bool IsTensorShapeValid(const Ort::ConstTensorTypeAndShapeInfo& tensorInfo) {
// 输入输出维度应是 [-1,3,-1,-1]
std::vector<int64_t> dimensions = tensorInfo.GetShape();
return dimensions.size() == 4 && dimensions[0] == -1 &&
dimensions[1] == 3 && dimensions[2] == -1 && dimensions[3] == -1;
}
bool InferenceBackendBase::_IsModelValid(const Ort::Session& session, bool& isFP16Data) {
if (session.GetInputCount() != 1 || session.GetOutputCount() != 1) {
Logger::Get().Error("不支持有多个输入/输出的模型");
return false;
}
// 必须在 inputTypeInfo 的生命周期内使用 inputTensorInfo
Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0);
Ort::ConstTensorTypeAndShapeInfo inputTensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType dataType = inputTensorInfo.GetElementType();
if (dataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && dataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
Logger::Get().Error("不支持 float16 和 float 之外的输入数据类型");
return false;
}
isFP16Data = dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
Ort::TypeInfo outputInfo = session.GetOutputTypeInfo(0);
Ort::ConstTensorTypeAndShapeInfo outputTensorInfo = outputInfo.GetTensorTypeAndShapeInfo();
if (outputInfo.GetTensorTypeAndShapeInfo().GetElementType() != dataType) {
Logger::Get().Error("不支持输入和输出数据类型不同的模型");
return false;
}
if (!IsTensorShapeValid(inputTensorInfo)) {
Logger::Get().Error("不支持的输入数据格式");
return false;
}
if (!IsTensorShapeValid(outputTensorInfo)) {
Logger::Get().Error("不支持的输出数据格式");
return false;
}
return true;
}
}

View file

@ -0,0 +1,41 @@
#pragma once
#include <onnxruntime_cxx_api.h>
namespace Magpie::Core {
class DeviceResources;
class BackendDescriptorStore;
class InferenceBackendBase {
public:
InferenceBackendBase() = default;
virtual ~InferenceBackendBase() noexcept {}
InferenceBackendBase(const InferenceBackendBase&) = delete;
InferenceBackendBase(InferenceBackendBase&&) = default;
virtual bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
ID3D11Texture2D** output
) noexcept = 0;
virtual void Evaluate() noexcept = 0;
protected:
static void ORT_API_CALL _OrtLog(
void* /*param*/,
OrtLoggingLevel severity,
const char* /*category*/,
const char* /*logid*/,
const char* /*code_location*/,
const char* message
);
static bool _IsModelValid(const Ort::Session& session, bool& isFP16Data);
};
}

View file

@ -43,12 +43,14 @@
</ItemDefinitionGroup>
<ItemGroup>
<ClInclude Include="BackendDescriptorStore.h" />
<ClInclude Include="TensorRTInferenceBackend.h" />
<ClInclude Include="CursorManager.h" />
<ClInclude Include="CursorDrawer.h" />
<ClInclude Include="DDS.h" />
<ClInclude Include="DDSLoderHelpers.h" />
<ClInclude Include="DesktopDuplicationFrameSource.h" />
<ClInclude Include="DeviceResources.h" />
<ClInclude Include="DirectMLInferenceBackend.h" />
<ClInclude Include="DirectXHelper.h" />
<ClInclude Include="DwmSharedSurfaceFrameSource.h" />
<ClInclude Include="EffectCacheManager.h" />
@ -61,11 +63,15 @@
<ClInclude Include="FrameSourceBase.h" />
<ClInclude Include="GDIFrameSource.h" />
<ClInclude Include="GraphicsCaptureFrameSource.h" />
<ClInclude Include="HashHelper.h" />
<ClInclude Include="ImGuiBackend.h" />
<ClInclude Include="ImGuiFontsCacheManager.h" />
<ClInclude Include="ImGuiHelper.h" />
<ClInclude Include="ImGuiImpl.h" />
<ClInclude Include="include\Magpie.Core.h" />
<ClInclude Include="InferenceBackendBase.h" />
<ClInclude Include="OnnxEffectDrawer.h" />
<ClInclude Include="OnnxHelper.h" />
<ClInclude Include="OverlayDrawer.h" />
<ClInclude Include="Renderer.h" />
<ClInclude Include="ScalingOptions.h" />
@ -80,10 +86,12 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="BackendDescriptorStore.cpp" />
<ClCompile Include="TensorRTInferenceBackend.cpp" />
<ClCompile Include="CursorManager.cpp" />
<ClCompile Include="CursorDrawer.cpp" />
<ClCompile Include="DesktopDuplicationFrameSource.cpp" />
<ClCompile Include="DeviceResources.cpp" />
<ClCompile Include="DirectMLInferenceBackend.cpp" />
<ClCompile Include="DirectXHelper.cpp" />
<ClCompile Include="DwmSharedSurfaceFrameSource.cpp" />
<ClCompile Include="EffectCacheManager.cpp" />
@ -94,10 +102,13 @@
<ClCompile Include="FrameSourceBase.cpp" />
<ClCompile Include="GDIFrameSource.cpp" />
<ClCompile Include="GraphicsCaptureFrameSource.cpp" />
<ClCompile Include="HashHelper.cpp" />
<ClCompile Include="ImGuiBackend.cpp" />
<ClCompile Include="ImGuiFontsCacheManager.cpp" />
<ClCompile Include="ImGuiHelper.cpp" />
<ClCompile Include="ImGuiImpl.cpp" />
<ClCompile Include="InferenceBackendBase.cpp" />
<ClCompile Include="OnnxEffectDrawer.cpp" />
<ClCompile Include="OverlayDrawer.cpp" />
<ClCompile Include="pch.cpp">
<PrecompiledHeader>Create</PrecompiledHeader>
@ -111,6 +122,9 @@
<ClCompile Include="WindowHelper.cpp" />
</ItemGroup>
<ItemGroup>
<FxCompile Include="shaders\TensorToTextureCS.hlsl">
<ShaderType>Compute</ShaderType>
</FxCompile>
<FxCompile Include="shaders\DuplicateFrameCS.hlsl">
<ShaderType>Compute</ShaderType>
</FxCompile>
@ -132,21 +146,24 @@
<FxCompile Include="shaders\SimpleVS.hlsl">
<ShaderType>Vertex</ShaderType>
</FxCompile>
<FxCompile Include="shaders\TextureToTensorCS.hlsl">
<ShaderType>Compute</ShaderType>
</FxCompile>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
</Target>
</Project>

View file

@ -19,6 +19,9 @@
<Filter Include="Shaders">
<UniqueIdentifier>{1956ae10-07ad-4b77-a37f-25f7fe10654b}</UniqueIdentifier>
</Filter>
<Filter Include="ONNX">
<UniqueIdentifier>{c5acb0d2-df90-4589-8914-2bfff00194ec}</UniqueIdentifier>
</Filter>
</ItemGroup>
<ItemGroup>
<ClInclude Include="pch.h" />
@ -91,7 +94,25 @@
<ClInclude Include="DesktopDuplicationFrameSource.h">
<Filter>Capture</Filter>
</ClInclude>
<ClInclude Include="DirectMLInferenceBackend.h">
<Filter>ONNX</Filter>
</ClInclude>
<ClInclude Include="OnnxEffectDrawer.h">
<Filter>ONNX</Filter>
</ClInclude>
<ClInclude Include="HashHelper.h">
<Filter>Helpers</Filter>
</ClInclude>
<ClInclude Include="InferenceBackendBase.h">
<Filter>ONNX</Filter>
</ClInclude>
<ClInclude Include="TensorRTInferenceBackend.h">
<Filter>ONNX</Filter>
</ClInclude>
<ClInclude Include="ExclModeHelper.h" />
<ClInclude Include="OnnxHelper.h">
<Filter>Helpers</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="ScalingRuntime.cpp" />
@ -146,6 +167,21 @@
<ClCompile Include="DesktopDuplicationFrameSource.cpp">
<Filter>Capture</Filter>
</ClCompile>
<ClCompile Include="DirectMLInferenceBackend.cpp">
<Filter>ONNX</Filter>
</ClCompile>
<ClCompile Include="OnnxEffectDrawer.cpp">
<Filter>ONNX</Filter>
</ClCompile>
<ClCompile Include="HashHelper.cpp">
<Filter>Helpers</Filter>
</ClCompile>
<ClCompile Include="InferenceBackendBase.cpp">
<Filter>ONNX</Filter>
</ClCompile>
<ClCompile Include="TensorRTInferenceBackend.cpp">
<Filter>ONNX</Filter>
</ClCompile>
<ClCompile Include="ExclModeHelper.cpp" />
<ClCompile Include="ScalingOptions.cpp" />
</ItemGroup>
@ -171,6 +207,12 @@
<FxCompile Include="shaders\ImGuiImplPS.hlsl">
<Filter>Shaders</Filter>
</FxCompile>
<FxCompile Include="shaders\TextureToTensorCS.hlsl">
<Filter>Shaders</Filter>
</FxCompile>
<FxCompile Include="shaders\TensorToTextureCS.hlsl">
<Filter>Shaders</Filter>
</FxCompile>
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />

View file

@ -0,0 +1,123 @@
#include "pch.h"
#include "OnnxEffectDrawer.h"
#include "Logger.h"
#include "DirectMLInferenceBackend.h"
#include "TensorRTInferenceBackend.h"
#include "Win32Utils.h"
#include <rapidjson/document.h>
#include "StrUtils.h"
namespace Magpie::Core {
OnnxEffectDrawer::OnnxEffectDrawer() {}
OnnxEffectDrawer::~OnnxEffectDrawer() {}
static bool ReadJson(
const rapidjson::Document& doc,
std::string& modelPath,
uint32_t& scale,
std::string& backend
) noexcept {
if (!doc.IsObject()) {
Logger::Get().Error("根元素不是 Object");
return false;
}
auto root = ((const rapidjson::Document&)doc).GetObj();
{
auto node = root.FindMember("path");
if (node == root.MemberEnd() || !node->value.IsString()) {
Logger::Get().Error("解析 path 失败");
return false;
}
modelPath = node->value.GetString();
}
{
auto node = root.FindMember("scale");
if (node == root.MemberEnd() || !node->value.IsUint()) {
Logger::Get().Error("解析 scale 失败");
return false;
}
scale = node->value.GetUint();
}
{
auto node = root.FindMember("backend");
if (node == root.MemberEnd() || !node->value.IsString()) {
Logger::Get().Error("解析 backend 失败");
return false;
}
backend = node->value.GetString();
}
return true;
}
bool OnnxEffectDrawer::Initialize(
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D** inOutTexture
) noexcept {
const wchar_t* jsonPath = L"model.json";
if (!Win32Utils::FileExists(jsonPath)) {
return true;
}
std::string json;
if (!Win32Utils::ReadTextFile(jsonPath, json)) {
Logger::Get().Error("Win32Utils::ReadTextFile 失败");
return false;
}
std::string modelPath;
uint32_t scale = 1;
std::string backend;
{
rapidjson::Document doc;
doc.ParseInsitu(json.data());
if (doc.HasParseError()) {
Logger::Get().Error("解析 json 失败");
return false;
}
if (!ReadJson(doc, modelPath, scale, backend)) {
Logger::Get().Error("ReadJson 失败");
return false;
}
}
StrUtils::ToLowerCase(backend);
if (backend == "directml" || backend == "dml" || backend == "d") {
_inferenceBackend = std::make_unique<DirectMLInferenceBackend>();
}
#if _M_X64
else if (backend == "tensorrt" || backend == "trt" || backend == "t") {
_inferenceBackend = std::make_unique<TensorRTInferenceBackend>();
}
#endif
else {
Logger::Get().Error("未知 backend");
return false;
}
std::wstring modelPathW = StrUtils::UTF8ToUTF16(modelPath);
if (!_inferenceBackend->Initialize(modelPathW.c_str(), scale, deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
return false;
}
return true;
}
void OnnxEffectDrawer::Draw(EffectsProfiler& /*profiler*/) const noexcept {
if (_inferenceBackend) {
_inferenceBackend->Evaluate();
}
}
}

View file

@ -0,0 +1,30 @@
#pragma once
namespace Magpie::Core {
class DeviceResources;
class EffectsProfiler;
class InferenceBackendBase;
class BackendDescriptorStore;
class OnnxEffectDrawer {
public:
OnnxEffectDrawer();
OnnxEffectDrawer(const OnnxEffectDrawer&) = delete;
OnnxEffectDrawer(OnnxEffectDrawer&&) = default;
~OnnxEffectDrawer();
bool Initialize(
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D** inOutTexture
) noexcept;
void Draw(EffectsProfiler& profiler) const noexcept;
private:
std::unique_ptr<InferenceBackendBase> _inferenceBackend;
};
}

View file

@ -0,0 +1,25 @@
#pragma once
#include "pch.h"
#include <onnxruntime_cxx_api.h>
namespace Magpie::Core {
struct OnnxHelper {
private:
static void _CloseCUDAProviderOptions(OrtCUDAProviderOptionsV2* options) {
Ort::GetApi().ReleaseCUDAProviderOptions(options);
}
static void _CloseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2* options) {
Ort::GetApi().ReleaseTensorRTProviderOptions(options);
}
public:
using unique_cuda_provider_options = wil::unique_any<OrtCUDAProviderOptionsV2*,
decltype(_CloseCUDAProviderOptions), _CloseCUDAProviderOptions>;
using unique_tensorrt_provider_options = wil::unique_any<OrtTensorRTProviderOptionsV2*,
decltype(_CloseTensorRTProviderOptions), _CloseTensorRTProviderOptions>;
};
}

View file

@ -511,9 +511,14 @@ ID3D11Texture2D* Renderer::_BuildEffects() noexcept {
Logger::Get().Info(fmt::format("编译着色器总计用时 {} 毫秒", duration / 1000.0f));
}
ID3D11Texture2D* inOutTexture = _frameSource->GetOutput();
if (!_onnxEffectDrawer.Initialize(_backendResources, _backendDescriptorStore, &inOutTexture)) {
return nullptr;
}
_effectDrawers.resize(effects.size());
ID3D11Texture2D* inOutTexture = _frameSource->GetOutput();
for (uint32_t i = 0; i < effectCount; ++i) {
if (!_effectDrawers[i].Initialize(
effectDescs[i],
@ -688,7 +693,10 @@ void Renderer::_BackendThreadProc() noexcept {
waitingForStepTimer = false;
}
const FrameSourceBase::UpdateState state = _frameSource->Update();
FrameSourceBase::UpdateState state = _frameSource->Update();
if (ScalingWindow::Get().Options().IsBenchmarkMode()) {
state = FrameSourceBase::UpdateState::NewFrame;
}
_stepTimer.UpdateFPS(state == FrameSourceBase::UpdateState::NewFrame);
switch (state) {
@ -815,6 +823,8 @@ void Renderer::_BackendRender(ID3D11Texture2D* effectsOutput) noexcept {
_effectsProfiler.OnBeginEffects(d3dDC);
_onnxEffectDrawer.Draw(_effectsProfiler);
for (const EffectDrawer& effectDrawer : _effectDrawers) {
effectDrawer.Draw(_effectsProfiler);
}

View file

@ -2,6 +2,7 @@
#include "DeviceResources.h"
#include "BackendDescriptorStore.h"
#include "EffectDrawer.h"
#include "OnnxEffectDrawer.h"
#include "Win32Utils.h"
#include "CursorDrawer.h"
#include "StepTimer.h"
@ -101,6 +102,7 @@ private:
Magpie::Core::BackendDescriptorStore _backendDescriptorStore;
std::unique_ptr<FrameSourceBase> _frameSource;
std::vector<EffectDrawer> _effectDrawers;
OnnxEffectDrawer _onnxEffectDrawer;
StepTimer _stepTimer;
EffectsProfiler _effectsProfiler;

View file

@ -47,6 +47,7 @@ struct ScalingFlags {
// Magpie.Core 不负责启动 TouchHelper.exe指定此标志会使 Magpie.Core 创建辅助窗口以拦截
// 黑边上的触控输入
static constexpr uint32_t IsTouchSupportEnabled = 1 << 17;
static constexpr uint32_t BenchmarkMode = 1 << 18;
};
enum class ScalingType {
@ -83,6 +84,7 @@ enum class DuplicateFrameDetectionMode {
struct ScalingOptions {
DEFINE_FLAG_ACCESSOR(IsWindowResizingDisabled, ScalingFlags::DisableWindowResizing, flags)
DEFINE_FLAG_ACCESSOR(IsDebugMode, ScalingFlags::BreakpointMode, flags)
DEFINE_FLAG_ACCESSOR(IsBenchmarkMode, ScalingFlags::BenchmarkMode, flags)
DEFINE_FLAG_ACCESSOR(IsEffectCacheDisabled, ScalingFlags::DisableEffectCache, flags)
DEFINE_FLAG_ACCESSOR(IsFontCacheDisabled, ScalingFlags::DisableFontCache, flags)
DEFINE_FLAG_ACCESSOR(IsSaveEffectSources, ScalingFlags::SaveEffectSources, flags)

View file

@ -0,0 +1,600 @@
#include "pch.h"
#include "TensorRTInferenceBackend.h"
#ifdef _M_X64
#include "DeviceResources.h"
#include <cuda/cuda_d3d11_interop.h>
#include "shaders/TextureToTensorCS.h"
#include "shaders/TensorToTextureCS.h"
#include "BackendDescriptorStore.h"
#include "Logger.h"
#include "DirectXHelper.h"
#include "Utils.h"
#include "OnnxHelper.h"
#include "HashHelper.h"
#include "Win32Utils.h"
#include "StrUtils.h"
#include "CommonSharedConstants.h"
namespace Magpie::Core {
static void LogCudaError(std::string_view msg, cudaError_t cudaResult) noexcept {
Logger::Get().Error(fmt::format("{}\n\tCUDA error code: {}", msg, (int)cudaResult));
}
static bool CheckComputeCapability(int deviceId) noexcept {
int major, minor;
cudaError_t cudaResult = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, deviceId);
if (cudaResult != cudaError_t::cudaSuccess) {
Logger::Get().Error("cudaDeviceGetAttribute 失败");
return false;
}
cudaResult = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, deviceId);
if (cudaResult != cudaError_t::cudaSuccess) {
Logger::Get().Error("cudaDeviceGetAttribute 失败");
return false;
}
Logger::Get().Info(fmt::format("当前设备 Compute Capability: {}.{}", major, minor));
// TensorRT 要求 Compute Capability 至少为 7.5
// https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html
if (std::make_pair(major, minor) < std::make_pair(7, 5)) {
Logger::Get().Error("当前设备无法使用 TensorRT");
return false;
}
return true;
}
static std::wstring GetCacheDir(
const std::vector<uint8_t>& modelData,
IDXGIAdapter4* adapter,
std::pair<uint16_t, uint16_t> minShapes,
std::pair<uint16_t, uint16_t> maxShapes,
std::pair<uint16_t, uint16_t> optShapes,
uint8_t optimizationLevel,
bool enableFP16
) noexcept {
DXGI_ADAPTER_DESC desc;
adapter->GetDesc(&desc);
// TensorRT 缓存和多种因素绑定,这里考虑的因素有:
// * 模型哈希
// * ONNX Runtime 版本
// * TensorRT 版本
// * 显卡型号 (替代 Compute Capability)
// * 配置文件
// * 优化等级
// * 是否启用半精度
std::string str = fmt::format(
"modelHash:{}\nortVersion:{}\nvendorId:{}\ndeviceId:{}\nminShapes:{},{}\nmaxShapes:{},{}\noptShapes:{},{}\noptLevel:{}\nfp16:{}",
Utils::HashData(modelData), Ort::GetVersionString(), desc.VendorId, desc.DeviceId,
minShapes.first, minShapes.second, maxShapes.first, maxShapes.second, optShapes.first,
optShapes.second, optimizationLevel, enableFP16);
std::wstring strHash = HashHelper::HexHash(std::span((const BYTE*)str.data(), str.size()));
return StrUtils::Concat(CommonSharedConstants::CACHE_DIR, L"tensorrt\\", strHash);
}
static void* ShareBufferWithCuda(
const winrt::com_ptr<ID3D11Buffer>& buffer,
uint32_t bufferSize,
cudaExternalMemory_t* bufferCudaMem,
cudaExternalSemaphore_t* bufferCudaSem
) noexcept {
winrt::com_ptr<IDXGIResource> dxgiRes = buffer.try_as<IDXGIResource>();
if (!dxgiRes) {
return nullptr;
}
HANDLE sharedHandle = NULL;
HRESULT hr = dxgiRes->GetSharedHandle(&sharedHandle);
if (FAILED(hr)) {
Logger::Get().ComError("GetSharedHandle 失败", hr);
return nullptr;
}
cudaExternalMemoryHandleDesc externalMemoryHandleDesc{
.type = cudaExternalMemoryHandleTypeD3D11ResourceKmt,
.handle = {.win32 = {.handle = sharedHandle } },
.size = bufferSize,
.flags = cudaExternalMemoryDedicated
};
cudaError_t cudaResult = cudaImportExternalMemory(
bufferCudaMem, &externalMemoryHandleDesc);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaImportExternalMemory 失败", cudaResult);
return nullptr;
}
cudaExternalSemaphoreHandleDesc extSemaDesc{
.type = cudaExternalSemaphoreHandleTypeKeyedMutexKmt,
.handle = {.win32 = {.handle = sharedHandle } },
};
cudaResult = cudaImportExternalSemaphore(bufferCudaSem, &extSemaDesc);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaImportExternalSemaphore 失败", cudaResult);
return nullptr;
}
void* bufferCudaPtr = nullptr;
cudaExternalMemoryBufferDesc externalMemoryBufferDesc{ .size = bufferSize };
cudaResult = cudaExternalMemoryGetMappedBuffer(
&bufferCudaPtr, *bufferCudaMem, &externalMemoryBufferDesc);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaExternalMemoryGetMappedBuffer 失败", cudaResult);
return nullptr;
}
return bufferCudaPtr;
}
TensorRTInferenceBackend::~TensorRTInferenceBackend() {
if (_inputBufferCudaSem) {
cudaDestroyExternalSemaphore((cudaExternalSemaphore_t)_inputBufferCudaSem);
}
if (_outputBufferCudaSem) {
cudaDestroyExternalSemaphore((cudaExternalSemaphore_t)_outputBufferCudaSem);
}
if (_inputBufferCudaPtr) {
cudaFree(_inputBufferCudaPtr);
}
if (_outputBufferCudaPtr) {
cudaFree(_outputBufferCudaPtr);
}
if (_inputBufferCudaMem) {
cudaDestroyExternalMemory((cudaExternalMemory_t)_inputBufferCudaMem);
}
if (_outputBufferCudaMem) {
cudaDestroyExternalMemory((cudaExternalMemory_t)_outputBufferCudaMem);
}
}
bool TensorRTInferenceBackend::Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
ID3D11Texture2D** output
) noexcept {
if (!Win32Utils::FileExists(L"third_party\\onnxruntime_providers_tensorrt.dll")) {
Logger::Get().Error("未安装 TensorRT 拓展");
return false;
}
int deviceId = 0;
cudaError_t cudaResult = cudaD3D11GetDevice(&deviceId, deviceResources.GetGraphicsAdapter());
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaD3D11GetDevice 失败", cudaResult);
return false;
}
if (!CheckComputeCapability(deviceId)) {
Logger::Get().Error("CheckComputeCapability 失败");
return false;
}
cudaResult = cudaSetDevice(deviceId);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaSetDevice 失败", cudaResult);
return false;
}
bool isFP16Data = false;
try {
const OrtApi& ortApi = Ort::GetApi();
_env = Ort::Env(ORT_LOGGING_LEVEL_INFO, "", _OrtLog, nullptr);
Ort::SessionOptions sessionOptions;
sessionOptions.SetIntraOpNumThreads(1);
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
if (!_CreateSession(deviceResources, deviceId, sessionOptions, modelPath)) {
Logger::Get().Error("_CreateSession 失败");
return false;
}
if (!_IsModelValid(_session, isFP16Data)) {
Logger::Get().Error("不支持此模型");
return false;
}
_cudaMemInfo = Ort::MemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemTypeDefault);
} catch (const Ort::Exception& e) {
Logger::Get().Error(e.what());
return false;
}
ID3D11Device5* d3dDevice = deviceResources.GetD3DDevice();
_d3dDC = deviceResources.GetD3DDC();
const SIZE inputSize = DirectXHelper::GetTextureSize(input);
const SIZE outputSize = SIZE{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
// 创建输出纹理
winrt::com_ptr<ID3D11Texture2D> outputTex = DirectXHelper::CreateTexture2D(
d3dDevice,
DXGI_FORMAT_R8G8B8A8_UNORM,
outputSize.cx,
outputSize.cy,
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS
);
if (!outputTex) {
Logger::Get().Error("创建输出纹理失败");
return false;
}
*output = outputTex.get();
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);
const uint32_t inputBufferSize = isFP16Data ? ((inputElemCount + 1) / 2 * 4) : (inputElemCount * 4);
const uint32_t outputBufferSize = isFP16Data ? ((outputElemCount + 1) / 2 * 4) : (outputElemCount * 4);
winrt::com_ptr<ID3D11Buffer> inputBuffer;
winrt::com_ptr<ID3D11Buffer> outputBuffer;
{
D3D11_BUFFER_DESC desc{
.ByteWidth = inputBufferSize,
.BindFlags = D3D11_BIND_UNORDERED_ACCESS,
.MiscFlags = D3D11_RESOURCE_MISC_SHARED_KEYEDMUTEX
};
HRESULT hr = d3dDevice->CreateBuffer(&desc, nullptr, inputBuffer.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateBuffer 失败", hr);
return false;
}
desc.ByteWidth = outputBufferSize;
desc.BindFlags = D3D11_BIND_SHADER_RESOURCE;
hr = d3dDevice->CreateBuffer(&desc, nullptr, outputBuffer.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateBuffer 失败", hr);
return false;
}
}
_inputBufferCudaPtr = ShareBufferWithCuda(
inputBuffer,
inputBufferSize,
(cudaExternalMemory_t*)&_inputBufferCudaMem,
(cudaExternalSemaphore_t*)&_inputBufferCudaSem
);
_outputBufferCudaPtr = ShareBufferWithCuda(
outputBuffer,
outputBufferSize,
(cudaExternalMemory_t*)&_outputBufferCudaMem,
(cudaExternalSemaphore_t*)&_outputBufferCudaSem
);
if (!_inputBufferCudaPtr || !_outputBufferCudaPtr) {
Logger::Get().Error("ShareBufferWithCuda 失败");
return false;
}
try {
_ioBinding = Ort::IoBinding(_session);
const int64_t inputShape[]{ 1,3,inputSize.cy,inputSize.cx };
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
_cudaMemInfo,
_inputBufferCudaPtr,
inputBufferSize,
inputShape,
std::size(inputShape),
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
));
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
_cudaMemInfo,
_outputBufferCudaPtr,
outputBufferSize,
outputShape,
std::size(outputShape),
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
));
} catch (const Ort::Exception& e) {
Logger::Get().Error(e.what());
return false;
}
_inputBufferKmt = inputBuffer.try_as<IDXGIKeyedMutex>();
if (!_inputBufferKmt) {
return false;
}
_outputBufferKmt = outputBuffer.try_as<IDXGIKeyedMutex>();
if (!_outputBufferKmt) {
return false;
}
_inputTexSrv = descriptorStore.GetShaderResourceView(input);
if (!_inputTexSrv) {
Logger::Get().Error("GetShaderResourceView 失败");
return false;
}
_sampler = deviceResources.GetSampler(
D3D11_FILTER_MIN_MAG_MIP_POINT, D3D11_TEXTURE_ADDRESS_CLAMP);
if (!_sampler) {
Logger::Get().Error("GetSampler 失败");
return false;
}
{
D3D11_UNORDERED_ACCESS_VIEW_DESC desc{
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D11_UAV_DIMENSION_BUFFER,
.Buffer{
.NumElements = inputElemCount
}
};
HRESULT hr = d3dDevice->CreateUnorderedAccessView(
inputBuffer.get(), &desc, _inputBufferUav.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateUnorderedAccessView 失败", hr);
return false;
}
}
{
D3D11_SHADER_RESOURCE_VIEW_DESC desc{
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
.ViewDimension = D3D11_SRV_DIMENSION_BUFFER,
.Buffer{
.NumElements = outputElemCount
}
};
HRESULT hr = d3dDevice->CreateShaderResourceView(
outputBuffer.get(), &desc, _outputBufferSrv.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateShaderResourceView 失败", hr);
return false;
}
}
{
D3D11_UNORDERED_ACCESS_VIEW_DESC desc{
.ViewDimension = D3D11_UAV_DIMENSION_TEXTURE2D
};
HRESULT hr = d3dDevice->CreateUnorderedAccessView(
outputTex.get(), &desc, _outputTexUav.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateUnorderedAccessView 失败", hr);
return false;
}
}
HRESULT hr = d3dDevice->CreateComputeShader(
TextureToTensorCS, sizeof(TextureToTensorCS), nullptr, _texToTensorShader.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateComputeShader 失败", hr);
return false;
}
hr = d3dDevice->CreateComputeShader(
TensorToTextureCS, sizeof(TensorToTextureCS), nullptr, _tensorToTexShader.put());
if (FAILED(hr)) {
Logger::Get().ComError("CreateComputeShader 失败", hr);
return false;
}
static constexpr std::pair<uint32_t, uint32_t> TEX_TO_TENSOR_BLOCK_SIZE{ 16, 16 };
static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
_texToTensorDispatchCount = {
(inputSize.cx + TEX_TO_TENSOR_BLOCK_SIZE.first - 1) / TEX_TO_TENSOR_BLOCK_SIZE.first,
(inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second
};
_tensorToTexDispatchCount = {
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
};
return true;
}
void TensorRTInferenceBackend::Evaluate() noexcept {
// 输入纹理 -> 输入张量
HRESULT hr = _inputBufferKmt->AcquireSync(_inputBufferMutexKey, INFINITE);
if (FAILED(hr)) {
Logger::Get().ComError("AcquireSync 失败", hr);
return;
}
_d3dDC->CSSetShaderResources(0, 1, &_inputTexSrv);
_d3dDC->CSSetSamplers(0, 1, &_sampler);
{
ID3D11UnorderedAccessView* uav = _inputBufferUav.get();
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
}
_d3dDC->CSSetShader(_texToTensorShader.get(), nullptr, 0);
_d3dDC->Dispatch(_texToTensorDispatchCount.first, _texToTensorDispatchCount.second, 1);
_inputBufferKmt->ReleaseSync(++_inputBufferMutexKey);
{
cudaExternalSemaphore_t semArr[] = {
(cudaExternalSemaphore_t)_inputBufferCudaSem,
(cudaExternalSemaphore_t)_outputBufferCudaSem
};
cudaExternalSemaphoreWaitParams extSemWaitParamsArr[] = {
{.params{.keyedMutex{.key = _inputBufferMutexKey, .timeoutMs = INFINITE}}},
{.params{.keyedMutex{.key = _outputBufferMutexKey, .timeoutMs = INFINITE}}}
};
cudaError_t cudaResult = cudaWaitExternalSemaphoresAsync(semArr, extSemWaitParamsArr, 2);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaWaitExternalSemaphoresAsync 失败", cudaResult);
return;
}
}
try {
Ort::RunOptions runOptions;
runOptions.AddConfigEntry("disable_synchronize_execution_providers", "1");
_session.Run(runOptions, _ioBinding);
} catch (const Ort::Exception& e) {
Logger::Get().Error(e.what());
return;
}
{
cudaExternalSemaphore_t semArr[] = {
(cudaExternalSemaphore_t)_inputBufferCudaSem,
(cudaExternalSemaphore_t)_outputBufferCudaSem
};
cudaExternalSemaphoreSignalParams extSemSigParams[] = {
{.params = {.keyedMutex = {.key = ++_inputBufferMutexKey}}},
{.params = {.keyedMutex = {.key = ++_outputBufferMutexKey}}}
};
cudaError_t cudaResult = cudaSignalExternalSemaphoresAsync(semArr, extSemSigParams, 2);
if (cudaResult != cudaError_t::cudaSuccess) {
LogCudaError("cudaSignalExternalSemaphoresAsync 失败", cudaResult);
return;
}
}
// 输出张量 -> 输出纹理
hr = _outputBufferKmt->AcquireSync(_outputBufferMutexKey, INFINITE);
if (FAILED(hr)) {
Logger::Get().ComError("AcquireSync 失败", hr);
return;
}
{
ID3D11ShaderResourceView* srv = _outputBufferSrv.get();
_d3dDC->CSSetShaderResources(0, 1, &srv);
}
{
ID3D11UnorderedAccessView* uav = _outputTexUav.get();
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
}
_d3dDC->CSSetShader(_tensorToTexShader.get(), nullptr, 0);
_d3dDC->Dispatch(_tensorToTexDispatchCount.first, _tensorToTexDispatchCount.second, 1);
{
ID3D11ShaderResourceView* srv = nullptr;
_d3dDC->CSSetShaderResources(0, 1, &srv);
}
{
ID3D11UnorderedAccessView* uav = nullptr;
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
}
_outputBufferKmt->ReleaseSync(++_outputBufferMutexKey);
}
bool TensorRTInferenceBackend::_CreateSession(
DeviceResources& deviceResources,
int deviceId,
Ort::SessionOptions& sessionOptions,
const wchar_t* modelPath
) {
const std::pair<uint16_t, uint16_t> minShapes(uint16_t(1), uint16_t(1));
const std::pair<uint16_t, uint16_t> maxShapes(uint16_t(1920), uint16_t(1080));
const std::pair<uint16_t, uint16_t> optShapes(uint16_t(1920), uint16_t(1080));
const bool enableFP16 = true;
const uint8_t optimizationLevel = 5;
std::vector<uint8_t> modelData;
if (!Win32Utils::ReadFile(modelPath, modelData)) {
Logger::Get().Error("读取模型失败");
return false;
}
const std::wstring cacheDir = GetCacheDir(
modelData,
deviceResources.GetGraphicsAdapter(),
minShapes,
maxShapes,
optShapes,
optimizationLevel,
enableFP16
);
if (!Win32Utils::CreateDir(cacheDir, true)) {
Logger::Get().Win32Error("创建缓存文件夹失败");
return false;
}
const std::wstring cacheCtxPath = cacheDir + L"\\ctx.onnx";
const OrtApi& ortApi = Ort::GetApi();
OnnxHelper::unique_tensorrt_provider_options trtOptions;
Ort::ThrowOnError(ortApi.CreateTensorRTProviderOptions(trtOptions.put()));
const std::string deviceIdStr = std::to_string(deviceId);
{
const char* keys[]{
"device_id",
"has_user_compute_stream",
"trt_fp16_enable",
"trt_builder_optimization_level",
"trt_profile_min_shapes",
"trt_profile_max_shapes",
"trt_profile_opt_shapes",
"trt_engine_cache_enable",
"trt_engine_cache_prefix",
"trt_dump_ep_context_model",
"trt_ep_context_file_path"
};
std::string optLevelStr = std::to_string(optimizationLevel);
std::string minShapesStr = fmt::format("input:1x3x{}x{}", minShapes.second, minShapes.first);
std::string maxShapesStr = fmt::format("input:1x3x{}x{}", maxShapes.second, maxShapes.first);
std::string optShapesStr = fmt::format("input:1x3x{}x{}", optShapes.second, optShapes.first);
std::string cacheDirANSI = StrUtils::UTF16ToANSI(cacheDir);
std::string cacheCtxPathANSI = StrUtils::UTF16ToANSI(cacheCtxPath);
const char* values[]{
deviceIdStr.c_str(),
"1",
enableFP16 ? "1" : "0",
optLevelStr.c_str(),
minShapesStr.c_str(),
maxShapesStr.c_str(),
optShapesStr.c_str(),
"1",
"trt",
"1",
cacheCtxPathANSI.c_str()
};
Ort::ThrowOnError(ortApi.UpdateTensorRTProviderOptions(trtOptions.get(), keys, values, std::size(keys)));
}
OnnxHelper::unique_cuda_provider_options cudaOptions;
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(cudaOptions.put()));
{
const char* keys[]{ "device_id", "has_user_compute_stream" };
const char* values[]{ deviceIdStr.c_str(), "1" };
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions.get(), keys, values, std::size(keys)));
}
sessionOptions.AppendExecutionProvider_TensorRT_V2(*trtOptions.get());
sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions.get());
if (Win32Utils::FileExists(cacheCtxPath.c_str())) {
Logger::Get().Info("读取缓存 " + StrUtils::UTF16ToUTF8(cacheCtxPath));
_session = Ort::Session(_env, cacheCtxPath.c_str(), sessionOptions);
} else {
_session = Ort::Session(_env, modelData.data(), modelData.size(), sessionOptions);
}
return true;
}
}
#endif

View file

@ -0,0 +1,76 @@
#pragma once
#include "InferenceBackendBase.h"
#ifdef _M_X64
struct cudaGraphicsResource;
namespace Magpie::Core {
class TensorRTInferenceBackend : public InferenceBackendBase {
public:
TensorRTInferenceBackend() = default;
TensorRTInferenceBackend(const TensorRTInferenceBackend&) = delete;
TensorRTInferenceBackend(TensorRTInferenceBackend&&) = default;
virtual ~TensorRTInferenceBackend();
bool Initialize(
const wchar_t* modelPath,
uint32_t scale,
DeviceResources& deviceResources,
BackendDescriptorStore& descriptorStore,
ID3D11Texture2D* input,
ID3D11Texture2D** output
) noexcept override;
void Evaluate() noexcept override;
private:
bool _CreateSession(
DeviceResources& deviceResources,
int deviceId,
Ort::SessionOptions& sessionOptions,
const wchar_t* modelPath
);
Ort::Env _env{ nullptr };
Ort::Session _session{ nullptr };
ID3D11DeviceContext4* _d3dDC = nullptr;
ID3D11SamplerState* _sampler = nullptr;
ID3D11ShaderResourceView* _inputTexSrv = nullptr;
winrt::com_ptr<ID3D11UnorderedAccessView> _inputBufferUav;
winrt::com_ptr<ID3D11ShaderResourceView> _outputBufferSrv;
winrt::com_ptr<ID3D11UnorderedAccessView> _outputTexUav;
winrt::com_ptr<IDXGIKeyedMutex> _inputBufferKmt;
winrt::com_ptr<IDXGIKeyedMutex> _outputBufferKmt;
UINT64 _inputBufferMutexKey = 0;
UINT64 _outputBufferMutexKey = 0;
winrt::com_ptr<ID3D11ComputeShader> _texToTensorShader;
winrt::com_ptr<ID3D11ComputeShader> _tensorToTexShader;
std::pair<uint32_t, uint32_t> _texToTensorDispatchCount{};
std::pair<uint32_t, uint32_t> _tensorToTexDispatchCount{};
Ort::MemoryInfo _cudaMemInfo{ nullptr };
// cudaExternalMemory_t
void* _inputBufferCudaMem = nullptr;
void* _outputBufferCudaMem = nullptr;
void* _inputBufferCudaPtr = nullptr;
void* _outputBufferCudaPtr = nullptr;
// cudaExternalSemaphore_t
void* _inputBufferCudaSem = nullptr;
void* _outputBufferCudaSem = nullptr;
Ort::IoBinding _ioBinding{ nullptr };
};
}
#endif

View file

@ -1,5 +1,5 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
</packages>

View file

@ -12,7 +12,7 @@ void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
return;
}
const int2 gxy = (gid.xy << 4) + (tid.xy << 1);
const uint2 gxy = (gid.xy << 4) + (tid.xy << 1);
// 不知为何这比通过 cbuffer 传入更快
uint width, height;

View file

@ -0,0 +1,15 @@
Buffer<min16float> tensor : register(t0);
RWTexture2D<min16float4> tex : register(u0);
[numthreads(8, 8, 1)]
void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
const uint2 gxy = (gid.xy << 3) + tid.xy;
uint width, height;
tex.GetDimensions(width, height);
const uint planeStride = width * height;
const uint idx = gxy.y * width + gxy.x;
min16float3 color = { tensor[idx], tensor[planeStride + idx], tensor[planeStride * 2 + idx] };
tex[gxy] = min16float4(color, 1);
}

View file

@ -0,0 +1,54 @@
Texture2D<min16float4> tex : register(t0);
RWBuffer<min16float> result : register(u0);
SamplerState sam : register(s0);
[numthreads(8, 8, 1)]
void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
const uint2 gxy = (gid.xy << 4) + (tid.xy << 1);
uint width, height;
tex.GetDimensions(width, height);
if (gxy.x >= width || gxy.y >= height) {
return;
}
const float2 pos = (gxy + 1) / float2(width, height);
min16float4 red = tex.GatherRed(sam, pos);
min16float4 green = tex.GatherGreen(sam, pos);
min16float4 blue = tex.GatherBlue(sam, pos);
const uint planeStride = width * height;
const uint planeStride2 = width * height * 2;
// w z
// x y
uint idx = gxy.y * width + gxy.x;
result[idx] = red.w;
result[idx + planeStride] = green.w;
result[idx + planeStride2] = blue.w;
const bool zyValid = gxy.x + 1 < width;
if (zyValid) {
result[idx + 1] = red.z;
result[idx + planeStride + 1] = green.z;
result[idx + planeStride2 + 1] = blue.z;
}
idx += width;
if (gxy.y + 1 < height) {
result[idx] = red.x;
result[idx + planeStride] = green.x;
result[idx + planeStride2] = blue.x;
if (zyValid) {
result[idx + 1] = red.y;
result[idx + planeStride + 1] = green.y;
result[idx + planeStride2 + 1] = blue.y;
}
}
}

View file

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" />
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" />
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" />
<PropertyGroup Label="Globals">
<VCProjectVersion>16.0</VCProjectVersion>
<Keyword>Win32Proj</Keyword>
@ -112,20 +112,20 @@
</ItemGroup>
</Target>
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" />
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" />
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
</Target>
</Project>

View file

@ -172,9 +172,8 @@ bool TouchHelper::Register() noexcept {
}
std::wstring magpieDir = StrUtils::Concat(system32Dir.get(), L"\\Magpie");
hr = wil::CreateDirectoryDeepNoThrow(magpieDir.c_str());
if (FAILED(hr)) {
Logger::Get().ComError("CreateDirectoryDeepNoThrow 失败", hr);
if (!CreateDirectory(magpieDir.c_str(), nullptr)) {
Logger::Get().Win32Error("CreateDirectory 失败");
return false;
}

View file

@ -1,7 +1,7 @@
[requires]
fmt/10.2.1
spdlog/1.14.1
parallel-hashmap/1.37
fmt/11.1.3
spdlog/1.15.1
parallel-hashmap/2.0.0
[generators]
MSBuildDeps

View file

@ -19,9 +19,10 @@
#include "Win32Utils.h"
#include "TouchHelper.h"
#include "CommonSharedConstants.h"
#include "StrUtils.h"
// 将当前目录设为程序所在目录
static void SetWorkingDir() noexcept {
static std::wstring SetWorkingDir() noexcept {
std::wstring path = Win32Utils::GetExePath();
FAIL_FAST_IF_FAILED(PathCchRemoveFileSpec(
@ -30,6 +31,9 @@ static void SetWorkingDir() noexcept {
));
FAIL_FAST_IF_WIN32_BOOL_FALSE(SetCurrentDirectory(path.c_str()));
path.resize(StrUtils::StrLen(path.c_str()));
return path;
}
static void InitializeLogger(const char* logFilePath) noexcept {
@ -54,7 +58,7 @@ int APIENTRY wWinMain(
// 堆损坏时终止进程
HeapSetInformation(NULL, HeapEnableTerminationOnCorruption, nullptr, 0);
SetWorkingDir();
std::wstring workingDir = SetWorkingDir();
enum {
Normal,
@ -90,6 +94,10 @@ int APIENTRY wWinMain(
return Magpie::TouchHelper::Unregister() ? 0 : 1;
}
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
workingDir += L"\\third_party";
AddDllDirectory(workingDir.c_str());
auto& app = Magpie::XamlApp::Get();
if (!app.Initialize(hInstance, lpCmdLine)) {
return 0;

View file

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.UI.Xaml" version="2.8.6" targetFramework="native" />
<package id="Microsoft.Web.WebView2" version="1.0.2535.41" targetFramework="native" />
<package id="Microsoft.UI.Xaml" version="2.8.7" targetFramework="native" />
<package id="Microsoft.Web.WebView2" version="1.0.3179.45" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
</packages>

View file

@ -232,6 +232,36 @@ bool Win32Utils::WriteTextFile(const wchar_t* fileName, std::string_view text) n
return true;
}
bool Win32Utils::CreateDir(const std::wstring& path, bool recursive) noexcept {
assert(!path.empty());
if (DirExists(path.c_str())) {
return true;
}
if (!recursive) {
return CreateDirectory(path.c_str(), nullptr);
}
size_t searchOffset = 0;
do {
auto segPos = path.find_first_of(L'\\', searchOffset);
if (segPos == std::wstring::npos) {
// 没有分隔符则将整个路径视为文件夹
segPos = path.size();
}
std::wstring subdir = path.substr(0, segPos);
if (!subdir.empty() && !DirExists(subdir.c_str()) && !CreateDirectory(subdir.c_str(), nullptr)) {
return false;
}
searchOffset = segPos + 1;
} while (searchOffset < path.size());
return true;
}
const Win32Utils::OSVersion& Win32Utils::GetOSVersion() noexcept {
static OSVersion version = []() -> OSVersion {
HMODULE hNtDll = GetModuleHandle(L"ntdll.dll");

View file

@ -43,6 +43,9 @@ struct Win32Utils {
return (attrs != INVALID_FILE_ATTRIBUTES) && (attrs & FILE_ATTRIBUTE_DIRECTORY);
}
// 相比 wil::CreateDirectoryDeepNoThrow 支持相对路径而且更快
static bool CreateDir(const std::wstring& path, bool recursive = false) noexcept;
struct OSVersion : Version {
constexpr OSVersion() {}
constexpr OSVersion(uint32_t major, uint32_t minor, uint32_t patch)

View file

@ -53,19 +53,19 @@
<ResourceCompile Include="TouchHelper.rc" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
<Manifest Include="app.manifest" />
</ItemGroup>
<ItemGroup>
<Manifest Include="app.manifest" />
<None Include="packages.config" />
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
</Target>
</Project>

View file

@ -20,13 +20,13 @@
<ClCompile Include="main.cpp" />
<ClCompile Include="pch.cpp" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<ItemGroup>
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natvis" />
</ItemGroup>
<ItemGroup>
<Manifest Include="app.manifest" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
</Project>

View file

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
</packages>

View file

@ -67,12 +67,12 @@
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
</ImportGroup>
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
<PropertyGroup>
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
</PropertyGroup>
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
</Target>
</Project>

View file

@ -33,10 +33,10 @@
<ItemGroup>
<Manifest Include="app.manifest" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
<ItemGroup>
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natvis" />
</ItemGroup>
<ItemGroup>
<None Include="packages.config" />
</ItemGroup>
</Project>

View file

@ -1,4 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
</packages>