Skip to content

Commit

Permalink
rework node3d graph and gltf loading to optimize model reuse and inst…
Browse files Browse the repository at this point in the history
…ancing
  • Loading branch information
LeHaine committed Jan 17, 2025
1 parent 2665a23 commit bdaa969
Show file tree
Hide file tree
Showing 17 changed files with 311 additions and 222 deletions.
16 changes: 10 additions & 6 deletions core/src/commonMain/kotlin/com/littlekt/file/gltf/GltfJson.kt
Original file line number Diff line number Diff line change
Expand Up @@ -574,19 +574,23 @@ data class GltfTexture(val sampler: Int = -1, val source: Int = 0, val name: Str
imageRef.bufferViewRef?.getData()?.toArray()?.readPixmap()
?: error("Unable to read GltfTexture data!")
}

val minFilters = samplerRef.minFilter.toFilterMode()
val magFilters = samplerRef.magFilter.toFilterMode()
texture =
PixmapTexture(device, preferredFormat, pixmap).apply {
val minFilters = samplerRef.minFilter.toFilterMode()
val magFilters = samplerRef.magFilter.toFilterMode()
PixmapTexture(
device,
preferredFormat,
pixmap,
samplerDescriptor =
samplerDescriptor.copy(
SamplerDescriptor(
addressModeU = samplerRef.wrapS.toAddressMode(),
addressModeV = samplerRef.wrapT.toAddressMode(),
minFilter = minFilters.first,
magFilter = magFilters.first,
mipmapFilter = minFilters.second,
)
}
),
)
}
return texture ?: error("Unable to convert the GltfTexture to a Texture!")
}
Expand Down
276 changes: 153 additions & 123 deletions core/src/commonMain/kotlin/com/littlekt/file/gltf/GltfModelGenerator.kt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.littlekt.graphics.g2d.font.*
import com.littlekt.graphics.g2d.tilemap.ldtk.LDtkWorld
import com.littlekt.graphics.g2d.tilemap.tiled.TiledMap
import com.littlekt.graphics.g3d.Model
import com.littlekt.graphics.g3d.Scene
import com.littlekt.graphics.webgpu.SamplerDescriptor
import com.littlekt.graphics.webgpu.TextureFormat
import com.littlekt.math.MutableVec4i
Expand Down Expand Up @@ -382,7 +383,7 @@ suspend fun VfsFile.readGltfModel(
preferredFormat: TextureFormat =
if (vfs.context.graphics.preferredFormat.srgb) TextureFormat.RGBA8_UNORM_SRGB
else TextureFormat.RGBA8_UNORM,
): Model {
): Scene {
val gltfData = readGltf()
return gltfData.toModel(config, preferredFormat)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ import com.littlekt.log.Logger
* @author Colton Daily
* @date 11/25/2024
*/
open class MeshNode(
open class MeshPrimitive(
val mesh: Mesh<*>,
val material: Material,
val topology: PrimitiveTopology = PrimitiveTopology.TRIANGLE_LIST,
val stripIndexFormat: IndexFormat? = null,
instanceSize: Int = 0,
) : Node3D(), Releasable {
) : Releasable {

/**
* A pre-cast of the underlying [mesh] as an [IndexedMesh], if applicable. If this value is
Expand All @@ -32,14 +32,14 @@ open class MeshNode(

/** if `true` then [instanceCount] is `> 1`; `false` otherwise. */
val isInstanced: Boolean
get() = instances.isNotEmpty()
get() = instances.size > 1

/**
* The size of all instances plus this. This will have a minimum value of `1` since it includes
* this MeshNode as an instance.
* The size of all instances. If this is `0` then it is not drawing. A size `> 1` will be marked
* as instanced.
*/
val instanceCount: Int
get() = instances.size + 1
get() = instances.size

val instanceBuffers: InstanceBuffers =
InstanceBuffers(mesh.device, (instanceSize + 1) * TRANSFORM_COMPONENTS_PER_INSTANCE)
Expand All @@ -64,10 +64,10 @@ open class MeshNode(
}
}

override fun dirty() {
super.dirty()
instancesDirty = true
}
// override fun dirty() {
// super.dirty()
// instancesDirty = true
// }

fun instanceDirty(instance: VisualInstance) {
if (dirtyInstances.contains(instance)) return
Expand All @@ -80,12 +80,12 @@ open class MeshNode(
dirtyInstances.forEach { dirtyInstance -> updateInstance(dirtyInstance) }
dirtyInstances.clear()
// ensure we set this MeshNode transform as first instance
instanceData.put(
globalTransform.data,
dstOffset = 0,
srcOffset = 0,
len = TRANSFORM_COMPONENTS_PER_INSTANCE,
)
// instanceData.put(
// globalTransform.data,
// dstOffset = 0,
// srcOffset = 0,
// len = TRANSFORM_COMPONENTS_PER_INSTANCE,
// )
instanceBuffers.updateStaticStorage(instanceData)
instancesDirty = false
}
Expand Down Expand Up @@ -122,7 +122,6 @@ open class MeshNode(
dirtyInstances -= instance
instancesToId.remove(instance)
instance.instanceOf = null
instance.remove()
instancesDirty = true
}
}
Expand Down Expand Up @@ -218,7 +217,7 @@ open class MeshNode(

companion object {
private const val TRANSFORM_COMPONENTS_PER_INSTANCE = 16
private val logger: Logger = Logger<MeshNode>()
private val logger: Logger = Logger<MeshPrimitive>()
}
}

Expand Down
23 changes: 1 addition & 22 deletions core/src/commonMain/kotlin/com/littlekt/graphics/g3d/Model.kt
Original file line number Diff line number Diff line change
@@ -1,28 +1,7 @@
package com.littlekt.graphics.g3d

import com.littlekt.math.Mat4

/**
* @author Colton Daily
* @date 11/24/2024
*/
open class Model : Node3D() {
val nodes = mutableMapOf<String, Node3D>()
val meshes = mutableMapOf<String, MeshNode>()
val skins = mutableListOf<Skin>()

val instances = mutableListOf<ModelInstance>()

fun createModelInstance(): ModelInstance {
val modelInstance = ModelInstance(this)
meshes.values.forEach { mesh ->
val meshInstance = VisualInstance()
meshInstance.globalTransform = Mat4().set(mesh.globalTransform)
mesh.addInstance(meshInstance)
modelInstance += meshInstance
modelInstance.meshInstances += meshInstance
}
instances += modelInstance
return modelInstance
}
}
class Model(val primitives: List<MeshPrimitive>)
52 changes: 32 additions & 20 deletions core/src/commonMain/kotlin/com/littlekt/graphics/g3d/ModelBatch.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.littlekt.graphics.g3d

import com.littlekt.EngineStats
import com.littlekt.Releasable
import com.littlekt.graphics.Camera
import com.littlekt.graphics.g3d.material.Material
Expand All @@ -10,6 +11,7 @@ import com.littlekt.graphics.g3d.util.MaterialPipelineSorter
import com.littlekt.graphics.util.BindingUsage
import com.littlekt.graphics.webgpu.*
import com.littlekt.log.Logger
import kotlin.math.max
import kotlin.reflect.KClass
import kotlin.time.Duration

Expand All @@ -21,7 +23,7 @@ class ModelBatch(val device: Device, size: Int = 128) : Releasable {
private val pipelineProviders: MutableMap<KClass<out Material>, MaterialPipelineProvider> =
mutableMapOf()
private val pipelines = mutableListOf<MaterialPipeline>()
private val meshNodesByPipeline = mutableMapOf<MaterialPipeline, MutableList<MeshNode>>()
private val primitivesByPipeline = mutableMapOf<MaterialPipeline, MutableList<MeshPrimitive>>()

/** By material id */
private val bindGroupByMaterialId = mutableMapOf<Int, BindGroup>()
Expand Down Expand Up @@ -57,37 +59,41 @@ class ModelBatch(val device: Device, size: Int = 128) : Releasable {
fun removePipelineProvider(provider: BaseMaterialPipelineProvider<*>) =
removePipelineProvider(provider.type)

fun render(model: Model, environment: Environment) {
model.meshes.values.forEach { render(it, environment) }
fun render(scene: Scene, environment: Environment) {
scene.modelInstances.forEach { render(it, environment) }
}

fun render(instance: MeshNode, environment: Environment) {
fun render(model: ModelInstance, environment: Environment) {
model.instanceOf.primitives.forEach { render(it, environment) }
}

fun render(meshPrimitive: MeshPrimitive, environment: Environment) {
val pipeline =
pipelineProviders[instance.material::class]?.getMaterialPipeline(
pipelineProviders[meshPrimitive.material::class]?.getMaterialPipeline(
device,
instance.material,
meshPrimitive.material,
environment,
instance.mesh.geometry.layout,
instance.topology,
instance.stripIndexFormat,
meshPrimitive.mesh.geometry.layout,
meshPrimitive.topology,
meshPrimitive.stripIndexFormat,
colorFormat,
depthFormat,
) ?: error("Unable to find pipeline for given instance!")
if (!pipelines.contains(pipeline)) {
pipelines += pipeline
}
// todo - pool lists?
meshNodesByPipeline.getOrPut(pipeline) { mutableListOf() }.apply { add(instance) }
primitivesByPipeline.getOrPut(pipeline) { mutableListOf() }.apply { add(meshPrimitive) }

bindGroupByMaterialId.getOrPut(instance.material.id) {
instance.material.createBindGroup(pipeline.shader)
bindGroupByMaterialId.getOrPut(meshPrimitive.material.id) {
meshPrimitive.material.createBindGroup(pipeline.shader)
}
}

fun flush(renderPassEncoder: RenderPassEncoder, camera: Camera, dt: Duration) {
sorter.sort(pipelines)
var lastEnvironmentSet: Environment? = null
var lastMaterialSet: Material? = null
var lastMaterialSet: Int? = null
pipelines.forEach { pipeline ->
// we only need to update the camera buffers in each environment once. So if we are
// sharing environment, just update the first instance of it.
Expand All @@ -102,24 +108,24 @@ class ModelBatch(val device: Device, size: Int = 128) : Releasable {
BindingUsage.CAMERA,
)
}
val meshNodes = meshNodesByPipeline[pipeline]
if (!meshNodes.isNullOrEmpty()) {
val primitive = primitivesByPipeline[pipeline]
if (!primitive.isNullOrEmpty()) {
renderPassEncoder.setPipeline(pipeline.renderPipeline)
meshNodes.forEach { meshNode ->
primitive.forEach { meshNode ->
val materialBindGroup =
bindGroupByMaterialId[meshNode.material.id]
?: error(
"Material (${meshNode.material.id}) bind groups could not be found!"
)
if (lastMaterialSet != meshNode.material) {
lastMaterialSet = meshNode.material
if (lastMaterialSet != meshNode.material.id) {
lastMaterialSet = meshNode.material.id
pipeline.shader.setBindGroup(
renderPassEncoder,
materialBindGroup,
BindingUsage.MATERIAL,
)
meshNode.material.update()
}
meshNode.material.update()
meshNode.writeInstanceDataToBuffer()
pipeline.shader.setBindGroup(
renderPassEncoder,
Expand All @@ -140,19 +146,23 @@ class ModelBatch(val device: Device, size: Int = 128) : Releasable {
renderPassEncoder.setVertexBuffer(0, mesh.vbo)

if (indexedMesh != null) {
EngineStats.extra(INSTANCED_STAT_NAME, max(0, meshNode.instanceCount - 1))
EngineStats.extra(DRAW_CALLS_STAT_NAME, 1)
renderPassEncoder.drawIndexed(
indexedMesh.geometry.numIndices,
meshNode.instanceCount,
)
} else {
EngineStats.extra(INSTANCED_STAT_NAME, max(0, meshNode.instanceCount - 1))
EngineStats.extra(DRAW_CALLS_STAT_NAME, 1)
renderPassEncoder.draw(mesh.geometry.numVertices, meshNode.instanceCount)
}
}
}
}

pipelines.clear()
meshNodesByPipeline.clear()
primitivesByPipeline.clear()
updatedEnvironments.clear()
}

Expand All @@ -163,5 +173,7 @@ class ModelBatch(val device: Device, size: Int = 128) : Releasable {

companion object {
private val logger = Logger<ModelBatch>()
private const val INSTANCED_STAT_NAME = "ModelBatch instanced count"
private const val DRAW_CALLS_STAT_NAME = "ModelBatch Draw calls"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ package com.littlekt.graphics.g3d
* @author Colton Daily
* @date 1/15/2025
*/
class ModelInstance(val instanceOf: Model) : Node3D() {
val meshInstances = mutableListOf<VisualInstance>()
open class ModelInstance(val instanceOf: Model) : Node3D() {

init {
instanceOf.primitives.forEach { prim -> addChild(VisualInstance().apply { addTo(prim) }) }
}

fun createInstance() = ModelInstance(instanceOf)
}
17 changes: 17 additions & 0 deletions core/src/commonMain/kotlin/com/littlekt/graphics/g3d/Node3D.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ open class Node3D {
val children: List<Node3D>
get() = _children

/** If destroy was called, this will be true until the next time node's are processed. */
val isDestroyed
get() = _isDestroyed

/**
* Global transform. Don't call `globalTransform.set` directly, the data won't be marked dirty.
* Set the globalTransform directly with `globalTransform = myMat4`.
Expand Down Expand Up @@ -355,6 +359,8 @@ open class Node3D {
private var _translationMatrix = Mat4()
private var _scaleMatrix = Mat4()

private var _isDestroyed = false

protected open fun dirty() {
dirty = true
children.forEach { it.propagateDirty() }
Expand Down Expand Up @@ -767,6 +773,17 @@ open class Node3D {
return this
}

fun destroy() {
_isDestroyed = true
while (children.isNotEmpty()) {
children[0].destroy()
}
onDestroy()
}

/** Called when [destroy] is invoked and all of its children have been destroyed. */
open fun onDestroy() = Unit

/** @return a tree string for all the child nodes under this [Node3D]. */
fun treeString(): String {
val builder = StringBuilder()
Expand Down
19 changes: 19 additions & 0 deletions core/src/commonMain/kotlin/com/littlekt/graphics/g3d/Scene.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.littlekt.graphics.g3d

/**
* @author Colton Daily
* @date 1/17/2025
*/
open class Scene : Node3D() {
var modelInstances = mutableListOf<ModelInstance>()
var skins = mutableListOf<Skin>()

fun createInstance(): Scene {
val newInstance = Scene()
val newModelInstances = modelInstances.map { it.createInstance() }
newInstance.modelInstances += newModelInstances
newModelInstances.forEach { newInstance += it }
newInstance.skins = skins.toMutableList()
return newInstance
}
}
Loading

0 comments on commit bdaa969

Please sign in to comment.