Skip to content

Multirom #249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions ezyrb/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ def __init__(self, parameters=None, snapshots=None):
if parameters is None and snapshots is None:
return

if len(parameters) != len(snapshots):
raise ValueError
if parameters is None:
parameters = [None] * len(snapshots)
elif snapshots is None:
snapshots = [None] * len(parameters)

if len(parameters) != len(snapshots):
raise ValueError('parameters and snapshots must have the same length')

for param, snap in zip(parameters, snapshots):
self.add(Parameter(param), Snapshot(snap))
param = Parameter(param)
snap = Snapshot(snap)

self.add(param, snap)

@property
def parameters_matrix(self):
Expand Down Expand Up @@ -74,7 +82,9 @@ def __len__(self):

def __str__(self):
""" Print minimal info about the Database """
return str(self.parameters_matrix)
s = 'Database with {} snapshots and {} parameters'.format(
self.snapshots_matrix.shape[1], self.parameters_matrix.shape[1])
return s

def add(self, parameter, snapshot):
"""
Expand Down Expand Up @@ -103,6 +113,10 @@ def split(self, chunks, seed=None):
>>> train, test = db.split([80, 20]) # n snapshots

"""

if seed is not None:
np.random.seed(seed)

if all(isinstance(n, int) for n in chunks):
if sum(chunks) != len(self):
raise ValueError('chunk elements are inconsistent')
Expand All @@ -118,6 +132,7 @@ def split(self, chunks, seed=None):
if not np.isclose(sum(chunks), 1.):
raise ValueError('chunk elements are inconsistent')


cum_chunks = np.cumsum(chunks)
cum_chunks = np.insert(cum_chunks, 0, 0.0)
ids = np.ones(len(self)) * -1.
Expand Down
8 changes: 6 additions & 2 deletions ezyrb/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
class Parameter:

def __init__(self, values):
self.values = values
if isinstance(values, Parameter):
self.values = values.values
else:
self.values = values

@property
def values(self):
Expand All @@ -15,4 +18,5 @@ def values(self):
def values(self, new_values):
if np.asarray(new_values).ndim != 1:
raise ValueError('only 1D array are usable as parameter.')
self._values = new_values

self._values = np.asarray(new_values)
97 changes: 97 additions & 0 deletions ezyrb/plugin/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from .plugin import Plugin
import numpy as np

class Aggregation(Plugin):

def __init__(self):
super().__init__()

def fit_postprocessing(self, mrom):

validation_predicted = dict()
for name, rom in mrom.roms.items():
validation_predicted[name] = rom.predict(rom.validation_full_database.parameters_matrix)

g = {}
sigma = 0.1
for k, v in validation_predicted.items():
g[k] = np.exp(- (v - rom.validation_full_database.snapshots_matrix)**2/(2 * (sigma**2)))

g_tensor = np.array([g[k] for k in g.keys()])
g_tensor /= np.sum(g_tensor, axis=0)

# concatenate params and space
space = rom.validation_full_database._pairs[0][1].space
params = rom.validation_full_database.parameters_matrix
# compute the aggregated solution
print(g_tensor.shape)
weights = []
for i in range(params.shape[0]):
a = g_tensor[:, i, :].T
b = rom.validation_full_database.snapshots_matrix[i]
param = rom.validation_full_database.parameters_matrix[i]

w_param = []

for a_, b_ in zip(a, b):
A = np.ones((a.shape[1], 2))
A[0] = a_

B = np.ones(2)
B[0] = b_
# print(A)
# print(B)

try:
w = np.linalg.solve(A, B).reshape(1, -1)
except np.linalg.LinAlgError:
w = np.zeros(shape=(1, 2)) + 0.5

w_param.append(w)

w_param = np.concatenate(w_param)
weights.append(
np.hstack(
(
space,
param.repeat(space.shape[0])[:, None],
w_param
)
)
)

weights = np.vstack(weights)

from ..approximation.rbf import RBF
from ..approximation.linear import Linear

self.rbf = Linear()
self.rbf.fit(weights[::10, :-2], weights[::10, -2:])


def predict_postprocessing(self, mrom):

space = list(mrom.roms.values())[0].validation_full_database._pairs[0][1].space
predict_weights = {}
db = list(mrom.multi_predict_database.values())[0]
input_ = np.hstack([
np.tile(space, (db.parameters_matrix.shape[0], 1)),
np.repeat(db.parameters_matrix, space.shape[0], axis=0)
])
predict_weights = self.rbf.predict(input_)
predicted_solution = np.zeros((db.parameters_matrix.shape[0], db.snapshots_matrix.shape[1]))
print(predicted_solution.shape)
for w, db in zip(predict_weights.T, mrom.multi_predict_database.values()):
predicted_solution += db.snapshots_matrix * w.reshape(db.snapshots_matrix.shape[0], -1)

# input_ = np.hstack([
# np.tile(space, (db.parameters_matrix.shape[0], 1)),
# np.repeat(db.parameters_matrix, space.shape[0], axis=0)
# ])
# predict_weights[k] = self.rbf.predict(input_)
# print(predict_weights[k])


return predicted_solution


8 changes: 4 additions & 4 deletions ezyrb/plugin/automatic_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def _train_shift_network(self, db):

n_epoch += 1

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]
self.reference_snapshot = reference_snapshot
Expand All @@ -154,11 +154,11 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

def fom_postprocessing(self, rom):
def predict_postprocessing(self, rom):

ref_space = self.reference_snapshot.space

for param, snap in rom._full_database._pairs:
for param, snap in rom.predict_full_database._pairs:
input_shift = np.hstack([
ref_space.reshape(-1, 1),
np.ones(shape=(ref_space.shape[0], 1))*param.values])
Expand Down
35 changes: 35 additions & 0 deletions ezyrb/plugin/database_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

from .plugin import Plugin


class DatabaseSplitter(Plugin):


def __init__(self, train=0.9, test=0.1, validation=0.0, predict=0.0,
seed=None):
super().__init__()

if sum([train, test, validation, predict]) != 1.0:
raise ValueError('The sum of the ratios must be equal to 1.0')

self.train = train
self.test = test
self.validation = validation
self.predict = predict
self.seed = seed

def fit_preprocessing(self, rom):
db = rom._database
train, test, validation, predict = db.split(
[self.train, self.test, self.validation, self.predict],
seed=self.seed
)

rom.train_full_database = train
rom.test_full_database = test
rom.validation_full_database = validation
rom.predict_full_database = predict
print('train', train.snapshots_matrix.shape)
print('test', test.snapshots_matrix.shape)
print('validation', validation.snapshots_matrix.shape)
print('predict', predict.snapshots_matrix.shape)
43 changes: 39 additions & 4 deletions ezyrb/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,53 @@ class Plugin(ABC):
All the classes that implement the input-output mapping should be inherited
from this class.
"""
def fom_preprocessing(self, rom):
def fit_preprocessing(self, rom):
""" Void """
pass

def rom_preprocessing(self, rom):
def fit_before_reduction(self, rom):
""" Void """
pass

def rom_postprocessing(self, rom):
def fit_after_reduction(self, rom):
""" Void """
pass

def fit_before_approximation(self, rom):
""" Void """
pass

def fom_postprocessing(self, rom):
def fit_after_approximation(self, rom):
""" Void """
pass

def fit_postprocessing(self, rom):
""" Void """
pass

def predict_preprocessing(self, rom):
""" Void """
pass

def predict_before_approximation(self, rom):
""" Void """
pass

def predict_after_approximation(self, rom):
""" Void """
pass

def predict_before_expansion(self, rom):
""" Void """
pass

def predict_after_expansion(self, rom):
""" Void """
pass

def predict_postprocessing(self, rom):
""" Void """
pass



10 changes: 5 additions & 5 deletions ezyrb/plugin/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def __init__(self, shift_function, interpolator, parameter_index=0,
self.parameter_index = parameter_index
self.reference_index = reference_index

def fom_preprocessing(self, rom):
db = rom._full_database
def fit_preprocessing(self, rom):
db = rom.database

reference_snapshot = db._pairs[self.reference_index][1]

Expand All @@ -68,10 +68,10 @@ def fom_preprocessing(self, rom):
snap.values = self.interpolator.predict(
reference_snapshot.space.reshape(-1, 1)).flatten()

rom._full_database = db
rom.database = db

def fom_postprocessing(self, rom):
for param, snap in rom._full_database._pairs:
def predict_postprocessing(self, rom):
for param, snap in rom.predict_full_database._pairs:
snap.space = (
rom.database._pairs[self.reference_index][1].space +
self.__shift_function(param.values)
Expand Down
Loading