Skip to content

Added support for @Inject to work on constructors #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
@Target({ElementType.FIELD, ElementType.CONSTRUCTOR})
public @interface Inject {}
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

import java.lang.reflect.Method;

public record SlashCommandClassMethod(Class<?> clazz, Method method) {}
public record SlashCommandClassMethod(Class<?> clazz, Method method, Object instance) {}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

public class ComponentLoader {
private static final Logger LOGGER = LogManager.getLogger(ComponentLoader.class);
private static final Map<Class<?>, Object> COMPONENTS = new HashMap<>();
public static final Map<Class<?>, Object> COMPONENTS = new HashMap<>();
private final ComponentValidator componentValidator = new ComponentValidator();

public void loadComponents() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import java.io.File;
import java.lang.reflect.Constructor;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;

import com.javadiscord.jdi.core.annotations.EventListener;
import com.javadiscord.jdi.internal.exceptions.NoZeroArgConstructorException;
import com.javadiscord.jdi.internal.processor.ClassFileUtil;
import com.javadiscord.jdi.internal.processor.validator.EventListenerValidator;

Expand All @@ -16,9 +17,11 @@ public class ListenerLoader {
private static final Logger LOGGER = LogManager.getLogger(ListenerLoader.class);
private final EventListenerValidator eventListenerValidator = new EventListenerValidator();
private final List<Object> eventListeners;
private final ComponentLoader componentLoader;

public ListenerLoader(List<Object> eventListeners) {
public ListenerLoader(List<Object> eventListeners, ComponentLoader componentLoader) {
this.eventListeners = eventListeners;
this.componentLoader = componentLoader;
try {
loadListeners();
} catch (Exception e) {
Expand Down Expand Up @@ -46,7 +49,13 @@ public void loadListeners() {

private void registerListener(Class<?> clazz) {
try {
Object instance = getZeroArgConstructor(clazz).newInstance();
Constructor<?> constructor = clazz.getConstructors()[0];
Object instance;
if (constructor.getParameterCount() > 0) {
instance = constructor.newInstance(getConstructorParameters(constructor).toArray());
} else {
instance = constructor.newInstance();
}
ComponentLoader.injectComponents(instance);
eventListeners.add(instance);
LOGGER.info("Registered listener {}", clazz.getName());
Expand All @@ -55,6 +64,19 @@ private void registerListener(Class<?> clazz) {
}
}

private List<Object> getConstructorParameters(Constructor<?> constructor) {
List<Object> constructorParameters = new ArrayList<>();
for (Parameter parameter : constructor.getParameters()) {
if (ComponentLoader.COMPONENTS.containsKey(parameter.getType())) {
constructorParameters.add(ComponentLoader.COMPONENTS.get(parameter.getType()));
} else {
constructorParameters.add(null);
LOGGER.warn("No component found for {}", parameter.getType());
}
}
return constructorParameters;
}

public boolean validateListener(Class<?> clazz) {
return eventListenerValidator.validate(clazz);
}
Expand All @@ -66,8 +88,6 @@ public static Constructor<?> getZeroArgConstructor(Class<?> clazz) {
return constructor;
}
}
throw new NoZeroArgConstructorException(
"No zero arg constructor found for " + clazz.getName()
);
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import java.io.File;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -34,9 +36,10 @@ private void loadInteractionListeners() {
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (method.getAnnotation(SlashCommand.class) != null) {
if (validator.validate(method) && hasZeroArgsConstructor(clazz)) {
if (validator.validate(method)) {
registerListener(
clazz, method, method.getAnnotation(SlashCommand.class).name()
clazz, method, method.getAnnotation(SlashCommand.class).name(),
createInstance(clazz)
);
} else {
throw new ValidationException(method.getName() + " failed validation");
Expand All @@ -49,7 +52,36 @@ private void loadInteractionListeners() {
}
}

private void registerListener(Class<?> clazz, Method method, String name) {
private Object createInstance(Class<?> clazz) {
Object instance = null;
try {
Constructor<?> constructor = clazz.getConstructors()[0];
if (constructor.getParameterCount() > 0) {
instance = constructor.newInstance(getConstructorParameters(constructor).toArray());
} else {
instance = constructor.newInstance();
}
injectComponents(instance);
} catch (Exception e) {
LOGGER.error("Failed to create {} instance", clazz.getName(), e);
}
return instance;
}

private List<Object> getConstructorParameters(Constructor<?> constructor) {
List<Object> constructorParameters = new ArrayList<>();
for (Parameter parameter : constructor.getParameters()) {
if (ComponentLoader.COMPONENTS.containsKey(parameter.getType())) {
constructorParameters.add(ComponentLoader.COMPONENTS.get(parameter.getType()));
} else {
constructorParameters.add(null);
LOGGER.warn("No component found for {}", parameter.getType());
}
}
return constructorParameters;
}

private void registerListener(Class<?> clazz, Method method, String name, Object instance) {
try {
if (interactionListeners.containsKey(name)) {
LOGGER.error(
Expand All @@ -60,7 +92,7 @@ private void registerListener(Class<?> clazz, Method method, String name) {
);
return;
}
interactionListeners.put(name, new SlashCommandClassMethod(clazz, method));
interactionListeners.put(name, new SlashCommandClassMethod(clazz, method, instance));
LOGGER.info("Found slash command handler {}", clazz.getName());
} catch (Exception e) {
LOGGER.error("Failed to create {} instance", clazz.getName(), e);
Expand All @@ -71,17 +103,6 @@ public SlashCommandClassMethod getSlashCommandClassMethod(String name) {
return interactionListeners.get(name);
}

private boolean hasZeroArgsConstructor(Class<?> clazz) {
Constructor<?>[] constructors = clazz.getConstructors();
for (Constructor<?> constructor : constructors) {
if (constructor.getParameterCount() == 0) {
return true;
}
}
LOGGER.error("{} does not have a 0 arg constructor", clazz.getName());
return false;
}

public void injectComponents(Object object) {
ComponentLoader.injectComponents(object);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.javadiscord.jdi.internal.processor.validator;

import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -388,21 +387,6 @@ public class EventListenerValidator {
}

public boolean validate(Class<?> clazz) {
return hasZeroArgsConstructor(clazz) && validateMethods(clazz);
}

public boolean hasZeroArgsConstructor(Class<?> clazz) {
Constructor<?>[] constructors = clazz.getConstructors();
for (Constructor<?> constructor : constructors) {
if (constructor.getParameterCount() == 0) {
return true;
}
}
LOGGER.error("{} does not have a 0 arg constructor", clazz.getName());
return false;
}

private boolean validateMethods(Class<?> clazz) {
Method[] methods = clazz.getMethods();
for (Method method : methods) {
if (!validateMethodAnnotations(method)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@

class EventListenerValidatorTest {

@Test
void testValidationFailsWhenZeroArgConstructorDoesNotExist() {
class Test {}

EventListenerValidator eventListenerValidator = new EventListenerValidator();
assertFalse(eventListenerValidator.validate(Test.class));
}

public static class ClassWithConstructor {
public ClassWithConstructor() {}
}
Expand Down
12 changes: 5 additions & 7 deletions core/src/main/java/com/javadiscord/jdi/core/Discord.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class Discord {
private WebSocketManager webSocketManager;
private long applicationId;
private boolean started = false;
private Object componentLoader;

public Discord(String botToken) {
this(
Expand Down Expand Up @@ -236,9 +237,10 @@ private void loadComponents() {
ReflectiveComponentLoader componentLoader = null;
for (Constructor<?> constructor : clazz.getConstructors()) {
if (constructor.getParameterCount() == 0) {
this.componentLoader = constructor.newInstance();
componentLoader =
ReflectiveLoader
.proxy(constructor.newInstance(), ReflectiveComponentLoader.class);
.proxy(this.componentLoader, ReflectiveComponentLoader.class);
}
}
if (componentLoader != null) {
Expand All @@ -257,12 +259,8 @@ private void loadAnnotations() {
Class<?> clazz =
Class.forName(Constants.LISTENER_LOADER_CLASS);
for (Constructor<?> constructor : clazz.getConstructors()) {
if (constructor.getParameterCount() == 1) {
Parameter parameters = constructor.getParameters()[0];
if (parameters.getType().equals(List.class)) {
constructor.newInstance(annotatedEventListeners);
return;
}
if (constructor.getParameterCount() == 2) {
constructor.newInstance(annotatedEventListeners, componentLoader);
}
}
} catch (Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ public class InteractionEventHandler implements EventListener {
private final Object slashCommandLoader;
private final Discord discord;

private final Map<String, Object> cachedInstances = new HashMap<>();

public InteractionEventHandler(Object slashCommandLoader, Discord discord) {
this.slashCommandLoader = slashCommandLoader;
this.discord = discord;
Expand All @@ -43,13 +41,12 @@ public void onInteractionCreate(Interaction interaction, Guild guild) {
ReflectiveSlashCommandClassMethod.class
);

Class<?> handler = reflectiveSlashCommandClassMethod.clazz();
Method method = reflectiveSlashCommandClassMethod.method();

Object instance = reflectiveSlashCommandClassMethod.instance();
List<Object> paramOrder = getOrderOfParameters(method, interaction);

if (validateParameterCount(method, paramOrder)) {
invokeHandler(handler, method, paramOrder);
invokeHandler(method, paramOrder, instance);
} else {
throw new InstantiationException(
"Bound " + paramOrder.size() + " parameters but expected "
Expand Down Expand Up @@ -86,31 +83,16 @@ private boolean validateParameterCount(Method method, List<Object> paramOrder) {
}

private void invokeHandler(
Class<?> handler,
Method method,
List<Object> paramOrder
List<Object> paramOrder,
Object instance
) throws InstantiationException {
try {
if (cachedInstances.containsKey(handler.getName())) {
method.invoke(cachedInstances.get(handler.getName()), paramOrder.toArray());
} else {
Object handlerInstance = handler.getDeclaredConstructor().newInstance();
cachedInstances.put(handler.getName(), handlerInstance);
injectComponents(handlerInstance);
method.invoke(handlerInstance, paramOrder.toArray());
}
method.invoke(instance, paramOrder.toArray());
} catch (
InvocationTargetException | IllegalAccessException | NoSuchMethodException
| InstantiationException e
InvocationTargetException | IllegalAccessException e
) {
throw new InstantiationException(e.getLocalizedMessage());
}
}

private void injectComponents(Object object) {
ReflectiveSlashCommandLoader reflectiveSlashCommandLoader =
ReflectiveLoader.proxy(slashCommandLoader, ReflectiveSlashCommandLoader.class);

reflectiveSlashCommandLoader.injectComponents(object);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ public interface ReflectiveSlashCommandClassMethod {
Class<?> clazz();

Method method();

Object instance();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
import com.javadiscord.jdi.core.models.message.embed.EmbedAuthor;

public class ChatGPTCommand {
private final ChatGPT chatGPT;

@Inject
private ChatGPT chatGPT;
public ChatGPTCommand(ChatGPT chatGPT) {
this.chatGPT = chatGPT;
}

@SlashCommand(
name = "chatgpt", description = "Ask ChatGPT a question", options = {
Expand Down
Loading