Skip to content

Commit

Permalink
msggen: Add model-side overrides
Browse files Browse the repository at this point in the history
Sometimes we just want to paper over the schema directly. Mostly
useful to sidestep the `oneof` things that are required for
expressiveness.
  • Loading branch information
cdecker authored and rustyrussell committed Apr 1, 2022
1 parent 1f40db3 commit ec5cd92
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 28 deletions.
2 changes: 1 addition & 1 deletion cln-grpc/proto/primitives.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum ChannelState {

message ChannelStateChangeCause {}

message Utxo {
message Outpoint {
bytes txid = 1;
uint32 outnum = 2;
}
Expand Down
14 changes: 7 additions & 7 deletions cln-grpc/src/pb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ tonic::include_proto!("cln");

use cln_rpc::primitives::{
Amount as JAmount, AmountOrAll as JAmountOrAll, AmountOrAny as JAmountOrAny,
Feerate as JFeerate, OutputDesc as JOutputDesc, Utxo as JUtxo,
Feerate as JFeerate, OutputDesc as JOutputDesc, Outpoint as JOutpoint,
};

impl From<JAmount> for Amount {
Expand All @@ -17,18 +17,18 @@ impl From<&Amount> for JAmount {
}
}

impl From<JUtxo> for Utxo {
fn from(a: JUtxo) -> Self {
Utxo {
impl From<JOutpoint> for Outpoint {
fn from(a: JOutpoint) -> Self {
Outpoint {
txid: a.txid,
outnum: a.outnum,
}
}
}

impl From<&Utxo> for JUtxo {
fn from(a: &Utxo) -> Self {
JUtxo {
impl From<&Outpoint> for JOutpoint {
fn from(a: &Outpoint) -> Self {
JOutpoint {
txid: a.txid.clone(),
outnum: a.outnum,
}
Expand Down
8 changes: 4 additions & 4 deletions cln-rpc/src/primitives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ impl Amount {
}

#[derive(Clone, Debug, PartialEq)]
pub struct Utxo {
pub struct Outpoint {
pub txid: Vec<u8>,
pub outnum: u32,
}

impl Serialize for Utxo {
impl Serialize for Outpoint {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
Expand All @@ -83,7 +83,7 @@ impl Serialize for Utxo {
}
}

impl<'de> Deserialize<'de> for Utxo {
impl<'de> Deserialize<'de> for Outpoint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
Expand All @@ -102,7 +102,7 @@ impl<'de> Deserialize<'de> for Utxo {
.parse()
.map_err(|e| Error::custom(format!("{} is not a valid number: {}", s, e)))?;

Ok(Utxo { txid, outnum })
Ok(Outpoint { txid, outnum })
}
}

Expand Down
11 changes: 9 additions & 2 deletions contrib/msggen/msggen/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
'u16': 'uint32', # Yeah, I know...
'f32': 'float',
'integer': 'sint64',
"utxo": "Utxo",
"outpoint": "Outpoint",
"feerate": "Feerate",
"outputdesc": "OutputDesc",
}


Expand All @@ -41,6 +42,7 @@
'ListTransactions.transactions[].type[]': None,
}


method_name_overrides = {
"Connect": "ConnectPeer",
}
Expand Down Expand Up @@ -373,7 +375,12 @@ def generate_composite(self, prefix, field: CompositeField) -> None:
for f in field.fields:
name = f.normalized()
if isinstance(f, ArrayField):
self.write(f"{name}: c.{name}.iter().map(|s| s.into()).collect(),\n", numindent=3)
typ = f.itemtype.typename
mapping = {
'hex': f'hex::decode(s).unwrap()',
'u32': f's.clone()',
}.get(typ, f's.into()')
self.write(f"{name}: c.{name}.iter().map(|s| {mapping}).collect(),\n", numindent=3)

elif isinstance(f, EnumField):
if f.required:
Expand Down
33 changes: 25 additions & 8 deletions contrib/msggen/msggen/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Union, Optional
import logging
from copy import copy

logger = logging.getLogger(__name__)

Expand All @@ -18,7 +19,7 @@ def normalized(self):
"type": "item_type"
}.get(self.name, self.name)

name = name.replace(' ', '_').replace('-', '_')
name = name.replace(' ', '_').replace('-', '_').replace('[]', '')
return name

def __str__(self):
Expand Down Expand Up @@ -133,8 +134,12 @@ def from_js(cls, js, path):
logger.warning(f"Unmanaged {fpath}, it is deprecated")
continue

if 'oneOf' in ftype:
field = UnionField.from_js(ftype, fpath)
if fpath in overrides:
field = copy(overrides[fpath])
field.path = fpath
field.description = desc
if isinstance(field, ArrayField):
field.itemtype.path = fpath

elif "type" not in ftype:
logger.warning(f"Unmanaged {fpath}, it doesn't have a type")
Expand Down Expand Up @@ -320,11 +325,6 @@ def from_js(cls, path, js):
itemtype, dims=dims, path=path, description=js.get("description", "")
)

def normalized(self):
# Strip the '[]' that we use to signal an array. The name
# itself doesn't need this.
return Field.normalized(self)[:-2]


class Command:
def __init__(self, name, fields):
Expand All @@ -336,6 +336,23 @@ def __str__(self):
return f"Command[name={self.name}, fields=[{fieldnames}]]"


InvoiceLabelField = PrimitiveField("string", None, None)
DatastoreKeyField = ArrayField(itemtype=PrimitiveField("string", None, None), dims=1, path=None, description=None)
InvoiceExposeprivatechannelsField = PrimitiveField("boolean", None, None)
PayExclude = ArrayField(itemtype=PrimitiveField("string", None, None), dims=1, path=None, description=None)
# Override fields with manually managed types, fieldpath -> field mapping
overrides = {
'Invoice.label': InvoiceLabelField,
'DelInvoice.label': InvoiceLabelField,
'ListInvoices.label': InvoiceLabelField,
'Datastore.key': DatastoreKeyField,
'DelDatastore.key': DatastoreKeyField,
'ListDatastore.key': DatastoreKeyField,
'Invoice.exposeprivatechannels': InvoiceExposeprivatechannelsField,
'Pay.exclude': PayExclude,
}


def parse_doc(command, js) -> Union[CompositeField, Command]:
"""Given a command name and its schema, generate the IR model"""
path = command
Expand Down
14 changes: 8 additions & 6 deletions contrib/msggen/msggen/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'ListPeers.peers[].channels[].features[]': "string",
'ListFunds.channels[].state': 'ChannelState',
'ListTransactions.transactions[].type[]': None,
'Invoice.exposeprivatechannels': None,
}

# A map of schema type to rust primitive types.
Expand All @@ -43,6 +44,8 @@
'float': 'f32',
'utxo': 'Utxo',
'feerate': 'Feerate',
'outpoint': 'Outpoint',
'outputdesc': 'OutputDesc',
}

header = f"""#![allow(non_camel_case_types)]
Expand Down Expand Up @@ -123,7 +126,7 @@ def gen_enum(e):
if e.required:
defi = f" // Path `{e.path}`\n #[serde(rename = \"{e.name}\")]\n pub {e.name.normalized()}: {typename},\n"
else:
defi = f' #[serde(skip_serializing_if = "Option::is_none")]'
defi = f' #[serde(skip_serializing_if = "Option::is_none")]\n'
defi = f" pub {e.name.normalized()}: Option<{typename}>,\n"

return defi, decl
Expand All @@ -148,17 +151,16 @@ def gen_array(a):
logger.debug(f"Generating array field {a.name} -> {name} ({a.path})")
_, decl = gen_field(a.itemtype)

if isinstance(a.itemtype, PrimitiveField):
if a.path in overrides:
decl = "" # No declaration if we have an override
itemtype = overrides[a.path]
elif isinstance(a.itemtype, PrimitiveField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, CompositeField):
itemtype = a.itemtype.typename
elif isinstance(a.itemtype, EnumField):
itemtype = a.itemtype.typename

if a.path in overrides:
decl = "" # No declaration if we have an override
itemtype = overrides[a.path]

if itemtype is None:
return ("", "") # Override said not to include

Expand Down

0 comments on commit ec5cd92

Please sign in to comment.