diff --git a/standardize_test_format/__init__.py b/standardize_test_format/__init__.py index ab13f2c..15fed5b 100644 --- a/standardize_test_format/__init__.py +++ b/standardize_test_format/__init__.py @@ -74,6 +74,14 @@ def replace_test_pattern(match: re.Match, with_javadoc: bool, with_display_name: with_javadoc=with_javadoc, with_display_name=with_display_name) +IMPORT_DISPLAY_NAME = 'import org.junit.jupiter.api.DisplayName;' + def standardize_java_text(text: str, with_javadoc: bool, with_display_name: bool): - return TEST_PATTERN.sub(lambda m: replace_test_pattern(m, with_javadoc, with_display_name), text) + text = TEST_PATTERN.sub(lambda m: replace_test_pattern(m, with_javadoc, with_display_name), text) + if '@DisplayName' in text and IMPORT_DISPLAY_NAME not in text: + lines = text.split('\n') + lines.insert(1, IMPORT_DISPLAY_NAME) + text = '\n'.join(lines) + del lines + return text