|
13 | 13 | import static org.truffleruby.core.array.ArrayHelpers.setStoreAndSize;
|
14 | 14 | import static org.truffleruby.language.dispatch.DispatchNode.PUBLIC;
|
15 | 15 |
|
| 16 | +import java.util.ArrayDeque; |
16 | 17 | import java.util.Arrays;
|
17 | 18 |
|
18 | 19 | import com.oracle.truffle.api.dsl.CachedLanguage;
|
19 | 20 | import com.oracle.truffle.api.profiles.LoopConditionProfile;
|
| 21 | +import org.graalvm.collections.EconomicSet; |
| 22 | +import org.graalvm.collections.Equivalence; |
20 | 23 | import org.truffleruby.Layouts;
|
21 | 24 | import org.truffleruby.RubyLanguage;
|
22 | 25 | import org.truffleruby.builtins.CoreMethod;
|
@@ -2338,4 +2341,99 @@ protected boolean isStoreNative(RubyArray array,
|
2338 | 2341 | return stores.isNative(array.store);
|
2339 | 2342 | }
|
2340 | 2343 | }
|
| 2344 | + |
| 2345 | + @Primitive(name = "array_flatten_helper", lowerFixnum = 2) |
| 2346 | + public abstract static class FlattenHelperNode extends PrimitiveArrayArgumentsNode { |
| 2347 | + |
| 2348 | + @Specialization(guards = "!canContainObject.execute(array)") |
| 2349 | + protected boolean flattenHelperPrimitive(RubyArray array, RubyArray out, int maxLevels, |
| 2350 | + @Cached ArrayAppendManyNode concat, |
| 2351 | + @Cached TypeNodes.CanContainObjectNode canContainObject) { |
| 2352 | + concat.executeAppendMany(out, array); |
| 2353 | + return false; |
| 2354 | + } |
| 2355 | + |
| 2356 | + @Specialization(replaces = "flattenHelperPrimitive") |
| 2357 | + protected boolean flattenHelper(RubyArray array, RubyArray out, int maxLevels, |
| 2358 | + @CachedLanguage RubyLanguage language, |
| 2359 | + @Cached TypeNodes.CanContainObjectNode canContainObject, |
| 2360 | + @Cached ArrayAppendManyNode concat, |
| 2361 | + @Cached AtNode at, |
| 2362 | + @Cached DispatchNode convert, |
| 2363 | + @Cached ArrayAppendOneNode append) { |
| 2364 | + |
| 2365 | + boolean modified = false; |
| 2366 | + final EconomicSet<RubyArray> visited = EconomicSet.create(Equivalence.IDENTITY); |
| 2367 | + // final HashSet<RubyArray> visited = new HashSet<>(); |
| 2368 | + class Entry { |
| 2369 | + final RubyArray array; |
| 2370 | + final int index; |
| 2371 | + |
| 2372 | + Entry(RubyArray array, int index) { |
| 2373 | + this.array = array; |
| 2374 | + this.index = index; |
| 2375 | + } |
| 2376 | + } |
| 2377 | + final ArrayDeque<Entry> workStack = new ArrayDeque<>(); |
| 2378 | + workStack.push(new Entry(array, 0)); |
| 2379 | + |
| 2380 | + while (!workStack.isEmpty()) { |
| 2381 | + final Entry e = workStack.pop(); |
| 2382 | + |
| 2383 | + if (e.index == 0) { |
| 2384 | + if (!canContainObject.execute(e.array)) { |
| 2385 | + concat.executeAppendMany(out, e.array); |
| 2386 | + continue; |
| 2387 | + } else if (contains(visited, e.array)) { |
| 2388 | + throw new RaiseException( |
| 2389 | + getContext(), |
| 2390 | + coreExceptions().argumentError("tried to flatten recursive array", this)); |
| 2391 | + } else if (maxLevels == workStack.size()) { |
| 2392 | + concat.executeAppendMany(out, e.array); |
| 2393 | + continue; |
| 2394 | + } |
| 2395 | + add(visited, e.array); |
| 2396 | + } |
| 2397 | + |
| 2398 | + int i = e.index; |
| 2399 | + for (; i < e.array.size; ++i) { |
| 2400 | + final Object obj = at.executeAt(e.array, i); |
| 2401 | + final Object converted = convert.call( |
| 2402 | + coreLibrary().truffleTypeModule, |
| 2403 | + "rb_check_convert_type", |
| 2404 | + obj, |
| 2405 | + coreLibrary().arrayClass, |
| 2406 | + language.coreSymbols.TO_ARY); |
| 2407 | + if (converted == nil) { |
| 2408 | + append.executeAppendOne(out, obj); |
| 2409 | + } else { |
| 2410 | + modified = true; |
| 2411 | + workStack.push(new Entry(e.array, i + 1)); |
| 2412 | + workStack.push(new Entry((RubyArray) converted, 0)); |
| 2413 | + break; |
| 2414 | + } |
| 2415 | + } |
| 2416 | + if (i == e.array.size) { |
| 2417 | + remove(visited, e.array); |
| 2418 | + } |
| 2419 | + } |
| 2420 | + |
| 2421 | + return modified; |
| 2422 | + } |
| 2423 | + |
| 2424 | + @TruffleBoundary |
| 2425 | + private static boolean contains(EconomicSet<RubyArray> set, RubyArray array) { |
| 2426 | + return set.contains(array); |
| 2427 | + } |
| 2428 | + |
| 2429 | + @TruffleBoundary |
| 2430 | + private static void remove(EconomicSet<RubyArray> set, RubyArray array) { |
| 2431 | + set.remove(array); |
| 2432 | + } |
| 2433 | + |
| 2434 | + @TruffleBoundary |
| 2435 | + private static void add(EconomicSet<RubyArray> set, RubyArray array) { |
| 2436 | + set.add(array); |
| 2437 | + } |
| 2438 | + } |
2341 | 2439 | }
|
0 commit comments