001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.util.struct; 006 007import java.lang.reflect.Field; 008import java.lang.reflect.InvocationTargetException; 009import java.lang.reflect.Modifier; 010import java.lang.reflect.RecordComponent; 011import java.nio.ByteBuffer; 012import java.util.ArrayList; 013import java.util.HashMap; 014import java.util.List; 015 016/** A utility class for procedurally generating {@link Struct}s from records and enums. */ 017public final class StructGenerator { 018 private StructGenerator() { 019 throw new UnsupportedOperationException("This is a utility class!"); 020 } 021 022 /** 023 * A functional interface representing a method that retrives a value from a {@link ByteBuffer}. 024 */ 025 @FunctionalInterface 026 private interface Unpacker<T> { 027 T unpack(ByteBuffer buffer); 028 } 029 030 /** A functional interface representing a method that packs a value into a {@link ByteBuffer}. */ 031 @FunctionalInterface 032 private interface Packer<T> { 033 ByteBuffer pack(ByteBuffer buffer, T value); 034 035 static <T> Packer<T> fromStruct(Struct<T> struct) { 036 return (buffer, value) -> { 037 struct.pack(buffer, value); 038 return buffer; 039 }; 040 } 041 } 042 043 private record PrimType<T>(String name, int size, Unpacker<T> unpacker, Packer<T> packer) {} 044 045 /** A map of primitive types to their schema types. */ 046 private static final HashMap<Class<?>, PrimType<?>> primitiveTypeMap = new HashMap<>(); 047 048 private static <T> void addPrimType( 049 Class<T> boxedClass, 050 Class<T> primitiveClass, 051 String name, 052 int size, 053 Unpacker<T> unpacker, 054 Packer<T> packer) { 055 PrimType<T> primType = new PrimType<>(name, size, unpacker, packer); 056 primitiveTypeMap.put(boxedClass, primType); 057 primitiveTypeMap.put(primitiveClass, primType); 058 } 059 060 // Add primitive types to the map 061 static { 062 addPrimType( 063 Integer.class, int.class, "int32", Integer.BYTES, ByteBuffer::getInt, ByteBuffer::putInt); 064 addPrimType( 065 Double.class, 066 double.class, 067 "float64", 068 Double.BYTES, 069 ByteBuffer::getDouble, 070 ByteBuffer::putDouble); 071 addPrimType( 072 Float.class, 073 float.class, 074 "float32", 075 Float.BYTES, 076 ByteBuffer::getFloat, 077 ByteBuffer::putFloat); 078 addPrimType( 079 Boolean.class, 080 boolean.class, 081 "bool", 082 Byte.BYTES, 083 buffer -> buffer.get() != 0, 084 (buffer, value) -> buffer.put((byte) (value ? 1 : 0))); 085 addPrimType( 086 Character.class, 087 char.class, 088 "char", 089 Character.BYTES, 090 ByteBuffer::getChar, 091 ByteBuffer::putChar); 092 addPrimType(Byte.class, byte.class, "uint8", Byte.BYTES, ByteBuffer::get, ByteBuffer::put); 093 addPrimType( 094 Short.class, short.class, "int16", Short.BYTES, ByteBuffer::getShort, ByteBuffer::putShort); 095 addPrimType( 096 Long.class, long.class, "int64", Long.BYTES, ByteBuffer::getLong, ByteBuffer::putLong); 097 } 098 099 /** 100 * A map of types to their custom struct schemas. 101 * 102 * <p>This allows adding custom struct implementations for types that are not supported by 103 * default. Think of vendor-specific. 104 */ 105 private static final HashMap<Class<?>, Struct<?>> customStructTypeMap = new HashMap<>(); 106 107 /** 108 * Add a custom struct to the structifier. 109 * 110 * @param <T> The type the struct is for. 111 * @param clazz The class of the type. 112 * @param struct The struct to add. 113 * @param override Whether to override an existing struct. An existing struct could mean the type 114 * already has a {@code struct} field and implemnts {@link StructSerializable} or that the 115 * type is already in the custom struct map. 116 */ 117 public static <T> void addCustomStruct(Class<T> clazz, Struct<T> struct, boolean override) { 118 if (override) { 119 customStructTypeMap.put(clazz, struct); 120 } else if (!StructSerializable.class.isAssignableFrom(clazz)) { 121 customStructTypeMap.putIfAbsent(clazz, struct); 122 } 123 } 124 125 /** A utility for building schema syntax in a procedural manner. */ 126 @SuppressWarnings("PMD.AvoidStringBufferField") 127 public static class SchemaBuilder { 128 /** A utility for building enum fields in a procedural manner. */ 129 public static final class EnumFieldBuilder { 130 private final StringBuilder m_builder = new StringBuilder(); 131 private final String m_fieldName; 132 private boolean m_firstVariant = true; 133 134 /** 135 * Creates a new enum field builder. 136 * 137 * @param fieldName The name of the field. 138 */ 139 public EnumFieldBuilder(String fieldName) { 140 this.m_fieldName = fieldName; 141 m_builder.append("enum {"); 142 } 143 144 /** 145 * Adds a variant to the enum field. 146 * 147 * @param name The name of the variant. 148 * @param value The value of the variant. 149 * @return The builder for chaining. 150 */ 151 public EnumFieldBuilder addVariant(String name, int value) { 152 if (!m_firstVariant) { 153 m_builder.append(','); 154 } 155 m_firstVariant = false; 156 m_builder.append(name).append('=').append(value); 157 return this; 158 } 159 160 /** 161 * Builds the enum field. If this object is being used with {@link SchemaBuilder#addEnumField} 162 * then {@link #build()} does not have to be called by the user. 163 * 164 * @return The built enum field. 165 */ 166 public String build() { 167 m_builder.append("} int8 ").append(m_fieldName).append(';'); 168 return m_builder.toString(); 169 } 170 } 171 172 /** Creates a new schema builder. */ 173 public SchemaBuilder() {} 174 175 private final StringBuilder m_builder = new StringBuilder(); 176 177 /** 178 * Adds a field to the schema. 179 * 180 * @param name The name of the field. 181 * @param type The type of the field. 182 * @return The builder for chaining. 183 */ 184 public SchemaBuilder addField(String name, String type) { 185 m_builder.append(type).append(' ').append(name).append(';'); 186 return this; 187 } 188 189 /** 190 * Adds an inline enum field to the schema. 191 * 192 * @param enumFieldBuilder The builder for the enum field. 193 * @return The builder for chaining. 194 */ 195 public SchemaBuilder addEnumField(EnumFieldBuilder enumFieldBuilder) { 196 m_builder.append(enumFieldBuilder.build()); 197 return this; 198 } 199 200 /** 201 * Builds the schema. 202 * 203 * @return The built schema. 204 */ 205 public String build() { 206 return m_builder.toString(); 207 } 208 } 209 210 private static <T> Struct<T> noopStruct(Class<T> cls) { 211 return new Struct<>() { 212 @Override 213 public Class<T> getTypeClass() { 214 return cls; 215 } 216 217 @Override 218 public String getTypeName() { 219 return cls.getSimpleName(); 220 } 221 222 @Override 223 public String getSchema() { 224 return ""; 225 } 226 227 @Override 228 public int getSize() { 229 return 0; 230 } 231 232 @Override 233 public void pack(ByteBuffer buffer, T value) {} 234 235 @Override 236 public T unpack(ByteBuffer buffer) { 237 return null; 238 } 239 240 @Override 241 public boolean isImmutable() { 242 return true; 243 } 244 }; 245 } 246 247 /** 248 * Generates a {@link Struct} for the given {@link Record} class. If a {@link Struct} cannot be 249 * generated from the {@link Record}, the errors encountered will be printed and a no-op {@link 250 * Struct} will be returned. 251 * 252 * @param <R> The type of the record. 253 * @param recordClass The class of the record. 254 * @return The generated struct. 255 */ 256 @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"}) 257 public static <R extends Record> Struct<R> genRecord(final Class<R> recordClass) { 258 final RecordComponent[] components = recordClass.getRecordComponents(); 259 final SchemaBuilder schemaBuilder = new SchemaBuilder(); 260 final ArrayList<Struct<?>> nestedStructs = new ArrayList<>(); 261 final ArrayList<Unpacker<?>> unpackers = new ArrayList<>(); 262 final ArrayList<Packer<?>> packers = new ArrayList<>(); 263 264 int size = 0; 265 boolean failed = false; 266 267 for (final RecordComponent component : components) { 268 final Class<?> type = component.getType(); 269 final String name = component.getName(); 270 component.getAccessor().setAccessible(true); 271 272 if (primitiveTypeMap.containsKey(type)) { 273 PrimType<?> primType = primitiveTypeMap.get(type); 274 schemaBuilder.addField(name, primType.name); 275 size += primType.size; 276 unpackers.add(primType.unpacker); 277 packers.add(primType.packer); 278 } else { 279 Struct<?> struct; 280 if (customStructTypeMap.containsKey(type)) { 281 struct = customStructTypeMap.get(type); 282 } else if (StructSerializable.class.isAssignableFrom(type)) { 283 var optStruct = StructFetcher.fetchStructDynamic(type); 284 if (optStruct.isPresent()) { 285 struct = optStruct.get(); 286 } else { 287 System.err.println( 288 "Could not structify record component: " 289 + recordClass.getSimpleName() 290 + "#" 291 + name 292 + "\n Could not extract struct from marked class: " 293 + type.getName()); 294 failed = true; 295 continue; 296 } 297 } else { 298 System.err.println( 299 "Could not structify record component: " + recordClass.getSimpleName() + "#" + name); 300 failed = true; 301 continue; 302 } 303 schemaBuilder.addField(name, struct.getTypeName()); 304 size += struct.getSize(); 305 nestedStructs.add(struct); 306 nestedStructs.addAll(List.of(struct.getNested())); 307 unpackers.add(struct::unpack); 308 packers.add(Packer.fromStruct(struct)); 309 } 310 } 311 312 if (failed) { 313 return noopStruct(recordClass); 314 } 315 316 final int frozenSize = size; 317 final String schema = schemaBuilder.build(); 318 return new Struct<>() { 319 @Override 320 public Class<R> getTypeClass() { 321 return recordClass; 322 } 323 324 @Override 325 public String getTypeName() { 326 return recordClass.getSimpleName(); 327 } 328 329 @Override 330 public String getSchema() { 331 return schema; 332 } 333 334 @Override 335 public int getSize() { 336 return frozenSize; 337 } 338 339 @Override 340 public void pack(ByteBuffer buffer, R value) { 341 boolean failed = false; 342 int startingPosition = buffer.position(); 343 for (int i = 0; i < components.length; i++) { 344 Packer<Object> packer = (Packer<Object>) packers.get(i); 345 try { 346 Object componentValue = components[i].getAccessor().invoke(value); 347 if (componentValue == null) { 348 throw new IllegalArgumentException("Component is null"); 349 } 350 packer.pack(buffer, componentValue); 351 } catch (IllegalAccessException 352 | IllegalArgumentException 353 | InvocationTargetException e) { 354 System.err.println( 355 "Could not pack record component: " 356 + recordClass.getSimpleName() 357 + "#" 358 + components[i].getName() 359 + "\n " 360 + e.getMessage()); 361 failed = true; 362 break; 363 } 364 } 365 if (failed) { 366 buffer.position(startingPosition); 367 for (int i = 0; i < frozenSize; i++) { 368 buffer.put((byte) 0); 369 } 370 } 371 } 372 373 @Override 374 public R unpack(ByteBuffer buffer) { 375 try { 376 Object[] args = new Object[components.length]; 377 Class<?>[] argTypes = new Class<?>[components.length]; 378 for (int i = 0; i < components.length; i++) { 379 args[i] = unpackers.get(i).unpack(buffer); 380 argTypes[i] = components[i].getType(); 381 } 382 return recordClass.getConstructor(argTypes).newInstance(args); 383 } catch (InstantiationException 384 | IllegalAccessException 385 | InvocationTargetException 386 | NoSuchMethodException 387 | SecurityException e) { 388 System.err.println( 389 "Could not unpack record: " 390 + recordClass.getSimpleName() 391 + "\n " 392 + e.getMessage()); 393 return null; 394 } 395 } 396 397 @Override 398 public Struct<?>[] getNested() { 399 return nestedStructs.toArray(new Struct<?>[0]); 400 } 401 402 @Override 403 public boolean isImmutable() { 404 return true; 405 } 406 }; 407 } 408 409 /** 410 * Generates a {@link Struct} for the given {@link Enum} class. If a {@link Struct} cannot be 411 * generated from the {@link Enum}, the errors encountered will be printed and a no-op {@link 412 * Struct} will be returned. 413 * 414 * @param <E> The type of the enum. 415 * @param enumClass The class of the enum. 416 * @return The generated struct. 417 */ 418 @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"}) 419 public static <E extends Enum<E>> Struct<E> genEnum(Class<E> enumClass) { 420 final E[] enumVariants = enumClass.getEnumConstants(); 421 final Field[] allEnumFields = enumClass.getDeclaredFields(); 422 final SchemaBuilder schemaBuilder = new SchemaBuilder(); 423 final SchemaBuilder.EnumFieldBuilder enumFieldBuilder = 424 new SchemaBuilder.EnumFieldBuilder("variant"); 425 final HashMap<Integer, E> enumMap = new HashMap<>(); 426 final ArrayList<Packer<?>> packers = new ArrayList<>(); 427 428 if (enumVariants == null || enumVariants.length == 0) { 429 System.err.println( 430 "Could not structify enum: " 431 + enumClass.getSimpleName() 432 + "\n " 433 + "Enum has no constants"); 434 return noopStruct(enumClass); 435 } 436 437 int size = 0; 438 boolean failed = false; 439 440 for (final E constant : enumVariants) { 441 final String name = constant.name(); 442 final int ordinal = constant.ordinal(); 443 444 enumFieldBuilder.addVariant(name, ordinal); 445 enumMap.put(ordinal, constant); 446 } 447 schemaBuilder.addEnumField(enumFieldBuilder); 448 size += 1; 449 450 final List<Field> enumFields = 451 List.of(allEnumFields).stream() 452 .filter(f -> !f.isEnumConstant() && !Modifier.isStatic(f.getModifiers())) 453 .toList(); 454 455 for (final Field field : enumFields) { 456 final Class<?> type = field.getType(); 457 final String name = field.getName(); 458 field.setAccessible(true); 459 460 if (primitiveTypeMap.containsKey(type)) { 461 PrimType<?> primType = primitiveTypeMap.get(type); 462 schemaBuilder.addField(name, primType.name); 463 size += primType.size; 464 packers.add(primType.packer); 465 } else { 466 Struct<?> struct; 467 if (customStructTypeMap.containsKey(type)) { 468 struct = customStructTypeMap.get(type); 469 } else if (StructSerializable.class.isAssignableFrom(type)) { 470 var optStruct = StructFetcher.fetchStructDynamic(type); 471 if (optStruct.isPresent()) { 472 struct = optStruct.get(); 473 } else { 474 System.err.println( 475 "Could not structify record component: " 476 + enumClass.getSimpleName() 477 + "#" 478 + name 479 + "\n Could not extract struct from marked class: " 480 + type.getName()); 481 failed = true; 482 continue; 483 } 484 } else { 485 System.err.println( 486 "Could not structify record component: " + enumClass.getSimpleName() + "#" + name); 487 failed = true; 488 continue; 489 } 490 schemaBuilder.addField(name, struct.getTypeName()); 491 size += struct.getSize(); 492 packers.add(Packer.fromStruct(struct)); 493 } 494 } 495 496 if (failed) { 497 return noopStruct(enumClass); 498 } 499 500 final int frozenSize = size; 501 final String schema = schemaBuilder.build(); 502 return new Struct<>() { 503 @Override 504 public Class<E> getTypeClass() { 505 return enumClass; 506 } 507 508 @Override 509 public String getTypeName() { 510 return enumClass.getSimpleName(); 511 } 512 513 @Override 514 public String getSchema() { 515 return schema; 516 } 517 518 @Override 519 public int getSize() { 520 return frozenSize; 521 } 522 523 @Override 524 public void pack(ByteBuffer buffer, E value) { 525 boolean failed = false; 526 int startingPosition = buffer.position(); 527 buffer.put((byte) value.ordinal()); 528 for (int i = 0; i < enumFields.size(); i++) { 529 Packer<Object> packer = (Packer<Object>) packers.get(i); 530 Field field = enumFields.get(i); 531 try { 532 Object fieldValue = field.get(value); 533 if (fieldValue == null) { 534 throw new IllegalArgumentException("Field is null"); 535 } 536 packer.pack(buffer, fieldValue); 537 } catch (IllegalArgumentException | IllegalAccessException e) { 538 System.err.println( 539 "Could not pack enum field: " 540 + enumClass.getSimpleName() 541 + "#" 542 + field.getName() 543 + "\n " 544 + e.getMessage()); 545 failed = true; 546 break; 547 } 548 } 549 if (failed) { 550 buffer.position(startingPosition); 551 for (int i = 0; i < frozenSize; i++) { 552 buffer.put((byte) 0); 553 } 554 } 555 } 556 557 final byte[] m_spongeBuffer = new byte[frozenSize - 1]; 558 559 @Override 560 public E unpack(ByteBuffer buffer) { 561 int ordinal = buffer.get(); 562 buffer.get(m_spongeBuffer); 563 return enumMap.getOrDefault(ordinal, null); 564 } 565 566 @Override 567 public boolean isImmutable() { 568 return true; 569 } 570 }; 571 } 572}