001// Copyright (c) FIRST and other WPILib contributors.
002// Open Source Software; you can modify and/or share it under the terms of
003// the WPILib BSD license file in the root directory of this project.
004
005package edu.wpi.first.util.struct;
006
007import java.lang.reflect.Field;
008import java.lang.reflect.InvocationTargetException;
009import java.lang.reflect.Modifier;
010import java.lang.reflect.RecordComponent;
011import java.nio.ByteBuffer;
012import java.util.ArrayList;
013import java.util.HashMap;
014import java.util.List;
015
016/** A utility class for procedurally generating {@link Struct}s from records and enums. */
017public final class StructGenerator {
018  private StructGenerator() {
019    throw new UnsupportedOperationException("This is a utility class!");
020  }
021
022  /**
023   * A functional interface representing a method that retrives a value from a {@link ByteBuffer}.
024   */
025  @FunctionalInterface
026  private interface Unpacker<T> {
027    T unpack(ByteBuffer buffer);
028  }
029
030  /** A functional interface representing a method that packs a value into a {@link ByteBuffer}. */
031  @FunctionalInterface
032  private interface Packer<T> {
033    ByteBuffer pack(ByteBuffer buffer, T value);
034
035    static <T> Packer<T> fromStruct(Struct<T> struct) {
036      return (buffer, value) -> {
037        struct.pack(buffer, value);
038        return buffer;
039      };
040    }
041  }
042
043  private record PrimType<T>(String name, int size, Unpacker<T> unpacker, Packer<T> packer) {}
044
045  /** A map of primitive types to their schema types. */
046  private static final HashMap<Class<?>, PrimType<?>> primitiveTypeMap = new HashMap<>();
047
048  private static <T> void addPrimType(
049      Class<T> boxedClass,
050      Class<T> primitiveClass,
051      String name,
052      int size,
053      Unpacker<T> unpacker,
054      Packer<T> packer) {
055    PrimType<T> primType = new PrimType<>(name, size, unpacker, packer);
056    primitiveTypeMap.put(boxedClass, primType);
057    primitiveTypeMap.put(primitiveClass, primType);
058  }
059
060  // Add primitive types to the map
061  static {
062    addPrimType(
063        Integer.class, int.class, "int32", Integer.BYTES, ByteBuffer::getInt, ByteBuffer::putInt);
064    addPrimType(
065        Double.class,
066        double.class,
067        "float64",
068        Double.BYTES,
069        ByteBuffer::getDouble,
070        ByteBuffer::putDouble);
071    addPrimType(
072        Float.class,
073        float.class,
074        "float32",
075        Float.BYTES,
076        ByteBuffer::getFloat,
077        ByteBuffer::putFloat);
078    addPrimType(
079        Boolean.class,
080        boolean.class,
081        "bool",
082        Byte.BYTES,
083        buffer -> buffer.get() != 0,
084        (buffer, value) -> buffer.put((byte) (value ? 1 : 0)));
085    addPrimType(
086        Character.class,
087        char.class,
088        "char",
089        Character.BYTES,
090        ByteBuffer::getChar,
091        ByteBuffer::putChar);
092    addPrimType(Byte.class, byte.class, "uint8", Byte.BYTES, ByteBuffer::get, ByteBuffer::put);
093    addPrimType(
094        Short.class, short.class, "int16", Short.BYTES, ByteBuffer::getShort, ByteBuffer::putShort);
095    addPrimType(
096        Long.class, long.class, "int64", Long.BYTES, ByteBuffer::getLong, ByteBuffer::putLong);
097  }
098
099  /**
100   * A map of types to their custom struct schemas.
101   *
102   * <p>This allows adding custom struct implementations for types that are not supported by
103   * default. Think of vendor-specific.
104   */
105  private static final HashMap<Class<?>, Struct<?>> customStructTypeMap = new HashMap<>();
106
107  /**
108   * Add a custom struct to the structifier.
109   *
110   * @param <T> The type the struct is for.
111   * @param clazz The class of the type.
112   * @param struct The struct to add.
113   * @param override Whether to override an existing struct. An existing struct could mean the type
114   *     already has a {@code struct} field and implemnts {@link StructSerializable} or that the
115   *     type is already in the custom struct map.
116   */
117  public static <T> void addCustomStruct(Class<T> clazz, Struct<T> struct, boolean override) {
118    if (override) {
119      customStructTypeMap.put(clazz, struct);
120    } else if (!StructSerializable.class.isAssignableFrom(clazz)) {
121      customStructTypeMap.putIfAbsent(clazz, struct);
122    }
123  }
124
125  /** A utility for building schema syntax in a procedural manner. */
126  @SuppressWarnings("PMD.AvoidStringBufferField")
127  public static class SchemaBuilder {
128    /** A utility for building enum fields in a procedural manner. */
129    public static final class EnumFieldBuilder {
130      private final StringBuilder m_builder = new StringBuilder();
131      private final String m_fieldName;
132      private boolean m_firstVariant = true;
133
134      /**
135       * Creates a new enum field builder.
136       *
137       * @param fieldName The name of the field.
138       */
139      public EnumFieldBuilder(String fieldName) {
140        this.m_fieldName = fieldName;
141        m_builder.append("enum {");
142      }
143
144      /**
145       * Adds a variant to the enum field.
146       *
147       * @param name The name of the variant.
148       * @param value The value of the variant.
149       * @return The builder for chaining.
150       */
151      public EnumFieldBuilder addVariant(String name, int value) {
152        if (!m_firstVariant) {
153          m_builder.append(',');
154        }
155        m_firstVariant = false;
156        m_builder.append(name).append('=').append(value);
157        return this;
158      }
159
160      /**
161       * Builds the enum field. If this object is being used with {@link SchemaBuilder#addEnumField}
162       * then {@link #build()} does not have to be called by the user.
163       *
164       * @return The built enum field.
165       */
166      public String build() {
167        m_builder.append("} int8 ").append(m_fieldName).append(';');
168        return m_builder.toString();
169      }
170    }
171
172    /** Creates a new schema builder. */
173    public SchemaBuilder() {}
174
175    private final StringBuilder m_builder = new StringBuilder();
176
177    /**
178     * Adds a field to the schema.
179     *
180     * @param name The name of the field.
181     * @param type The type of the field.
182     * @return The builder for chaining.
183     */
184    public SchemaBuilder addField(String name, String type) {
185      m_builder.append(type).append(' ').append(name).append(';');
186      return this;
187    }
188
189    /**
190     * Adds an inline enum field to the schema.
191     *
192     * @param enumFieldBuilder The builder for the enum field.
193     * @return The builder for chaining.
194     */
195    public SchemaBuilder addEnumField(EnumFieldBuilder enumFieldBuilder) {
196      m_builder.append(enumFieldBuilder.build());
197      return this;
198    }
199
200    /**
201     * Builds the schema.
202     *
203     * @return The built schema.
204     */
205    public String build() {
206      return m_builder.toString();
207    }
208  }
209
210  private static <T> Struct<T> noopStruct(Class<T> cls) {
211    return new Struct<>() {
212      @Override
213      public Class<T> getTypeClass() {
214        return cls;
215      }
216
217      @Override
218      public String getTypeName() {
219        return cls.getSimpleName();
220      }
221
222      @Override
223      public String getSchema() {
224        return "";
225      }
226
227      @Override
228      public int getSize() {
229        return 0;
230      }
231
232      @Override
233      public void pack(ByteBuffer buffer, T value) {}
234
235      @Override
236      public T unpack(ByteBuffer buffer) {
237        return null;
238      }
239
240      @Override
241      public boolean isImmutable() {
242        return true;
243      }
244    };
245  }
246
247  /**
248   * Generates a {@link Struct} for the given {@link Record} class. If a {@link Struct} cannot be
249   * generated from the {@link Record}, the errors encountered will be printed and a no-op {@link
250   * Struct} will be returned.
251   *
252   * @param <R> The type of the record.
253   * @param recordClass The class of the record.
254   * @return The generated struct.
255   */
256  @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"})
257  public static <R extends Record> Struct<R> genRecord(final Class<R> recordClass) {
258    final RecordComponent[] components = recordClass.getRecordComponents();
259    final SchemaBuilder schemaBuilder = new SchemaBuilder();
260    final ArrayList<Struct<?>> nestedStructs = new ArrayList<>();
261    final ArrayList<Unpacker<?>> unpackers = new ArrayList<>();
262    final ArrayList<Packer<?>> packers = new ArrayList<>();
263
264    int size = 0;
265    boolean failed = false;
266
267    for (final RecordComponent component : components) {
268      final Class<?> type = component.getType();
269      final String name = component.getName();
270      component.getAccessor().setAccessible(true);
271
272      if (primitiveTypeMap.containsKey(type)) {
273        PrimType<?> primType = primitiveTypeMap.get(type);
274        schemaBuilder.addField(name, primType.name);
275        size += primType.size;
276        unpackers.add(primType.unpacker);
277        packers.add(primType.packer);
278      } else {
279        Struct<?> struct;
280        if (customStructTypeMap.containsKey(type)) {
281          struct = customStructTypeMap.get(type);
282        } else if (StructSerializable.class.isAssignableFrom(type)) {
283          var optStruct = StructFetcher.fetchStructDynamic(type);
284          if (optStruct.isPresent()) {
285            struct = optStruct.get();
286          } else {
287            System.err.println(
288                "Could not structify record component: "
289                    + recordClass.getSimpleName()
290                    + "#"
291                    + name
292                    + "\n    Could not extract struct from marked class: "
293                    + type.getName());
294            failed = true;
295            continue;
296          }
297        } else {
298          System.err.println(
299              "Could not structify record component: " + recordClass.getSimpleName() + "#" + name);
300          failed = true;
301          continue;
302        }
303        schemaBuilder.addField(name, struct.getTypeName());
304        size += struct.getSize();
305        nestedStructs.add(struct);
306        nestedStructs.addAll(List.of(struct.getNested()));
307        unpackers.add(struct::unpack);
308        packers.add(Packer.fromStruct(struct));
309      }
310    }
311
312    if (failed) {
313      return noopStruct(recordClass);
314    }
315
316    final int frozenSize = size;
317    final String schema = schemaBuilder.build();
318    return new Struct<>() {
319      @Override
320      public Class<R> getTypeClass() {
321        return recordClass;
322      }
323
324      @Override
325      public String getTypeName() {
326        return recordClass.getSimpleName();
327      }
328
329      @Override
330      public String getSchema() {
331        return schema;
332      }
333
334      @Override
335      public int getSize() {
336        return frozenSize;
337      }
338
339      @Override
340      public void pack(ByteBuffer buffer, R value) {
341        boolean failed = false;
342        int startingPosition = buffer.position();
343        for (int i = 0; i < components.length; i++) {
344          Packer<Object> packer = (Packer<Object>) packers.get(i);
345          try {
346            Object componentValue = components[i].getAccessor().invoke(value);
347            if (componentValue == null) {
348              throw new IllegalArgumentException("Component is null");
349            }
350            packer.pack(buffer, componentValue);
351          } catch (IllegalAccessException
352              | IllegalArgumentException
353              | InvocationTargetException e) {
354            System.err.println(
355                "Could not pack record component: "
356                    + recordClass.getSimpleName()
357                    + "#"
358                    + components[i].getName()
359                    + "\n    "
360                    + e.getMessage());
361            failed = true;
362            break;
363          }
364        }
365        if (failed) {
366          buffer.position(startingPosition);
367          for (int i = 0; i < frozenSize; i++) {
368            buffer.put((byte) 0);
369          }
370        }
371      }
372
373      @Override
374      public R unpack(ByteBuffer buffer) {
375        try {
376          Object[] args = new Object[components.length];
377          Class<?>[] argTypes = new Class<?>[components.length];
378          for (int i = 0; i < components.length; i++) {
379            args[i] = unpackers.get(i).unpack(buffer);
380            argTypes[i] = components[i].getType();
381          }
382          return recordClass.getConstructor(argTypes).newInstance(args);
383        } catch (InstantiationException
384            | IllegalAccessException
385            | InvocationTargetException
386            | NoSuchMethodException
387            | SecurityException e) {
388          System.err.println(
389              "Could not unpack record: "
390                  + recordClass.getSimpleName()
391                  + "\n    "
392                  + e.getMessage());
393          return null;
394        }
395      }
396
397      @Override
398      public Struct<?>[] getNested() {
399        return nestedStructs.toArray(new Struct<?>[0]);
400      }
401
402      @Override
403      public boolean isImmutable() {
404        return true;
405      }
406    };
407  }
408
409  /**
410   * Generates a {@link Struct} for the given {@link Enum} class. If a {@link Struct} cannot be
411   * generated from the {@link Enum}, the errors encountered will be printed and a no-op {@link
412   * Struct} will be returned.
413   *
414   * @param <E> The type of the enum.
415   * @param enumClass The class of the enum.
416   * @return The generated struct.
417   */
418  @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"})
419  public static <E extends Enum<E>> Struct<E> genEnum(Class<E> enumClass) {
420    final E[] enumVariants = enumClass.getEnumConstants();
421    final Field[] allEnumFields = enumClass.getDeclaredFields();
422    final SchemaBuilder schemaBuilder = new SchemaBuilder();
423    final SchemaBuilder.EnumFieldBuilder enumFieldBuilder =
424        new SchemaBuilder.EnumFieldBuilder("variant");
425    final HashMap<Integer, E> enumMap = new HashMap<>();
426    final ArrayList<Packer<?>> packers = new ArrayList<>();
427
428    if (enumVariants == null || enumVariants.length == 0) {
429      System.err.println(
430          "Could not structify enum: "
431              + enumClass.getSimpleName()
432              + "\n    "
433              + "Enum has no constants");
434      return noopStruct(enumClass);
435    }
436
437    int size = 0;
438    boolean failed = false;
439
440    for (final E constant : enumVariants) {
441      final String name = constant.name();
442      final int ordinal = constant.ordinal();
443
444      enumFieldBuilder.addVariant(name, ordinal);
445      enumMap.put(ordinal, constant);
446    }
447    schemaBuilder.addEnumField(enumFieldBuilder);
448    size += 1;
449
450    final List<Field> enumFields =
451        List.of(allEnumFields).stream()
452            .filter(f -> !f.isEnumConstant() && !Modifier.isStatic(f.getModifiers()))
453            .toList();
454
455    for (final Field field : enumFields) {
456      final Class<?> type = field.getType();
457      final String name = field.getName();
458      field.setAccessible(true);
459
460      if (primitiveTypeMap.containsKey(type)) {
461        PrimType<?> primType = primitiveTypeMap.get(type);
462        schemaBuilder.addField(name, primType.name);
463        size += primType.size;
464        packers.add(primType.packer);
465      } else {
466        Struct<?> struct;
467        if (customStructTypeMap.containsKey(type)) {
468          struct = customStructTypeMap.get(type);
469        } else if (StructSerializable.class.isAssignableFrom(type)) {
470          var optStruct = StructFetcher.fetchStructDynamic(type);
471          if (optStruct.isPresent()) {
472            struct = optStruct.get();
473          } else {
474            System.err.println(
475                "Could not structify record component: "
476                    + enumClass.getSimpleName()
477                    + "#"
478                    + name
479                    + "\n    Could not extract struct from marked class: "
480                    + type.getName());
481            failed = true;
482            continue;
483          }
484        } else {
485          System.err.println(
486              "Could not structify record component: " + enumClass.getSimpleName() + "#" + name);
487          failed = true;
488          continue;
489        }
490        schemaBuilder.addField(name, struct.getTypeName());
491        size += struct.getSize();
492        packers.add(Packer.fromStruct(struct));
493      }
494    }
495
496    if (failed) {
497      return noopStruct(enumClass);
498    }
499
500    final int frozenSize = size;
501    final String schema = schemaBuilder.build();
502    return new Struct<>() {
503      @Override
504      public Class<E> getTypeClass() {
505        return enumClass;
506      }
507
508      @Override
509      public String getTypeName() {
510        return enumClass.getSimpleName();
511      }
512
513      @Override
514      public String getSchema() {
515        return schema;
516      }
517
518      @Override
519      public int getSize() {
520        return frozenSize;
521      }
522
523      @Override
524      public void pack(ByteBuffer buffer, E value) {
525        boolean failed = false;
526        int startingPosition = buffer.position();
527        buffer.put((byte) value.ordinal());
528        for (int i = 0; i < enumFields.size(); i++) {
529          Packer<Object> packer = (Packer<Object>) packers.get(i);
530          Field field = enumFields.get(i);
531          try {
532            Object fieldValue = field.get(value);
533            if (fieldValue == null) {
534              throw new IllegalArgumentException("Field is null");
535            }
536            packer.pack(buffer, fieldValue);
537          } catch (IllegalArgumentException | IllegalAccessException e) {
538            System.err.println(
539                "Could not pack enum field: "
540                    + enumClass.getSimpleName()
541                    + "#"
542                    + field.getName()
543                    + "\n    "
544                    + e.getMessage());
545            failed = true;
546            break;
547          }
548        }
549        if (failed) {
550          buffer.position(startingPosition);
551          for (int i = 0; i < frozenSize; i++) {
552            buffer.put((byte) 0);
553          }
554        }
555      }
556
557      final byte[] m_spongeBuffer = new byte[frozenSize - 1];
558
559      @Override
560      public E unpack(ByteBuffer buffer) {
561        int ordinal = buffer.get();
562        buffer.get(m_spongeBuffer);
563        return enumMap.getOrDefault(ordinal, null);
564      }
565
566      @Override
567      public boolean isImmutable() {
568        return true;
569      }
570    };
571  }
572}