7
7
import os
8
8
import pathlib
9
9
from collections import defaultdict
10
- from contextlib import redirect_stderr , redirect_stdout
10
+ from contextlib import contextmanager , redirect_stderr , redirect_stdout
11
11
from queue import Queue
12
12
from threading import Thread
13
13
from typing import Dict , List , Optional
@@ -248,6 +248,17 @@ def run_cell(self) -> List[IPytestResult]:
248
248
249
249
return test_results
250
250
251
+ @contextmanager
252
+ def traceback_handling (self , debug : bool ):
253
+ """Context manager to temporarily modify traceback behavior"""
254
+ original_traceback = self .shell ._showtraceback
255
+ try :
256
+ if not debug :
257
+ self .shell ._showtraceback = lambda * args , ** kwargs : None
258
+ yield
259
+ finally :
260
+ self .shell ._showtraceback = original_traceback
261
+
251
262
@cell_magic
252
263
def ipytest (self , line : str , cell : str ):
253
264
"""The `%%ipytest` cell magic"""
@@ -270,56 +281,53 @@ def ipytest(self, line: str, cell: str):
270
281
self .threaded = True
271
282
self .test_queue = Queue ()
272
283
273
- # If debug is in the line, then we want to show the traceback
274
- if self .debug :
275
- self .shell ._showtraceback = self ._orig_traceback
276
- else :
277
- self .shell ._showtraceback = lambda * args , ** kwargs : None
278
-
279
- # Get the module containing the test(s)
280
- if (
281
- module_name := get_module_name (
282
- " " .join (line_contents ), self .shell .user_global_ns
283
- )
284
- ) is None :
285
- raise TestModuleNotFoundError
284
+ with self .traceback_handling (self .debug ):
285
+ # Get the module containing the test(s)
286
+ if (
287
+ module_name := get_module_name (
288
+ " " .join (line_contents ), self .shell .user_global_ns
289
+ )
290
+ ) is None :
291
+ raise TestModuleNotFoundError
286
292
287
- self .module_name = module_name
293
+ self .module_name = module_name
288
294
289
- # Check that the test module file exists
290
- if not (
291
- module_file := pathlib .Path (f"tutorial/tests/test_{ self .module_name } .py" )
292
- ).exists ():
293
- raise FileNotFoundError (module_file )
295
+ # Check that the test module file exists
296
+ if not (
297
+ module_file := pathlib .Path (
298
+ f"tutorial/tests/test_{ self .module_name } .py"
299
+ )
300
+ ).exists ():
301
+ raise FileNotFoundError (module_file )
294
302
295
- self .module_file = module_file
303
<
D95F
span class="diff-text-marker">+ self .module_file = module_file
296
304
297
- # Run the cell
298
- results = self .run_cell ()
305
+ # Run the cell
306
+ results = self .run_cell ()
299
307
300
- # If in debug mode, display debug information first
301
- if self .debug :
302
- debug_output = DebugOutput (
303
- module_name = self .module_name ,
304
- module_file = self .module_file ,
305
- results = results ,
306
- )
307
- display (HTML (debug_output .to_html ()))
308
-
309
- # Parse the AST of the test module to retrieve the solution code
310
- ast_parser = AstParser (self .module_file )
311
- # Display the test results and the solution code
312
- for result in results :
313
- solution = (
314
- ast_parser .get_solution_code (result .function .name )
315
- if result .function and result .function .name
316
- else None
317
- )
318
- TestResultOutput (
319
- result ,
320
- solution ,
321
- self .shell .openai_client , # type: ignore
322
- ).display_results ()
308
+ # If in debug mode, display debug information first
309
+ if self .debug :
310
+ debug_output = DebugOutput (
311
+ module_name = self .module_name ,
312
+ module_file = self .module_file ,
313
+ results = results ,
314
+ )
315
+ display (HTML (debug_output .to_html ()))
316
+
317
+ # Parse the AST of the test module to retrieve the solution code
318
+ ast_parser = AstParser (self .module_file )
319
+ # Display the test results and the solution code
320
+ for result in results :
321
+ solution = (
322
+ ast_parser .get_solution_code (result .function .name )
323
+ if result .function and result .function .name
324
+ else None
325
+ )
326
+ TestResultOutput (
327
+ result ,
328
+ solution ,
329
+ self .shell .openai_client , # type: ignore
330
+ ).display_results ()
323
331
324
332
325
333
def load_ipython_extension (ipython ):
0 commit comments