You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

484 lines
17 KiB

  1. // Copyright (c) 2018 Siegfried Pammer
  2. //
  3. // Permission is hereby granted, free of charge, to any person obtaining a copy of this
  4. // software and associated documentation files (the "Software"), to deal in the Software
  5. // without restriction, including without limitation the rights to use, copy, modify, merge,
  6. // publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
  7. // to whom the Software is furnished to do so, subject to the following conditions:
  8. //
  9. // The above copyright notice and this permission notice shall be included in all copies or
  10. // substantial portions of the Software.
  11. //
  12. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
  13. // INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
  14. // PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
  15. // FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
  16. // OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
  17. // DEALINGS IN THE SOFTWARE.
  18. using System;
  19. using System.Collections.Generic;
  20. using System.Diagnostics;
  21. using System.Linq;
  22. using ICSharpCode.Decompiler.IL.Transforms;
  23. using ICSharpCode.Decompiler.TypeSystem;
  24. namespace ICSharpCode.Decompiler.IL.ControlFlow
  25. {
  26. class AwaitInCatchTransform
  27. {
  28. readonly struct CatchBlockInfo
  29. {
  30. public readonly int Id;
  31. public readonly TryCatchHandler Handler;
  32. public readonly Block RealCatchBlockEntryPoint;
  33. public readonly ILInstruction NextBlockOrExitContainer;
  34. public readonly ILInstruction JumpTableEntry;
  35. public readonly ILVariable ObjectVariable;
  36. public CatchBlockInfo(int id, TryCatchHandler handler, Block realCatchBlockEntryPoint,
  37. ILInstruction nextBlockOrExitContainer, ILInstruction jumpTableEntry, ILVariable objectVariable)
  38. {
  39. Id = id;
  40. Handler = handler;
  41. RealCatchBlockEntryPoint = realCatchBlockEntryPoint;
  42. NextBlockOrExitContainer = nextBlockOrExitContainer;
  43. JumpTableEntry = jumpTableEntry;
  44. ObjectVariable = objectVariable;
  45. }
  46. }
  47. public static void Run(ILFunction function, ILTransformContext context)
  48. {
  49. if (!context.Settings.AwaitInCatchFinally)
  50. return;
  51. HashSet<BlockContainer> changedContainers = new HashSet<BlockContainer>();
  52. HashSet<Block> removedBlocks = new HashSet<Block>();
  53. // analyze all try-catch statements in the function
  54. foreach (var tryCatch in function.Descendants.OfType<TryCatch>().ToArray())
  55. {
  56. if (!(tryCatch.Parent?.Parent is BlockContainer container))
  57. continue;
  58. // Detect all handlers that contain an await expression
  59. AnalyzeHandlers(tryCatch.Handlers, out var catchHandlerIdentifier, out var transformableCatchBlocks);
  60. var cfg = new ControlFlowGraph(container, context.CancellationToken);
  61. if (transformableCatchBlocks.Count > 0)
  62. changedContainers.Add(container);
  63. foreach (var result in transformableCatchBlocks)
  64. {
  65. removedBlocks.Clear();
  66. var node = cfg.GetNode(result.RealCatchBlockEntryPoint);
  67. context.StepStartGroup($"Inline catch block with await (at {result.Handler.Variable.Name})", result.Handler);
  68. // Remove the IfInstruction from the jump table and eliminate all branches to the block.
  69. if (result.JumpTableEntry is IfInstruction jumpTableEntry)
  70. {
  71. var jumpTableBlock = (Block)jumpTableEntry.Parent;
  72. context.Step("Remove jump-table entry", result.JumpTableEntry);
  73. jumpTableBlock.Instructions.RemoveAt(result.JumpTableEntry.ChildIndex);
  74. foreach (var branch in tryCatch.Descendants.OfType<Branch>())
  75. {
  76. if (branch.TargetBlock == jumpTableBlock)
  77. {
  78. if (result.NextBlockOrExitContainer is BlockContainer exitContainer)
  79. {
  80. context.Step("branch jumpTableBlock => leave exitContainer", branch);
  81. branch.ReplaceWith(new Leave(exitContainer));
  82. }
  83. else
  84. {
  85. context.Step("branch jumpTableBlock => branch nextBlock", branch);
  86. branch.ReplaceWith(new Branch((Block)result.NextBlockOrExitContainer));
  87. }
  88. }
  89. }
  90. }
  91. // Add the real catch block entry-point to the block container
  92. var catchBlockHead = ((BlockContainer)result.Handler.Body).Blocks.Last();
  93. result.RealCatchBlockEntryPoint.Remove();
  94. ((BlockContainer)result.Handler.Body).Blocks.Insert(0, result.RealCatchBlockEntryPoint);
  95. // Remove the generated catch block
  96. catchBlockHead.Remove();
  97. TransformAsyncThrowToThrow(context, removedBlocks, result.RealCatchBlockEntryPoint);
  98. // Inline all blocks that are dominated by the entrypoint of the real catch block
  99. foreach (var n in cfg.cfg)
  100. {
  101. Block block = (Block)n.UserData;
  102. if (node.Dominates(n))
  103. {
  104. TransformAsyncThrowToThrow(context, removedBlocks, block);
  105. if (block.Parent == result.Handler.Body)
  106. continue;
  107. if (!removedBlocks.Contains(block))
  108. {
  109. context.Step("Move block", result.Handler.Body);
  110. MoveBlock(block, (BlockContainer)result.Handler.Body);
  111. }
  112. }
  113. }
  114. // Remove unreachable pattern blocks
  115. // TODO : sanity check
  116. if (result.NextBlockOrExitContainer is Block nextBlock && nextBlock.IncomingEdgeCount == 0)
  117. {
  118. List<Block> dependentBlocks = new List<Block>();
  119. Block current = nextBlock;
  120. do
  121. {
  122. foreach (var branch in current.Descendants.OfType<Branch>())
  123. {
  124. dependentBlocks.Add(branch.TargetBlock);
  125. }
  126. current.Remove();
  127. dependentBlocks.Remove(current);
  128. current = dependentBlocks.FirstOrDefault(b => b.IncomingEdgeCount == 0);
  129. } while (current != null);
  130. }
  131. // Remove all assignments to the common object variable that stores the exception object.
  132. if (result.ObjectVariable != result.Handler.Variable)
  133. {
  134. foreach (var load in result.ObjectVariable.LoadInstructions.ToArray())
  135. {
  136. if (!load.IsDescendantOf(result.Handler))
  137. continue;
  138. if (load.Parent is CastClass cc && cc.Type.Equals(result.Handler.Variable.Type))
  139. {
  140. cc.ReplaceWith(new LdLoc(result.Handler.Variable).WithILRange(cc).WithILRange(load));
  141. }
  142. else
  143. {
  144. load.ReplaceWith(new LdLoc(result.Handler.Variable).WithILRange(load));
  145. }
  146. }
  147. }
  148. context.StepEndGroup(keepIfEmpty: true);
  149. }
  150. }
  151. // clean up all modified containers
  152. foreach (var container in changedContainers)
  153. container.SortBlocks(deleteUnreachableBlocks: true);
  154. }
  155. private static void TransformAsyncThrowToThrow(ILTransformContext context, HashSet<Block> removedBlocks, Block block)
  156. {
  157. ILVariable v = null;
  158. if (MatchExceptionCaptureBlock(context, block,
  159. ref v, out StLoc typedExceptionVariableStore,
  160. out Block captureBlock, out Block throwBlock))
  161. {
  162. context.Step($"ExceptionDispatchInfo.Capture({v.Name}).Throw() => throw;", typedExceptionVariableStore);
  163. block.Instructions.RemoveRange(typedExceptionVariableStore.ChildIndex + 1, 2);
  164. captureBlock.Remove();
  165. throwBlock.Remove();
  166. removedBlocks.Add(captureBlock);
  167. removedBlocks.Add(throwBlock);
  168. typedExceptionVariableStore.ReplaceWith(new Rethrow());
  169. }
  170. }
  171. static void MoveBlock(Block block, BlockContainer target)
  172. {
  173. block.Remove();
  174. target.Blocks.Add(block);
  175. }
  176. /// <summary>
  177. /// Analyzes all catch handlers and returns every handler that follows the await catch handler pattern.
  178. /// </summary>
  179. static bool AnalyzeHandlers(InstructionCollection<TryCatchHandler> handlers, out ILVariable catchHandlerIdentifier,
  180. out List<CatchBlockInfo> transformableCatchBlocks)
  181. {
  182. transformableCatchBlocks = new List<CatchBlockInfo>();
  183. catchHandlerIdentifier = null;
  184. foreach (var handler in handlers)
  185. {
  186. if (!MatchAwaitCatchHandler(handler, out int id, out var identifierVariable,
  187. out var realEntryPoint, out var nextBlockOrExitContainer, out var jumpTableEntry,
  188. out var objectVariable))
  189. {
  190. continue;
  191. }
  192. if (id < 1 || (catchHandlerIdentifier != null && identifierVariable != catchHandlerIdentifier))
  193. {
  194. continue;
  195. }
  196. catchHandlerIdentifier = identifierVariable;
  197. transformableCatchBlocks.Add(new(id, handler, realEntryPoint, nextBlockOrExitContainer, jumpTableEntry, objectVariable ?? handler.Variable));
  198. }
  199. return transformableCatchBlocks.Count > 0;
  200. }
  201. /// <summary>
  202. /// Matches the await catch handler pattern:
  203. /// [stloc V_3(ldloc E_100) - copy exception variable to a temporary]
  204. /// stloc V_6(ldloc V_3) - store exception in 'global' object variable
  205. /// stloc V_5(ldc.i4 2) - store id of catch block in 'identifierVariable'
  206. /// br IL_0075 - jump out of catch block to the head of the catch-handler jump table
  207. /// </summary>
  208. static bool MatchAwaitCatchHandler(TryCatchHandler handler, out int id, out ILVariable identifierVariable,
  209. out Block realEntryPoint, out ILInstruction nextBlockOrExitContainer,
  210. out ILInstruction jumpTableEntry, out ILVariable objectVariable)
  211. {
  212. id = 0;
  213. identifierVariable = null;
  214. realEntryPoint = null;
  215. jumpTableEntry = null;
  216. objectVariable = null;
  217. nextBlockOrExitContainer = null;
  218. var exceptionVariable = handler.Variable;
  219. var catchBlock = ((BlockContainer)handler.Body).EntryPoint;
  220. ILInstruction value;
  221. switch (catchBlock.Instructions.Count)
  222. {
  223. case 3:
  224. if (!catchBlock.Instructions[0].MatchStLoc(out objectVariable, out value))
  225. return false;
  226. if (!value.MatchLdLoc(exceptionVariable))
  227. return false;
  228. break;
  229. case 4:
  230. if (!catchBlock.Instructions[0].MatchStLoc(out var temporaryVariable, out value))
  231. return false;
  232. if (!value.MatchLdLoc(exceptionVariable))
  233. return false;
  234. if (!catchBlock.Instructions[1].MatchStLoc(out objectVariable, out value))
  235. return false;
  236. if (!value.MatchLdLoc(temporaryVariable))
  237. return false;
  238. break;
  239. default:
  240. // if the exception variable is not used at all (e.g., catch (Exception))
  241. // the "exception-variable-assignment" is omitted completely.
  242. // This can happen in optimized code.
  243. break;
  244. }
  245. if (!catchBlock.Instructions.Last().MatchBranch(out var jumpTableStartBlock))
  246. return false;
  247. var identifierVariableAssignment = catchBlock.Instructions.SecondToLastOrDefault();
  248. if (!identifierVariableAssignment.MatchStLoc(out identifierVariable, out value) || !value.MatchLdcI4(out id))
  249. return false;
  250. // analyze jump table:
  251. switch (jumpTableStartBlock.Instructions.Count)
  252. {
  253. case 3:
  254. // stloc identifierVariableCopy(identifierVariable)
  255. // if (comp(identifierVariable == id)) br realEntryPoint
  256. // br jumpTableEntryBlock
  257. if (!jumpTableStartBlock.Instructions[0].MatchStLoc(out var identifierVariableCopy, out var identifierVariableLoad)
  258. || !identifierVariableLoad.MatchLdLoc(identifierVariable))
  259. {
  260. return false;
  261. }
  262. return ParseIfJumpTable(id, jumpTableStartBlock, identifierVariableCopy, out realEntryPoint, out nextBlockOrExitContainer, out jumpTableEntry);
  263. case 2:
  264. // if (comp(identifierVariable == id)) br realEntryPoint
  265. // br jumpTableEntryBlock
  266. return ParseIfJumpTable(id, jumpTableStartBlock, identifierVariable, out realEntryPoint, out nextBlockOrExitContainer, out jumpTableEntry);
  267. case 1:
  268. if (jumpTableStartBlock.Instructions[0] is not SwitchInstruction switchInst)
  269. {
  270. return false;
  271. }
  272. return ParseSwitchJumpTable(id, switchInst, identifierVariable, out realEntryPoint, out nextBlockOrExitContainer, out jumpTableEntry);
  273. default:
  274. return false;
  275. }
  276. bool ParseSwitchJumpTable(int id, SwitchInstruction jumpTable, ILVariable identifierVariable, out Block realEntryPoint, out ILInstruction nextBlockOrExitContainer, out ILInstruction jumpTableEntry)
  277. {
  278. realEntryPoint = null;
  279. nextBlockOrExitContainer = null;
  280. jumpTableEntry = null;
  281. if (!jumpTable.Value.MatchLdLoc(identifierVariable))
  282. return false;
  283. var defaultSection = jumpTable.GetDefaultSection();
  284. foreach (var section in jumpTable.Sections)
  285. {
  286. if (!section.Labels.Contains(id))
  287. continue;
  288. if (!section.Body.MatchBranch(out realEntryPoint))
  289. return false;
  290. if (defaultSection.Body.MatchBranch(out var t))
  291. nextBlockOrExitContainer = t;
  292. else if (defaultSection.Body.MatchLeave(out var t2))
  293. nextBlockOrExitContainer = t2;
  294. jumpTableEntry = section;
  295. return true;
  296. }
  297. return false;
  298. }
  299. bool ParseIfJumpTable(int id, Block jumpTableEntryBlock, ILVariable identifierVariable, out Block realEntryPoint, out ILInstruction nextBlockOrExitContainer, out ILInstruction jumpTableEntry)
  300. {
  301. realEntryPoint = null;
  302. nextBlockOrExitContainer = null;
  303. jumpTableEntry = null;
  304. do
  305. {
  306. if (!(jumpTableEntryBlock.Instructions.SecondToLastOrDefault() is IfInstruction ifInst))
  307. return false;
  308. ILInstruction lastInst = jumpTableEntryBlock.Instructions.Last();
  309. if (ifInst.Condition.MatchCompEquals(out var left, out var right))
  310. {
  311. if (!ifInst.TrueInst.MatchBranch(out realEntryPoint))
  312. return false;
  313. if (!lastInst.MatchBranch(out jumpTableEntryBlock))
  314. {
  315. if (!lastInst.MatchLeave((BlockContainer)lastInst.Parent.Parent))
  316. return false;
  317. }
  318. }
  319. else if (ifInst.Condition.MatchCompNotEquals(out left, out right))
  320. {
  321. if (!lastInst.MatchBranch(out realEntryPoint))
  322. return false;
  323. if (!ifInst.TrueInst.MatchBranch(out jumpTableEntryBlock))
  324. {
  325. if (!ifInst.TrueInst.MatchLeave((BlockContainer)lastInst.Parent.Parent))
  326. return false;
  327. }
  328. }
  329. else
  330. {
  331. return false;
  332. }
  333. if (!left.MatchLdLoc(identifierVariable))
  334. return false;
  335. if (right.MatchLdcI4(id))
  336. {
  337. nextBlockOrExitContainer = jumpTableEntryBlock ?? lastInst.Parent.Parent;
  338. jumpTableEntry = ifInst;
  339. return true;
  340. }
  341. } while (jumpTableEntryBlock?.Instructions.Count == 2);
  342. return false;
  343. }
  344. }
  345. // Block beforeThrowBlock {
  346. // [before throw]
  347. // stloc typedExceptionVariable(isinst System.Exception(ldloc objectVariable))
  348. // if (comp.o(ldloc typedExceptionVariable != ldnull)) br captureBlock
  349. // br throwBlock
  350. // }
  351. //
  352. // Block throwBlock {
  353. // throw(ldloc objectVariable)
  354. // }
  355. //
  356. // Block captureBlock {
  357. // callvirt Throw(call Capture(ldloc typedExceptionVariable))
  358. // br nextBlock
  359. // }
  360. // =>
  361. // throw(ldloc result.Handler.Variable)
  362. internal static bool MatchExceptionCaptureBlock(ILTransformContext context, Block block,
  363. ref ILVariable objectVariable, out StLoc typedExceptionVariableStore, out Block captureBlock, out Block throwBlock)
  364. {
  365. bool DerivesFromException(IType t) => t.GetAllBaseTypes().Any(ty => ty.IsKnownType(KnownTypeCode.Exception));
  366. captureBlock = null;
  367. throwBlock = null;
  368. typedExceptionVariableStore = null;
  369. var typedExceptionVariableStLoc = block.Instructions.ElementAtOrDefault(block.Instructions.Count - 3) as StLoc;
  370. if (typedExceptionVariableStLoc == null
  371. || !typedExceptionVariableStLoc.Value.MatchIsInst(out var arg, out var type)
  372. || !DerivesFromException(type)
  373. || !arg.MatchLdLoc(out var v))
  374. {
  375. return false;
  376. }
  377. if (objectVariable == null)
  378. {
  379. objectVariable = v;
  380. }
  381. else if (!objectVariable.Equals(v))
  382. {
  383. return false;
  384. }
  385. typedExceptionVariableStore = typedExceptionVariableStLoc;
  386. if (!block.Instructions[block.Instructions.Count - 2].MatchIfInstruction(out var condition, out var trueInst))
  387. return false;
  388. ILInstruction lastInstr = block.Instructions.Last();
  389. if (!lastInstr.MatchBranch(out throwBlock))
  390. return false;
  391. if (condition.MatchCompNotEqualsNull(out arg)
  392. && trueInst is Branch branchToCapture)
  393. {
  394. if (!arg.MatchLdLoc(typedExceptionVariableStore.Variable))
  395. return false;
  396. captureBlock = branchToCapture.TargetBlock;
  397. }
  398. else
  399. {
  400. return false;
  401. }
  402. if (throwBlock.IncomingEdgeCount != 1
  403. || throwBlock.Instructions.Count != 1
  404. || !(throwBlock.Instructions[0].MatchThrow(out var ov) && ov.MatchLdLoc(objectVariable)))
  405. {
  406. return false;
  407. }
  408. if (captureBlock.IncomingEdgeCount != 1
  409. || captureBlock.Instructions.Count != 2
  410. || !MatchCaptureThrowCalls(captureBlock.Instructions[0]))
  411. {
  412. return false;
  413. }
  414. return true;
  415. bool MatchCaptureThrowCalls(ILInstruction inst)
  416. {
  417. var exceptionDispatchInfoType = context.TypeSystem.FindType(typeof(System.Runtime.ExceptionServices.ExceptionDispatchInfo));
  418. if (inst is not CallVirt callVirt || callVirt.Arguments.Count != 1)
  419. return false;
  420. if (callVirt.Arguments[0] is not Call call || call.Arguments.Count != 1
  421. || !call.Arguments[0].MatchLdLoc(typedExceptionVariableStLoc.Variable))
  422. {
  423. return false;
  424. }
  425. return callVirt.Method.Name == "Throw"
  426. && callVirt.Method.DeclaringType.Equals(exceptionDispatchInfoType)
  427. && call.Method.Name == "Capture"
  428. && call.Method.DeclaringType.Equals(exceptionDispatchInfoType);
  429. }
  430. }
  431. }
  432. }