Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

import dev.snowdrop.treesitter4j.util.LanguageDetector;
import io.roastedroot.treesitter.Language;
import io.roastedroot.treesitter.TreeSitter;
import io.roastedroot.treesitter.TreeSitterException;
import io.roastedroot.treesitter.TreeSitterParser;
import io.roastedroot.treesitter.TreeSitterTree;
import io.roastedroot.treesitter.TreeSitterPool;
import io.roastedroot.treesitter.ast.ASTExporter;
import io.roastedroot.treesitter.ast.ASTJsonSerializer;
import io.roastedroot.treesitter.ast.ASTTree;
Expand All @@ -18,23 +15,32 @@
import org.eclipse.microprofile.config.ConfigProvider;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileVisitOption;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HexFormat;
import java.util.List;
import java.util.Optional;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

@CommandDefinition(name = "parse", description = "Parse files from a directory and persist AST nodes as JSON")
@CommandDefinition(name = "parse", description = "Parse files from a directory and persist AST nodes as JSON files under the store")
public class ParseCommand implements Command<CommandInvocation> {

private static final String STORE_DIR = ".ast-store";
private static final String STORE_DIR = ".ts4j";

@Argument(description = "Path to the project directory to parse", required = true)
private String projectPath;
Expand All @@ -49,21 +55,22 @@ public CommandResult execute(CommandInvocation invocation) {
return CommandResult.FAILURE;
}

// Collect supported files
List<Path> sourceFiles;
// Collect supported files grouped by language
Map<Language, List<Path>> sourceFilesByLanguage;
try {
sourceFiles = findSourceFiles(rootDir);
sourceFilesByLanguage = findSourceFiles(rootDir);
} catch (IOException e) {
invocation.println("Error walking directory: " + e.getMessage());
return CommandResult.FAILURE;
}

if (sourceFiles.isEmpty()) {
if (sourceFilesByLanguage.isEmpty()) {
invocation.println("No supported source files found under " + rootDir);
return CommandResult.SUCCESS;
}

invocation.println("Found " + sourceFiles.size() + " source file(s). Parsing...");
int totalFiles = sourceFilesByLanguage.values().stream().mapToInt(List::size).sum();
invocation.println("Found " + totalFiles + " source file(s) in " + sourceFilesByLanguage.size() + " language(s). Parsing...");

// Prepare output directory
Path storeDir = rootDir.resolve(STORE_DIR);
Expand All @@ -74,63 +81,89 @@ public CommandResult execute(CommandInvocation invocation) {
return CommandResult.FAILURE;
}

int successCount = 0;
int errorCount = 0;
AtomicInteger successCount = new AtomicInteger();
AtomicInteger errorCount = new AtomicInteger();

try (TreeSitter ts = TreeSitter.create();
TreeSitterParser parser = ts.newParser()) {
// Create language directories upfront (single-threaded, cheap)
for (Language lang : sourceFilesByLanguage.keySet()) {
try {
Files.createDirectories(storeDir.resolve(lang.name().toLowerCase()));
} catch (IOException e) {
invocation.println(" ERROR creating language directory for " + lang + ": " + e.getMessage());
}
}

for (Path file : sourceFiles) {
Optional<Language> langOpt = LanguageDetector.detect(file);
if (langOpt.isEmpty()) {
continue;
}
Language lang = langOpt.get();
// Flatten all (language, file) pairs into a list of tasks
record ParseTask(Language lang, Path file) {}
List<ParseTask> tasks = new ArrayList<>();
for (var entry : sourceFilesByLanguage.entrySet()) {
for (Path file : entry.getValue()) {
tasks.add(new ParseTask(entry.getKey(), file));
}
}

try {
String source = Files.readString(file);
int poolSize = Runtime.getRuntime().availableProcessors();
try (var pool = new TreeSitterPool(poolSize);
ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor()) {

List<Future<?>> futures = new ArrayList<>(tasks.size());
for (ParseTask task : tasks) {
futures.add(executor.submit(() -> {
try {
parser.setLanguage(lang);
} catch (TreeSitterException e) {
invocation.println(" WARN: language " + lang + " not supported at runtime, skipping " + relativize(rootDir, file));
continue;
pool.execute(ts -> {
try (var parser = ts.newParser()) {
parser.setLanguage(task.lang());
String source = Files.readString(task.file());

try (var tree = parser.parseString(source)) {
if (tree == null) {
invocation.println(" WARN: failed to parse " + relativize(rootDir, task.file()));
errorCount.incrementAndGet();
return;
}

ASTTree ast = ASTExporter.export(tree, task.lang(), source, relativize(rootDir, task.file()));

String relPath = rootDir.relativize(task.file()).toString();
String sha = sha256(relPath);
Path langDir = storeDir.resolve(task.lang().name().toLowerCase());
Path jsonFile = langDir.resolve(sha + ".json");
ASTJsonSerializer.toJson(ast, jsonFile);

successCount.incrementAndGet();
}
} catch (Exception e) {
invocation.println(" ERROR parsing " + relativize(rootDir, task.file()) + ": " + e.getMessage());
errorCount.incrementAndGet();
}
});
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
invocation.println(" ERROR: interrupted while parsing " + relativize(rootDir, task.file()));
errorCount.incrementAndGet();
}
}));
}

try (TreeSitterTree tree = parser.parseString(source)) {
if (tree == null) {
invocation.println(" WARN: failed to parse " + relativize(rootDir, file));
errorCount++;
continue;
}

ASTTree ast = ASTExporter.export(tree, lang, source, relativize(rootDir, file));

// Write JSON file mirroring the source path
Path relPath = rootDir.relativize(file);
Path jsonFile = storeDir.resolve(relPath + ".json");
Files.createDirectories(jsonFile.getParent());
ASTJsonSerializer.toJson(ast, jsonFile);

successCount++;
}
// Wait for all tasks to complete
for (Future<?> future : futures) {
try {
future.get();
} catch (Exception e) {
invocation.println(" ERROR parsing " + relativize(rootDir, file) + ": " + e.getMessage());
e.printStackTrace();
errorCount++;
// Errors already counted inside the task
}
}
}

long elapsedMs = (System.nanoTime() - startTime) / 1_000_000;
invocation.println("Parsing complete: " + successCount + " succeeded, " + errorCount + " failed.");
invocation.println("Parsing complete: " + successCount.get() + " succeeded, " + errorCount.get() + " failed.");
invocation.println("AST store saved to " + storeDir);
invocation.println("Elapsed time: " + elapsedMs + " ms");

return CommandResult.SUCCESS;
}

private List<Path> findSourceFiles(Path root) throws IOException {
private Map<Language, List<Path>> findSourceFiles(Path root) throws IOException {
List<String> excludePatterns = ConfigProvider.getConfig()
.getOptionalValue("ts4j.parser.exclude-dirs", String.class)
.map(s -> Arrays.stream(s.split(","))
Expand All @@ -139,7 +172,7 @@ private List<Path> findSourceFiles(Path root) throws IOException {
.toList())
.orElse(List.of());

List<Path> files = new ArrayList<>();
Map<Language, List<Path>> filesByLanguage = new HashMap<>();
Files.walkFileTree(root, EnumSet.of(FileVisitOption.FOLLOW_LINKS),
Integer.MAX_VALUE, new SimpleFileVisitor<>() {

Expand All @@ -154,9 +187,8 @@ public FileVisitResult preVisitDirectory(@Nonnull Path dir, @Nonnull BasicFileAt

@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) {
if (LanguageDetector.detect(file).isPresent()) {
files.add(file);
}
LanguageDetector.detect(file).ifPresent(lang ->
filesByLanguage.computeIfAbsent(lang, k -> new ArrayList<>()).add(file));
return FileVisitResult.CONTINUE;
}

Expand All @@ -165,7 +197,7 @@ public FileVisitResult visitFileFailed(Path file, IOException exc) {
return FileVisitResult.CONTINUE;
}
});
return files;
return filesByLanguage;
}

private boolean shouldExclude(String name, List<String> patterns) {
Expand All @@ -188,4 +220,14 @@ private String relativize(Path root, Path file) {
return file.toString();
}
}

private static String sha256(String input) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
byte[] hash = digest.digest(input.getBytes(StandardCharsets.UTF_8));
return HexFormat.of().formatHex(hash);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("SHA-256 not available", e);
}
}
}