Skip to content

Commit

Permalink
discojs: add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
s314cy committed Aug 2, 2023
1 parent 021f499 commit 6e85f6f
Show file tree
Hide file tree
Showing 27 changed files with 475 additions and 88 deletions.
67 changes: 53 additions & 14 deletions discojs/discojs-core/src/aggregator/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ export enum AggregationStep {
AGGREGATE
}

/**
* Main, abstract, aggregator class whose role is to buffer contributions and to produce
* a result based off their aggregation, whenever some defined condition is met.
*/
export abstract class Base<T> {
/**
* Contains the ids of all active nodes, i.e. members of the aggregation group at
Expand All @@ -32,44 +36,67 @@ export abstract class Base<T> {
protected informant?: AsyncInformant<T>
/**
* The result promise which, on resolve, will contain the current aggregation result.
* This promise should be fetched by any object making use of an aggregator, in order
* to await upon aggregation.
*/
protected result: Promise<T>
/**
* The current aggregation round, used for assessing whether a contribution is recent enough
* The current aggregation round, used for assessing whether a node contribution is recent enough
* or not.
*/
protected _round = 0

/**
* The current communication round. A single aggregation round is made of possibly multiple
* communication rounds. This makes the aggregator free to perform intermediate aggregation
* steps based off communication with its nodes. Overall, this allows for more complex
* aggregation schemes requiring an exchange of information between nodes before aggregating.
*/
protected _communicationRound = 0

constructor (
/**
* The task for which the aggregator should be created.
*/
public readonly task: Task,
/**
* The TF.js model whose weights are updated on aggregation.
*/
protected _model?: tf.LayersModel,
/**
* The round cut-off for contributions.
*/
protected readonly roundCutoff = 0,
/**
* The number of communication rounds occuring during any given aggregation round.
*/
public readonly communicationRounds = 1
) {
this.eventEmitter = new EventEmitter()
this.contributions = Map()
this._nodes = Set()

// Make the initial result promise
this.result = this.makeResult()

// On every aggregation, update the object's state to match the current aggregation
// and communication rounds.
this.eventEmitter.on('aggregation', () => {
this.nextRound()
})
}

/**
* Adds a node's contribution to the aggregator for a given round.
* The contribution will be aggregated during the round's aggregation step.
* Adds a node's contribution to the aggregator for the given aggregation and communication rounds.
* The contribution will be aggregated during the next aggregation step.
* @param nodeId The node's id
* @param contribution The node's contribution
* @param round For which round the contribution was made
* @param round For which aggregation round the contribution was made
* @param communicationRound For which communication round the contribution was made
*/
abstract add (nodeId: client.NodeID, contribution: T, round: number, communicationRound?: number): boolean

/**
* Performs the aggregation step over the received node contributions.
* Performs an aggregation step over the received node contributions.
* Must store the aggregation's result in the aggregator's result promise.
*/
abstract aggregate (): void
Expand Down Expand Up @@ -110,6 +137,10 @@ export abstract class Base<T> {
}
}

/**
* Sets the aggregator's TF.js model.
* @param model The new TF.js model
*/
setModel (model: tf.LayersModel): void {
this._model = model
}
Expand Down Expand Up @@ -138,6 +169,10 @@ export abstract class Base<T> {
this._nodes = nodeIds
}

/**
* Empties the current set of "nodes". Usually called at the end of an aggregation round,
* if the set of nodes is meant to change or to be actualized.
*/
resetNodes (): void {
this._nodes = Set()
}
Expand All @@ -163,7 +198,9 @@ export abstract class Base<T> {
}

/**
* Resets the aggregator's step and prepares it for the next aggregation round.
* Updates the aggregator's state to proceed to the next communication round.
* If all communication rounds were performed, proceeds to the next aggregation round
* and empties the collection of stored contributions.
*/
public nextRound (): void {
if (++this._communicationRound === this.communicationRounds) {
Expand All @@ -184,10 +221,9 @@ export abstract class Base<T> {
}

/**
* The aggregation result can be awaited upon in an asynchronous fashion, to allow
* for the receipt of contributions while performing other tasks. This function
* gives access to the current aggregation result's promise, which will eventually
* resolve and contain the result of the very next aggregation step, at the
* Aggregation steps are performed asynchronously, yet can be awaited upon when required.
* This function gives access to the current aggregation result's promise, which will
* eventually resolve and contain the result of the very next aggregation step, at the
* time of the function call.
* @returns The promise containing the aggregation result
*/
Expand All @@ -196,7 +232,7 @@ export abstract class Base<T> {
}

/**
* Constructs the payload sent to other nodes as contribution.
* Constructs the payloads sent to other nodes as contribution.
* @param base Object from which the payload is computed
*/
abstract makePayloads (base: T): Map<client.NodeID, T>
Expand All @@ -218,8 +254,8 @@ export abstract class Base<T> {
}

/**
* The aggregator's current size, defined by its amount of contributions.
* The size is bounded by the amount of all active nodes.
* The aggregator's current size, defined by its number of contributions. The size is bounded by
* the amount of all active nodes times the number of communication rounds.
*/
get size (): number {
return this.contributions
Expand All @@ -235,6 +271,9 @@ export abstract class Base<T> {
return this._model
}

/**
* The current commnication round.
*/
get communicationRound (): number {
return this._communicationRound
}
Expand Down
8 changes: 8 additions & 0 deletions discojs/discojs-core/src/aggregator/get.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import { aggregator, Task } from '..'

/**
* Enumeration of the available types of aggregator.
*/
export enum AggregatorChoice {
MEAN,
ROBUST,
SECURE,
BANDIT
}

/**
* Provides the aggregator object adequate to the given task.
* @param task The task
* @returns The aggregator
*/
export function getAggregator (task: Task): aggregator.Aggregator {
const error = new Error('not implemented')
switch (task.trainingInformation.aggregator) {
Expand Down
14 changes: 13 additions & 1 deletion discojs/discojs-core/src/aggregator/mean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@ import { AggregationStep, Base as Aggregator } from './base'
import { Task, WeightsContainer, aggregation, tf, client } from '..'

/**
* Aggregator that computes the mean of the weights received from the nodes.
* Mean aggregator whose aggregation step consists in computing the mean of the received weights.
*/
export class MeanAggregator extends Aggregator<WeightsContainer> {
/**
* The threshold t to fulfill to trigger an aggregation step. It can either be:
* - relative: 0 < t <= 1, thus requiring t * |nodes| contributions
* - absolute: t > 1, thus requiring t contributions
*/
public readonly threshold: number

constructor (
Expand All @@ -17,17 +22,24 @@ export class MeanAggregator extends Aggregator<WeightsContainer> {
) {
super(task, model, roundCutoff, 1)

// Default threshold is 100% of node participation
if (threshold === undefined) {
this.threshold = 1
// Threshold must be positive
} else if (threshold <= 0) {
throw new Error('threshold must be positive')
// Thresholds greater than 1 are considered absolute instead of relative to the number of nodes
} else if (threshold > 1 && Math.round(threshold) !== threshold) {
throw new Error('absolute thresholds must integers')
} else {
this.threshold = threshold
}
}

/**
* Checks whether the contributions buffer is full, according to the set threshold.
* @returns Whether the contributions buffer is full
*/
isFull (): boolean {
if (this.threshold <= 1) {
const contribs = this.contributions.get(this.communicationRound)
Expand Down
14 changes: 7 additions & 7 deletions discojs/discojs-core/src/aggregator/secure.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import * as crypto from 'crypto'
import { Map, List, Range } from 'immutable'

/**
* Received contributions are the nodes' partial sums. The payloads are our random additive shares.
* Aggregator implementing secure multi-party computation for decentralized learning.
* An aggregation is made of two communication rounds:
* - first, nodes communicate their random shares to each other;
* - then, they sum their received shares and communicate the result.
* Finally, nodes are able to average the received partial sums to establish the aggregation result.
*/
export class SecureAggregator extends Aggregator<WeightsContainer> {
public static readonly MAX_SEED: number = 2 ** 47
Expand Down Expand Up @@ -71,7 +75,7 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
}

/**
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers
* Generate N additive shares that aggregate to the secret weights array, where N is the number of peers.
*/
public generateAllShares (secret: WeightsContainer): List<WeightsContainer> {
if (this.nodes.size === 0) {
Expand All @@ -86,16 +90,12 @@ export class SecureAggregator extends Aggregator<WeightsContainer> {
}

/**
* Generates one share in the same shape as the secret that is populated with values randomly chosend from
* Generates one share in the same shape as the secret that is populated with values randomly chosen from
* a uniform distribution between (-maxShareValue, maxShareValue).
*/
public generateRandomShare (secret: WeightsContainer): WeightsContainer {
const seed = crypto.randomInt(SecureAggregator.MAX_SEED)
return secret.map((t) =>
tf.randomUniform(t.shape, -this.maxShareValue, this.maxShareValue, 'float32', seed))
}

get communicationRound (): number {
return this._communicationRound
}
}
54 changes: 46 additions & 8 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,53 @@ import { NodeID } from './types'
import { EventConnection } from './event_connection'
import { Aggregator } from '../aggregator'

/**
* Main, abstract, class representing a Disco client in a network, which handles
* communication with other nodes, be it peers or a server.
*/
export abstract class Base {
/**
* Own ID provided by the network's server.
*/
protected _ownId?: NodeID
/**
* The network's server.
*/
protected _server?: EventConnection
/**
* The aggregator's result produced after aggregation.
*/
protected aggregationResult?: Promise<WeightsContainer>

constructor (
/**
* The network server's URL to connect to.
*/
public readonly url: URL,
/**
* The client's corresponding task.
*/
public readonly task: Task,
/**
* The client's aggregator.
*/
public readonly aggregator: Aggregator
) {}

/**
* Handles the connection process from the client to any sort of
* centralized server.
* Handles the connection process from the client to any sort of network server.
*/
async connect (): Promise<void> {}

/**
* Handles the disconnection process of the client from any sort
* of centralized server.
* Handles the disconnection process of the client from any sort of network server.
*/
async disconnect (): Promise<void> {}

/**
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<tf.LayersModel> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
Expand All @@ -41,29 +65,43 @@ export abstract class Base {
return await serialization.model.decode(response.data)
}

/**
* Communication callback called once at the beginning of the training instance.
* @param weights The initial model weights
* @param trainingInformant The training informant
*/
async onTrainBeginCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* The training manager matches this function with the training loop's
* onTrainEnd callback when training a TFJS model object. See the
* training manager for more details.
* Communication callback called once at the end of the training instance.
* @param weights The final model weights
* @param trainingInformant The training informant
*/
async onTrainEndCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called at the beginning of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundBeginCommunication (
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* This function will be called whenever a local round has ended.
* Communication callback called the end of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundEndCommunication (
weights: WeightsContainer,
Expand Down
Loading

0 comments on commit 6e85f6f

Please sign in to comment.