001// Copyright (c) FIRST and other WPILib contributors. 002// Open Source Software; you can modify and/or share it under the terms of 003// the WPILib BSD license file in the root directory of this project. 004 005package edu.wpi.first.util.struct; 006 007import java.lang.reflect.Field; 008import java.lang.reflect.InvocationTargetException; 009import java.lang.reflect.Modifier; 010import java.lang.reflect.RecordComponent; 011import java.nio.ByteBuffer; 012import java.util.ArrayList; 013import java.util.HashMap; 014import java.util.List; 015 016/** A utility class for procedurally generating {@link Struct}s from records and enums. */ 017public final class StructGenerator { 018 private StructGenerator() { 019 throw new UnsupportedOperationException("This is a utility class!"); 020 } 021 022 /** 023 * A functional interface representing a method that retrives a value from a {@link ByteBuffer}. 024 */ 025 @FunctionalInterface 026 private interface Unpacker<T> { 027 T unpack(ByteBuffer buffer); 028 } 029 030 /** A functional interface representing a method that packs a value into a {@link ByteBuffer}. */ 031 @FunctionalInterface 032 private interface Packer<T> { 033 ByteBuffer pack(ByteBuffer buffer, T value); 034 035 static <T> Packer<T> fromStruct(Struct<T> struct) { 036 return (buffer, value) -> { 037 struct.pack(buffer, value); 038 return buffer; 039 }; 040 } 041 } 042 043 private record PrimType<T>(String name, int size, Unpacker<T> unpacker, Packer<T> packer) {} 044 045 /** A map of primitive types to their schema types. */ 046 private static final HashMap<Class<?>, PrimType<?>> primitiveTypeMap = new HashMap<>(); 047 048 private static <T> void addPrimType( 049 Class<T> boxedClass, 050 Class<T> primitiveClass, 051 String name, 052 int size, 053 Unpacker<T> unpacker, 054 Packer<T> packer) { 055 PrimType<T> primType = new PrimType<>(name, size, unpacker, packer); 056 primitiveTypeMap.put(boxedClass, primType); 057 primitiveTypeMap.put(primitiveClass, primType); 058 } 059 060 // Add primitive types to the map 061 static { 062 addPrimType( 063 Integer.class, int.class, "int32", Integer.BYTES, ByteBuffer::getInt, ByteBuffer::putInt); 064 addPrimType( 065 Double.class, 066 double.class, 067 "float64", 068 Double.BYTES, 069 ByteBuffer::getDouble, 070 ByteBuffer::putDouble); 071 addPrimType( 072 Float.class, 073 float.class, 074 "float32", 075 Float.BYTES, 076 ByteBuffer::getFloat, 077 ByteBuffer::putFloat); 078 addPrimType( 079 Boolean.class, 080 boolean.class, 081 "bool", 082 Byte.BYTES, 083 buffer -> buffer.get() != 0, 084 (buffer, value) -> buffer.put((byte) (value ? 1 : 0))); 085 addPrimType( 086 Character.class, 087 char.class, 088 "char", 089 Character.BYTES, 090 ByteBuffer::getChar, 091 ByteBuffer::putChar); 092 addPrimType(Byte.class, byte.class, "uint8", Byte.BYTES, ByteBuffer::get, ByteBuffer::put); 093 addPrimType( 094 Short.class, short.class, "int16", Short.BYTES, ByteBuffer::getShort, ByteBuffer::putShort); 095 addPrimType( 096 Long.class, long.class, "int64", Long.BYTES, ByteBuffer::getLong, ByteBuffer::putLong); 097 } 098 099 /** 100 * A map of types to their custom struct schemas. 101 * 102 * <p>This allows adding custom struct implementations for types that are not supported by 103 * default. Think of vendor-specific. 104 */ 105 private static final HashMap<Class<?>, Struct<?>> customStructTypeMap = new HashMap<>(); 106 107 /** 108 * Add a custom struct to the structifier. 109 * 110 * @param <T> The type the struct is for. 111 * @param clazz The class of the type. 112 * @param struct The struct to add. 113 * @param override Whether to override an existing struct. An existing struct could mean the type 114 * already has a {@code struct} field and implemnts {@link StructSerializable} or that the 115 * type is already in the custom struct map. 116 */ 117 public static <T> void addCustomStruct(Class<T> clazz, Struct<T> struct, boolean override) { 118 if (override) { 119 customStructTypeMap.put(clazz, struct); 120 } else if (!StructSerializable.class.isAssignableFrom(clazz)) { 121 customStructTypeMap.putIfAbsent(clazz, struct); 122 } 123 } 124 125 /** A utility for building schema syntax in a procedural manner. */ 126 @SuppressWarnings("PMD.AvoidStringBufferField") 127 public static class SchemaBuilder { 128 /** A utility for building enum fields in a procedural manner. */ 129 public static final class EnumFieldBuilder { 130 private final StringBuilder m_builder = new StringBuilder(); 131 private final String m_fieldName; 132 private boolean m_firstVariant = true; 133 134 /** 135 * Creates a new enum field builder. 136 * 137 * @param fieldName The name of the field. 138 */ 139 public EnumFieldBuilder(String fieldName) { 140 this.m_fieldName = fieldName; 141 m_builder.append("enum {"); 142 } 143 144 /** 145 * Adds a variant to the enum field. 146 * 147 * @param name The name of the variant. 148 * @param value The value of the variant. 149 * @return The builder for chaining. 150 */ 151 public EnumFieldBuilder addVariant(String name, int value) { 152 if (!m_firstVariant) { 153 m_builder.append(','); 154 } 155 m_firstVariant = false; 156 m_builder.append(name).append('=').append(value); 157 return this; 158 } 159 160 /** 161 * Builds the enum field. If this object is being used with {@link SchemaBuilder#addEnumField} 162 * then {@link #build()} does not have to be called by the user. 163 * 164 * @return The built enum field. 165 */ 166 public String build() { 167 m_builder.append("} int8 ").append(m_fieldName).append(';'); 168 return m_builder.toString(); 169 } 170 } 171 172 /** Creates a new schema builder. */ 173 public SchemaBuilder() {} 174 175 private final StringBuilder m_builder = new StringBuilder(); 176 177 /** 178 * Adds a field to the schema. 179 * 180 * @param name The name of the field. 181 * @param type The type of the field. 182 * @return The builder for chaining. 183 */ 184 public SchemaBuilder addField(String name, String type) { 185 m_builder.append(type).append(' ').append(name).append(';'); 186 return this; 187 } 188 189 /** 190 * Adds an inline enum field to the schema. 191 * 192 * @param enumFieldBuilder The builder for the enum field. 193 * @return The builder for chaining. 194 */ 195 public SchemaBuilder addEnumField(EnumFieldBuilder enumFieldBuilder) { 196 m_builder.append(enumFieldBuilder.build()); 197 return this; 198 } 199 200 /** 201 * Builds the schema. 202 * 203 * @return The built schema. 204 */ 205 public String build() { 206 return m_builder.toString(); 207 } 208 } 209 210 private static <T> Struct<T> noopStruct(Class<T> cls) { 211 return new Struct<>() { 212 @Override 213 public Class<T> getTypeClass() { 214 return cls; 215 } 216 217 @Override 218 public String getTypeName() { 219 return cls.getSimpleName(); 220 } 221 222 @Override 223 public String getSchema() { 224 return ""; 225 } 226 227 @Override 228 public int getSize() { 229 return 0; 230 } 231 232 @Override 233 public void pack(ByteBuffer buffer, T value) {} 234 235 @Override 236 public T unpack(ByteBuffer buffer) { 237 return null; 238 } 239 240 @Override 241 public boolean isImmutable() { 242 return true; 243 } 244 }; 245 } 246 247 /** 248 * Generates a {@link Struct} for the given {@link Record} class. If a {@link Struct} cannot be 249 * generated from the {@link Record}, the errors encountered will be printed and a no-op {@link 250 * Struct} will be returned. 251 * 252 * @param <R> The type of the record. 253 * @param recordClass The class of the record. 254 * @return The generated struct. 255 */ 256 @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"}) 257 public static <R extends Record> Struct<R> genRecord(final Class<R> recordClass) { 258 final RecordComponent[] components = recordClass.getRecordComponents(); 259 final SchemaBuilder schemaBuilder = new SchemaBuilder(); 260 final ArrayList<Struct<?>> nestedStructs = new ArrayList<>(); 261 final ArrayList<Unpacker<?>> unpackers = new ArrayList<>(); 262 final ArrayList<Packer<?>> packers = new ArrayList<>(); 263 264 int size = 0; 265 boolean failed = false; 266 267 for (final RecordComponent component : components) { 268 final Class<?> type = component.getType(); 269 final String name = component.getName(); 270 component.getAccessor().setAccessible(true); 271 272 if (primitiveTypeMap.containsKey(type)) { 273 PrimType<?> primType = primitiveTypeMap.get(type); 274 schemaBuilder.addField(name, primType.name); 275 size += primType.size; 276 unpackers.add(primType.unpacker); 277 packers.add(primType.packer); 278 } else { 279 Struct<?> struct; 280 if (customStructTypeMap.containsKey(type)) { 281 struct = customStructTypeMap.get(type); 282 } else if (StructSerializable.class.isAssignableFrom(type)) { 283 var optStruct = StructFetcher.fetchStructDynamic(type); 284 if (optStruct.isPresent()) { 285 struct = optStruct.get(); 286 } else { 287 System.err.println( 288 "Could not structify record component: " 289 + recordClass.getSimpleName() 290 + "#" 291 + name 292 + "\n Could not extract struct from marked class: " 293 + type.getName()); 294 failed = true; 295 continue; 296 } 297 } else { 298 System.err.println( 299 "Could not structify record component: " + recordClass.getSimpleName() + "#" + name); 300 failed = true; 301 continue; 302 } 303 schemaBuilder.addField(name, struct.getTypeName()); 304 size += struct.getSize(); 305 nestedStructs.add(struct); 306 nestedStructs.addAll(List.of(struct.getNested())); 307 unpackers.add(struct::unpack); 308 packers.add(Packer.fromStruct(struct)); 309 } 310 } 311 312 if (failed) { 313 return noopStruct(recordClass); 314 } 315 316 final int frozenSize = size; 317 final String schema = schemaBuilder.build(); 318 return new Struct<>() { 319 @Override 320 public Class<R> getTypeClass() { 321 return recordClass; 322 } 323 324 @Override 325 public String getTypeName() { 326 return recordClass.getSimpleName(); 327 } 328 329 @Override 330 public String getSchema() { 331 return schema; 332 } 333 334 @Override 335 public int getSize() { 336 return frozenSize; 337 } 338 339 @Override 340 public void pack(ByteBuffer buffer, R value) { 341 boolean failed = false; 342 int startingPosition = buffer.position(); 343 for (int i = 0; i < components.length; i++) { 344 Packer<Object> packer = (Packer<Object>) packers.get(i); 345 try { 346 Object componentValue = components[i].getAccessor().invoke(value); 347 if (componentValue == null) { 348 throw new IllegalArgumentException("Component is null"); 349 } 350 packer.pack(buffer, componentValue); 351 } catch (IllegalAccessException 352 | IllegalArgumentException 353 | InvocationTargetException e) { 354 System.err.println( 355 "Could not pack record component: " 356 + recordClass.getSimpleName() 357 + "#" 358 + components[i].getName() 359 + "\n " 360 + e.getMessage()); 361 failed = true; 362 break; 363 } 364 } 365 if (failed) { 366 buffer.position(startingPosition); 367 for (int i = 0; i < frozenSize; i++) { 368 buffer.put((byte) 0); 369 } 370 } 371 } 372 373 @Override 374 public R unpack(ByteBuffer buffer) { 375 try { 376 Object[] args = new Object[components.length]; 377 Class<?>[] argTypes = new Class<?>[components.length]; 378 for (int i = 0; i < components.length; i++) { 379 args[i] = unpackers.get(i).unpack(buffer); 380 argTypes[i] = components[i].getType(); 381 } 382 return recordClass.getConstructor(argTypes).newInstance(args); 383 } catch (InstantiationException 384 | IllegalAccessException 385 | InvocationTargetException 386 | NoSuchMethodException 387 | SecurityException e) { 388 System.err.println( 389 "Could not unpack record: " 390 + recordClass.getSimpleName() 391 + "\n " 392 + e.getMessage()); 393 return null; 394 } 395 } 396 397 @Override 398 public Struct<?>[] getNested() { 399 return nestedStructs.toArray(new Struct<?>[0]); 400 } 401 402 @Override 403 public boolean isImmutable() { 404 return true; 405 } 406 }; 407 } 408 409 /** 410 * Generates a {@link Struct} for the given {@link Enum} class. If a {@link Struct} cannot be 411 * generated from the {@link Enum}, the errors encountered will be printed and a no-op {@link 412 * Struct} will be returned. 413 * 414 * @param <E> The type of the enum. 415 * @param enumClass The class of the enum. 416 * @return The generated struct. 417 */ 418 @SuppressWarnings({"unchecked", "PMD.AvoidAccessibilityAlteration"}) 419 public static <E extends Enum<E>> Struct<E> genEnum(Class<E> enumClass) { 420 final E[] enumVariants = enumClass.getEnumConstants(); 421 final Field[] allEnumFields = enumClass.getDeclaredFields(); 422 final SchemaBuilder schemaBuilder = new SchemaBuilder(); 423 final SchemaBuilder.EnumFieldBuilder enumFieldBuilder = 424 new SchemaBuilder.EnumFieldBuilder("variant"); 425 final HashMap<Integer, E> enumMap = new HashMap<>(); 426 final ArrayList<Packer<?>> packers = new ArrayList<>(); 427 428 if (enumVariants == null || enumVariants.length == 0) { 429 System.err.println( 430 "Could not structify enum: " 431 + enumClass.getSimpleName() 432 + "\n " 433 + "Enum has no constants"); 434 return noopStruct(enumClass); 435 } 436 437 int size = 0; 438 439 for (final E constant : enumVariants) { 440 final String name = constant.name(); 441 final int ordinal = constant.ordinal(); 442 443 enumFieldBuilder.addVariant(name, ordinal); 444 enumMap.put(ordinal, constant); 445 } 446 schemaBuilder.addEnumField(enumFieldBuilder); 447 size += 1; 448 449 final List<Field> enumFields = 450 List.of(allEnumFields).stream() 451 .filter(f -> !f.isEnumConstant() && !Modifier.isStatic(f.getModifiers())) 452 .toList(); 453 454 boolean failed = false; 455 456 for (final Field field : enumFields) { 457 final Class<?> type = field.getType(); 458 final String name = field.getName(); 459 field.setAccessible(true); 460 461 if (primitiveTypeMap.containsKey(type)) { 462 PrimType<?> primType = primitiveTypeMap.get(type); 463 schemaBuilder.addField(name, primType.name); 464 size += primType.size; 465 packers.add(primType.packer); 466 } else { 467 Struct<?> struct; 468 if (customStructTypeMap.containsKey(type)) { 469 struct = customStructTypeMap.get(type); 470 } else if (StructSerializable.class.isAssignableFrom(type)) { 471 var optStruct = StructFetcher.fetchStructDynamic(type); 472 if (optStruct.isPresent()) { 473 struct = optStruct.get(); 474 } else { 475 System.err.println( 476 "Could not structify record component: " 477 + enumClass.getSimpleName() 478 + "#" 479 + name 480 + "\n Could not extract struct from marked class: " 481 + type.getName()); 482 failed = true; 483 continue; 484 } 485 } else { 486 System.err.println( 487 "Could not structify record component: " + enumClass.getSimpleName() + "#" + name); 488 failed = true; 489 continue; 490 } 491 schemaBuilder.addField(name, struct.getTypeName()); 492 size += struct.getSize(); 493 packers.add(Packer.fromStruct(struct)); 494 } 495 } 496 497 if (failed) { 498 return noopStruct(enumClass); 499 } 500 501 final int frozenSize = size; 502 final String schema = schemaBuilder.build(); 503 return new Struct<>() { 504 @Override 505 public Class<E> getTypeClass() { 506 return enumClass; 507 } 508 509 @Override 510 public String getTypeName() { 511 return enumClass.getSimpleName(); 512 } 513 514 @Override 515 public String getSchema() { 516 return schema; 517 } 518 519 @Override 520 public int getSize() { 521 return frozenSize; 522 } 523 524 @Override 525 public void pack(ByteBuffer buffer, E value) { 526 boolean failed = false; 527 int startingPosition = buffer.position(); 528 buffer.put((byte) value.ordinal()); 529 for (int i = 0; i < enumFields.size(); i++) { 530 Packer<Object> packer = (Packer<Object>) packers.get(i); 531 Field field = enumFields.get(i); 532 try { 533 Object fieldValue = field.get(value); 534 if (fieldValue == null) { 535 throw new IllegalArgumentException("Field is null"); 536 } 537 packer.pack(buffer, fieldValue); 538 } catch (IllegalArgumentException | IllegalAccessException e) { 539 System.err.println( 540 "Could not pack enum field: " 541 + enumClass.getSimpleName() 542 + "#" 543 + field.getName() 544 + "\n " 545 + e.getMessage()); 546 failed = true; 547 break; 548 } 549 } 550 if (failed) { 551 buffer.position(startingPosition); 552 for (int i = 0; i < frozenSize; i++) { 553 buffer.put((byte) 0); 554 } 555 } 556 } 557 558 final byte[] m_spongeBuffer = new byte[frozenSize - 1]; 559 560 @Override 561 public E unpack(ByteBuffer buffer) { 562 int ordinal = buffer.get(); 563 buffer.get(m_spongeBuffer); 564 return enumMap.getOrDefault(ordinal, null); 565 } 566 567 @Override 568 public boolean isImmutable() { 569 return true; 570 } 571 }; 572 } 573}