mirror of
https://github.com/Blinue/Magpie.git
synced 2026-06-24 02:04:10 +00:00
Compare commits
68 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c42ce9adeb | ||
|
|
d894579bd7 | ||
|
|
6f084e9fb1 | ||
|
|
11062b11b4 | ||
|
|
c91cd6f587 | ||
|
|
83d312aacd | ||
|
|
b18d2aae54 | ||
|
|
4b0bc51504 | ||
|
|
523a24b2b1 | ||
|
|
2bcbc3e399 | ||
|
|
91859d2d1c | ||
|
|
d1e60d7e41 | ||
|
|
8e37bc17d0 | ||
|
|
768f13d31a | ||
|
|
7de2f4bf32 | ||
|
|
1f2693ff4e | ||
|
|
b56809634b | ||
|
|
eace2e87b0 | ||
|
|
6b4a92cc6d | ||
|
|
08b07e155c | ||
|
|
6ede0212cd | ||
|
|
f8dc1ff04d | ||
|
|
ca99356fbf | ||
|
|
a5726c7506 | ||
|
|
69416aff3d | ||
|
|
06ca4e0be6 | ||
|
|
a6cc7fa67a | ||
|
|
cc176f72f2 | ||
|
|
a1019bba34 | ||
|
|
4bcd77be76 | ||
|
|
544ab2a0bf | ||
|
|
8270f3a24c | ||
|
|
6f9b8b358f | ||
|
|
048a3494a3 | ||
|
|
d87c47c790 | ||
|
|
d36bf89b65 | ||
|
|
a8325eccfa | ||
|
|
005498b029 | ||
|
|
a4af854ed3 | ||
|
|
9d26a4c795 | ||
|
|
0700810727 | ||
|
|
d0dc556239 | ||
|
|
587fdd5cc6 | ||
|
|
786ff2fc22 | ||
|
|
98f8649a27 | ||
|
|
08e20fe1a9 | ||
|
|
feb52a2ca9 | ||
|
|
a5f9e4ecb6 | ||
|
|
6685a1df01 | ||
|
|
972d0b057a | ||
|
|
fb7c840ca1 | ||
|
|
b1036cd9f2 | ||
|
|
657209dd39 | ||
|
|
427c7a6973 | ||
|
|
45a3178a10 | ||
|
|
6a94c860fd | ||
|
|
e8cad13732 | ||
|
|
95e04aed1d | ||
|
|
aee12b750b | ||
|
|
34c2123b36 | ||
|
|
809ab1aac1 | ||
|
|
4c0bd3131f | ||
|
|
b658d536e6 | ||
|
|
30a43bc919 | ||
|
|
1bfabb45e3 | ||
|
|
4512d3e399 | ||
|
|
94d94f9508 | ||
|
|
4246f9841a |
59 changed files with 2069 additions and 128 deletions
|
|
@ -29,7 +29,7 @@
|
|||
<PrecompiledHeader>Use</PrecompiledHeader>
|
||||
<PrecompiledHeaderFile>pch.h</PrecompiledHeaderFile>
|
||||
<PrecompiledHeaderOutputFile>$(IntDir)pch.pch</PrecompiledHeaderOutputFile>
|
||||
<PreprocessorDefinitions>_WINDOWS;WIN32_LEAN_AND_MEAN;WINRT_LEAN_AND_MEAN;WINRT_NO_MODULE_LOCK;WIL_SUPPRESS_EXCEPTIONS;NOGDICAPMASKS;NOICONS;NOATOM;NOCLIPBOARD;NODRAWTEXT;NOMEMMGR;NOMETAFILE;NOMINMAX;NOOPENFILE;NOSCROLL;NOSERVICE;NOSOUND;NOTEXTMETRIC;NOCOMM;NOKANJI;NOHELP;NOPROFILER;NODEFERWINDOWPOS;NOMCX;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions>_WINDOWS;WIN32_LEAN_AND_MEAN;WINRT_LEAN_AND_MEAN;WINRT_NO_MODULE_LOCK;WIL_SUPPRESS_EXCEPTIONS;WIL_USE_STL=1;NOGDICAPMASKS;NOICONS;NOATOM;NOCLIPBOARD;NODRAWTEXT;NOMEMMGR;NOMETAFILE;NOMINMAX;NOOPENFILE;NOSCROLL;NOSERVICE;NOSOUND;NOTEXTMETRIC;NOCOMM;NOKANJI;NOHELP;NOPROFILER;NODEFERWINDOWPOS;NOMCX;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions Condition="'$(CommitId)'!=''">MAGPIE_COMMIT_ID=$(CommitId);%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<PreprocessorDefinitions Condition="'$(MajorVersion)'!='' And '$(MinorVersion)'!='' And '$(PatchVersion)'!='' And '$(VersionTag)'!=''">MAGPIE_VERSION_MAJOR=$(MajorVersion);MAGPIE_VERSION_MINOR=$(MinorVersion);MAGPIE_VERSION_PATCH=$(PatchVersion);MAGPIE_VERSION_TAG=$(VersionTag);%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
<AdditionalOptions>/bigobj %(AdditionalOptions)</AdditionalOptions>
|
||||
|
|
@ -38,7 +38,7 @@
|
|||
<GenerateDebugInformation>true</GenerateDebugInformation>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
||||
|
||||
<ItemDefinitionGroup Condition="'$(Configuration)'=='Debug'">
|
||||
<ClCompile>
|
||||
<PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
|
||||
|
|
@ -59,7 +59,37 @@
|
|||
<OptimizeReferences>true</OptimizeReferences>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
|
||||
|
||||
<!-- HybridCRT -->
|
||||
<Import Project="$(MSBuildThisFileDirectory)HybridCRT.props" />
|
||||
|
||||
<!-- onnxruntime -->
|
||||
<ItemDefinitionGroup>
|
||||
<ClCompile>
|
||||
<AdditionalIncludeDirectories>$(SolutionDir)obj\onnxruntime\include;$(SolutionDir)obj\onnxruntime\include\onnxruntime\core\session;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
|
||||
</ClCompile>
|
||||
<Link>
|
||||
<AdditionalLibraryDirectories>$(SolutionDir)obj\onnxruntime\lib\$(Platform);%(AdditionalLibraryDirectories)</AdditionalLibraryDirectories>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.Debug.dll"
|
||||
Condition="'$(Configuration)'=='Debug' And Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.Debug.dll')">
|
||||
<Link>third_party\DirectML.Debug.dll</Link>
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.dll"
|
||||
Condition="Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\DirectML.dll')">
|
||||
<Link>third_party\DirectML.dll</Link>
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
<None Include="$(SolutionDir)obj\onnxruntime\bin\$(Platform)\onnxruntime.dll"
|
||||
Condition="Exists('$(SolutionDir)obj\onnxruntime\bin\$(Platform)\onnxruntime.dll')">
|
||||
<Link>third_party\onnxruntime.dll</Link>
|
||||
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
|
||||
<Visible>false</Visible>
|
||||
</None>
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -117,19 +117,19 @@
|
|||
</local:SimpleStackPanel>
|
||||
</local:SimpleStackPanel>
|
||||
<local:SettingsGroup x:Uid="About_Version_UpdateSettings">
|
||||
<local:SettingsCard x:Uid="About_Version_UpdateSettings_AutoCheckForUpdates">
|
||||
<local:SettingsCard x:Uid="About_Version_UpdateSettings_AutoCheckForUpdates"
|
||||
IsEnabled="False">
|
||||
<local:SettingsCard.HeaderIcon>
|
||||
<FontIcon Glyph="" />
|
||||
</local:SettingsCard.HeaderIcon>
|
||||
<ToggleSwitch x:Uid="ToggleSwitch"
|
||||
IsOn="{x:Bind ViewModel.IsAutoCheckForUpdates, Mode=TwoWay}" />
|
||||
<ToggleSwitch x:Uid="ToggleSwitch" />
|
||||
</local:SettingsCard>
|
||||
<local:SettingsCard x:Uid="About_Version_UpdateSettings_CheckForPreviewUpdates">
|
||||
<local:SettingsCard x:Uid="About_Version_UpdateSettings_CheckForPreviewUpdates"
|
||||
IsEnabled="False">
|
||||
<local:SettingsCard.HeaderIcon>
|
||||
<FontIcon Glyph="" />
|
||||
</local:SettingsCard.HeaderIcon>
|
||||
<ToggleSwitch x:Uid="ToggleSwitch"
|
||||
IsOn="{x:Bind ViewModel.IsCheckForPreviewUpdates, Mode=TwoWay}" />
|
||||
<ToggleSwitch x:Uid="ToggleSwitch" />
|
||||
</local:SettingsCard>
|
||||
</local:SettingsGroup>
|
||||
<local:SettingsGroup x:Uid="About_Feedback">
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ hstring AboutViewModel::Version() const noexcept {
|
|||
L" ",
|
||||
WIDEN(STRING(MAGPIE_VERSION_TAG)) + 1,
|
||||
#else
|
||||
L" dev",
|
||||
L" onnx-preview2",
|
||||
#endif
|
||||
#ifdef MAGPIE_COMMIT_ID
|
||||
L" | ",
|
||||
|
|
|
|||
|
|
@ -394,6 +394,7 @@ void AppSettings::IsDeveloperMode(bool value) noexcept {
|
|||
if (!value) {
|
||||
// 关闭开发者模式则禁用所有开发者选项
|
||||
_isDebugMode = false;
|
||||
_isBenchmarkMode = false;
|
||||
_isEffectCacheDisabled = false;
|
||||
_isFontCacheDisabled = false;
|
||||
_isSaveEffectSources = false;
|
||||
|
|
@ -458,9 +459,8 @@ void AppSettings::_UpdateWindowPlacement() noexcept {
|
|||
}
|
||||
|
||||
bool AppSettings::_Save(const _AppSettingsData& data) noexcept {
|
||||
HRESULT hr = wil::CreateDirectoryDeepNoThrow(data._configDir.c_str());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("创建配置文件夹失败", hr);
|
||||
if (!Win32Utils::CreateDir(data._configDir, true)) {
|
||||
Logger::Get().Win32Error("创建配置文件夹失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -509,6 +509,8 @@ bool AppSettings::_Save(const _AppSettingsData& data) noexcept {
|
|||
writer.Bool(data._isDeveloperMode);
|
||||
writer.Key("debugMode");
|
||||
writer.Bool(data._isDebugMode);
|
||||
writer.Key("benchmarkMode");
|
||||
writer.Bool(data._isBenchmarkMode);
|
||||
writer.Key("disableEffectCache");
|
||||
writer.Bool(data._isEffectCacheDisabled);
|
||||
writer.Key("disableFontCache");
|
||||
|
|
@ -666,6 +668,7 @@ void AppSettings::_LoadSettings(const rapidjson::GenericObject<true, rapidjson::
|
|||
}
|
||||
JsonHelper::ReadBool(root, "developerMode", _isDeveloperMode);
|
||||
JsonHelper::ReadBool(root, "debugMode", _isDebugMode);
|
||||
JsonHelper::ReadBool(root, "benchmarkMode", _isBenchmarkMode);
|
||||
JsonHelper::ReadBool(root, "disableEffectCache", _isEffectCacheDisabled);
|
||||
JsonHelper::ReadBool(root, "disableFontCache", _isFontCacheDisabled);
|
||||
JsonHelper::ReadBool(root, "saveEffectSources", _isSaveEffectSources);
|
||||
|
|
@ -1039,9 +1042,8 @@ bool AppSettings::_UpdateConfigPath(std::wstring* existingConfigPath) noexcept {
|
|||
}
|
||||
|
||||
// 确保配置文件夹存在
|
||||
HRESULT hr = wil::CreateDirectoryDeepNoThrow(_configDir.c_str());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("创建配置文件夹失败", hr);
|
||||
if (!Win32Utils::CreateDir(_configDir, true)) {
|
||||
Logger::Get().Win32Error("创建配置文件夹失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ struct _AppSettingsData {
|
|||
bool _isAlwaysRunAsAdmin = false;
|
||||
bool _isDeveloperMode = false;
|
||||
bool _isDebugMode = false;
|
||||
bool _isBenchmarkMode = false;
|
||||
bool _isEffectCacheDisabled = false;
|
||||
bool _isFontCacheDisabled = false;
|
||||
bool _isSaveEffectSources = false;
|
||||
|
|
@ -151,6 +152,15 @@ public:
|
|||
SaveAsync();
|
||||
}
|
||||
|
||||
bool IsBenchmarkMode() const noexcept {
|
||||
return _isBenchmarkMode;
|
||||
}
|
||||
|
||||
void IsBenchmarkMode(bool value) noexcept {
|
||||
_isBenchmarkMode = value;
|
||||
SaveAsync();
|
||||
}
|
||||
|
||||
bool IsEffectCacheDisabled() const noexcept {
|
||||
return _isEffectCacheDisabled;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -194,6 +194,10 @@
|
|||
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_DebugMode"
|
||||
IsChecked="{x:Bind ViewModel.IsDebugMode, Mode=TwoWay}" />
|
||||
</local:SettingsCard>
|
||||
<local:SettingsCard ContentAlignment="Left">
|
||||
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_BenchmarkMode"
|
||||
IsChecked="{x:Bind ViewModel.IsBenchmarkMode, Mode=TwoWay}" />
|
||||
</local:SettingsCard>
|
||||
<local:SettingsCard ContentAlignment="Left">
|
||||
<CheckBox x:Uid="Home_Advanced_DeveloperOptions_DisableEffectCache"
|
||||
IsChecked="{x:Bind ViewModel.IsEffectCacheDisabled, Mode=TwoWay}" />
|
||||
|
|
|
|||
|
|
@ -298,6 +298,21 @@ void HomeViewModel::IsDebugMode(bool value) {
|
|||
RaisePropertyChanged(L"IsDebugMode");
|
||||
}
|
||||
|
||||
bool HomeViewModel::IsBenchmarkMode() const noexcept {
|
||||
return AppSettings::Get().IsBenchmarkMode();
|
||||
}
|
||||
|
||||
void HomeViewModel::IsBenchmarkMode(bool value) {
|
||||
AppSettings& settings = AppSettings::Get();
|
||||
|
||||
if (settings.IsBenchmarkMode() == value) {
|
||||
return;
|
||||
}
|
||||
|
||||
settings.IsBenchmarkMode(value);
|
||||
RaisePropertyChanged(L"IsBenchmarkMode");
|
||||
}
|
||||
|
||||
bool HomeViewModel::IsEffectCacheDisabled() const noexcept {
|
||||
return AppSettings::Get().IsEffectCacheDisabled();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,6 +72,9 @@ struct HomeViewModel : HomeViewModelT<HomeViewModel>, wil::notify_property_chang
|
|||
bool IsDebugMode() const noexcept;
|
||||
void IsDebugMode(bool value);
|
||||
|
||||
bool IsBenchmarkMode() const noexcept;
|
||||
void IsBenchmarkMode(bool value);
|
||||
|
||||
bool IsEffectCacheDisabled() const noexcept;
|
||||
void IsEffectCacheDisabled(bool value);
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ namespace Magpie.App {
|
|||
|
||||
Boolean IsDeveloperMode;
|
||||
Boolean IsDebugMode;
|
||||
Boolean IsBenchmarkMode;
|
||||
Boolean IsEffectCacheDisabled;
|
||||
Boolean IsFontCacheDisabled;
|
||||
Boolean IsSaveEffectSources;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" />
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" />
|
||||
<PropertyGroup Label="Globals">
|
||||
<CppWinRTGenerateWindowsMetadata>true</CppWinRTGenerateWindowsMetadata>
|
||||
<MinimalCoreWin>true</MinimalCoreWin>
|
||||
|
|
@ -59,9 +59,11 @@
|
|||
<Link>
|
||||
<GenerateWindowsMetadata>false</GenerateWindowsMetadata>
|
||||
<SubSystem>Console</SubSystem>
|
||||
<AdditionalDependencies>kernel32.lib;ole32.lib;oleaut32.lib;user32.lib;gdi32.lib;$(OutDir).\Magpie.Core.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies>kernel32.lib;ole32.lib;oleaut32.lib;user32.lib;gdi32.lib;onnxruntime.lib;directml.lib;$(OutDir).\Magpie.Core.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<AdditionalDependencies Condition="'$(Platform)'=='x64'">cudart.lib;%(AdditionalDependencies)</AdditionalDependencies>
|
||||
<ModuleDefinitionFile>Magpie.App.def</ModuleDefinitionFile>
|
||||
<DelayLoadDLLs>d3dcompiler_47.dll;Magnification.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
|
||||
<DelayLoadDLLs>d3d12.dll;DirectML.dll;d3dcompiler_47.dll;Magnification.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
|
||||
<DelayLoadDLLs Condition="'$(Platform)'=='x64'">cudart64_12.dll;%(DelayLoadDLLs)</DelayLoadDLLs>
|
||||
</Link>
|
||||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
|
|
@ -719,20 +721,20 @@ File.Delete("priconfig.xml");
|
|||
</ItemGroup>
|
||||
</Target>
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
||||
|
|
@ -838,4 +838,7 @@
|
|||
<data name="Home_Advanced_SimulateExclusiveFullscreen_InfoBar.Title" xml:space="preserve">
|
||||
<value>This option is not compatible with some older games. Please use it with caution.</value>
|
||||
</data>
|
||||
<data name="Home_Advanced_DeveloperOptions_BenchmarkMode.Content" xml:space="preserve">
|
||||
<value>Benchmark mode</value>
|
||||
</data>
|
||||
</root>
|
||||
|
|
@ -838,4 +838,7 @@
|
|||
<data name="Home_Advanced_SimulateExclusiveFullscreen_InfoBar.Title" xml:space="preserve">
|
||||
<value>此选项和一些旧游戏不兼容,请谨慎使用。</value>
|
||||
</data>
|
||||
<data name="Home_Advanced_DeveloperOptions_BenchmarkMode.Content" xml:space="preserve">
|
||||
<value>性能测试模式</value>
|
||||
</data>
|
||||
</root>
|
||||
|
|
@ -304,6 +304,7 @@ bool ScalingService::_StartScale(HWND hWnd, const Profile& profile) {
|
|||
}
|
||||
|
||||
options.IsDebugMode(settings.IsDebugMode());
|
||||
options.IsBenchmarkMode(settings.IsBenchmarkMode());
|
||||
options.IsEffectCacheDisabled(settings.IsEffectCacheDisabled());
|
||||
options.IsFontCacheDisabled(settings.IsFontCacheDisabled());
|
||||
options.IsSaveEffectSources(settings.IsSaveEffectSources());
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
[requires]
|
||||
fmt/10.2.1
|
||||
spdlog/1.14.1
|
||||
parallel-hashmap/1.37
|
||||
fmt/11.1.3
|
||||
spdlog/1.15.1
|
||||
parallel-hashmap/2.0.0
|
||||
rapidjson/cci.20230929
|
||||
kuba-zip/0.3.2
|
||||
muparser/2.3.4
|
||||
muparser/2.3.5
|
||||
yas/7.1.0
|
||||
imgui/1.90.8
|
||||
imgui/1.90.9
|
||||
|
||||
[generators]
|
||||
MSBuildDeps
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="Microsoft.UI.Xaml" version="2.8.6" targetFramework="native" />
|
||||
<package id="Microsoft.Web.WebView2" version="1.0.2535.41" targetFramework="native" />
|
||||
<package id="Microsoft.UI.Xaml" version="2.8.7" targetFramework="native" />
|
||||
<package id="Microsoft.Web.WebView2" version="1.0.3179.45" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
|
||||
</packages>
|
||||
|
|
@ -95,7 +95,9 @@ bool DesktopDuplicationFrameSource::_Initialize() noexcept {
|
|||
DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
_srcRect.right - _srcRect.left,
|
||||
_srcRect.bottom - _srcRect.top,
|
||||
D3D11_BIND_SHADER_RESOURCE
|
||||
D3D11_BIND_SHADER_RESOURCE,
|
||||
D3D11_USAGE_DEFAULT,
|
||||
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
|
||||
);
|
||||
if (!_output) {
|
||||
Logger::Get().Error("CreateTexture2D 失败");
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ bool DeviceResources::_ObtainAdapterAndDevice(int adapterIdx) noexcept {
|
|||
if (desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) {
|
||||
Logger::Get().Warn("用户指定的显示卡为 WARP,已忽略");
|
||||
} else if (_TryCreateD3DDevice(adapter)) {
|
||||
_adapterIdx = adapterIdx;
|
||||
return true;
|
||||
} else {
|
||||
Logger::Get().Warn("用户指定的显示卡不支持 FL 11");
|
||||
|
|
@ -105,21 +106,31 @@ bool DeviceResources::_ObtainAdapterAndDevice(int adapterIdx) noexcept {
|
|||
}
|
||||
|
||||
if (_TryCreateD3DDevice(adapter)) {
|
||||
_adapterIdx = adapterIndex;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// 作为最后手段,回落到 Basic Render Driver Adapter(WARP)
|
||||
// https://docs.microsoft.com/en-us/windows/win32/direct3darticles/directx-warp
|
||||
HRESULT hr = _dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("EnumWarpAdapter 失败", hr);
|
||||
return false;
|
||||
}
|
||||
for (UINT adapterIndex = 0;
|
||||
SUCCEEDED(_dxgiFactory->EnumAdapters1(adapterIndex, adapter.put()));
|
||||
++adapterIndex
|
||||
) {
|
||||
DXGI_ADAPTER_DESC1 desc;
|
||||
HRESULT hr = adapter->GetDesc1(&desc);
|
||||
if (FAILED(hr)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!_TryCreateD3DDevice(adapter)) {
|
||||
Logger::Get().ComError("创建 WARP 设备失败", hr);
|
||||
return false;
|
||||
if ((desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (_TryCreateD3DDevice(adapter)) {
|
||||
_adapterIdx = adapterIndex;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ public:
|
|||
ID3D11Device5* GetD3DDevice() const noexcept { return _d3dDevice.get(); }
|
||||
ID3D11DeviceContext4* GetD3DDC() const noexcept { return _d3dDC.get(); }
|
||||
IDXGIAdapter4* GetGraphicsAdapter() const noexcept { return _graphicsAdapter.get(); }
|
||||
uint32_t GetAdapterIndex() const noexcept { return _adapterIdx; }
|
||||
|
||||
bool IsSupportTearing() const noexcept {
|
||||
return _isSupportTearing;
|
||||
|
|
@ -28,6 +29,7 @@ private:
|
|||
|
||||
winrt::com_ptr<IDXGIFactory7> _dxgiFactory;
|
||||
winrt::com_ptr<IDXGIAdapter4> _graphicsAdapter;
|
||||
uint32_t _adapterIdx = 0;
|
||||
winrt::com_ptr<ID3D11Device5> _d3dDevice;
|
||||
winrt::com_ptr<ID3D11DeviceContext4> _d3dDC;
|
||||
|
||||
|
|
|
|||
597
src/Magpie.Core/DirectMLInferenceBackend.cpp
Normal file
597
src/Magpie.Core/DirectMLInferenceBackend.cpp
Normal file
|
|
@ -0,0 +1,597 @@
|
|||
#include "pch.h"
|
||||
#include "DirectMLInferenceBackend.h"
|
||||
#include "DeviceResources.h"
|
||||
#include "DirectXHelper.h"
|
||||
#include "shaders/TensorToTextureCS.h"
|
||||
#include "shaders/TextureToTensorCS.h"
|
||||
#include "Logger.h"
|
||||
#include <onnxruntime/core/providers/dml/dml_provider_factory.h>
|
||||
#include "Win32Utils.h"
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
static winrt::com_ptr<ID3D12Device> CreateD3D12Device(IDXGIAdapter4* adapter) noexcept {
|
||||
#ifdef _DEBUG
|
||||
// 启用 D3D12 调试层
|
||||
{
|
||||
winrt::com_ptr<ID3D12Debug> debugController;
|
||||
HRESULT hr = D3D12GetDebugInterface(IID_PPV_ARGS(&debugController));
|
||||
if (SUCCEEDED(hr)) {
|
||||
debugController->EnableDebugLayer();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
winrt::com_ptr<ID3D12Device> d3d12Device;
|
||||
HRESULT hr = D3D12CreateDevice(
|
||||
adapter,
|
||||
D3D_FEATURE_LEVEL_11_0,
|
||||
IID_PPV_ARGS(&d3d12Device)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("D3D12CreateDevice 失败", hr);
|
||||
return d3d12Device;
|
||||
}
|
||||
|
||||
return d3d12Device;
|
||||
}
|
||||
|
||||
static winrt::com_ptr<IDMLDevice> CreateDMLDevice(ID3D12Device* d3d12Device) noexcept {
|
||||
winrt::com_ptr<IDMLDevice> dmlDevice;
|
||||
HRESULT hr = DMLCreateDevice1(
|
||||
d3d12Device,
|
||||
#ifdef _DEBUG
|
||||
DML_CREATE_DEVICE_FLAG_DEBUG,
|
||||
#else
|
||||
DML_CREATE_DEVICE_FLAG_NONE,
|
||||
#endif
|
||||
// https://github.com/microsoft/onnxruntime/blob/554fb4ad1fcf808304d4758d73d93a8ecc362bf6/onnxruntime/core/providers/dml/dml_provider_factory.cc#L519
|
||||
DML_FEATURE_LEVEL_5_0,
|
||||
IID_PPV_ARGS(&dmlDevice)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("DMLCreateDevice1 失败", hr);
|
||||
return dmlDevice;
|
||||
}
|
||||
|
||||
return dmlDevice;
|
||||
}
|
||||
|
||||
static winrt::com_ptr<ID3D12Resource> ShareTextureWithD3D12(ID3D11Texture2D* texture, ID3D12Device* d3d12Device, DWORD access) noexcept {
|
||||
winrt::com_ptr<ID3D12Resource> result;
|
||||
|
||||
winrt::com_ptr<IDXGIResource1> dxgiResource;
|
||||
HRESULT hr = texture->QueryInterface<IDXGIResource1>(dxgiResource.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("获取 IDXGIResource1 失败", hr);
|
||||
return result;
|
||||
}
|
||||
|
||||
wil::unique_handle sharedHandle;
|
||||
hr = dxgiResource->CreateSharedHandle(nullptr, access, nullptr, sharedHandle.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateSharedHandle 失败", hr);
|
||||
return result;
|
||||
}
|
||||
|
||||
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&result));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("OpenSharedHandle 失败", hr);
|
||||
return result;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static winrt::com_ptr<IUnknown> AllocateD3D12Resource(const OrtDmlApi* ortDmlApi, ID3D12Resource* buffer) {
|
||||
void* dmlResource;
|
||||
Ort::ThrowOnError(ortDmlApi->CreateGPUAllocationFromD3DResource(buffer, &dmlResource));
|
||||
|
||||
winrt::com_ptr<IUnknown> allocatedBuffer;
|
||||
allocatedBuffer.copy_from((IUnknown*)dmlResource);
|
||||
|
||||
Ort::ThrowOnError(ortDmlApi->FreeGPUAllocation(dmlResource));
|
||||
|
||||
return allocatedBuffer;
|
||||
}
|
||||
|
||||
bool DirectMLInferenceBackend::Initialize(
|
||||
const wchar_t* modelPath,
|
||||
uint32_t scale,
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& /*descriptorStore*/,
|
||||
ID3D11Texture2D* input,
|
||||
ID3D11Texture2D** output
|
||||
) noexcept {
|
||||
ID3D11Device5* d3d11Device = deviceResources.GetD3DDevice();
|
||||
_d3d11DC = deviceResources.GetD3DDC();
|
||||
|
||||
const SIZE inputSize = DirectXHelper::GetTextureSize(input);
|
||||
const SIZE outputSize{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
|
||||
|
||||
// 创建输出纹理
|
||||
_outputTex = DirectXHelper::CreateTexture2D(
|
||||
d3d11Device,
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM,
|
||||
outputSize.cx,
|
||||
outputSize.cy,
|
||||
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS,
|
||||
D3D11_USAGE_DEFAULT,
|
||||
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
|
||||
);
|
||||
if (!_outputTex) {
|
||||
Logger::Get().Error("创建输出纹理失败");
|
||||
return false;
|
||||
}
|
||||
*output = _outputTex.get();
|
||||
|
||||
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
|
||||
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);
|
||||
|
||||
winrt::com_ptr<ID3D12Device> d3d12Device = CreateD3D12Device(deviceResources.GetGraphicsAdapter());
|
||||
if (!d3d12Device) {
|
||||
Logger::Get().Error("CreateD3D12Device 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
D3D12_COMMAND_QUEUE_DESC desc{
|
||||
.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE,
|
||||
.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT
|
||||
};
|
||||
|
||||
HRESULT hr = d3d12Device->CreateCommandQueue(&desc, IID_PPV_ARGS(&_commandQueue));
|
||||
if (FAILED(hr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool isFP16Data = false;
|
||||
|
||||
try {
|
||||
const OrtApi& ortApi = Ort::GetApi();
|
||||
|
||||
_env = Ort::Env(ORT_LOGGING_LEVEL_INFO, "", _OrtLog, nullptr);
|
||||
|
||||
const OrtDmlApi* ortDmlApi = nullptr;
|
||||
ortApi.GetExecutionProviderApi("DML", ORT_API_VERSION, (const void**)&ortDmlApi);
|
||||
|
||||
Ort::SessionOptions sessionOptions;
|
||||
sessionOptions.SetIntraOpNumThreads(1);
|
||||
sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
||||
sessionOptions.DisableMemPattern();
|
||||
|
||||
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
|
||||
|
||||
winrt::com_ptr<IDMLDevice> dmlDevice = CreateDMLDevice(d3d12Device.get());
|
||||
if (!dmlDevice) {
|
||||
Logger::Get().Error("CreateDMLDevice 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
Ort::ThrowOnError(ortDmlApi->SessionOptionsAppendExecutionProvider_DML1(
|
||||
sessionOptions, dmlDevice.get(), _commandQueue.get()));
|
||||
|
||||
_session = Ort::Session(_env, modelPath, sessionOptions);
|
||||
|
||||
if (!_IsModelValid(_session, isFP16Data)) {
|
||||
Logger::Get().Error("不支持此模型");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 创建张量缓冲区
|
||||
{
|
||||
D3D12_HEAP_PROPERTIES heapDesc{
|
||||
.Type = D3D12_HEAP_TYPE_DEFAULT
|
||||
};
|
||||
D3D12_RESOURCE_DESC resDesc{
|
||||
.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER,
|
||||
.Width = (inputElemCount * (isFP16Data ? 2 : 4) + 3) & ~3,
|
||||
.Height = 1,
|
||||
.DepthOrArraySize = 1,
|
||||
.MipLevels = 1,
|
||||
.SampleDesc{
|
||||
.Count = 1
|
||||
},
|
||||
.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR,
|
||||
.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS
|
||||
};
|
||||
HRESULT hr = d3d12Device->CreateCommittedResource(
|
||||
&heapDesc,
|
||||
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
|
||||
&resDesc,
|
||||
D3D12_RESOURCE_STATE_COMMON,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(&_inputBuffer)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
resDesc.Width = UINT64((outputElemCount * (isFP16Data ? 2 : 4) + 3) & ~3);
|
||||
hr = d3d12Device->CreateCommittedResource(
|
||||
&heapDesc,
|
||||
D3D12_HEAP_FLAG_CREATE_NOT_ZEROED,
|
||||
&resDesc,
|
||||
D3D12_RESOURCE_STATE_COMMON,
|
||||
nullptr,
|
||||
IID_PPV_ARGS(&_outputBuffer)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 创建 IOBinding
|
||||
_ioBinding = Ort::IoBinding(_session);
|
||||
|
||||
// DmlExecutionProvider 的 device_id 始终为 0,传其他值会出错。
|
||||
// 见 https://github.com/microsoft/onnxruntime/blob/89f8206ba4f1c22c39e0297fb55272e8ce8cd7d0/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp#L77
|
||||
// WinML 也始终使用 0: https://github.com/microsoft/onnxruntime/blob/89f8206ba4f1c22c39e0297fb55272e8ce8cd7d0/winml/lib/Api.Ort/OnnxruntimeEngine.cpp#L654
|
||||
Ort::MemoryInfo memoryInfo(
|
||||
"DML",
|
||||
OrtAllocatorType::OrtDeviceAllocator,
|
||||
0,
|
||||
OrtMemType::OrtMemTypeDefault
|
||||
);
|
||||
|
||||
const ONNXTensorElementDataType dataType =
|
||||
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||
|
||||
const int64_t inputShape[]{ 1,3,inputSize.cy,inputSize.cx };
|
||||
_allocatedInput = AllocateD3D12Resource(ortDmlApi, _inputBuffer.get());
|
||||
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
|
||||
memoryInfo,
|
||||
_allocatedInput.get(),
|
||||
size_t(inputElemCount * (isFP16Data ? 2 : 4)),
|
||||
inputShape,
|
||||
std::size(inputShape),
|
||||
dataType
|
||||
));
|
||||
|
||||
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
|
||||
_allocatedOutput = AllocateD3D12Resource(ortDmlApi, _outputBuffer.get());
|
||||
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
|
||||
memoryInfo,
|
||||
_allocatedOutput.get(),
|
||||
size_t(outputElemCount * (isFP16Data ? 2 : 4)),
|
||||
outputShape,
|
||||
std::size(outputShape),
|
||||
dataType
|
||||
));
|
||||
} catch (const Ort::Exception& e) {
|
||||
Logger::Get().Error(e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_CreateFence(d3d11Device, d3d12Device.get())) {
|
||||
Logger::Get().Error("_CreateFence 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
_d3d12InputTex = ShareTextureWithD3D12(input, d3d12Device.get(), DXGI_SHARED_RESOURCE_READ);
|
||||
_d3d12OutputTex = ShareTextureWithD3D12(_outputTex.get(), d3d12Device.get(),
|
||||
DXGI_SHARED_RESOURCE_READ | DXGI_SHARED_RESOURCE_WRITE);
|
||||
if (!_d3d12InputTex || !_d3d12OutputTex) {
|
||||
Logger::Get().Error("ShareTextureWithD3D12 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
UINT descriptorSize;
|
||||
if (!_CreateCBVHeap(d3d12Device.get(), inputElemCount, outputElemCount, isFP16Data, descriptorSize)) {
|
||||
Logger::Get().Error("_CreateCBVHeap 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_CreatePipelineStates(d3d12Device.get())) {
|
||||
Logger::Get().Error("_CreatePipelineStates 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_CalcCommandLists(d3d12Device.get(), inputSize, outputSize, descriptorSize)) {
|
||||
Logger::Get().Error("_CalcCommandLists 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void DirectMLInferenceBackend::Evaluate() noexcept {
|
||||
HRESULT hr = _d3d11DC->Signal(_d3d11Fence.get(), ++_fenceValue);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Signal 失败", hr);
|
||||
return;
|
||||
}
|
||||
_d3d11DC->Flush();
|
||||
|
||||
hr = _commandQueue->Wait(_d3d12Fence.get(), _fenceValue);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Wait 失败", hr);
|
||||
return;
|
||||
}
|
||||
|
||||
// 输入纹理 -> 输入张量
|
||||
{
|
||||
ID3D12CommandList* t = _tex2TensorCommandList.get();
|
||||
_commandQueue->ExecuteCommandLists(1, &t);
|
||||
}
|
||||
|
||||
try {
|
||||
_session.Run(Ort::RunOptions{ nullptr }, _ioBinding);
|
||||
} catch (const Ort::Exception& e) {
|
||||
Logger::Get().Error(e.what());
|
||||
return;
|
||||
}
|
||||
|
||||
// 输出张量 -> 输出纹理
|
||||
{
|
||||
ID3D12CommandList* t = _tensor2TexCommandList.get();
|
||||
_commandQueue->ExecuteCommandLists(1, &t);
|
||||
}
|
||||
|
||||
hr = _commandQueue->Signal(_d3d12Fence.get(), ++_fenceValue);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Signal 失败", hr);
|
||||
return;
|
||||
}
|
||||
|
||||
hr = _d3d11DC->Wait(_d3d11Fence.get(), _fenceValue);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Wait 失败", hr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
bool DirectMLInferenceBackend::_CreateFence(ID3D11Device5* d3d11Device, ID3D12Device* d3d12Device) noexcept {
|
||||
HRESULT hr = d3d11Device->CreateFence(
|
||||
_fenceValue, D3D11_FENCE_FLAG_SHARED, IID_PPV_ARGS(&_d3d11Fence));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateFence 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
wil::unique_handle sharedHandle;
|
||||
hr = _d3d11Fence->CreateSharedHandle(nullptr, GENERIC_ALL, nullptr, sharedHandle.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateSharedHandle 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
hr = d3d12Device->OpenSharedHandle(sharedHandle.get(), IID_PPV_ARGS(&_d3d12Fence));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("OpenSharedHandle 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DirectMLInferenceBackend::_CreateCBVHeap(
|
||||
ID3D12Device* d3d12Device,
|
||||
uint32_t inputElemCount,
|
||||
uint32_t outputElemCount,
|
||||
bool isFP16Data,
|
||||
UINT& descriptorSize
|
||||
) noexcept {
|
||||
{
|
||||
D3D12_DESCRIPTOR_HEAP_DESC desc{
|
||||
.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
|
||||
.NumDescriptors = 4,
|
||||
.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE
|
||||
};
|
||||
|
||||
HRESULT hr = d3d12Device->CreateDescriptorHeap(&desc, IID_PPV_ARGS(&_cbvHeap));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateDescriptorHeap 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
descriptorSize = d3d12Device->GetDescriptorHandleIncrementSize(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
|
||||
|
||||
D3D12_CPU_DESCRIPTOR_HANDLE cbvHandle = _cbvHeap->GetCPUDescriptorHandleForHeapStart();
|
||||
|
||||
d3d12Device->CreateShaderResourceView(_d3d12InputTex.get(), nullptr, cbvHandle);
|
||||
cbvHandle.ptr += descriptorSize;
|
||||
|
||||
{
|
||||
D3D12_UNORDERED_ACCESS_VIEW_DESC desc{
|
||||
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
|
||||
.ViewDimension = D3D12_UAV_DIMENSION_BUFFER,
|
||||
.Buffer{
|
||||
.NumElements = inputElemCount
|
||||
}
|
||||
};
|
||||
d3d12Device->CreateUnorderedAccessView(_inputBuffer.get(), nullptr, &desc, cbvHandle);
|
||||
}
|
||||
cbvHandle.ptr += descriptorSize;
|
||||
|
||||
{
|
||||
D3D12_SHADER_RESOURCE_VIEW_DESC desc{
|
||||
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
|
||||
.ViewDimension = D3D12_SRV_DIMENSION_BUFFER,
|
||||
.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING,
|
||||
.Buffer{
|
||||
.NumElements = outputElemCount
|
||||
}
|
||||
};
|
||||
d3d12Device->CreateShaderResourceView(_outputBuffer.get(), &desc, cbvHandle);
|
||||
}
|
||||
cbvHandle.ptr += descriptorSize;
|
||||
|
||||
d3d12Device->CreateUnorderedAccessView(_d3d12OutputTex.get(), nullptr, nullptr, cbvHandle);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DirectMLInferenceBackend::_CreatePipelineStates(ID3D12Device* d3d12Device) noexcept {
|
||||
{
|
||||
D3D12_DESCRIPTOR_RANGE ranges[]{
|
||||
D3D12_DESCRIPTOR_RANGE{
|
||||
.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_SRV,
|
||||
.NumDescriptors = 1,
|
||||
.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND
|
||||
},
|
||||
D3D12_DESCRIPTOR_RANGE{
|
||||
.RangeType = D3D12_DESCRIPTOR_RANGE_TYPE_UAV,
|
||||
.NumDescriptors = 1,
|
||||
.OffsetInDescriptorsFromTableStart = D3D12_DESCRIPTOR_RANGE_OFFSET_APPEND
|
||||
},
|
||||
};
|
||||
|
||||
D3D12_ROOT_PARAMETER rootParam{
|
||||
D3D12_ROOT_PARAMETER{
|
||||
.ParameterType = D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE,
|
||||
.DescriptorTable{
|
||||
.NumDescriptorRanges = (UINT)std::size(ranges),
|
||||
.pDescriptorRanges = ranges
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
D3D12_STATIC_SAMPLER_DESC samDesc{
|
||||
.Filter = D3D12_FILTER_MIN_MAG_MIP_POINT,
|
||||
.AddressU = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
|
||||
.AddressV = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
|
||||
.AddressW = D3D12_TEXTURE_ADDRESS_MODE_CLAMP,
|
||||
.ComparisonFunc = D3D12_COMPARISON_FUNC_NEVER
|
||||
};
|
||||
|
||||
D3D12_VERSIONED_ROOT_SIGNATURE_DESC desc{
|
||||
.Version = D3D_ROOT_SIGNATURE_VERSION_1_0,
|
||||
.Desc_1_0{
|
||||
.NumParameters = 1,
|
||||
.pParameters = &rootParam,
|
||||
.NumStaticSamplers = 1,
|
||||
.pStaticSamplers = &samDesc
|
||||
}
|
||||
};
|
||||
|
||||
winrt::com_ptr<ID3DBlob> blob;
|
||||
HRESULT hr = D3D12SerializeVersionedRootSignature(&desc, blob.put(), nullptr);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("D3D12SerializeVersionedRootSignature 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
hr = d3d12Device->CreateRootSignature(
|
||||
0,
|
||||
blob->GetBufferPointer(),
|
||||
blob->GetBufferSize(),
|
||||
IID_PPV_ARGS(&_rootSignature)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateRootSignature 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC desc{
|
||||
.pRootSignature = _rootSignature.get(),
|
||||
.CS{
|
||||
.pShaderBytecode = TextureToTensorCS,
|
||||
.BytecodeLength = std::size(TextureToTensorCS)
|
||||
}
|
||||
};
|
||||
HRESULT hr = d3d12Device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&_tex2TensorPipelineState));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateComputePipelineState 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
desc.CS.pShaderBytecode = TensorToTextureCS;
|
||||
desc.CS.BytecodeLength = std::size(TensorToTextureCS);
|
||||
hr = d3d12Device->CreateComputePipelineState(&desc, IID_PPV_ARGS(&_tensor2TexPipelineState));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateComputePipelineState 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DirectMLInferenceBackend::_CalcCommandLists(
|
||||
ID3D12Device* d3d12Device,
|
||||
SIZE inputSize,
|
||||
SIZE outputSize,
|
||||
UINT descriptorSize
|
||||
) noexcept {
|
||||
winrt::com_ptr<ID3D12CommandAllocator> d3d12CommandAllocator;
|
||||
HRESULT hr = d3d12Device->CreateCommandAllocator(
|
||||
D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&d3d12CommandAllocator));
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateCommandAllocator 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// 输入纹理 -> 输入张量
|
||||
hr = d3d12Device->CreateCommandList(
|
||||
0,
|
||||
D3D12_COMMAND_LIST_TYPE_COMPUTE,
|
||||
d3d12CommandAllocator.get(),
|
||||
_tex2TensorPipelineState.get(),
|
||||
IID_PPV_ARGS(&_tex2TensorCommandList)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateCommandList 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
_tex2TensorCommandList->SetComputeRootSignature(_rootSignature.get());
|
||||
{
|
||||
ID3D12DescriptorHeap* t = _cbvHeap.get();
|
||||
_tex2TensorCommandList->SetDescriptorHeaps(1, &t);
|
||||
}
|
||||
_tex2TensorCommandList->SetComputeRootDescriptorTable(0, _cbvHeap->GetGPUDescriptorHandleForHeapStart());
|
||||
|
||||
static constexpr std::pair<uint32_t, uint32_t> TEX_TO_TENSOR_BLOCK_SIZE{ 16, 16 };
|
||||
_tex2TensorCommandList->Dispatch(
|
||||
(inputSize.cx + TEX_TO_TENSOR_BLOCK_SIZE.first - 1) / TEX_TO_TENSOR_BLOCK_SIZE.first,
|
||||
(inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second,
|
||||
1
|
||||
);
|
||||
hr = _tex2TensorCommandList->Close();
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Close 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
// 输出张量 -> 输出纹理
|
||||
hr = d3d12Device->CreateCommandList(
|
||||
0,
|
||||
D3D12_COMMAND_LIST_TYPE_COMPUTE,
|
||||
d3d12CommandAllocator.get(),
|
||||
_tensor2TexPipelineState.get(),
|
||||
IID_PPV_ARGS(&_tensor2TexCommandList)
|
||||
);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateCommandList 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
_tensor2TexCommandList->SetComputeRootSignature(_rootSignature.get());
|
||||
{
|
||||
ID3D12DescriptorHeap* t = _cbvHeap.get();
|
||||
_tensor2TexCommandList->SetDescriptorHeaps(1, &t);
|
||||
}
|
||||
{
|
||||
D3D12_GPU_DESCRIPTOR_HANDLE gpuHandle = _cbvHeap->GetGPUDescriptorHandleForHeapStart();
|
||||
gpuHandle.ptr += 2 * static_cast<UINT64>(descriptorSize);
|
||||
_tensor2TexCommandList->SetComputeRootDescriptorTable(0, gpuHandle);
|
||||
}
|
||||
|
||||
static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
|
||||
_tensor2TexCommandList->Dispatch(
|
||||
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
|
||||
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second,
|
||||
1
|
||||
);
|
||||
hr = _tensor2TexCommandList->Close();
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("Close 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
77
src/Magpie.Core/DirectMLInferenceBackend.h
Normal file
77
src/Magpie.Core/DirectMLInferenceBackend.h
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
#pragma once
|
||||
#include "InferenceBackendBase.h"
|
||||
#include <d3d12.h>
|
||||
|
||||
struct OrtDmlApi;
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
class DirectMLInferenceBackend : public InferenceBackendBase {
|
||||
public:
|
||||
DirectMLInferenceBackend() = default;
|
||||
DirectMLInferenceBackend(const DirectMLInferenceBackend&) = delete;
|
||||
DirectMLInferenceBackend(DirectMLInferenceBackend&&) = default;
|
||||
|
||||
bool Initialize(
|
||||
const wchar_t* modelPath,
|
||||
uint32_t scale,
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D* input,
|
||||
ID3D11Texture2D** output
|
||||
) noexcept override;
|
||||
|
||||
void Evaluate() noexcept override;
|
||||
|
||||
private:
|
||||
bool _CreateFence(ID3D11Device5* d3d11Device, ID3D12Device* d3d12Device) noexcept;
|
||||
|
||||
bool _CreateCBVHeap(
|
||||
ID3D12Device* d3d12Device,
|
||||
uint32_t inputElemCount,
|
||||
uint32_t outputElemCount,
|
||||
bool isFP16Data,
|
||||
UINT& descriptorSize
|
||||
) noexcept;
|
||||
|
||||
bool _CreatePipelineStates(ID3D12Device* d3d12Device) noexcept;
|
||||
|
||||
bool _CalcCommandLists(
|
||||
ID3D12Device* d3d12Device,
|
||||
SIZE inputSize,
|
||||
SIZE outputSize,
|
||||
UINT descriptorSize
|
||||
) noexcept;
|
||||
|
||||
ID3D11DeviceContext4* _d3d11DC = nullptr;
|
||||
|
||||
winrt::com_ptr<ID3D11Texture2D> _outputTex;
|
||||
|
||||
winrt::com_ptr<ID3D11Fence> _d3d11Fence;
|
||||
winrt::com_ptr<ID3D12Fence> _d3d12Fence;
|
||||
UINT64 _fenceValue = 0;
|
||||
|
||||
winrt::com_ptr<ID3D12Resource> _d3d12InputTex;
|
||||
winrt::com_ptr<ID3D12Resource> _d3d12OutputTex;
|
||||
winrt::com_ptr<ID3D12Resource> _inputBuffer;
|
||||
winrt::com_ptr<ID3D12Resource> _outputBuffer;
|
||||
|
||||
winrt::com_ptr<ID3D12DescriptorHeap> _cbvHeap;
|
||||
winrt::com_ptr<ID3D12RootSignature> _rootSignature;
|
||||
winrt::com_ptr<ID3D12PipelineState> _tex2TensorPipelineState;
|
||||
winrt::com_ptr<ID3D12PipelineState> _tensor2TexPipelineState;
|
||||
|
||||
winrt::com_ptr<ID3D12CommandQueue> _commandQueue;
|
||||
winrt::com_ptr<ID3D12GraphicsCommandList> _tex2TensorCommandList;
|
||||
winrt::com_ptr<ID3D12GraphicsCommandList> _tensor2TexCommandList;
|
||||
|
||||
Ort::Env _env{ nullptr };
|
||||
Ort::Session _session{ nullptr };
|
||||
|
||||
winrt::com_ptr<IUnknown> _allocatedInput;
|
||||
winrt::com_ptr<IUnknown> _allocatedOutput;
|
||||
|
||||
Ort::IoBinding _ioBinding{ nullptr };
|
||||
};
|
||||
|
||||
}
|
||||
|
|
@ -107,4 +107,10 @@ winrt::com_ptr<ID3D11Texture2D> DirectXHelper::CreateTexture2D(
|
|||
return result;
|
||||
}
|
||||
|
||||
SIZE DirectXHelper::GetTextureSize(ID3D11Texture2D* texture) noexcept {
|
||||
D3D11_TEXTURE2D_DESC desc;
|
||||
texture->GetDesc(&desc);
|
||||
return SIZE{ (LONG)desc.Width, (LONG)desc.Height };
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ struct DirectXHelper {
|
|||
UINT miscFlags = 0,
|
||||
const D3D11_SUBRESOURCE_DATA* pInitialData = nullptr
|
||||
) noexcept;
|
||||
|
||||
static SIZE GetTextureSize(ID3D11Texture2D* texture) noexcept;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,7 +87,9 @@ bool DwmSharedSurfaceFrameSource::_Initialize() noexcept {
|
|||
DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
frameRect.right - frameRect.left,
|
||||
frameRect.bottom - frameRect.top,
|
||||
D3D11_BIND_SHADER_RESOURCE
|
||||
D3D11_BIND_SHADER_RESOURCE,
|
||||
D3D11_USAGE_DEFAULT,
|
||||
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
|
||||
);
|
||||
if (!_output) {
|
||||
Logger::Get().Error("CreateTexture2D 失败");
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <d3dcompiler.h>
|
||||
#include "Utils.h"
|
||||
#include "YasHelper.h"
|
||||
#include "HashHelper.h"
|
||||
|
||||
namespace yas::detail {
|
||||
|
||||
|
|
@ -235,27 +236,6 @@ void EffectCacheManager::Save(std::wstring_view effectName, std::wstring_view ha
|
|||
Logger::Get().Info(StrUtils::Concat("已保存缓存 ", StrUtils::UTF16ToUTF8(cacheFileName)));
|
||||
}
|
||||
|
||||
static std::wstring HexHash(std::span<const BYTE> data) {
|
||||
uint64_t hashBytes = Utils::HashData(data);
|
||||
|
||||
static wchar_t oct2Hex[16] = {
|
||||
L'0',L'1',L'2',L'3',L'4',L'5',L'6',L'7',
|
||||
L'8',L'9',L'a',L'b',L'c',L'd',L'e',L'f'
|
||||
};
|
||||
|
||||
std::wstring result(16, 0);
|
||||
wchar_t* pResult = &result[0];
|
||||
|
||||
BYTE* b = (BYTE*)&hashBytes;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
*pResult++ = oct2Hex[(*b >> 4) & 0xf];
|
||||
*pResult++ = oct2Hex[*b & 0xf];
|
||||
++b;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::wstring EffectCacheManager::GetHash(
|
||||
std::string_view source,
|
||||
const phmap::flat_hash_map<std::wstring, float>* inlineParams
|
||||
|
|
@ -271,7 +251,7 @@ std::wstring EffectCacheManager::GetHash(
|
|||
}
|
||||
}
|
||||
|
||||
return HexHash(std::span((const BYTE*)source.data(), source.size()));
|
||||
return HashHelper::HexHash(std::span((const BYTE*)str.data(), str.size()));
|
||||
}
|
||||
|
||||
std::wstring EffectCacheManager::GetHash(std::string& source, const phmap::flat_hash_map<std::wstring, float>* inlineParams) {
|
||||
|
|
@ -286,7 +266,7 @@ std::wstring EffectCacheManager::GetHash(std::string& source, const phmap::flat_
|
|||
}
|
||||
}
|
||||
|
||||
std::wstring result = HexHash(std::span((const BYTE*)source.data(), source.size()));
|
||||
std::wstring result = HashHelper::HexHash(std::span((const BYTE*)source.data(), source.size()));
|
||||
source.resize(originSize);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1376,9 +1376,8 @@ static uint32_t CompilePasses(
|
|||
std::wstring sourcesPath = sourcesPathName.substr(0, sourcesPathName.find_last_of(L'\\'));
|
||||
|
||||
if ((flags & EffectCompilerFlags::SaveSources) && !Win32Utils::DirExists(sourcesPath.c_str())) {
|
||||
HRESULT hr = wil::CreateDirectoryDeepNoThrow(sourcesPath.c_str());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("创建 sources 文件夹失败", hr);
|
||||
if (!Win32Utils::CreateDir(sourcesPath, true)) {
|
||||
Logger::Get().Win32Error("创建 sources 文件夹失败");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -92,12 +92,7 @@ bool EffectDrawer::Initialize(
|
|||
) noexcept {
|
||||
_d3dDC = deviceResources.GetD3DDC();
|
||||
|
||||
SIZE inputSize{};
|
||||
{
|
||||
D3D11_TEXTURE2D_DESC inputDesc;
|
||||
(*inOutTexture)->GetDesc(&inputDesc);
|
||||
inputSize = { (LONG)inputDesc.Width, (LONG)inputDesc.Height };
|
||||
}
|
||||
const SIZE inputSize = DirectXHelper::GetTextureSize(*inOutTexture);
|
||||
|
||||
static mu::Parser exprParser;
|
||||
exprParser.DefineConst("INPUT_WIDTH", inputSize.cx);
|
||||
|
|
@ -165,7 +160,7 @@ bool EffectDrawer::Initialize(
|
|||
|
||||
if (texDesc.format != EffectIntermediateTextureFormat::UNKNOWN) {
|
||||
// 检查纹理格式是否匹配
|
||||
D3D11_TEXTURE2D_DESC srcDesc{};
|
||||
D3D11_TEXTURE2D_DESC srcDesc;
|
||||
_textures[i]->GetDesc(&srcDesc);
|
||||
if (srcDesc.Format != EffectHelper::FORMAT_DESCS[(uint32_t)texDesc.format].dxgiFormat) {
|
||||
Logger::Get().Error("SOURCE 纹理格式不匹配");
|
||||
|
|
@ -235,11 +230,10 @@ bool EffectDrawer::Initialize(
|
|||
}
|
||||
}
|
||||
|
||||
D3D11_TEXTURE2D_DESC outputDesc;
|
||||
_textures[passDesc.outputs[0]]->GetDesc(&outputDesc);
|
||||
SIZE passOutputSize = DirectXHelper::GetTextureSize(_textures[passDesc.outputs[0]].get());
|
||||
_dispatches.emplace_back(
|
||||
(outputDesc.Width + passDesc.blockSize.first - 1) / passDesc.blockSize.first,
|
||||
(outputDesc.Height + passDesc.blockSize.second - 1) / passDesc.blockSize.second
|
||||
(passOutputSize.cx + passDesc.blockSize.first - 1) / passDesc.blockSize.first,
|
||||
(passOutputSize.cy + passDesc.blockSize.second - 1) / passDesc.blockSize.second
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -293,7 +287,7 @@ bool EffectDrawer::_InitializeConstants(
|
|||
psStylePassParams += 4;
|
||||
}
|
||||
}
|
||||
_constants.resize((builtinConstantCount + psStylePassParams + (isInlineParams ? 0 : desc.params.size()) + 3) / 4 * 4);
|
||||
_constants.resize((builtinConstantCount + psStylePassParams + (isInlineParams ? 0 : desc.params.size()) + 3) & ~3);
|
||||
// cbuffer __CB1 : register(b0) {
|
||||
// uint2 __inputSize;
|
||||
// uint2 __outputSize;
|
||||
|
|
@ -318,15 +312,14 @@ bool EffectDrawer::_InitializeConstants(
|
|||
if (psStylePassParams > 0) {
|
||||
for (UINT i = 0, end = (UINT)desc.passes.size() - 1; i < end; ++i) {
|
||||
if (desc.passes[i].isPSStyle) {
|
||||
D3D11_TEXTURE2D_DESC outputDesc;
|
||||
_textures[desc.passes[i].outputs[0]]->GetDesc(&outputDesc);
|
||||
pCurParam->uintVal = outputDesc.Width;
|
||||
SIZE passOutputSize = DirectXHelper::GetTextureSize(_textures[desc.passes[i].outputs[0]].get());
|
||||
pCurParam->uintVal = passOutputSize.cx;
|
||||
++pCurParam;
|
||||
pCurParam->uintVal = outputDesc.Height;
|
||||
pCurParam->uintVal = passOutputSize.cy;
|
||||
++pCurParam;
|
||||
pCurParam->floatVal = 1.0f / outputDesc.Width;
|
||||
pCurParam->floatVal = 1.0f / passOutputSize.cx;
|
||||
++pCurParam;
|
||||
pCurParam->floatVal = 1.0f / outputDesc.Height;
|
||||
pCurParam->floatVal = 1.0f / passOutputSize.cy;
|
||||
++pCurParam;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ bool GDIFrameSource::_Initialize() noexcept {
|
|||
_frameRect.bottom - _frameRect.top,
|
||||
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_RENDER_TARGET,
|
||||
D3D11_USAGE_DEFAULT,
|
||||
D3D11_RESOURCE_MISC_GDI_COMPATIBLE
|
||||
D3D11_RESOURCE_MISC_GDI_COMPATIBLE | D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
|
||||
);
|
||||
if (!_output) {
|
||||
Logger::Get().Error("创建纹理失败");
|
||||
|
|
|
|||
|
|
@ -75,7 +75,9 @@ bool GraphicsCaptureFrameSource::_Initialize() noexcept {
|
|||
DXGI_FORMAT_B8G8R8A8_UNORM,
|
||||
_frameBox.right - _frameBox.left,
|
||||
_frameBox.bottom - _frameBox.top,
|
||||
D3D11_BIND_SHADER_RESOURCE
|
||||
D3D11_BIND_SHADER_RESOURCE,
|
||||
D3D11_USAGE_DEFAULT,
|
||||
D3D11_RESOURCE_MISC_SHARED | D3D11_RESOURCE_MISC_SHARED_NTHANDLE
|
||||
);
|
||||
if (!_output) {
|
||||
Logger::Get().Error("创建纹理失败");
|
||||
|
|
|
|||
28
src/Magpie.Core/HashHelper.cpp
Normal file
28
src/Magpie.Core/HashHelper.cpp
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
#include "pch.h"
|
||||
#include "HashHelper.h"
|
||||
#include "Utils.h"
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
std::wstring HashHelper::HexHash(std::span<const uint8_t> data) noexcept {
|
||||
uint64_t hashBytes = Utils::HashData(data);
|
||||
|
||||
static wchar_t oct2Hex[16] = {
|
||||
L'0',L'1',L'2',L'3',L'4',L'5',L'6',L'7',
|
||||
L'8',L'9',L'a',L'b',L'c',L'd',L'e',L'f'
|
||||
};
|
||||
|
||||
std::wstring result(16, 0);
|
||||
wchar_t* pResult = &result[0];
|
||||
|
||||
BYTE* b = (BYTE*)&hashBytes;
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
*pResult++ = oct2Hex[(*b >> 4) & 0xf];
|
||||
*pResult++ = oct2Hex[*b & 0xf];
|
||||
++b;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
9
src/Magpie.Core/HashHelper.h
Normal file
9
src/Magpie.Core/HashHelper.h
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
struct HashHelper {
|
||||
static std::wstring HexHash(std::span<const uint8_t> data) noexcept;
|
||||
};
|
||||
|
||||
}
|
||||
80
src/Magpie.Core/InferenceBackendBase.cpp
Normal file
80
src/Magpie.Core/InferenceBackendBase.cpp
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
#include "pch.h"
|
||||
#include "InferenceBackendBase.h"
|
||||
#include "StrUtils.h"
|
||||
#include "Logger.h"
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
void ORT_API_CALL InferenceBackendBase::_OrtLog(
|
||||
void* /*param*/,
|
||||
OrtLoggingLevel severity,
|
||||
const char* /*category*/,
|
||||
const char* /*logid*/,
|
||||
const char* /*code_location*/,
|
||||
const char* message
|
||||
) {
|
||||
const char* SEVERITIES[] = {
|
||||
"verbose",
|
||||
"info",
|
||||
"warning",
|
||||
"error",
|
||||
"fatal"
|
||||
};
|
||||
|
||||
std::string log = StrUtils::Concat("[", SEVERITIES[severity], "] ", message);
|
||||
if (severity == ORT_LOGGING_LEVEL_INFO) {
|
||||
Logger::Get().Info(log);
|
||||
OutputDebugStringA((log + "\n").c_str());
|
||||
} else if (severity == ORT_LOGGING_LEVEL_WARNING) {
|
||||
Logger::Get().Warn(log);
|
||||
} else {
|
||||
Logger::Get().Error(log);
|
||||
}
|
||||
}
|
||||
|
||||
static bool IsTensorShapeValid(const Ort::ConstTensorTypeAndShapeInfo& tensorInfo) {
|
||||
// 输入输出维度应是 [-1,3,-1,-1]
|
||||
std::vector<int64_t> dimensions = tensorInfo.GetShape();
|
||||
return dimensions.size() == 4 && dimensions[0] == -1 &&
|
||||
dimensions[1] == 3 && dimensions[2] == -1 && dimensions[3] == -1;
|
||||
}
|
||||
|
||||
bool InferenceBackendBase::_IsModelValid(const Ort::Session& session, bool& isFP16Data) {
|
||||
if (session.GetInputCount() != 1 || session.GetOutputCount() != 1) {
|
||||
Logger::Get().Error("不支持有多个输入/输出的模型");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 必须在 inputTypeInfo 的生命周期内使用 inputTensorInfo
|
||||
Ort::TypeInfo inputTypeInfo = session.GetInputTypeInfo(0);
|
||||
Ort::ConstTensorTypeAndShapeInfo inputTensorInfo = inputTypeInfo.GetTensorTypeAndShapeInfo();
|
||||
|
||||
ONNXTensorElementDataType dataType = inputTensorInfo.GetElementType();
|
||||
if (dataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 && dataType != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
Logger::Get().Error("不支持 float16 和 float 之外的输入数据类型");
|
||||
return false;
|
||||
}
|
||||
|
||||
isFP16Data = dataType == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
|
||||
|
||||
Ort::TypeInfo outputInfo = session.GetOutputTypeInfo(0);
|
||||
Ort::ConstTensorTypeAndShapeInfo outputTensorInfo = outputInfo.GetTensorTypeAndShapeInfo();
|
||||
if (outputInfo.GetTensorTypeAndShapeInfo().GetElementType() != dataType) {
|
||||
Logger::Get().Error("不支持输入和输出数据类型不同的模型");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsTensorShapeValid(inputTensorInfo)) {
|
||||
Logger::Get().Error("不支持的输入数据格式");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsTensorShapeValid(outputTensorInfo)) {
|
||||
Logger::Get().Error("不支持的输出数据格式");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
41
src/Magpie.Core/InferenceBackendBase.h
Normal file
41
src/Magpie.Core/InferenceBackendBase.h
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#pragma once
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
class DeviceResources;
|
||||
class BackendDescriptorStore;
|
||||
|
||||
class InferenceBackendBase {
|
||||
public:
|
||||
InferenceBackendBase() = default;
|
||||
virtual ~InferenceBackendBase() noexcept {}
|
||||
|
||||
InferenceBackendBase(const InferenceBackendBase&) = delete;
|
||||
InferenceBackendBase(InferenceBackendBase&&) = default;
|
||||
|
||||
virtual bool Initialize(
|
||||
const wchar_t* modelPath,
|
||||
uint32_t scale,
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D* input,
|
||||
ID3D11Texture2D** output
|
||||
) noexcept = 0;
|
||||
|
||||
virtual void Evaluate() noexcept = 0;
|
||||
|
||||
protected:
|
||||
static void ORT_API_CALL _OrtLog(
|
||||
void* /*param*/,
|
||||
OrtLoggingLevel severity,
|
||||
const char* /*category*/,
|
||||
const char* /*logid*/,
|
||||
const char* /*code_location*/,
|
||||
const char* message
|
||||
);
|
||||
|
||||
static bool _IsModelValid(const Ort::Session& session, bool& isFP16Data);
|
||||
};
|
||||
|
||||
}
|
||||
|
|
@ -43,12 +43,14 @@
|
|||
</ItemDefinitionGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="BackendDescriptorStore.h" />
|
||||
<ClInclude Include="TensorRTInferenceBackend.h" />
|
||||
<ClInclude Include="CursorManager.h" />
|
||||
<ClInclude Include="CursorDrawer.h" />
|
||||
<ClInclude Include="DDS.h" />
|
||||
<ClInclude Include="DDSLoderHelpers.h" />
|
||||
<ClInclude Include="DesktopDuplicationFrameSource.h" />
|
||||
<ClInclude Include="DeviceResources.h" />
|
||||
<ClInclude Include="DirectMLInferenceBackend.h" />
|
||||
<ClInclude Include="DirectXHelper.h" />
|
||||
<ClInclude Include="DwmSharedSurfaceFrameSource.h" />
|
||||
<ClInclude Include="EffectCacheManager.h" />
|
||||
|
|
@ -61,11 +63,15 @@
|
|||
<ClInclude Include="FrameSourceBase.h" />
|
||||
<ClInclude Include="GDIFrameSource.h" />
|
||||
<ClInclude Include="GraphicsCaptureFrameSource.h" />
|
||||
<ClInclude Include="HashHelper.h" />
|
||||
<ClInclude Include="ImGuiBackend.h" />
|
||||
<ClInclude Include="ImGuiFontsCacheManager.h" />
|
||||
<ClInclude Include="ImGuiHelper.h" />
|
||||
<ClInclude Include="ImGuiImpl.h" />
|
||||
<ClInclude Include="include\Magpie.Core.h" />
|
||||
<ClInclude Include="InferenceBackendBase.h" />
|
||||
<ClInclude Include="OnnxEffectDrawer.h" />
|
||||
<ClInclude Include="OnnxHelper.h" />
|
||||
<ClInclude Include="OverlayDrawer.h" />
|
||||
<ClInclude Include="Renderer.h" />
|
||||
<ClInclude Include="ScalingOptions.h" />
|
||||
|
|
@ -80,10 +86,12 @@
|
|||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="BackendDescriptorStore.cpp" />
|
||||
<ClCompile Include="TensorRTInferenceBackend.cpp" />
|
||||
<ClCompile Include="CursorManager.cpp" />
|
||||
<ClCompile Include="CursorDrawer.cpp" />
|
||||
<ClCompile Include="DesktopDuplicationFrameSource.cpp" />
|
||||
<ClCompile Include="DeviceResources.cpp" />
|
||||
<ClCompile Include="DirectMLInferenceBackend.cpp" />
|
||||
<ClCompile Include="DirectXHelper.cpp" />
|
||||
<ClCompile Include="DwmSharedSurfaceFrameSource.cpp" />
|
||||
<ClCompile Include="EffectCacheManager.cpp" />
|
||||
|
|
@ -94,10 +102,13 @@
|
|||
<ClCompile Include="FrameSourceBase.cpp" />
|
||||
<ClCompile Include="GDIFrameSource.cpp" />
|
||||
<ClCompile Include="GraphicsCaptureFrameSource.cpp" />
|
||||
<ClCompile Include="HashHelper.cpp" />
|
||||
<ClCompile Include="ImGuiBackend.cpp" />
|
||||
<ClCompile Include="ImGuiFontsCacheManager.cpp" />
|
||||
<ClCompile Include="ImGuiHelper.cpp" />
|
||||
<ClCompile Include="ImGuiImpl.cpp" />
|
||||
<ClCompile Include="InferenceBackendBase.cpp" />
|
||||
<ClCompile Include="OnnxEffectDrawer.cpp" />
|
||||
<ClCompile Include="OverlayDrawer.cpp" />
|
||||
<ClCompile Include="pch.cpp">
|
||||
<PrecompiledHeader>Create</PrecompiledHeader>
|
||||
|
|
@ -111,6 +122,9 @@
|
|||
<ClCompile Include="WindowHelper.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<FxCompile Include="shaders\TensorToTextureCS.hlsl">
|
||||
<ShaderType>Compute</ShaderType>
|
||||
</FxCompile>
|
||||
<FxCompile Include="shaders\DuplicateFrameCS.hlsl">
|
||||
<ShaderType>Compute</ShaderType>
|
||||
</FxCompile>
|
||||
|
|
@ -132,21 +146,24 @@
|
|||
<FxCompile Include="shaders\SimpleVS.hlsl">
|
||||
<ShaderType>Vertex</ShaderType>
|
||||
</FxCompile>
|
||||
<FxCompile Include="shaders\TextureToTensorCS.hlsl">
|
||||
<ShaderType>Compute</ShaderType>
|
||||
</FxCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
||||
|
|
@ -19,6 +19,9 @@
|
|||
<Filter Include="Shaders">
|
||||
<UniqueIdentifier>{1956ae10-07ad-4b77-a37f-25f7fe10654b}</UniqueIdentifier>
|
||||
</Filter>
|
||||
<Filter Include="ONNX">
|
||||
<UniqueIdentifier>{c5acb0d2-df90-4589-8914-2bfff00194ec}</UniqueIdentifier>
|
||||
</Filter>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClInclude Include="pch.h" />
|
||||
|
|
@ -91,7 +94,25 @@
|
|||
<ClInclude Include="DesktopDuplicationFrameSource.h">
|
||||
<Filter>Capture</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="DirectMLInferenceBackend.h">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="OnnxEffectDrawer.h">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="HashHelper.h">
|
||||
<Filter>Helpers</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="InferenceBackendBase.h">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="TensorRTInferenceBackend.h">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="ExclModeHelper.h" />
|
||||
<ClInclude Include="OnnxHelper.h">
|
||||
<Filter>Helpers</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="ScalingRuntime.cpp" />
|
||||
|
|
@ -146,6 +167,21 @@
|
|||
<ClCompile Include="DesktopDuplicationFrameSource.cpp">
|
||||
<Filter>Capture</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="DirectMLInferenceBackend.cpp">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="OnnxEffectDrawer.cpp">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="HashHelper.cpp">
|
||||
<Filter>Helpers</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="InferenceBackendBase.cpp">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="TensorRTInferenceBackend.cpp">
|
||||
<Filter>ONNX</Filter>
|
||||
</ClCompile>
|
||||
<ClCompile Include="ExclModeHelper.cpp" />
|
||||
<ClCompile Include="ScalingOptions.cpp" />
|
||||
</ItemGroup>
|
||||
|
|
@ -171,6 +207,12 @@
|
|||
<FxCompile Include="shaders\ImGuiImplPS.hlsl">
|
||||
<Filter>Shaders</Filter>
|
||||
</FxCompile>
|
||||
<FxCompile Include="shaders\TextureToTensorCS.hlsl">
|
||||
<Filter>Shaders</Filter>
|
||||
</FxCompile>
|
||||
<FxCompile Include="shaders\TensorToTextureCS.hlsl">
|
||||
<Filter>Shaders</Filter>
|
||||
</FxCompile>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
|
|
|
|||
123
src/Magpie.Core/OnnxEffectDrawer.cpp
Normal file
123
src/Magpie.Core/OnnxEffectDrawer.cpp
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
#include "pch.h"
|
||||
#include "OnnxEffectDrawer.h"
|
||||
#include "Logger.h"
|
||||
#include "DirectMLInferenceBackend.h"
|
||||
#include "TensorRTInferenceBackend.h"
|
||||
#include "Win32Utils.h"
|
||||
#include <rapidjson/document.h>
|
||||
#include "StrUtils.h"
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
OnnxEffectDrawer::OnnxEffectDrawer() {}
|
||||
|
||||
OnnxEffectDrawer::~OnnxEffectDrawer() {}
|
||||
|
||||
static bool ReadJson(
|
||||
const rapidjson::Document& doc,
|
||||
std::string& modelPath,
|
||||
uint32_t& scale,
|
||||
std::string& backend
|
||||
) noexcept {
|
||||
if (!doc.IsObject()) {
|
||||
Logger::Get().Error("根元素不是 Object");
|
||||
return false;
|
||||
}
|
||||
|
||||
auto root = ((const rapidjson::Document&)doc).GetObj();
|
||||
|
||||
{
|
||||
auto node = root.FindMember("path");
|
||||
if (node == root.MemberEnd() || !node->value.IsString()) {
|
||||
Logger::Get().Error("解析 path 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
modelPath = node->value.GetString();
|
||||
}
|
||||
|
||||
{
|
||||
auto node = root.FindMember("scale");
|
||||
if (node == root.MemberEnd() || !node->value.IsUint()) {
|
||||
Logger::Get().Error("解析 scale 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
scale = node->value.GetUint();
|
||||
}
|
||||
|
||||
{
|
||||
auto node = root.FindMember("backend");
|
||||
if (node == root.MemberEnd() || !node->value.IsString()) {
|
||||
Logger::Get().Error("解析 backend 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
backend = node->value.GetString();
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool OnnxEffectDrawer::Initialize(
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D** inOutTexture
|
||||
) noexcept {
|
||||
const wchar_t* jsonPath = L"model.json";
|
||||
if (!Win32Utils::FileExists(jsonPath)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string json;
|
||||
if (!Win32Utils::ReadTextFile(jsonPath, json)) {
|
||||
Logger::Get().Error("Win32Utils::ReadTextFile 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string modelPath;
|
||||
uint32_t scale = 1;
|
||||
std::string backend;
|
||||
{
|
||||
rapidjson::Document doc;
|
||||
doc.ParseInsitu(json.data());
|
||||
if (doc.HasParseError()) {
|
||||
Logger::Get().Error("解析 json 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ReadJson(doc, modelPath, scale, backend)) {
|
||||
Logger::Get().Error("ReadJson 失败");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
StrUtils::ToLowerCase(backend);
|
||||
if (backend == "directml" || backend == "dml" || backend == "d") {
|
||||
_inferenceBackend = std::make_unique<DirectMLInferenceBackend>();
|
||||
}
|
||||
#if _M_X64
|
||||
else if (backend == "tensorrt" || backend == "trt" || backend == "t") {
|
||||
_inferenceBackend = std::make_unique<TensorRTInferenceBackend>();
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
Logger::Get().Error("未知 backend");
|
||||
return false;
|
||||
}
|
||||
|
||||
std::wstring modelPathW = StrUtils::UTF8ToUTF16(modelPath);
|
||||
if (!_inferenceBackend->Initialize(modelPathW.c_str(), scale, deviceResources, descriptorStore, *inOutTexture, inOutTexture)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void OnnxEffectDrawer::Draw(EffectsProfiler& /*profiler*/) const noexcept {
|
||||
if (_inferenceBackend) {
|
||||
_inferenceBackend->Evaluate();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
30
src/Magpie.Core/OnnxEffectDrawer.h
Normal file
30
src/Magpie.Core/OnnxEffectDrawer.h
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
#pragma once
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
class DeviceResources;
|
||||
class EffectsProfiler;
|
||||
class InferenceBackendBase;
|
||||
class BackendDescriptorStore;
|
||||
|
||||
class OnnxEffectDrawer {
|
||||
public:
|
||||
OnnxEffectDrawer();
|
||||
OnnxEffectDrawer(const OnnxEffectDrawer&) = delete;
|
||||
OnnxEffectDrawer(OnnxEffectDrawer&&) = default;
|
||||
|
||||
~OnnxEffectDrawer();
|
||||
|
||||
bool Initialize(
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D** inOutTexture
|
||||
) noexcept;
|
||||
|
||||
void Draw(EffectsProfiler& profiler) const noexcept;
|
||||
|
||||
private:
|
||||
std::unique_ptr<InferenceBackendBase> _inferenceBackend;
|
||||
};
|
||||
|
||||
}
|
||||
25
src/Magpie.Core/OnnxHelper.h
Normal file
25
src/Magpie.Core/OnnxHelper.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
#include "pch.h"
|
||||
#include <onnxruntime_cxx_api.h>
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
struct OnnxHelper {
|
||||
private:
|
||||
static void _CloseCUDAProviderOptions(OrtCUDAProviderOptionsV2* options) {
|
||||
Ort::GetApi().ReleaseCUDAProviderOptions(options);
|
||||
}
|
||||
|
||||
static void _CloseTensorRTProviderOptions(OrtTensorRTProviderOptionsV2* options) {
|
||||
Ort::GetApi().ReleaseTensorRTProviderOptions(options);
|
||||
}
|
||||
|
||||
public:
|
||||
using unique_cuda_provider_options = wil::unique_any<OrtCUDAProviderOptionsV2*,
|
||||
decltype(_CloseCUDAProviderOptions), _CloseCUDAProviderOptions>;
|
||||
|
||||
using unique_tensorrt_provider_options = wil::unique_any<OrtTensorRTProviderOptionsV2*,
|
||||
decltype(_CloseTensorRTProviderOptions), _CloseTensorRTProviderOptions>;
|
||||
};
|
||||
|
||||
}
|
||||
|
|
@ -511,9 +511,14 @@ ID3D11Texture2D* Renderer::_BuildEffects() noexcept {
|
|||
Logger::Get().Info(fmt::format("编译着色器总计用时 {} 毫秒", duration / 1000.0f));
|
||||
}
|
||||
|
||||
ID3D11Texture2D* inOutTexture = _frameSource->GetOutput();
|
||||
|
||||
if (!_onnxEffectDrawer.Initialize(_backendResources, _backendDescriptorStore, &inOutTexture)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
_effectDrawers.resize(effects.size());
|
||||
|
||||
ID3D11Texture2D* inOutTexture = _frameSource->GetOutput();
|
||||
for (uint32_t i = 0; i < effectCount; ++i) {
|
||||
if (!_effectDrawers[i].Initialize(
|
||||
effectDescs[i],
|
||||
|
|
@ -688,7 +693,10 @@ void Renderer::_BackendThreadProc() noexcept {
|
|||
waitingForStepTimer = false;
|
||||
}
|
||||
|
||||
const FrameSourceBase::UpdateState state = _frameSource->Update();
|
||||
FrameSourceBase::UpdateState state = _frameSource->Update();
|
||||
if (ScalingWindow::Get().Options().IsBenchmarkMode()) {
|
||||
state = FrameSourceBase::UpdateState::NewFrame;
|
||||
}
|
||||
_stepTimer.UpdateFPS(state == FrameSourceBase::UpdateState::NewFrame);
|
||||
|
||||
switch (state) {
|
||||
|
|
@ -815,6 +823,8 @@ void Renderer::_BackendRender(ID3D11Texture2D* effectsOutput) noexcept {
|
|||
|
||||
_effectsProfiler.OnBeginEffects(d3dDC);
|
||||
|
||||
_onnxEffectDrawer.Draw(_effectsProfiler);
|
||||
|
||||
for (const EffectDrawer& effectDrawer : _effectDrawers) {
|
||||
effectDrawer.Draw(_effectsProfiler);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include "DeviceResources.h"
|
||||
#include "BackendDescriptorStore.h"
|
||||
#include "EffectDrawer.h"
|
||||
#include "OnnxEffectDrawer.h"
|
||||
#include "Win32Utils.h"
|
||||
#include "CursorDrawer.h"
|
||||
#include "StepTimer.h"
|
||||
|
|
@ -101,6 +102,7 @@ private:
|
|||
Magpie::Core::BackendDescriptorStore _backendDescriptorStore;
|
||||
std::unique_ptr<FrameSourceBase> _frameSource;
|
||||
std::vector<EffectDrawer> _effectDrawers;
|
||||
OnnxEffectDrawer _onnxEffectDrawer;
|
||||
|
||||
StepTimer _stepTimer;
|
||||
EffectsProfiler _effectsProfiler;
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ struct ScalingFlags {
|
|||
// Magpie.Core 不负责启动 TouchHelper.exe,指定此标志会使 Magpie.Core 创建辅助窗口以拦截
|
||||
// 黑边上的触控输入
|
||||
static constexpr uint32_t IsTouchSupportEnabled = 1 << 17;
|
||||
static constexpr uint32_t BenchmarkMode = 1 << 18;
|
||||
};
|
||||
|
||||
enum class ScalingType {
|
||||
|
|
@ -83,6 +84,7 @@ enum class DuplicateFrameDetectionMode {
|
|||
struct ScalingOptions {
|
||||
DEFINE_FLAG_ACCESSOR(IsWindowResizingDisabled, ScalingFlags::DisableWindowResizing, flags)
|
||||
DEFINE_FLAG_ACCESSOR(IsDebugMode, ScalingFlags::BreakpointMode, flags)
|
||||
DEFINE_FLAG_ACCESSOR(IsBenchmarkMode, ScalingFlags::BenchmarkMode, flags)
|
||||
DEFINE_FLAG_ACCESSOR(IsEffectCacheDisabled, ScalingFlags::DisableEffectCache, flags)
|
||||
DEFINE_FLAG_ACCESSOR(IsFontCacheDisabled, ScalingFlags::DisableFontCache, flags)
|
||||
DEFINE_FLAG_ACCESSOR(IsSaveEffectSources, ScalingFlags::SaveEffectSources, flags)
|
||||
|
|
|
|||
600
src/Magpie.Core/TensorRTInferenceBackend.cpp
Normal file
600
src/Magpie.Core/TensorRTInferenceBackend.cpp
Normal file
|
|
@ -0,0 +1,600 @@
|
|||
#include "pch.h"
|
||||
#include "TensorRTInferenceBackend.h"
|
||||
|
||||
#ifdef _M_X64
|
||||
|
||||
#include "DeviceResources.h"
|
||||
#include <cuda/cuda_d3d11_interop.h>
|
||||
#include "shaders/TextureToTensorCS.h"
|
||||
#include "shaders/TensorToTextureCS.h"
|
||||
#include "BackendDescriptorStore.h"
|
||||
#include "Logger.h"
|
||||
#include "DirectXHelper.h"
|
||||
#include "Utils.h"
|
||||
#include "OnnxHelper.h"
|
||||
#include "HashHelper.h"
|
||||
#include "Win32Utils.h"
|
||||
#include "StrUtils.h"
|
||||
#include "CommonSharedConstants.h"
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
static void LogCudaError(std::string_view msg, cudaError_t cudaResult) noexcept {
|
||||
Logger::Get().Error(fmt::format("{}\n\tCUDA error code: {}", msg, (int)cudaResult));
|
||||
}
|
||||
|
||||
static bool CheckComputeCapability(int deviceId) noexcept {
|
||||
int major, minor;
|
||||
|
||||
cudaError_t cudaResult = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, deviceId);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
Logger::Get().Error("cudaDeviceGetAttribute 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaResult = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, deviceId);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
Logger::Get().Error("cudaDeviceGetAttribute 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
Logger::Get().Info(fmt::format("当前设备 Compute Capability: {}.{}", major, minor));
|
||||
|
||||
// TensorRT 要求 Compute Capability 至少为 7.5
|
||||
// https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html
|
||||
if (std::make_pair(major, minor) < std::make_pair(7, 5)) {
|
||||
Logger::Get().Error("当前设备无法使用 TensorRT");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::wstring GetCacheDir(
|
||||
const std::vector<uint8_t>& modelData,
|
||||
IDXGIAdapter4* adapter,
|
||||
std::pair<uint16_t, uint16_t> minShapes,
|
||||
std::pair<uint16_t, uint16_t> maxShapes,
|
||||
std::pair<uint16_t, uint16_t> optShapes,
|
||||
uint8_t optimizationLevel,
|
||||
bool enableFP16
|
||||
) noexcept {
|
||||
DXGI_ADAPTER_DESC desc;
|
||||
adapter->GetDesc(&desc);
|
||||
|
||||
// TensorRT 缓存和多种因素绑定,这里考虑的因素有:
|
||||
// * 模型哈希
|
||||
// * ONNX Runtime 版本
|
||||
// * TensorRT 版本
|
||||
// * 显卡型号 (替代 Compute Capability)
|
||||
// * 配置文件
|
||||
// * 优化等级
|
||||
// * 是否启用半精度
|
||||
std::string str = fmt::format(
|
||||
"modelHash:{}\nortVersion:{}\nvendorId:{}\ndeviceId:{}\nminShapes:{},{}\nmaxShapes:{},{}\noptShapes:{},{}\noptLevel:{}\nfp16:{}",
|
||||
Utils::HashData(modelData), Ort::GetVersionString(), desc.VendorId, desc.DeviceId,
|
||||
minShapes.first, minShapes.second, maxShapes.first, maxShapes.second, optShapes.first,
|
||||
optShapes.second, optimizationLevel, enableFP16);
|
||||
|
||||
std::wstring strHash = HashHelper::HexHash(std::span((const BYTE*)str.data(), str.size()));
|
||||
return StrUtils::Concat(CommonSharedConstants::CACHE_DIR, L"tensorrt\\", strHash);
|
||||
}
|
||||
|
||||
static void* ShareBufferWithCuda(
|
||||
const winrt::com_ptr<ID3D11Buffer>& buffer,
|
||||
uint32_t bufferSize,
|
||||
cudaExternalMemory_t* bufferCudaMem,
|
||||
cudaExternalSemaphore_t* bufferCudaSem
|
||||
) noexcept {
|
||||
winrt::com_ptr<IDXGIResource> dxgiRes = buffer.try_as<IDXGIResource>();
|
||||
if (!dxgiRes) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
HANDLE sharedHandle = NULL;
|
||||
HRESULT hr = dxgiRes->GetSharedHandle(&sharedHandle);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("GetSharedHandle 失败", hr);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cudaExternalMemoryHandleDesc externalMemoryHandleDesc{
|
||||
.type = cudaExternalMemoryHandleTypeD3D11ResourceKmt,
|
||||
.handle = {.win32 = {.handle = sharedHandle } },
|
||||
.size = bufferSize,
|
||||
.flags = cudaExternalMemoryDedicated
|
||||
};
|
||||
cudaError_t cudaResult = cudaImportExternalMemory(
|
||||
bufferCudaMem, &externalMemoryHandleDesc);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaImportExternalMemory 失败", cudaResult);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
cudaExternalSemaphoreHandleDesc extSemaDesc{
|
||||
.type = cudaExternalSemaphoreHandleTypeKeyedMutexKmt,
|
||||
.handle = {.win32 = {.handle = sharedHandle } },
|
||||
};
|
||||
cudaResult = cudaImportExternalSemaphore(bufferCudaSem, &extSemaDesc);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaImportExternalSemaphore 失败", cudaResult);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void* bufferCudaPtr = nullptr;
|
||||
cudaExternalMemoryBufferDesc externalMemoryBufferDesc{ .size = bufferSize };
|
||||
cudaResult = cudaExternalMemoryGetMappedBuffer(
|
||||
&bufferCudaPtr, *bufferCudaMem, &externalMemoryBufferDesc);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaExternalMemoryGetMappedBuffer 失败", cudaResult);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return bufferCudaPtr;
|
||||
}
|
||||
|
||||
TensorRTInferenceBackend::~TensorRTInferenceBackend() {
|
||||
if (_inputBufferCudaSem) {
|
||||
cudaDestroyExternalSemaphore((cudaExternalSemaphore_t)_inputBufferCudaSem);
|
||||
}
|
||||
if (_outputBufferCudaSem) {
|
||||
cudaDestroyExternalSemaphore((cudaExternalSemaphore_t)_outputBufferCudaSem);
|
||||
}
|
||||
if (_inputBufferCudaPtr) {
|
||||
cudaFree(_inputBufferCudaPtr);
|
||||
}
|
||||
if (_outputBufferCudaPtr) {
|
||||
cudaFree(_outputBufferCudaPtr);
|
||||
}
|
||||
if (_inputBufferCudaMem) {
|
||||
cudaDestroyExternalMemory((cudaExternalMemory_t)_inputBufferCudaMem);
|
||||
}
|
||||
if (_outputBufferCudaMem) {
|
||||
cudaDestroyExternalMemory((cudaExternalMemory_t)_outputBufferCudaMem);
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorRTInferenceBackend::Initialize(
|
||||
const wchar_t* modelPath,
|
||||
uint32_t scale,
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D* input,
|
||||
ID3D11Texture2D** output
|
||||
) noexcept {
|
||||
if (!Win32Utils::FileExists(L"third_party\\onnxruntime_providers_tensorrt.dll")) {
|
||||
Logger::Get().Error("未安装 TensorRT 拓展");
|
||||
return false;
|
||||
}
|
||||
|
||||
int deviceId = 0;
|
||||
cudaError_t cudaResult = cudaD3D11GetDevice(&deviceId, deviceResources.GetGraphicsAdapter());
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaD3D11GetDevice 失败", cudaResult);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!CheckComputeCapability(deviceId)) {
|
||||
Logger::Get().Error("CheckComputeCapability 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
cudaResult = cudaSetDevice(deviceId);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaSetDevice 失败", cudaResult);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isFP16Data = false;
|
||||
try {
|
||||
const OrtApi& ortApi = Ort::GetApi();
|
||||
|
||||
_env = Ort::Env(ORT_LOGGING_LEVEL_INFO, "", _OrtLog, nullptr);
|
||||
|
||||
Ort::SessionOptions sessionOptions;
|
||||
sessionOptions.SetIntraOpNumThreads(1);
|
||||
|
||||
Ort::ThrowOnError(ortApi.AddFreeDimensionOverride(sessionOptions, "DATA_BATCH", 1));
|
||||
|
||||
if (!_CreateSession(deviceResources, deviceId, sessionOptions, modelPath)) {
|
||||
Logger::Get().Error("_CreateSession 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!_IsModelValid(_session, isFP16Data)) {
|
||||
Logger::Get().Error("不支持此模型");
|
||||
return false;
|
||||
}
|
||||
|
||||
_cudaMemInfo = Ort::MemoryInfo("Cuda", OrtAllocatorType::OrtDeviceAllocator, deviceId, OrtMemTypeDefault);
|
||||
} catch (const Ort::Exception& e) {
|
||||
Logger::Get().Error(e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
ID3D11Device5* d3dDevice = deviceResources.GetD3DDevice();
|
||||
_d3dDC = deviceResources.GetD3DDC();
|
||||
|
||||
const SIZE inputSize = DirectXHelper::GetTextureSize(input);
|
||||
const SIZE outputSize = SIZE{ inputSize.cx * (LONG)scale, inputSize.cy * (LONG)scale };
|
||||
|
||||
// 创建输出纹理
|
||||
winrt::com_ptr<ID3D11Texture2D> outputTex = DirectXHelper::CreateTexture2D(
|
||||
d3dDevice,
|
||||
DXGI_FORMAT_R8G8B8A8_UNORM,
|
||||
outputSize.cx,
|
||||
outputSize.cy,
|
||||
D3D11_BIND_SHADER_RESOURCE | D3D11_BIND_UNORDERED_ACCESS
|
||||
);
|
||||
if (!outputTex) {
|
||||
Logger::Get().Error("创建输出纹理失败");
|
||||
return false;
|
||||
}
|
||||
*output = outputTex.get();
|
||||
|
||||
const uint32_t inputElemCount = uint32_t(inputSize.cx * inputSize.cy * 3);
|
||||
const uint32_t outputElemCount = uint32_t(outputSize.cx * outputSize.cy * 3);
|
||||
const uint32_t inputBufferSize = isFP16Data ? ((inputElemCount + 1) / 2 * 4) : (inputElemCount * 4);
|
||||
const uint32_t outputBufferSize = isFP16Data ? ((outputElemCount + 1) / 2 * 4) : (outputElemCount * 4);
|
||||
|
||||
winrt::com_ptr<ID3D11Buffer> inputBuffer;
|
||||
winrt::com_ptr<ID3D11Buffer> outputBuffer;
|
||||
{
|
||||
D3D11_BUFFER_DESC desc{
|
||||
.ByteWidth = inputBufferSize,
|
||||
.BindFlags = D3D11_BIND_UNORDERED_ACCESS,
|
||||
.MiscFlags = D3D11_RESOURCE_MISC_SHARED_KEYEDMUTEX
|
||||
};
|
||||
HRESULT hr = d3dDevice->CreateBuffer(&desc, nullptr, inputBuffer.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateBuffer 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
desc.ByteWidth = outputBufferSize;
|
||||
desc.BindFlags = D3D11_BIND_SHADER_RESOURCE;
|
||||
hr = d3dDevice->CreateBuffer(&desc, nullptr, outputBuffer.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateBuffer 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
_inputBufferCudaPtr = ShareBufferWithCuda(
|
||||
inputBuffer,
|
||||
inputBufferSize,
|
||||
(cudaExternalMemory_t*)&_inputBufferCudaMem,
|
||||
(cudaExternalSemaphore_t*)&_inputBufferCudaSem
|
||||
);
|
||||
_outputBufferCudaPtr = ShareBufferWithCuda(
|
||||
outputBuffer,
|
||||
outputBufferSize,
|
||||
(cudaExternalMemory_t*)&_outputBufferCudaMem,
|
||||
(cudaExternalSemaphore_t*)&_outputBufferCudaSem
|
||||
);
|
||||
if (!_inputBufferCudaPtr || !_outputBufferCudaPtr) {
|
||||
Logger::Get().Error("ShareBufferWithCuda 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
_ioBinding = Ort::IoBinding(_session);
|
||||
|
||||
const int64_t inputShape[]{ 1,3,inputSize.cy,inputSize.cx };
|
||||
_ioBinding.BindInput("input", Ort::Value::CreateTensor(
|
||||
_cudaMemInfo,
|
||||
_inputBufferCudaPtr,
|
||||
inputBufferSize,
|
||||
inputShape,
|
||||
std::size(inputShape),
|
||||
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
|
||||
));
|
||||
|
||||
const int64_t outputShape[]{ 1,3,outputSize.cy,outputSize.cx };
|
||||
_ioBinding.BindOutput("output", Ort::Value::CreateTensor(
|
||||
_cudaMemInfo,
|
||||
_outputBufferCudaPtr,
|
||||
outputBufferSize,
|
||||
outputShape,
|
||||
std::size(outputShape),
|
||||
isFP16Data ? ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 : ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT
|
||||
));
|
||||
} catch (const Ort::Exception& e) {
|
||||
Logger::Get().Error(e.what());
|
||||
return false;
|
||||
}
|
||||
|
||||
_inputBufferKmt = inputBuffer.try_as<IDXGIKeyedMutex>();
|
||||
if (!_inputBufferKmt) {
|
||||
return false;
|
||||
}
|
||||
|
||||
_outputBufferKmt = outputBuffer.try_as<IDXGIKeyedMutex>();
|
||||
if (!_outputBufferKmt) {
|
||||
return false;
|
||||
}
|
||||
|
||||
_inputTexSrv = descriptorStore.GetShaderResourceView(input);
|
||||
if (!_inputTexSrv) {
|
||||
Logger::Get().Error("GetShaderResourceView 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
_sampler = deviceResources.GetSampler(
|
||||
D3D11_FILTER_MIN_MAG_MIP_POINT, D3D11_TEXTURE_ADDRESS_CLAMP);
|
||||
if (!_sampler) {
|
||||
Logger::Get().Error("GetSampler 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
{
|
||||
D3D11_UNORDERED_ACCESS_VIEW_DESC desc{
|
||||
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
|
||||
.ViewDimension = D3D11_UAV_DIMENSION_BUFFER,
|
||||
.Buffer{
|
||||
.NumElements = inputElemCount
|
||||
}
|
||||
};
|
||||
|
||||
HRESULT hr = d3dDevice->CreateUnorderedAccessView(
|
||||
inputBuffer.get(), &desc, _inputBufferUav.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateUnorderedAccessView 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
D3D11_SHADER_RESOURCE_VIEW_DESC desc{
|
||||
.Format = isFP16Data ? DXGI_FORMAT_R16_FLOAT : DXGI_FORMAT_R32_FLOAT,
|
||||
.ViewDimension = D3D11_SRV_DIMENSION_BUFFER,
|
||||
.Buffer{
|
||||
.NumElements = outputElemCount
|
||||
}
|
||||
};
|
||||
|
||||
HRESULT hr = d3dDevice->CreateShaderResourceView(
|
||||
outputBuffer.get(), &desc, _outputBufferSrv.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateShaderResourceView 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
D3D11_UNORDERED_ACCESS_VIEW_DESC desc{
|
||||
.ViewDimension = D3D11_UAV_DIMENSION_TEXTURE2D
|
||||
};
|
||||
HRESULT hr = d3dDevice->CreateUnorderedAccessView(
|
||||
outputTex.get(), &desc, _outputTexUav.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateUnorderedAccessView 失败", hr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
HRESULT hr = d3dDevice->CreateComputeShader(
|
||||
TextureToTensorCS, sizeof(TextureToTensorCS), nullptr, _texToTensorShader.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateComputeShader 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
hr = d3dDevice->CreateComputeShader(
|
||||
TensorToTextureCS, sizeof(TensorToTextureCS), nullptr, _tensorToTexShader.put());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateComputeShader 失败", hr);
|
||||
return false;
|
||||
}
|
||||
|
||||
static constexpr std::pair<uint32_t, uint32_t> TEX_TO_TENSOR_BLOCK_SIZE{ 16, 16 };
|
||||
static constexpr std::pair<uint32_t, uint32_t> TENSOR_TO_TEX_BLOCK_SIZE{ 8, 8 };
|
||||
_texToTensorDispatchCount = {
|
||||
(inputSize.cx + TEX_TO_TENSOR_BLOCK_SIZE.first - 1) / TEX_TO_TENSOR_BLOCK_SIZE.first,
|
||||
(inputSize.cy + TEX_TO_TENSOR_BLOCK_SIZE.second - 1) / TEX_TO_TENSOR_BLOCK_SIZE.second
|
||||
};
|
||||
_tensorToTexDispatchCount = {
|
||||
(outputSize.cx + TENSOR_TO_TEX_BLOCK_SIZE.first - 1) / TENSOR_TO_TEX_BLOCK_SIZE.first,
|
||||
(outputSize.cy + TENSOR_TO_TEX_BLOCK_SIZE.second - 1) / TENSOR_TO_TEX_BLOCK_SIZE.second
|
||||
};
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void TensorRTInferenceBackend::Evaluate() noexcept {
|
||||
// 输入纹理 -> 输入张量
|
||||
HRESULT hr = _inputBufferKmt->AcquireSync(_inputBufferMutexKey, INFINITE);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("AcquireSync 失败", hr);
|
||||
return;
|
||||
}
|
||||
|
||||
_d3dDC->CSSetShaderResources(0, 1, &_inputTexSrv);
|
||||
_d3dDC->CSSetSamplers(0, 1, &_sampler);
|
||||
{
|
||||
ID3D11UnorderedAccessView* uav = _inputBufferUav.get();
|
||||
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
|
||||
}
|
||||
|
||||
_d3dDC->CSSetShader(_texToTensorShader.get(), nullptr, 0);
|
||||
_d3dDC->Dispatch(_texToTensorDispatchCount.first, _texToTensorDispatchCount.second, 1);
|
||||
|
||||
_inputBufferKmt->ReleaseSync(++_inputBufferMutexKey);
|
||||
|
||||
{
|
||||
cudaExternalSemaphore_t semArr[] = {
|
||||
(cudaExternalSemaphore_t)_inputBufferCudaSem,
|
||||
(cudaExternalSemaphore_t)_outputBufferCudaSem
|
||||
};
|
||||
cudaExternalSemaphoreWaitParams extSemWaitParamsArr[] = {
|
||||
{.params{.keyedMutex{.key = _inputBufferMutexKey, .timeoutMs = INFINITE}}},
|
||||
{.params{.keyedMutex{.key = _outputBufferMutexKey, .timeoutMs = INFINITE}}}
|
||||
};
|
||||
cudaError_t cudaResult = cudaWaitExternalSemaphoresAsync(semArr, extSemWaitParamsArr, 2);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaWaitExternalSemaphoresAsync 失败", cudaResult);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
Ort::RunOptions runOptions;
|
||||
runOptions.AddConfigEntry("disable_synchronize_execution_providers", "1");
|
||||
_session.Run(runOptions, _ioBinding);
|
||||
} catch (const Ort::Exception& e) {
|
||||
Logger::Get().Error(e.what());
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
cudaExternalSemaphore_t semArr[] = {
|
||||
(cudaExternalSemaphore_t)_inputBufferCudaSem,
|
||||
(cudaExternalSemaphore_t)_outputBufferCudaSem
|
||||
};
|
||||
|
||||
cudaExternalSemaphoreSignalParams extSemSigParams[] = {
|
||||
{.params = {.keyedMutex = {.key = ++_inputBufferMutexKey}}},
|
||||
{.params = {.keyedMutex = {.key = ++_outputBufferMutexKey}}}
|
||||
};
|
||||
cudaError_t cudaResult = cudaSignalExternalSemaphoresAsync(semArr, extSemSigParams, 2);
|
||||
if (cudaResult != cudaError_t::cudaSuccess) {
|
||||
LogCudaError("cudaSignalExternalSemaphoresAsync 失败", cudaResult);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 输出张量 -> 输出纹理
|
||||
hr = _outputBufferKmt->AcquireSync(_outputBufferMutexKey, INFINITE);
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("AcquireSync 失败", hr);
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
ID3D11ShaderResourceView* srv = _outputBufferSrv.get();
|
||||
_d3dDC->CSSetShaderResources(0, 1, &srv);
|
||||
}
|
||||
{
|
||||
ID3D11UnorderedAccessView* uav = _outputTexUav.get();
|
||||
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
|
||||
}
|
||||
|
||||
_d3dDC->CSSetShader(_tensorToTexShader.get(), nullptr, 0);
|
||||
_d3dDC->Dispatch(_tensorToTexDispatchCount.first, _tensorToTexDispatchCount.second, 1);
|
||||
|
||||
{
|
||||
ID3D11ShaderResourceView* srv = nullptr;
|
||||
_d3dDC->CSSetShaderResources(0, 1, &srv);
|
||||
}
|
||||
{
|
||||
ID3D11UnorderedAccessView* uav = nullptr;
|
||||
_d3dDC->CSSetUnorderedAccessViews(0, 1, &uav, nullptr);
|
||||
}
|
||||
|
||||
_outputBufferKmt->ReleaseSync(++_outputBufferMutexKey);
|
||||
}
|
||||
|
||||
bool TensorRTInferenceBackend::_CreateSession(
|
||||
DeviceResources& deviceResources,
|
||||
int deviceId,
|
||||
Ort::SessionOptions& sessionOptions,
|
||||
const wchar_t* modelPath
|
||||
) {
|
||||
const std::pair<uint16_t, uint16_t> minShapes(uint16_t(1), uint16_t(1));
|
||||
const std::pair<uint16_t, uint16_t> maxShapes(uint16_t(1920), uint16_t(1080));
|
||||
const std::pair<uint16_t, uint16_t> optShapes(uint16_t(1920), uint16_t(1080));
|
||||
|
||||
const bool enableFP16 = true;
|
||||
const uint8_t optimizationLevel = 5;
|
||||
|
||||
std::vector<uint8_t> modelData;
|
||||
if (!Win32Utils::ReadFile(modelPath, modelData)) {
|
||||
Logger::Get().Error("读取模型失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::wstring cacheDir = GetCacheDir(
|
||||
modelData,
|
||||
deviceResources.GetGraphicsAdapter(),
|
||||
minShapes,
|
||||
maxShapes,
|
||||
optShapes,
|
||||
optimizationLevel,
|
||||
enableFP16
|
||||
);
|
||||
if (!Win32Utils::CreateDir(cacheDir, true)) {
|
||||
Logger::Get().Win32Error("创建缓存文件夹失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::wstring cacheCtxPath = cacheDir + L"\\ctx.onnx";
|
||||
|
||||
const OrtApi& ortApi = Ort::GetApi();
|
||||
|
||||
OnnxHelper::unique_tensorrt_provider_options trtOptions;
|
||||
Ort::ThrowOnError(ortApi.CreateTensorRTProviderOptions(trtOptions.put()));
|
||||
|
||||
const std::string deviceIdStr = std::to_string(deviceId);
|
||||
{
|
||||
const char* keys[]{
|
||||
"device_id",
|
||||
"has_user_compute_stream",
|
||||
"trt_fp16_enable",
|
||||
"trt_builder_optimization_level",
|
||||
"trt_profile_min_shapes",
|
||||
"trt_profile_max_shapes",
|
||||
"trt_profile_opt_shapes",
|
||||
"trt_engine_cache_enable",
|
||||
"trt_engine_cache_prefix",
|
||||
"trt_dump_ep_context_model",
|
||||
"trt_ep_context_file_path"
|
||||
};
|
||||
std::string optLevelStr = std::to_string(optimizationLevel);
|
||||
std::string minShapesStr = fmt::format("input:1x3x{}x{}", minShapes.second, minShapes.first);
|
||||
std::string maxShapesStr = fmt::format("input:1x3x{}x{}", maxShapes.second, maxShapes.first);
|
||||
std::string optShapesStr = fmt::format("input:1x3x{}x{}", optShapes.second, optShapes.first);
|
||||
|
||||
std::string cacheDirANSI = StrUtils::UTF16ToANSI(cacheDir);
|
||||
std::string cacheCtxPathANSI = StrUtils::UTF16ToANSI(cacheCtxPath);
|
||||
|
||||
const char* values[]{
|
||||
deviceIdStr.c_str(),
|
||||
"1",
|
||||
enableFP16 ? "1" : "0",
|
||||
optLevelStr.c_str(),
|
||||
minShapesStr.c_str(),
|
||||
maxShapesStr.c_str(),
|
||||
optShapesStr.c_str(),
|
||||
"1",
|
||||
"trt",
|
||||
"1",
|
||||
cacheCtxPathANSI.c_str()
|
||||
};
|
||||
Ort::ThrowOnError(ortApi.UpdateTensorRTProviderOptions(trtOptions.get(), keys, values, std::size(keys)));
|
||||
}
|
||||
|
||||
OnnxHelper::unique_cuda_provider_options cudaOptions;
|
||||
Ort::ThrowOnError(ortApi.CreateCUDAProviderOptions(cudaOptions.put()));
|
||||
|
||||
{
|
||||
const char* keys[]{ "device_id", "has_user_compute_stream" };
|
||||
const char* values[]{ deviceIdStr.c_str(), "1" };
|
||||
Ort::ThrowOnError(ortApi.UpdateCUDAProviderOptions(cudaOptions.get(), keys, values, std::size(keys)));
|
||||
}
|
||||
|
||||
sessionOptions.AppendExecutionProvider_TensorRT_V2(*trtOptions.get());
|
||||
sessionOptions.AppendExecutionProvider_CUDA_V2(*cudaOptions.get());
|
||||
|
||||
if (Win32Utils::FileExists(cacheCtxPath.c_str())) {
|
||||
Logger::Get().Info("读取缓存 " + StrUtils::UTF16ToUTF8(cacheCtxPath));
|
||||
_session = Ort::Session(_env, cacheCtxPath.c_str(), sessionOptions);
|
||||
} else {
|
||||
_session = Ort::Session(_env, modelData.data(), modelData.size(), sessionOptions);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
76
src/Magpie.Core/TensorRTInferenceBackend.h
Normal file
76
src/Magpie.Core/TensorRTInferenceBackend.h
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#pragma once
|
||||
#include "InferenceBackendBase.h"
|
||||
|
||||
#ifdef _M_X64
|
||||
|
||||
struct cudaGraphicsResource;
|
||||
|
||||
namespace Magpie::Core {
|
||||
|
||||
class TensorRTInferenceBackend : public InferenceBackendBase {
|
||||
public:
|
||||
TensorRTInferenceBackend() = default;
|
||||
TensorRTInferenceBackend(const TensorRTInferenceBackend&) = delete;
|
||||
TensorRTInferenceBackend(TensorRTInferenceBackend&&) = default;
|
||||
|
||||
virtual ~TensorRTInferenceBackend();
|
||||
|
||||
bool Initialize(
|
||||
const wchar_t* modelPath,
|
||||
uint32_t scale,
|
||||
DeviceResources& deviceResources,
|
||||
BackendDescriptorStore& descriptorStore,
|
||||
ID3D11Texture2D* input,
|
||||
ID3D11Texture2D** output
|
||||
) noexcept override;
|
||||
|
||||
void Evaluate() noexcept override;
|
||||
|
||||
private:
|
||||
bool _CreateSession(
|
||||
DeviceResources& deviceResources,
|
||||
int deviceId,
|
||||
Ort::SessionOptions& sessionOptions,
|
||||
const wchar_t* modelPath
|
||||
);
|
||||
|
||||
Ort::Env _env{ nullptr };
|
||||
Ort::Session _session{ nullptr };
|
||||
|
||||
ID3D11DeviceContext4* _d3dDC = nullptr;
|
||||
|
||||
ID3D11SamplerState* _sampler = nullptr;
|
||||
ID3D11ShaderResourceView* _inputTexSrv = nullptr;
|
||||
winrt::com_ptr<ID3D11UnorderedAccessView> _inputBufferUav;
|
||||
winrt::com_ptr<ID3D11ShaderResourceView> _outputBufferSrv;
|
||||
winrt::com_ptr<ID3D11UnorderedAccessView> _outputTexUav;
|
||||
|
||||
winrt::com_ptr<IDXGIKeyedMutex> _inputBufferKmt;
|
||||
winrt::com_ptr<IDXGIKeyedMutex> _outputBufferKmt;
|
||||
|
||||
UINT64 _inputBufferMutexKey = 0;
|
||||
UINT64 _outputBufferMutexKey = 0;
|
||||
|
||||
winrt::com_ptr<ID3D11ComputeShader> _texToTensorShader;
|
||||
winrt::com_ptr<ID3D11ComputeShader> _tensorToTexShader;
|
||||
|
||||
std::pair<uint32_t, uint32_t> _texToTensorDispatchCount{};
|
||||
std::pair<uint32_t, uint32_t> _tensorToTexDispatchCount{};
|
||||
|
||||
Ort::MemoryInfo _cudaMemInfo{ nullptr };
|
||||
|
||||
// cudaExternalMemory_t
|
||||
void* _inputBufferCudaMem = nullptr;
|
||||
void* _outputBufferCudaMem = nullptr;
|
||||
void* _inputBufferCudaPtr = nullptr;
|
||||
void* _outputBufferCudaPtr = nullptr;
|
||||
// cudaExternalSemaphore_t
|
||||
void* _inputBufferCudaSem = nullptr;
|
||||
void* _outputBufferCudaSem = nullptr;
|
||||
|
||||
Ort::IoBinding _ioBinding{ nullptr };
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
|
||||
</packages>
|
||||
|
|
@ -12,7 +12,7 @@ void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
|
|||
return;
|
||||
}
|
||||
|
||||
const int2 gxy = (gid.xy << 4) + (tid.xy << 1);
|
||||
const uint2 gxy = (gid.xy << 4) + (tid.xy << 1);
|
||||
|
||||
// 不知为何这比通过 cbuffer 传入更快
|
||||
uint width, height;
|
||||
|
|
|
|||
15
src/Magpie.Core/shaders/TensorToTextureCS.hlsl
Normal file
15
src/Magpie.Core/shaders/TensorToTextureCS.hlsl
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
Buffer<min16float> tensor : register(t0);
|
||||
RWTexture2D<min16float4> tex : register(u0);
|
||||
|
||||
[numthreads(8, 8, 1)]
|
||||
void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
|
||||
const uint2 gxy = (gid.xy << 3) + tid.xy;
|
||||
|
||||
uint width, height;
|
||||
tex.GetDimensions(width, height);
|
||||
|
||||
const uint planeStride = width * height;
|
||||
const uint idx = gxy.y * width + gxy.x;
|
||||
min16float3 color = { tensor[idx], tensor[planeStride + idx], tensor[planeStride * 2 + idx] };
|
||||
tex[gxy] = min16float4(color, 1);
|
||||
}
|
||||
54
src/Magpie.Core/shaders/TextureToTensorCS.hlsl
Normal file
54
src/Magpie.Core/shaders/TextureToTensorCS.hlsl
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
Texture2D<min16float4> tex : register(t0);
|
||||
RWBuffer<min16float> result : register(u0);
|
||||
|
||||
SamplerState sam : register(s0);
|
||||
|
||||
[numthreads(8, 8, 1)]
|
||||
void main(uint3 tid : SV_GroupThreadID, uint3 gid : SV_GroupID) {
|
||||
const uint2 gxy = (gid.xy << 4) + (tid.xy << 1);
|
||||
|
||||
uint width, height;
|
||||
tex.GetDimensions(width, height);
|
||||
|
||||
if (gxy.x >= width || gxy.y >= height) {
|
||||
return;
|
||||
}
|
||||
|
||||
const float2 pos = (gxy + 1) / float2(width, height);
|
||||
|
||||
min16float4 red = tex.GatherRed(sam, pos);
|
||||
min16float4 green = tex.GatherGreen(sam, pos);
|
||||
min16float4 blue = tex.GatherBlue(sam, pos);
|
||||
|
||||
const uint planeStride = width * height;
|
||||
const uint planeStride2 = width * height * 2;
|
||||
|
||||
// w z
|
||||
// x y
|
||||
uint idx = gxy.y * width + gxy.x;
|
||||
|
||||
result[idx] = red.w;
|
||||
result[idx + planeStride] = green.w;
|
||||
result[idx + planeStride2] = blue.w;
|
||||
|
||||
const bool zyValid = gxy.x + 1 < width;
|
||||
if (zyValid) {
|
||||
result[idx + 1] = red.z;
|
||||
result[idx + planeStride + 1] = green.z;
|
||||
result[idx + planeStride2 + 1] = blue.z;
|
||||
}
|
||||
|
||||
idx += width;
|
||||
|
||||
if (gxy.y + 1 < height) {
|
||||
result[idx] = red.x;
|
||||
result[idx + planeStride] = green.x;
|
||||
result[idx + planeStride2] = blue.x;
|
||||
|
||||
if (zyValid) {
|
||||
result[idx + 1] = red.y;
|
||||
result[idx + planeStride + 1] = green.y;
|
||||
result[idx + planeStride2 + 1] = blue.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" />
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" />
|
||||
<PropertyGroup Label="Globals">
|
||||
<VCProjectVersion>16.0</VCProjectVersion>
|
||||
<Keyword>Win32Proj</Keyword>
|
||||
|
|
@ -112,20 +112,20 @@
|
|||
</ItemGroup>
|
||||
</Target>
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets" Condition="Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets" Condition="Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets" Condition="Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.6\build\native\Microsoft.UI.Xaml.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.CppWinRT.2.0.240405.15\build\native\Microsoft.Windows.CppWinRT.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.2535.41\build\native\Microsoft.Web.WebView2.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Web.WebView2.1.0.3179.45\build\native\Microsoft.Web.WebView2.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.props'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.UI.Xaml.2.8.7\build\native\Microsoft.UI.Xaml.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
||||
|
|
@ -172,9 +172,8 @@ bool TouchHelper::Register() noexcept {
|
|||
}
|
||||
|
||||
std::wstring magpieDir = StrUtils::Concat(system32Dir.get(), L"\\Magpie");
|
||||
hr = wil::CreateDirectoryDeepNoThrow(magpieDir.c_str());
|
||||
if (FAILED(hr)) {
|
||||
Logger::Get().ComError("CreateDirectoryDeepNoThrow 失败", hr);
|
||||
if (!CreateDirectory(magpieDir.c_str(), nullptr)) {
|
||||
Logger::Get().Win32Error("CreateDirectory 失败");
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
[requires]
|
||||
fmt/10.2.1
|
||||
spdlog/1.14.1
|
||||
parallel-hashmap/1.37
|
||||
fmt/11.1.3
|
||||
spdlog/1.15.1
|
||||
parallel-hashmap/2.0.0
|
||||
|
||||
[generators]
|
||||
MSBuildDeps
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@
|
|||
#include "Win32Utils.h"
|
||||
#include "TouchHelper.h"
|
||||
#include "CommonSharedConstants.h"
|
||||
#include "StrUtils.h"
|
||||
|
||||
// 将当前目录设为程序所在目录
|
||||
static void SetWorkingDir() noexcept {
|
||||
static std::wstring SetWorkingDir() noexcept {
|
||||
std::wstring path = Win32Utils::GetExePath();
|
||||
|
||||
FAIL_FAST_IF_FAILED(PathCchRemoveFileSpec(
|
||||
|
|
@ -30,6 +31,9 @@ static void SetWorkingDir() noexcept {
|
|||
));
|
||||
|
||||
FAIL_FAST_IF_WIN32_BOOL_FALSE(SetCurrentDirectory(path.c_str()));
|
||||
|
||||
path.resize(StrUtils::StrLen(path.c_str()));
|
||||
return path;
|
||||
}
|
||||
|
||||
static void InitializeLogger(const char* logFilePath) noexcept {
|
||||
|
|
@ -54,7 +58,7 @@ int APIENTRY wWinMain(
|
|||
// 堆损坏时终止进程
|
||||
HeapSetInformation(NULL, HeapEnableTerminationOnCorruption, nullptr, 0);
|
||||
|
||||
SetWorkingDir();
|
||||
std::wstring workingDir = SetWorkingDir();
|
||||
|
||||
enum {
|
||||
Normal,
|
||||
|
|
@ -90,6 +94,10 @@ int APIENTRY wWinMain(
|
|||
return Magpie::TouchHelper::Unregister() ? 0 : 1;
|
||||
}
|
||||
|
||||
SetDefaultDllDirectories(LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
|
||||
workingDir += L"\\third_party";
|
||||
AddDllDirectory(workingDir.c_str());
|
||||
|
||||
auto& app = Magpie::XamlApp::Get();
|
||||
if (!app.Initialize(hInstance, lpCmdLine)) {
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="Microsoft.UI.Xaml" version="2.8.6" targetFramework="native" />
|
||||
<package id="Microsoft.Web.WebView2" version="1.0.2535.41" targetFramework="native" />
|
||||
<package id="Microsoft.UI.Xaml" version="2.8.7" targetFramework="native" />
|
||||
<package id="Microsoft.Web.WebView2" version="1.0.3179.45" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.CppWinRT" version="2.0.240405.15" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
|
||||
</packages>
|
||||
|
|
@ -232,6 +232,36 @@ bool Win32Utils::WriteTextFile(const wchar_t* fileName, std::string_view text) n
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Win32Utils::CreateDir(const std::wstring& path, bool recursive) noexcept {
|
||||
assert(!path.empty());
|
||||
|
||||
if (DirExists(path.c_str())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!recursive) {
|
||||
return CreateDirectory(path.c_str(), nullptr);
|
||||
}
|
||||
|
||||
size_t searchOffset = 0;
|
||||
do {
|
||||
auto segPos = path.find_first_of(L'\\', searchOffset);
|
||||
if (segPos == std::wstring::npos) {
|
||||
// 没有分隔符则将整个路径视为文件夹
|
||||
segPos = path.size();
|
||||
}
|
||||
|
||||
std::wstring subdir = path.substr(0, segPos);
|
||||
if (!subdir.empty() && !DirExists(subdir.c_str()) && !CreateDirectory(subdir.c_str(), nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
searchOffset = segPos + 1;
|
||||
} while (searchOffset < path.size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
const Win32Utils::OSVersion& Win32Utils::GetOSVersion() noexcept {
|
||||
static OSVersion version = []() -> OSVersion {
|
||||
HMODULE hNtDll = GetModuleHandle(L"ntdll.dll");
|
||||
|
|
|
|||
|
|
@ -43,6 +43,9 @@ struct Win32Utils {
|
|||
return (attrs != INVALID_FILE_ATTRIBUTES) && (attrs & FILE_ATTRIBUTE_DIRECTORY);
|
||||
}
|
||||
|
||||
// 相比 wil::CreateDirectoryDeepNoThrow 支持相对路径而且更快
|
||||
static bool CreateDir(const std::wstring& path, bool recursive = false) noexcept;
|
||||
|
||||
struct OSVersion : Version {
|
||||
constexpr OSVersion() {}
|
||||
constexpr OSVersion(uint32_t major, uint32_t minor, uint32_t patch)
|
||||
|
|
|
|||
|
|
@ -53,19 +53,19 @@
|
|||
<ResourceCompile Include="TouchHelper.rc" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
<Manifest Include="app.manifest" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Manifest Include="app.manifest" />
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
||||
|
|
@ -20,13 +20,13 @@
|
|||
<ClCompile Include="main.cpp" />
|
||||
<ClCompile Include="pch.cpp" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natvis" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Manifest Include="app.manifest" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
|
||||
</packages>
|
||||
|
|
@ -67,12 +67,12 @@
|
|||
</ItemGroup>
|
||||
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
|
||||
<ImportGroup Label="ExtensionTargets">
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
<Import Project="..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets" Condition="Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" />
|
||||
</ImportGroup>
|
||||
<Target Name="EnsureNuGetPackageBuildImports" BeforeTargets="PrepareForBuild">
|
||||
<PropertyGroup>
|
||||
<ErrorText>这台计算机上缺少此项目引用的 NuGet 程序包。使用“NuGet 程序包还原”可下载这些程序包。有关更多信息,请参见 http://go.microsoft.com/fwlink/?LinkID=322105。缺少的文件是 {0}。</ErrorText>
|
||||
</PropertyGroup>
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.240122.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
<Error Condition="!Exists('..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets')" Text="$([System.String]::Format('$(ErrorText)', '..\..\packages\Microsoft.Windows.ImplementationLibrary.1.0.250325.1\build\native\Microsoft.Windows.ImplementationLibrary.targets'))" />
|
||||
</Target>
|
||||
</Project>
|
||||
|
|
@ -33,10 +33,10 @@
|
|||
<ItemGroup>
|
||||
<Manifest Include="app.manifest" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<Natvis Include="$(MSBuildThisFileDirectory)..\..\natvis\wil.natvis" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<None Include="packages.config" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<packages>
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.240122.1" targetFramework="native" />
|
||||
<package id="Microsoft.Windows.ImplementationLibrary" version="1.0.250325.1" targetFramework="native" />
|
||||
</packages>
|
||||
Loading…
Add table
Add a link
Reference in a new issue